import asyncio
import struct
from enum import IntEnum
from typing import Callable, Optional, Tuple
from curl_cffi.const import CurlECode, CurlWsFlag
from curl_cffi.curl import CurlError
ON_MESSAGE_T = Callable[["WebSocket", bytes], None]
ON_ERROR_T = Callable[["WebSocket", CurlError], None]
ON_OPEN_T = Callable[["WebSocket"], None]
ON_CLOSE_T = Callable[["WebSocket", int, str], None]
class WsCloseCode(IntEnum):
OK = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
UNKNOWN = 1005
ABNORMAL_CLOSURE = 1006
INVALID_DATA = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXTENSION = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
BAD_GATEWAY = 1014
class WebSocketError(CurlError):
pass
[docs]
class WebSocket:
[docs]
def __init__(
self,
session,
curl,
on_message: Optional[ON_MESSAGE_T] = None,
on_error: Optional[ON_ERROR_T] = None,
on_open: Optional[ON_OPEN_T] = None,
on_close: Optional[ON_CLOSE_T] = None,
):
self.session = session
self.curl = curl
self.on_message = on_message
self.on_error = on_error
self.on_open = on_open
self.on_close = on_close
self.keep_running = True
self._loop = None
def recv_fragment(self):
return self.curl.ws_recv()
[docs]
def recv(self) -> Tuple[bytes, int]:
"""
Receive a frame as bytes.
libcurl split frames into fragments, so we have to collect all the chunks for
a frame.
"""
chunks = []
flags = 0
# TODO use select here
while True:
try:
chunk, frame = self.curl.ws_recv()
flags = frame.flags
chunks.append(chunk)
if frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0:
break
except CurlError as e:
if e.code == CurlECode.AGAIN:
pass
else:
raise
return b"".join(chunks), flags
[docs]
def send(self, payload: bytes, flags: CurlWsFlag = CurlWsFlag.BINARY):
"""Send a data frame"""
return self.curl.ws_send(payload, flags)
[docs]
def run_forever(self):
"""
libcurl automatically handles pings and pongs.
ref: https://curl.se/libcurl/c/libcurl-ws.html
"""
if self.on_open:
self.on_open(self)
# Keep reading the messages and invoke callbacks
while self.keep_running:
try:
msg, flags = self.recv()
if self.on_message:
self.on_message(self, msg)
if flags & CurlWsFlag.CLOSE:
self.keep_running = False
# Unpack close code and reason
if len(msg) < 2:
code = WsCloseCode.UNKNOWN
reason = ""
else:
try:
code = struct.unpack_from("!H", msg)[0]
reason = msg[2:].decode()
except UnicodeDecodeError as e:
raise WebSocketError(
"Invalid close message", WsCloseCode.INVALID_DATA
) from e
except Exception as e:
raise WebSocketError(
"Invalid close frame", WsCloseCode.PROTOCOL_ERROR
) from e
else:
if code < 3000 and (code not in WsCloseCode or code == 1005):
raise WebSocketError(
"Invalid close code", WsCloseCode.PROTOCOL_ERROR
)
if self.on_close:
self.on_close(self, code, reason)
except WebSocketError as e:
# Follow the spec to close the connection
# TODO: Consider adding setting to autoclose connection on error-free close
self.close(e.code)
if self.on_error:
self.on_error(self, e)
except CurlError as e:
if self.on_error:
self.on_error(self, e)
[docs]
def close(self, code: int = WsCloseCode.OK, message: bytes = b""):
msg = struct.pack("!H", code) + message
# FIXME how to reset, or can a curl handle connect to two websockets?
self.send(msg, CurlWsFlag.CLOSE)
self.keep_running = False
self.curl.close()
@property
def loop(self):
if self._loop is None:
self._loop = asyncio.get_running_loop()
return self._loop
[docs]
async def arecv(self) -> Tuple[bytes, int]:
return await self.loop.run_in_executor(None, self.recv)
[docs]
async def asend(self, payload: bytes, flags: CurlWsFlag = CurlWsFlag.BINARY):
return await self.loop.run_in_executor(None, self.send, payload, flags)
async def aclose(self, code: int = WsCloseCode.OK, message: bytes = b""):
await self.loop.run_in_executor(None, self.close, code, message)
self.curl.reset()
self.session.push_curl(self.curl)