138 lines
3.5 KiB
Python
138 lines
3.5 KiB
Python
import logging
|
|
import struct
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class WsFrame:
|
|
|
|
CONT = 0
|
|
TEXT = 1
|
|
BINARY = 2
|
|
RSVD3 = 3
|
|
RSVD4 = 4
|
|
RSVD5 = 5
|
|
RSVD6 = 6
|
|
RSVD7 = 7
|
|
CLOSE = 8
|
|
PING = 9
|
|
PONG = 10
|
|
RSVD11 = 11
|
|
RSVD12 = 12
|
|
RSVD13 = 13
|
|
RSVD14 = 14
|
|
RSVD15 = 15
|
|
|
|
OP_NAMES = [
|
|
"CONT",
|
|
"TEXT",
|
|
"BINARY",
|
|
"RSVD3",
|
|
"RSVD4",
|
|
"RSVD5",
|
|
"RSVD6",
|
|
"RSVD7",
|
|
"CLOSE",
|
|
"PING",
|
|
"PONG",
|
|
"RSVD11",
|
|
"RSVD12",
|
|
"RSVD13",
|
|
"RSVD14",
|
|
"RSVD15",
|
|
]
|
|
|
|
def __init__(self, opcode: int, fin: bool, mask: bytes, data: bytes):
|
|
self.opcode = opcode
|
|
self.fin = fin
|
|
self.mask = mask
|
|
self.data = data
|
|
self.length = len(data)
|
|
|
|
def __repr__(self):
|
|
return f'WsFrame[{self.OP_NAMES[self.opcode]} fin={self.fin}, mask={self.mask}, len={len(self.data)}]'
|
|
|
|
@property
|
|
def data_len(self) -> int:
|
|
return len(self.data) if self.data else 0
|
|
|
|
def to_network(self) -> bytes:
|
|
nd = bytearray()
|
|
h1 = self.opcode
|
|
if self.fin:
|
|
h1 |= 0x80
|
|
nd.extend(struct.pack("!B", h1))
|
|
mask_bit = 0x80 if self.mask is not None else 0x0
|
|
h2 = self.data_len
|
|
if h2 > 65535:
|
|
nd.extend(struct.pack("!BQ", 127|mask_bit, h2))
|
|
elif h2 > 126:
|
|
nd.extend(struct.pack("!BH", 126|mask_bit, h2))
|
|
else:
|
|
nd.extend(struct.pack("!B", h2|mask_bit))
|
|
if self.mask is not None:
|
|
nd.extend(self.mask)
|
|
if self.data is not None:
|
|
nd.extend(self.data)
|
|
return nd
|
|
|
|
@classmethod
|
|
def client_ping(cls, data: bytes, mask: bytes = None) -> 'WsFrame':
|
|
if mask is None:
|
|
mask = bytes.fromhex('00 00 00 00')
|
|
return WsFrame(opcode=WsFrame.PING, fin=True, mask=mask, data=data)
|
|
|
|
@classmethod
|
|
def client_close(cls, code: int, reason: str = None,
|
|
mask: bytes = None) -> 'WsFrame':
|
|
data = bytearray(struct.pack("!H", code))
|
|
if reason is not None:
|
|
data.extend(reason.encode())
|
|
if mask is None:
|
|
mask = bytes.fromhex('00 00 00 00')
|
|
return WsFrame(opcode=WsFrame.CLOSE, fin=True, mask=mask, data=data)
|
|
|
|
|
|
class WsFrameReader:
|
|
|
|
def __init__(self, data: bytes):
|
|
self.data = data
|
|
|
|
def _read(self, n: int):
|
|
if len(self.data) < n:
|
|
raise EOFError(f'have {len(self.data)} bytes left, but {n} requested')
|
|
elif n == 0:
|
|
return b''
|
|
chunk = self.data[:n]
|
|
del self.data[:n]
|
|
return chunk
|
|
|
|
def next_frame(self):
|
|
data = self._read(2)
|
|
h1, h2 = struct.unpack("!BB", data)
|
|
log.debug(f'parsed h1={h1} h2={h2} from {data}')
|
|
fin = True if h1 & 0x80 else False
|
|
opcode = h1 & 0xf
|
|
has_mask = True if h2 & 0x80 else False
|
|
mask = None
|
|
dlen = h2 & 0x7f
|
|
if dlen == 126:
|
|
(dlen,) = struct.unpack("!H", self._read(2))
|
|
elif dlen == 127:
|
|
(dlen,) = struct.unpack("!Q", self._read(8))
|
|
if has_mask:
|
|
mask = self._read(4)
|
|
return WsFrame(opcode=opcode, fin=fin, mask=mask, data=self._read(dlen))
|
|
|
|
def eof(self):
|
|
return len(self.data) == 0
|
|
|
|
@classmethod
|
|
def parse(cls, data: bytes):
|
|
frames = []
|
|
reader = WsFrameReader(data=data)
|
|
while not reader.eof():
|
|
frames.append(reader.next_frame())
|
|
return frames
|