"""
The Curl CFFI WebSocket client implementation.
"""
from __future__ import annotations
import asyncio
import struct
import threading
import warnings
from asyncio import InvalidStateError
from collections.abc import Awaitable, Callable, Generator
from contextlib import suppress
from dataclasses import dataclass, field
from enum import IntEnum
from json import dumps as json_dumps
from json import loads as json_loads
from random import uniform
from select import select
from typing import TYPE_CHECKING, Final, Literal, TypeVar, cast, final
from ..aio import CURL_SOCKET_BAD, get_selector
from ..const import CurlECode, CurlFollow, CurlInfo, CurlOpt, CurlWsFlag
from ..curl import Curl, CurlError
from ..utils import CurlCffiWarning
from .exceptions import SessionClosed, Timeout
from .models import Response
from .utils import NOT_SET, NotSetType, set_curl_options
if TYPE_CHECKING:
from typing_extensions import Self
from ..const import CurlHttpVersion
from ..curl import CurlWsFrame
from ..fingerprints import Fingerprint
from .cookies import CookieTypes
from .headers import HeaderTypes
from .impersonate import BrowserTypeLiteral, ExtraFingerprints, ExtraFpDict
from .session import AsyncSession, ProxySpec
T = TypeVar("T")
ON_DATA_T = Callable[["WebSocket", bytes, CurlWsFrame], None]
ON_MESSAGE_T = Callable[["WebSocket", bytes | str], None]
ON_ERROR_T = Callable[["WebSocket", CurlError], None]
ON_OPEN_T = Callable[["WebSocket"], None]
ON_CLOSE_T = Callable[["WebSocket", int, str], None]
RECV_QUEUE_ITEM = tuple[bytes, int]
SEND_QUEUE_ITEM = tuple[bytes | bytearray | memoryview, CurlWsFlag | int]
@dataclass
class WebSocketRetryStrategy:
"""Configurable WebSocket policy for retrying failed message receives.
When enabled, each failed receive attempt will use exponential backoff with
jitter.
Calculation: ``delay * 2^(count - 1) ± 10%``
Args:
retry: Enable or disable WebSocket message receive retry policy.
delay: The base value (seconds) to compute the retry delay from.
count: How many times to retry a receive operation before giving up.
codes: Set of ``CurlECode`` values for which the receive operation
should be retried. Default is ``CurlECode.RECV_ERROR``.
"""
retry: bool = False
delay: float = 0.3
count: int = 3
codes: set[CurlECode] = field(default_factory=lambda: {CurlECode.RECV_ERROR})
class WsCloseCode(IntEnum):
"""See: https://www.iana.org/assignments/websocket/websocket.xhtml"""
OK = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
UNKNOWN = 1005
ABNORMAL_CLOSURE = 1006
INVALID_DATA = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXTENSION = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
BAD_GATEWAY = 1014
TLS_HANDSHAKE = 1015
UNAUTHORIZED = 3000
FORBIDDEN = 3003
TIMEOUT = 3008
class WebSocketError(CurlError):
"""WebSocket-specific error."""
def __init__(
self, message: str, code: WsCloseCode | CurlECode | Literal[0] = 0
) -> None:
super().__init__(message, code) # pyright: ignore[reportUnknownMemberType]
class WebSocketClosed( # pyright: ignore[reportUnsafeMultipleInheritance]
WebSocketError, SessionClosed
):
"""WebSocket is already closed."""
class WebSocketTimeout( # pyright: ignore[reportUnsafeMultipleInheritance]
WebSocketError, Timeout
):
"""WebSocket operation timed out."""
def _safe_set_result(fut: asyncio.Future[None]) -> None:
"""
Called by the event loop when fd becomes readable/writable.
We try to set_result() and silently ignore InvalidStateError which is
raised if the future was already finished/cancelled concurrently.
This avoids spurious 'Exception in callback' traces in uvloop/asyncio.
Intentionally using try/except, cheaper than checking if the future is done.
"""
try: # noqa: SIM105
fut.set_result(None)
except InvalidStateError:
pass
class BaseWebSocket:
__slots__: tuple[str, ...] = (
"_curl",
"autoclose",
"_close_code",
"_close_reason",
"debug",
"closed",
)
def __init__(
self, curl: Curl | NotSetType, *, autoclose: bool = True, debug: bool = False
) -> None:
self._curl: Curl | NotSetType = curl
self.autoclose: bool = autoclose
self._close_code: int | None = None
self._close_reason: str | None = None
self.debug: bool = debug
self.closed: bool = False
@property
def curl(self) -> Curl:
"""Return reference to Curl associated with current WebSocket."""
if isinstance(self._curl, NotSetType):
self._curl = Curl(debug=self.debug)
return self._curl
@property
def close_code(self) -> int | None:
"""The WebSocket close code, if the connection has been closed."""
return self._close_code
@property
def close_reason(self) -> str | None:
"""The WebSocket close reason, if the connection has been closed."""
return self._close_reason
@staticmethod
def _pack_close_frame(code: int, reason: bytes) -> bytes:
return struct.pack("!H", code) + reason
@staticmethod
def _unpack_close_frame(frame: bytes) -> tuple[int, str]:
if len(frame) < 2:
code: int = WsCloseCode.UNKNOWN
reason = ""
else:
try:
code = struct.unpack_from("!H", frame)[0]
reason = frame[2:].decode()
except UnicodeDecodeError as e:
raise WebSocketError(
"Invalid close message", WsCloseCode.INVALID_DATA
) from e
except Exception as e:
raise WebSocketError(
"Invalid close frame", WsCloseCode.PROTOCOL_ERROR
) from e
if code == WsCloseCode.UNKNOWN or code < 1000 or code >= 5000:
raise WebSocketError(
f"Invalid close code: {code}", WsCloseCode.PROTOCOL_ERROR
)
return code, reason
def terminate(self) -> None:
"""Terminate the underlying connection."""
self.closed = True
self.curl.close()
EventTypeLiteral = Literal["open", "close", "data", "message", "error"]
[docs]
@final
class WebSocket(BaseWebSocket):
"""A WebSocket implementation using libcurl."""
__slots__ = (
"skip_utf8_validation",
"_emitters",
"keep_running",
)
[docs]
def __init__(
self,
curl: Curl | NotSetType = NOT_SET,
*,
autoclose: bool = True,
skip_utf8_validation: bool = False,
debug: bool = False,
on_open: ON_OPEN_T | None = None,
on_close: ON_CLOSE_T | None = None,
on_data: ON_DATA_T | None = None,
on_message: ON_MESSAGE_T | None = None,
on_error: ON_ERROR_T | None = None,
) -> None:
"""
Args:
autoclose: whether to close the WebSocket after receiving a close frame.
skip_utf8_validation: whether to skip UTF-8 validation for text frames in
run_forever().
debug: print extra curl debug info.
on_open: open callback, ``def on_open(ws)``
on_close: close callback, ``def on_close(ws, code, reason)``
on_data: raw data receive callback, ``def on_data(ws, data, frame)``
on_message: message receive callback, ``def on_message(ws, message)``
on_error: error callback, ``def on_error(ws, exception)``
"""
super().__init__(curl=curl, autoclose=autoclose, debug=debug)
self.skip_utf8_validation: bool = skip_utf8_validation
self.keep_running: bool = False
self._emitters: dict[EventTypeLiteral, Callable[..., object]] = {}
if on_open:
self._emitters["open"] = on_open
if on_close:
self._emitters["close"] = on_close
if on_data:
self._emitters["data"] = on_data
if on_message:
self._emitters["message"] = on_message
if on_error:
self._emitters["error"] = on_error
def __iter__(self) -> WebSocket:
if self.closed:
raise WebSocketClosed("WebSocket is closed")
return self
def __next__(self) -> bytes:
msg, flags = self.recv()
if flags & CurlWsFlag.CLOSE:
raise StopIteration
return msg
def _emit(
self,
event_type: EventTypeLiteral,
*args: str | bytes | int | CurlWsFrame | CurlError,
) -> None:
callback: Callable[..., object] | None = self._emitters.get(event_type)
if callback:
try:
_ = callback(self, *args)
# pylint: disable-next=broad-exception-caught
except Exception as e:
error_callback: Callable[..., object] | None = self._emitters.get(
"error"
)
if error_callback:
_ = error_callback(self, e)
else:
warnings.warn(
f"WebSocket callback '{event_type}' failed",
CurlCffiWarning,
stacklevel=2,
)
[docs]
def connect(
self,
url: str,
params: (
dict[str, object]
| list[object]
| tuple[str, int | list[str] | dict[str, str | int]]
| None
) = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
auth: tuple[str, str] | None = None,
timeout: float | tuple[float, float] | object | None = NOT_SET,
allow_redirects: bool | CurlFollow | str = True,
max_redirects: int = 30,
proxies: ProxySpec | None = None,
proxy: str | None = None,
proxy_auth: tuple[str, str] | None = None,
verify: bool | None = None,
referer: str | None = None,
accept_encoding: str | None = "gzip, deflate, br",
impersonate: BrowserTypeLiteral | str | Fingerprint | None = None,
ja3: str | None = None,
akamai: str | None = None,
perk: str | None = None,
extra_fp: ExtraFingerprints | ExtraFpDict | None = None,
default_headers: bool = True,
quote: str | Literal[False] = "",
http_version: CurlHttpVersion | None = None,
interface: str | None = None,
cert: str | tuple[str, str] | None = None,
max_recv_speed: int = 0,
curl_options: dict[CurlOpt, str] | None = None,
) -> Self:
"""Connect to the WebSocket.
libcurl automatically handles pings and pongs.
ref: https://curl.se/libcurl/c/libcurl-ws.html
Args:
url: url for the requests.
params: query string for the requests.
headers: headers to send.
cookies: cookies to use.
auth: HTTP basic auth, a tuple of (username, password), only basic auth is
supported.
timeout: how many seconds to wait before giving up.
allow_redirects: whether to allow redirection. Can be a bool, a
``CurlFollow`` value, or the string ``"safe"``.
max_redirects: max redirect counts, default 30, use -1 for unlimited.
proxies: dict of proxies to use, prefer to use ``proxy`` if they are the
same. format: ``{"http": proxy_url, "https": proxy_url}``.
proxy: proxy to use, format: "http://user@pass:proxy_url".
Can't be used with `proxies` parameter.
proxy_auth: HTTP basic auth for proxy, a tuple of (username, password).
verify: whether to verify https certs.
referer: shortcut for setting referer header.
accept_encoding: shortcut for setting accept-encoding header.
impersonate: which browser version or fingerprint to impersonate.
ja3: ja3 string to impersonate.
akamai: akamai string to impersonate.
perk: perk string to impersonate.
extra_fp: extra fingerprints options, in complement to ja3 and akamai str.
default_headers: whether to set default browser headers.
default_encoding: Encoding for decoding content if charset is not found.
Defaults to "utf-8".
quote: Set characters to be quoted (percent-encoded). Default safe
string is ``!#$%&'()*+,/:;=?@[]~``. If set to a string, the characters
will be removed from the safe string. If set to ``False``, the URL
is used as-is (you must encode it yourself).
http_version: Limiting http version, defaults to http2.
interface: interface name or local IP to bind to (bare IP = source address).
cert: a tuple of (cert, key) filenames for client cert.
max_recv_speed: maximum receive speed, bytes per second.
curl_options: extra curl options to use.
"""
curl: Curl = self.curl
_ = set_curl_options(
curl=curl,
method="GET",
url=url,
params_list=[None, params],
headers_list=[None, headers],
cookies_list=[None, cookies],
auth=auth,
timeout=timeout,
allow_redirects=allow_redirects,
max_redirects=max_redirects,
proxies_list=[None, proxies],
proxy=proxy,
proxy_auth=proxy_auth,
verify_list=[None, verify],
referer=referer,
accept_encoding=accept_encoding,
impersonate=impersonate,
ja3=ja3,
akamai=akamai,
perk=perk,
extra_fp=extra_fp,
default_headers=default_headers,
quote=quote,
http_version=http_version,
interface=interface,
max_recv_speed=max_recv_speed,
cert=cert,
curl_options=curl_options,
)
# Magic number defined in: https://curl.se/docs/websocket.html
_ = curl.setopt(CurlOpt.CONNECT_ONLY, 2)
curl.perform()
return self
def recv_fragment(self) -> tuple[bytes, CurlWsFrame]:
"""Receive a single curl websocket fragment as bytes."""
if self.closed:
raise WebSocketClosed("WebSocket is already closed")
chunk, frame = self.curl.ws_recv()
if frame.flags & CurlWsFlag.CLOSE:
try:
self._close_code, self._close_reason = self._unpack_close_frame(chunk)
except WebSocketError as e:
# Follow the spec to close the connection
# Errors do not respect autoclose
self._close_code = e.code
self.close(e.code)
raise
if self.autoclose:
self.close()
return chunk, frame
[docs]
def recv(self) -> tuple[bytes, int]:
"""
Receive a frame as bytes. libcurl splits frames into fragments, so we have to
collect all the chunks for a frame.
"""
chunks: list[bytes] = []
flags: int = 0
sock_fd = self.curl.getinfo(CurlInfo.ACTIVESOCKET)
if sock_fd == CURL_SOCKET_BAD:
raise WebSocketError(
"Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE
)
while True:
try:
# Try to receive the first fragment first
chunk, frame = self.recv_fragment()
# Ignore control frames during data assembly
if frame.flags & (CurlWsFlag.PING | CurlWsFlag.PONG):
continue
flags = frame.flags
chunks.append(chunk)
if frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0:
break
except CurlError as e:
if e.code == CurlECode.AGAIN:
# According to https://curl.se/libcurl/c/curl_ws_recv.html
# > in real application: wait for socket here, e.g. using select()
_, _, _ = select([sock_fd], [], [], 0.5)
else:
raise
return b"".join(chunks), flags
[docs]
def recv_str(self) -> str:
"""Receive a text frame."""
data, flags = self.recv()
if not (flags & CurlWsFlag.TEXT):
raise WebSocketError("Not valid text frame", WsCloseCode.INVALID_DATA)
return data.decode("utf-8")
[docs]
def recv_json(self, *, loads: Callable[[str], T] = json_loads) -> T:
"""Receive a JSON frame.
Args:
loads: JSON decoder, default is json.loads.
"""
data = self.recv_str()
return loads(data)
[docs]
def send(
self,
payload: str | bytes,
flags: CurlWsFlag = CurlWsFlag.BINARY,
) -> int:
"""Send a data frame.
Args:
payload: data to send.
flags: flags for the frame.
"""
if flags & CurlWsFlag.CLOSE:
self.keep_running = False
if self.closed:
raise WebSocketClosed("WebSocket is already closed")
# curl expects bytes
if isinstance(payload, str):
payload = payload.encode()
sock_fd: bytes | int | float | list[str | int] = self.curl.getinfo(
CurlInfo.ACTIVESOCKET
)
if sock_fd == CURL_SOCKET_BAD:
raise WebSocketError(
"Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE
)
# Loop checks for CurlECode.Again
# https://curl.se/libcurl/c/curl_ws_send.html
offset = 0
while offset < len(payload):
current_buffer = payload[offset:]
try:
n_sent = self.curl.ws_send(current_buffer, flags)
except CurlError as e:
if e.code == CurlECode.AGAIN:
_, writeable, _ = select([], [sock_fd], [], 0.5)
if not writeable:
raise WebSocketError("Socket write timeout") from e
continue
raise
offset += n_sent
return offset
[docs]
def send_binary(self, payload: bytes) -> int:
"""Send a binary frame.
Args:
payload: binary data to send.
"""
return self.send(payload, CurlWsFlag.BINARY)
[docs]
def send_bytes(self, payload: bytes):
"""Send a binary frame, alias of :meth:`send_binary`.
Args:
payload: binary data to send.
"""
return self.send(payload, CurlWsFlag.BINARY)
[docs]
def send_str(self, payload: str) -> int:
"""Send a text frame.
Args:
payload: text data to send.
"""
return self.send(payload, CurlWsFlag.TEXT)
[docs]
def send_json(
self, payload: object, *, dumps: Callable[..., str] = json_dumps
) -> int:
"""Send a JSON frame.
Args:
payload: data to send.
dumps: JSON encoder, default is json.dumps.
"""
if dumps is json_dumps:
return self.send_str(json_dumps(payload, separators=(",", ":")))
return self.send_str(dumps(payload))
[docs]
def ping(self, payload: str | bytes) -> int:
"""Send a ping frame.
Args:
payload: data to send.
"""
return self.send(payload, CurlWsFlag.PING)
[docs]
def run_forever(self, url: str = "", **kwargs) -> None:
"""Run the WebSocket forever. See :meth:`connect` for details on parameters.
libcurl automatically handles pings and pongs.
ref: https://curl.se/libcurl/c/libcurl-ws.html
"""
if url:
_ = self.connect(url, **kwargs)
sock_fd = self.curl.getinfo(CurlInfo.ACTIVESOCKET)
if sock_fd == CURL_SOCKET_BAD:
raise WebSocketError(
"Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE
)
self._emit("open")
# Keep reading the messages and invoke callbacks
# TODO: Reconnect logic
chunks: list[bytes] = []
self.keep_running = True
while self.keep_running:
try:
chunk, frame = self.recv_fragment()
flags = frame.flags
self._emit("data", chunk, frame)
chunks.append(chunk)
if not (frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0):
continue
# Avoid unnecessary computation
if "message" in self._emitters:
# Concatenate collected chunks with the final message
msg = b"".join(chunks)
if (flags & CurlWsFlag.TEXT) and not self.skip_utf8_validation:
try:
emit_msg: str | bytes = msg.decode()
except UnicodeDecodeError as e:
self._close_code = WsCloseCode.INVALID_DATA
self.close(WsCloseCode.INVALID_DATA)
raise WebSocketError(
"Invalid UTF-8", WsCloseCode.INVALID_DATA
) from e
else:
emit_msg = msg
if (flags & CurlWsFlag.BINARY) or (flags & CurlWsFlag.TEXT):
self._emit("message", emit_msg)
chunks = [] # Reset chunks for next message
if flags & CurlWsFlag.CLOSE:
self.keep_running = False
self._emit("close", self._close_code or 0, self._close_reason or "")
except CurlError as e:
if e.code == CurlECode.AGAIN:
_, _, _ = select([sock_fd], [], [], 0.5)
else:
self._emit("error", e)
if not self.closed:
code: int = WsCloseCode.UNKNOWN
if isinstance(e, WebSocketError):
code = e.code
self.close(code)
raise
[docs]
def close(self, code: int = WsCloseCode.OK, message: bytes = b""):
"""Close the connection.
Args:
code: close code.
message: close reason.
"""
# TODO: As per spec, we should wait for the server to close the connection
# But this is not a requirement
msg = self._pack_close_frame(code, message)
_ = self.send(msg, CurlWsFlag.CLOSE)
# The only way to close the connection appears to be curl_easy_cleanup
self.terminate()
[docs]
@final
class AsyncWebSocket(BaseWebSocket):
"""
An asyncio WebSocket implementation using libcurl.
"""
__slots__ = (
"session",
"_loop",
"_sock_fd",
"_close_lock",
"_terminate_lock",
"_read_task",
"_write_task",
"_receive_queue",
"_send_queue",
"_max_send_batch_size",
"_coalesce_frames",
"_recv_time_slice",
"_send_time_slice",
"_terminated",
"_terminated_event",
"ws_retry",
"_transport_exception",
"_max_message_size",
"drain_on_error",
"_block_on_recv_queue_full",
)
_MAX_CURL_FRAME_SIZE: Final[int] = 65536
[docs]
def __init__(
self,
session: AsyncSession,
curl: Curl,
*,
autoclose: bool = True,
debug: bool = False,
recv_queue_size: int = 128,
send_queue_size: int = 128,
max_send_batch_size: int = 64,
coalesce_frames: bool = False,
ws_retry: WebSocketRetryStrategy | None = None,
recv_time_slice: float = 0.01,
send_time_slice: float = 0.005,
max_message_size: int = 4 * 1024 * 1024,
drain_on_error: bool = False,
block_on_recv_queue_full: bool = True,
) -> None:
"""Initializes an Async WebSocket session.
Do not instantiate this class directly. Use ``AsyncSession.ws_connect``.
This class implements an async context manager, closing the connection
automatically on exit:
::
async with AsyncSession() as session:
async with session.ws_connect("wss://api.example.com") as ws:
await ws.send("Hello")
msg = await ws.recv()
Args:
session (AsyncSession): The parent session object.
curl (Curl): The underlying Curl handle.
autoclose (bool): Automatically close on receiving a close frame.
debug (bool): Enable verbose debug logging.
recv_queue_size (int): Max number of incoming messages to buffer.
send_queue_size (int): Max number of outgoing messages to buffer.
max_send_batch_size (int): Max frames to coalesce per transmission.
coalesce_frames (bool): Combine small frame payloads to improve throughput.
ws_retry (WebSocketRetryStrategy): Retry configuration for failed receives.
recv_time_slice (float): Max seconds to read messages before yielding.
send_time_slice (float): Max seconds to write messages before yielding.
max_message_size (int): Max size (bytes) of a single received message.
drain_on_error (bool): Yield buffered messages before raising errors.
block_on_recv_queue_full (bool): Behavior when the receive queue is full.
If True (default), the reader blocks (may cause timeouts).
If False, the connection fails immediately to prevent data loss.
Note:
Architecture: This uses a background I/O model. Network operations run in
background tasks. Errors are raised in subsequent calls to send() or recv().
Performance: The time_slice defaults (5ms read / 1ms write) favor reading
to compensate for libcurl's overhead. Increase these values to allocate more
CPU time to I/O operations at the cost of event loop latency.
See also:
- https://curl.se/libcurl/c/curl_ws_recv.html
- https://curl.se/libcurl/c/curl_ws_send.html
"""
super().__init__(curl=curl, autoclose=autoclose, debug=debug)
self.session: AsyncSession[Response] = session
self._loop: asyncio.AbstractEventLoop | None = None
self._sock_fd: int = -1
self._terminated: bool = False
self._close_lock: asyncio.Lock = asyncio.Lock()
self._terminate_lock: threading.Lock = threading.Lock()
self._terminated_event: asyncio.Event = asyncio.Event()
self._read_task: asyncio.Task[None] | None = None
self._write_task: asyncio.Task[None] | None = None
self._receive_queue: asyncio.Queue[RECV_QUEUE_ITEM] = asyncio.Queue(
maxsize=recv_queue_size
)
self._send_queue: asyncio.Queue[SEND_QUEUE_ITEM] = asyncio.Queue(
maxsize=send_queue_size
)
self._max_send_batch_size: int = max_send_batch_size
self._coalesce_frames: bool = coalesce_frames
self.ws_retry: WebSocketRetryStrategy = ws_retry or WebSocketRetryStrategy()
self._recv_time_slice: float = recv_time_slice
self._send_time_slice: float = send_time_slice
self._transport_exception: Exception | None = None
self._max_message_size: int = max_message_size
self.drain_on_error: bool = drain_on_error
self._block_on_recv_queue_full: bool = block_on_recv_queue_full
@property
def loop(self) -> asyncio.AbstractEventLoop:
"""Get a reference to the running event loop"""
if self._loop is None:
self._loop = get_selector(asyncio.get_running_loop())
return self._loop
@property
def send_queue_size(self) -> int:
"""Returns the current number of items in the send queue."""
return self._send_queue.qsize()
def is_alive(self) -> bool:
"""
Checks if the background I/O tasks are still running.
Returns ``False`` if either the read or write task has terminated due
to an error or a clean shutdown.
Note: This is a snapshot in time. A return value of ``True`` does not
guarantee the next network operation will succeed, but ``False``
definitively indicates the connection is no longer active.
"""
if self.closed or self._terminated:
return False
if self._read_task and self._read_task.done():
return False
return not (self._write_task and self._write_task.done())
async def __aenter__(self) -> Self:
"""Enable context manager usage for automatic session management and closure.
This cannot be used to initiate a WebSocket connection, that must be done
beforehand using the :meth:`AsyncSession.ws_connect()` factory method.
Returns:
Self: The instantiated AsyncWebSocket object.
"""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object | None,
) -> None:
"""
On exiting the context manager, close the WebSocket connection.
"""
if exc_type is None:
await self.close()
else:
# Don't mask existing exception.
with suppress(CurlError):
await self.close()
def __aiter__(self) -> Self:
if self.closed:
raise WebSocketClosed("WebSocket has been closed")
return self
async def __anext__(self) -> bytes:
try:
msg, flags = await self.recv()
except WebSocketClosed:
raise StopAsyncIteration from None
if flags & CurlWsFlag.CLOSE:
raise StopAsyncIteration
return msg
def _finalize_connection(self, exc: Exception) -> None:
"""Finalize the connection into a terminal state.
This method is called for all terminal conditions, including:
- normal WebSocket closure
- protocol errors
- transport errors
After this method is called, no further messages will be delivered
and all ``recv()`` calls will fail. ``_finalize_connection()`` is intended
for event-loop context, but ``terminate()`` is thread-safe.
Args:
exc (Exception): The exception object that gets raised. This does not
have to be an error, enqueuing ``WebSocketClosed`` indicates closure.
"""
if self.closed or self._transport_exception is not None:
return
self._transport_exception = exc
self.terminate()
def _start_io_tasks(self) -> None: # pyright: ignore[reportUnusedFunction]
"""Start the read/write I/O loop tasks.
NOTE: This should be called only once after object creation by the factory.
Once started, the tasks cannot be restarted again, this is a one-shot.
Raises:
WebSocketError: The WebSocket FD was invalid.
"""
# Return early if already started
if self._read_task is not None:
return
# Return early if terminated before start
if self._terminated:
raise WebSocketClosed("WebSocket already terminated")
# Get the currently active socket FD
self._sock_fd = cast(int, self.curl.getinfo(CurlInfo.ACTIVESOCKET))
if self._sock_fd == CURL_SOCKET_BAD:
raise WebSocketError(
"Invalid active socket.", code=CurlECode.NO_CONNECTION_AVAILABLE
)
# Get an identifier for the websocket from its object id
ws_id: str = f"WebSocket-{id(self):#x}"
# Start the I/O loop tasks
self._read_task = self.loop.create_task(
self._read_loop(), name=f"{ws_id}-reader"
)
self._write_task = self.loop.create_task(
self._write_loop(), name=f"{ws_id}-writer"
)
[docs]
async def recv(self, *, timeout: float | None = None) -> tuple[bytes, int]:
"""Receive a WebSocket message.
This method waits for and returns the next complete WebSocket message.
Args:
timeout: How many seconds to wait for a message before raising
a timeout error.
Returns:
tuple[bytes, int]: A tuple with the received payload and flags.
Raises:
WebSocketTimeout: If the timeout expires.
WebSocketClosed: If the connection is closed.
WebSocketError: If a network-level transport error occurs.
Notes:
Message fragmentation and reassembly are handled automatically by the
implementation, so callers will always receive complete messages.
``WebSocketError`` exceptions may have originated from prior
``send()`` or ``recv()`` operations, since all operations
share the same transport state once a failure occurs.
This method does not wait for additional messages after a transport
error is detected. If ``drain_on_error=True``, subsequent calls to
``recv()`` will return any messages that were buffered in the receive
queue at the time the reader failed, before the connection error is raised.
Concurrent calls to ``recv()`` are supported and safe; each caller
awaits the next available message and will receive distinct messages
in FIFO order.
If this coroutine is cancelled while a message is being received,
that message may be dropped. Cancellation is treated as abandoning
the receive operation.
"""
# Fast-fail when transport already errored and we aren't draining.
if self._transport_exception is not None and not self.drain_on_error:
raise self._transport_exception
# Hot path: immediate buffered item (zero allocation).
try:
return self._receive_queue.get_nowait()
except asyncio.QueueEmpty:
pass
# Terminal checks when queue is empty.
if self._transport_exception is not None:
raise self._transport_exception
if self._read_task is None or self._read_task.done():
raise WebSocketClosed("WebSocket is closed")
# Cold path: wait for data or for the reader task to finish.
queue_waiter: asyncio.Task[RECV_QUEUE_ITEM] = asyncio.create_task(
self._receive_queue.get()
)
try:
# Wait for the first of: queue_waiter or _read_task
done, _ = await asyncio.wait(
(queue_waiter, self._read_task),
return_when=asyncio.FIRST_COMPLETED,
timeout=timeout,
)
# Caller cancelled — cancel the waiter and re-raise
except asyncio.CancelledError:
if not queue_waiter.done():
_ = queue_waiter.cancel()
with suppress(asyncio.CancelledError):
await queue_waiter
raise
# Timeout occurred and no message
if not done:
_ = queue_waiter.cancel()
try:
return await queue_waiter
except asyncio.CancelledError:
pass
# Prefer transport error over timeout if both happen
if self._transport_exception is not None:
raise self._transport_exception
raise WebSocketTimeout(
"WebSocket recv() timed out", CurlECode.OPERATION_TIMEDOUT
)
# If queue_waiter completed first, return its result.
if queue_waiter in done:
return await queue_waiter
# Reader task finished first. Cancel the waiter.
_ = queue_waiter.cancel()
with suppress(asyncio.CancelledError):
await queue_waiter
# Try to return one buffered item when drain is set
if self.drain_on_error:
with suppress(asyncio.QueueEmpty):
return self._receive_queue.get_nowait()
# Propagate any transport exception or raise closed.
if self._transport_exception is not None:
raise self._transport_exception
raise WebSocketClosed("Connection closed")
[docs]
async def recv_str(self, *, timeout: float | None = None) -> str:
"""Receive a text frame.
Args:
timeout: how many seconds to wait before giving up.
"""
data, flags = await self.recv(timeout=timeout)
if not (flags & CurlWsFlag.TEXT):
raise WebSocketError("Not a valid text frame", WsCloseCode.INVALID_DATA)
try:
return data.decode("utf-8")
except UnicodeDecodeError as e:
raise WebSocketError(
"Invalid UTF-8 in text frame", WsCloseCode.INVALID_DATA
) from e
[docs]
async def recv_json(
self,
*,
loads: Callable[[str], T] = json_loads,
timeout: float | None = None,
) -> T:
"""Receive a JSON frame.
Args:
loads: JSON decoder, default is :meth:`json.loads`.
timeout: how many seconds to wait before giving up.
Raises:
WebSocketError: Received frame is invalid or failed to decode JSON.
"""
data: str = await self.recv_str(timeout=timeout)
if not data:
raise WebSocketError(
"Received empty frame, cannot decode JSON", WsCloseCode.INVALID_DATA
)
try:
return loads(data)
except UnicodeDecodeError as e:
raise WebSocketError(
"Invalid UTF-8 in JSON text frame", WsCloseCode.INVALID_DATA
) from e
except Exception as e:
raise WebSocketError(
f"Invalid JSON payload: {e}", WsCloseCode.INVALID_DATA
) from e
[docs]
async def send(
self,
payload: str | bytes | bytearray | memoryview,
flags: CurlWsFlag | int = CurlWsFlag.BINARY,
timeout: float | None = None,
) -> None:
"""Send a WebSocket message.
Args:
payload: Data to send (``str``/``bytes``/``bytearray``/``memoryview``).
flags: Frame type flags (e.g., ``CurlWsFlag.TEXT`` / ``CurlWsFlag.BINARY``).
timeout: Max seconds to wait if the send queue is full.
Raises:
CurlError: Network related exception occured.
WebSocketClosed: The WebSocket has been closed.
WebSocketTimeout: The send operation timed out.
Note:
There are no limits on the size of the message that can be sent.
Large outbound messages are seamlessly broken down into optimal
fragments using the ``CURLWS_CONT`` flag, arriving as a single
logical message to the server.
Warning:
This method is non-blocking. It queues the message for immediate
transmission. Use ``await ws.flush()`` after sending if you need
to guarantee that the data has actually reached the socket.
"""
if self._transport_exception is not None:
raise self._transport_exception
if self.closed:
raise WebSocketClosed("WebSocket is closed")
# Fail fast when writer is done
if self._write_task is not None and self._write_task.done():
raise WebSocketClosed("WebSocket writer terminated; cannot send")
# cURL expects bytes
if isinstance(payload, str):
payload = payload.encode("utf-8")
try:
self._send_queue.put_nowait((payload, flags))
except asyncio.QueueFull as exc:
if self._terminated:
raise WebSocketClosed("WebSocket connection is terminated") from exc
# Check exception after encoding
if self._transport_exception is not None:
raise self._transport_exception from exc
if timeout is not None:
try:
await asyncio.wait_for(
self._send_queue.put((payload, flags)), timeout
)
except asyncio.TimeoutError as e:
raise WebSocketTimeout(
"Send queue full (network slow) - hit timeout enqueuing message"
) from e
else:
await self._send_queue.put((payload, flags))
# If we woke up because terminate() drained the queue, fail now.
if self._transport_exception is not None:
raise self._transport_exception from exc
if self.closed or self._terminated:
raise WebSocketClosed(
"Connection was terminated while waiting to send"
) from exc
[docs]
async def send_binary(self, payload: bytes) -> None:
"""Send a binary frame.
Args:
payload: binary data to send.
For more info, see the docstring for :meth:`send()`
"""
return await self.send(payload, CurlWsFlag.BINARY)
[docs]
async def send_bytes(self, payload: bytes) -> None:
"""Send a binary frame, alias of :meth:`send_binary`.
Args:
payload: binary data to send.
For more info, see the docstring for :meth:`send()`
"""
return await self.send(payload, CurlWsFlag.BINARY)
[docs]
async def send_str(self, payload: str) -> None:
"""Send a text frame.
Args:
payload: text data to send.
For more info, see the docstring for :meth:`send()`
"""
return await self.send(payload, CurlWsFlag.TEXT)
[docs]
async def send_json(
self, payload: object, *, dumps: Callable[..., str] = json_dumps
) -> None:
"""Send a JSON frame.
Args:
payload: data to send.
dumps: JSON encoder, default is :meth:`json.dumps()`.
For more info, see the docstring for :meth:`send()`
"""
if dumps is json_dumps:
return await self.send_str(json_dumps(payload, separators=(",", ":")))
return await self.send_str(dumps(payload))
[docs]
async def ping(self, payload: str | bytes) -> None:
"""Send a ping frame.
Args:
payload: data to send.
Raises:
WebSocketError: The payload length is outside specification.
For more info, see the docstring for :meth:`send()`
"""
if isinstance(payload, str):
payload_bytes: bytes = payload.encode("utf-8")
else:
payload_bytes = bytes(payload)
if len(payload_bytes) > 125:
raise WebSocketError(
f"Ping frame has invalid length: {len(payload_bytes)}",
CurlECode.TOO_LARGE,
)
return await self.send(payload_bytes, CurlWsFlag.PING)
[docs]
async def close(
self,
code: int = WsCloseCode.OK,
message: str | bytes = b"",
timeout: float = 3.0,
) -> None:
"""
Performs a graceful WebSocket closing handshake and terminates the connection.
This method sends a WebSocket close frame to the peer, waits for queued
outgoing messages to be sent, and then shuts down the connection.
Args:
code (int): Close code. Defaults to ``WsCloseCode.OK``.
message (bytes): Close reason. Defaults to ``b""``.
timeout (float): How long (in seconds) to wait for the connection to close
gracefully before force-terminating.
"""
async with self._close_lock:
if self.closed:
return
self.closed = True
close_start: float = self.loop.time()
try:
if (
self._write_task
and not self._write_task.done()
and self._transport_exception is None
):
if isinstance(message, str):
message = message.encode("utf-8")
# 125 bytes (Spec) - 2 bytes for close code
if len(message) > 123:
message = message[:123]
# Send Close Frame and wait for queue to empty
close_frame: bytes = self._pack_close_frame(code, message)
await asyncio.wait_for(
self._send_queue.put((close_frame, CurlWsFlag.CLOSE)),
timeout=timeout,
)
# Subtract time already elapsed when flushing queue
await self.flush(
max(0.0, timeout - (self.loop.time() - close_start))
)
except (asyncio.TimeoutError, WebSocketError):
pass
finally:
# Ensure resources are cleaned up
self.terminate()
with suppress(asyncio.TimeoutError):
_ = await asyncio.wait_for(
self._terminated_event.wait(),
max(0.0, timeout - (self.loop.time() - close_start)),
)
def terminate(self) -> None: # pyright: ignore[reportImplicitOverride]
"""
Immediately terminates the connection without a graceful handshake.
This method is a forceful shutdown that cancels all background I/O tasks
and cleans up resources. It should be used for final cleanup or after an
unrecoverable error. Unlike ``close()``, it does not attempt to send a close
frame or wait for pending messages. It schedules the cleanup to run on the
event loop and returns immediately. It does not wait for cleanup completion.
This method is thread-safe, task-safe, and idempotent.
"""
with self._terminate_lock:
if self._terminated:
return
self._terminated = True
loop: asyncio.AbstractEventLoop | None = self._loop
# Get the currently running event loop
try:
current_loop: asyncio.AbstractEventLoop | None = (
asyncio.get_running_loop()
)
except RuntimeError:
current_loop = None
try:
if loop is None:
raise RuntimeError("Event loop not available")
# Run the termination task
if current_loop is not None and current_loop is loop:
_ = loop.create_task(self._terminate_helper())
else:
_ = asyncio.run_coroutine_threadsafe(self._terminate_helper(), loop)
# pylint: disable-next=broad-exception-caught
except Exception:
try:
super().terminate()
finally:
self._terminated_event.set()
async def _read_loop(self) -> None:
"""
The main asynchronous task for reading incoming WebSocket frames.
Attempts to read immediately and only registers an event loop reader if
the socket returns EAGAIN (empty). It waits for the underlying socket to
become readable, and upon being woken by the event loop, it drains all
buffered data from libcurl until it receives an EAGAIN error. This error
signals that the buffer is empty, and the loop returns to an idle state,
waiting for the next readability event. This is "optimistic reading".
To ensure cooperative multitasking during high-volume message streams,
the loop yields control to the asyncio event loop periodically which
is tracked using an operation counter.
If the receive queue becomes full, ``await self._receive_queue.put()`` will
block the reader loop and stall the socket read task. Thus, appropriate queue
sizes should be set by the user, to match the speed at which they are expected
to be consumed. If latency is a factor, a smaller queue size should be used.
Conversely, a larger queue size provides burst message handling capacity.
"""
# Cache locals to avoid repeated attribute lookups
curl_ws_recv: Callable[[], tuple[bytes, CurlWsFrame]] = self.curl.ws_recv
queue_put_nowait: Callable[[RECV_QUEUE_ITEM], None] = (
self._receive_queue.put_nowait
)
queue_put: Callable[[RECV_QUEUE_ITEM], Awaitable[None]] = (
self._receive_queue.put
)
loop: asyncio.AbstractEventLoop = self.loop
loop_time: Callable[[], float] = loop.time
create_future: Callable[[], asyncio.Future[None]] = loop.create_future
add_reader: Callable[..., None] = loop.add_reader
remove_reader: Callable[..., bool] = loop.remove_reader
time_slice: float = self._recv_time_slice
next_yield: float = loop_time() + time_slice
retry_on_error: bool = self.ws_retry.retry
retry_codes: set[CurlECode] = self.ws_retry.codes
max_retries: int = self.ws_retry.count
retry_base: float = float(self.ws_retry.delay)
e_again: int = int(CurlECode.AGAIN)
e_recv_err: int = int(CurlECode.RECV_ERROR)
e_nothing: int = int(CurlECode.GOT_NOTHING)
close_flag: int = int(CurlWsFlag.CLOSE)
cont_flag: int = int(CurlWsFlag.CONT)
data_mask: int = CurlWsFlag.BINARY | CurlWsFlag.TEXT | cont_flag
max_msg_size: int = self._max_message_size
block_on_recv: bool = self._block_on_recv_queue_full
queue_full_err: str = (
"Receive queue full; failing connection to preserve message integrity"
)
set_fut_result: Callable[[asyncio.Future[None]], None] = _safe_set_result
# Message specific values
recv_error_retries: int = 0
chunks: list[bytes] = []
msg_size: int = 0
chunks_append: Callable[[bytes], None] = chunks.append
chunks_clear: Callable[[], None] = chunks.clear
try:
while not self.closed:
try:
chunk, frame = curl_ws_recv()
except CurlError as e:
should_retry: bool = False
# Handle normal cURL EAGAINs
if e.code == e_again:
should_retry = True
# EAGAIN ("errno 11") bubbling up as RECV_ERROR from BoringSSL
elif e.code == e_recv_err:
err_msg: str = str(e).lower()
if (
"errno 11" in err_msg
or "resource temporarily unavailable" in err_msg
):
should_retry = True
# Handle Server Disconnect (Empty Reply)
elif e.code == e_nothing:
final_exc: WebSocketClosed = WebSocketClosed(
"Connection closed unexpectedly by server (EOF)",
WsCloseCode.ABNORMAL_CLOSURE,
)
final_exc.__cause__ = e
final_exc.__suppress_context__ = True
self._finalize_connection(final_exc)
return
if should_retry:
read_future: asyncio.Future[None] = create_future()
try:
add_reader(self._sock_fd, set_fut_result, read_future)
await read_future
# pylint: disable-next=broad-exception-caught
except Exception as exc:
self._finalize_connection(
WebSocketError(
f"Socket closed unexpectedly: {exc}",
CurlECode.NO_CONNECTION_AVAILABLE,
)
)
return
finally:
if self._sock_fd != -1:
try: # noqa: SIM105
_ = remove_reader(self._sock_fd)
# pylint: disable-next=broad-exception-caught
except Exception:
pass
# Loop back to the top to try reading again
continue
# Apply the user-configured retry logic
if (
retry_on_error
and e.code in retry_codes
and recv_error_retries < max_retries
):
recv_error_retries += 1
# Formula: base * (2 ^ (attempt - 1))
retry_delay: float = ( # pyright: ignore[reportAny]
retry_base * (2 ** (recv_error_retries - 1))
)
# Add Jitter: +/- 10%
jitter: float = retry_delay * 0.1
retry_delay += uniform(-jitter, jitter)
await asyncio.sleep(max(0.0, retry_delay))
continue
# Fatal error - can't retry
self._finalize_connection(e)
return
flags: int = frame.flags
if recv_error_retries > 0:
recv_error_retries = 0
# Data Frames (Text / Binary / Cont)
if flags & data_mask:
# Perform message size checks
msg_size += len(chunk)
if msg_size > max_msg_size:
chunks_clear()
self._finalize_connection(
WebSocketError(
(
f"Message too large: {msg_size} bytes "
f"(limit {max_msg_size} bytes). "
"Consider increasing max_message_size or "
"chunking the message."
),
CurlECode.TOO_LARGE,
)
)
return
# Collect the chunk
chunks_append(chunk)
# If the message is complete, process and dispatch it
if not (flags & cont_flag or frame.bytesleft):
message: bytes = (
chunks[0] if len(chunks) == 1 else b"".join(chunks)
)
chunks_clear()
msg_size = 0
try:
queue_put_nowait((message, flags))
except asyncio.QueueFull:
if not block_on_recv:
self._finalize_connection(
WebSocketError(
queue_full_err, CurlECode.OUT_OF_MEMORY
)
)
return
await queue_put((message, flags))
if loop_time() >= next_yield:
await asyncio.sleep(0)
next_yield = loop_time() + time_slice
continue
# If a CLOSE frame is received, the reader is done.
if flags & close_flag:
chunks_clear()
try:
queue_put_nowait((chunk, flags))
except asyncio.QueueFull:
if not block_on_recv:
self._finalize_connection(
WebSocketError(queue_full_err, CurlECode.OUT_OF_MEMORY)
)
return
await queue_put((chunk, flags))
await self._handle_close_frame(chunk)
return
except asyncio.CancelledError:
pass
# pylint: disable-next=broad-exception-caught
except Exception as e:
self._finalize_connection(e)
async def _write_loop(self) -> None:
"""
The background task responsible for consuming the send queue
and transmitting frames.
To maximize performance, this loop hoists the configuration
check and enters one of two distinct processing strategies:
1. Standard Mode (No Coalescing):
The default, low-latency path. Messages are consumed one-by-one
from the queue and transmitted immediately. This guarantees that one
``send()`` call results in exactly one WebSocket message, preserving
logical message boundaries.
2. Coalescing Mode:
An optimized throughput path for chatty streams. The loop greedily gathers
multiple pending messages from the queue (up to ``max_send_batch_size``
and merges their payloads into a single transmission if they share the
same flags (e.g., multiple text frames). This reduces system call
overhead but does not preserve individual message boundaries.
Features:
- Cooperative Multitasking: Yields to the event loop periodically to prevent
the writer from starving the reader task during high-volume transmission.
- Control Frame Priority: PING and CLOSE frames are never coalesced; they
trigger an immediate flush of any pending batched data before being sent.
- Lifecycle Management: Automatically terminates the connection cleanly upon
transmitting a CLOSE frame.
"""
control_frame_flags: int = CurlWsFlag.CLOSE | CurlWsFlag.PING | CurlWsFlag.PONG
close_flag: int = int(CurlWsFlag.CLOSE)
send_payload: Callable[..., Awaitable[bool]] = self._send_payload
queue_get: Callable[[], Awaitable[SEND_QUEUE_ITEM]] = self._send_queue.get
queue_get_nowait: Callable[[], SEND_QUEUE_ITEM] = self._send_queue.get_nowait
queue_done: Callable[[], None] = self._send_queue.task_done
loop: asyncio.AbstractEventLoop = self.loop
loop_time: Callable[[], float] = loop.time
time_slice: float = self._send_time_slice
next_yield: float = loop_time() + time_slice
try:
# Hoist the branch - decide loop strategy once at start
if not self._coalesce_frames:
# Optimized fast path, no batching overhead
while True:
payload, flags = await queue_get()
try:
if not await send_payload(payload, flags):
return
if flags & close_flag:
break
# Perform yield checks
if loop_time() >= next_yield:
await asyncio.sleep(0)
next_yield = loop_time() + time_slice
finally:
queue_done()
else:
# Coalescing path: Batch multiple frames to merge payloads
while True:
payload, flags = await queue_get()
# Build the rest of the batch without waiting.
batch: list[SEND_QUEUE_ITEM] = [(payload, flags)]
if not (flags & close_flag):
while len(batch) < self._max_send_batch_size:
try:
payload, frame = queue_get_nowait()
batch.append((payload, frame))
if frame & close_flag:
break
except asyncio.QueueEmpty:
break
try:
# Group consecutive frames with same flags to preserve order.
# Control frames are strictly isolated and never coalesced.
coalesced: list[
tuple[list[bytes | bytearray | memoryview], int]
] = []
for payload, frame in batch:
if frame & control_frame_flags:
coalesced.append(([payload], frame))
else:
if coalesced and coalesced[-1][1] == frame:
coalesced[-1][0].append(payload)
else:
coalesced.append(([payload], frame))
# Transmit the coalesced groups in their exact original order
for payloads, frame_group in coalesced:
if not await send_payload(b"".join(payloads), frame_group):
return
# Perform yield checks
if loop_time() >= next_yield:
await asyncio.sleep(0)
next_yield = loop_time() + time_slice
finally:
# Mark all processed items as done.
for _ in range(len(batch)):
queue_done()
# Exit cleanly after sending a CLOSE frame.
if batch[-1][1] & close_flag:
break
except asyncio.CancelledError:
pass
# pylint: disable-next=broad-exception-caught
except Exception as e:
self._finalize_connection(e)
finally:
# If the loop exits unexpectedly, ensure we terminate the connection.
if not self.closed:
self.terminate()
async def _send_payload(
self, payload: bytes | memoryview | bytearray, flags: CurlWsFlag | int
) -> bool:
"""
Optimized low-level sender with fragmentation logic.
"""
# Cache locals to reduce lookup cost
curl_ws_send: Callable[[memoryview, CurlWsFlag | int], int] = self.curl.ws_send
loop: asyncio.AbstractEventLoop = self.loop
loop_time: Callable[[], float] = loop.time
create_future: Callable[[], asyncio.Future[None]] = loop.create_future
add_writer: Callable[..., None] = loop.add_writer
remove_writer: Callable[..., bool] = loop.remove_writer
set_fut_result: Callable[[asyncio.Future[None]], None] = _safe_set_result
sock_fd: int = self._sock_fd
time_slice: float = self._send_time_slice
next_yield: float = loop_time() + time_slice
max_frame_size: int = self._MAX_CURL_FRAME_SIZE
e_again: int = int(CurlECode.AGAIN)
cont_flag: int = int(CurlWsFlag.CONT)
max_zero_writes: int = 3
# Message specific values
base_flags: int = flags & ~cont_flag
view: memoryview = memoryview(payload)
total_bytes: int = view.nbytes
offset: int = 0
write_retries: int = 0
frame_end: int = 0
current_flags: CurlWsFlag | int = flags
# Loop until the entire view is sent
while offset < total_bytes or (offset == 0 and total_bytes == 0):
# Boundary check: Calculate next fragment ONLY when needed
if offset == frame_end:
if total_bytes - offset > max_frame_size:
frame_end = offset + max_frame_size
current_flags = base_flags | cont_flag
else:
frame_end = total_bytes
current_flags = flags
try:
# libcurl returns the number of bytes actually sent
n_sent: int = curl_ws_send(view[offset:frame_end], current_flags)
if n_sent == 0:
# Handle 0-byte payload (Valid Empty Frame)
if frame_end - offset == 0:
return True
# Raise AGAIN to jump to the existing wait logic below
write_retries += 1
if write_retries >= max_zero_writes:
self._finalize_connection(
WebSocketError(
f"Writer stalled ({write_retries} attempts).",
CurlECode.WRITE_ERROR,
)
)
return False
raise CurlError("0 bytes sent", e_again)
if write_retries:
write_retries = 0
offset += n_sent
# Cooperative yield checks
if loop_time() >= next_yield:
await asyncio.sleep(0)
next_yield = loop_time() + time_slice
except CurlError as e:
if e.code == e_again:
# Wait for socket to be writable
write_future: asyncio.Future[None] = create_future()
try:
add_writer(sock_fd, set_fut_result, write_future)
await write_future
# pylint: disable-next=broad-exception-caught
except Exception as exc:
self._finalize_connection(
WebSocketError(
f"Socket closed unexpectedly during write: {exc}",
CurlECode.NO_CONNECTION_AVAILABLE,
)
)
return False
finally:
if sock_fd != -1:
try: # noqa: SIM105
_ = remove_writer(sock_fd)
# pylint: disable-next=broad-exception-caught
except Exception:
pass
# Retry the exact same chunk
continue
# Fatal Error
self._finalize_connection(e)
return False
return True
async def flush(self, timeout: float | None = None) -> None:
"""Waits until all items in the send queue have been processed.
This ensures that all messages passed to `send()` have been handed off to the
underlying socket for transmission. It does not guarantee that the data has
been received by the remote peer.
Args:
timeout (float | None, optional): The maximum number of seconds to wait
for the queue to drain.
Raises:
WebSocketTimeout: If the send queue is not fully processed within the
specified ``timeout`` period.
WebSocketError: If the writer task has already terminated while unsent
messages remain in the queue.
"""
if self._write_task is None:
return
# Create a task for the queue join operation
join_task: asyncio.Task[None] = asyncio.create_task(self._send_queue.join())
try:
done, _ = await asyncio.wait(
{join_task, self._write_task},
return_when=asyncio.FIRST_COMPLETED,
timeout=timeout,
)
if not done:
raise WebSocketTimeout("Timed out waiting for send queue to flush.")
# Fast path: queue drained
if join_task in done:
return
# Writer finished first
if self._write_task in done:
try:
self._write_task.result()
except Exception as exc:
raise WebSocketError("Writer task crashed while flushing.") from exc
# If the writer finished gracefully
if self._send_queue.empty():
return
raise WebSocketError("Writer task stopped unexpectedly while flushing.")
finally:
# Cancel join task on early exit to avoid leak.
if not join_task.done():
_ = join_task.cancel()
with suppress(asyncio.CancelledError):
await join_task
async def _terminate_helper(self) -> None:
"""Utility method for connection termination"""
tasks_to_cancel: set[asyncio.Task[None]] = {
t
for t in (self._read_task, self._write_task)
if t is not None and not t.done()
}
max_timeout: int = 2
try:
# Cancel all the I/O tasks
for io_task in tasks_to_cancel:
_ = io_task.cancel()
# Wait for cancellation but don't get stuck
if tasks_to_cancel:
_, pending = await asyncio.wait(
tasks_to_cancel,
timeout=max_timeout,
return_when=asyncio.ALL_COMPLETED,
)
# Force cancel tasks that didn't complete within timeout.
for p in pending:
_ = p.cancel()
# Drain the send_queue
while not self._send_queue.empty():
try:
_ = self._send_queue.get_nowait()
self._send_queue.task_done()
except (asyncio.QueueEmpty, ValueError):
break
# Remove the reader/writer if still registered
if self._sock_fd != -1:
with suppress(Exception):
_ = self.loop.remove_reader(self._sock_fd)
with suppress(Exception):
_ = self.loop.remove_writer(self._sock_fd)
self._sock_fd = -1
# Close the Curl connection
super().terminate()
if self.session and not self.session._closed:
# WebSocket curls CANNOT be reused
self.session.push_curl(None)
finally:
self._terminated_event.set()
async def _handle_close_frame(self, message: bytes) -> None:
"""Unpack and handle the closing frame, then initiate shutdown."""
try:
self._close_code, self._close_reason = self._unpack_close_frame(message)
except WebSocketError as e:
self._close_code = e.code
if self.autoclose and not self.closed:
close_code: int | CurlECode | Literal[WsCloseCode.OK] = (
WsCloseCode.OK
if self._close_code == WsCloseCode.UNKNOWN
else (self._close_code or WsCloseCode.OK)
)
await self.close(close_code)
else:
# If not sending a reply, we must still terminate the connection.
self.terminate()
@final
class AsyncWebSocketContext:
"""Helper to enable simpler context manager usage"""
__slots__ = ("_coro", "_obj")
def __init__(self, coro: Awaitable[AsyncWebSocket]) -> None:
self._coro: Awaitable[AsyncWebSocket] = coro
self._obj: AsyncWebSocket | None = None
def __await__(self) -> Generator[object, object, AsyncWebSocket]:
return self._coro.__await__()
async def __aenter__(self) -> AsyncWebSocket:
self._obj = await self._coro
return self._obj
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: object | None,
) -> None:
if self._obj:
if exc_type is None:
await self._obj.close()
else:
with suppress(CurlError):
await self._obj.close()