diff --git a/proxy/http/proxy/plugin.py b/proxy/http/proxy/plugin.py index 4b762f03cd..3180cfb09c 100644 --- a/proxy/http/proxy/plugin.py +++ b/proxy/http/proxy/plugin.py @@ -152,8 +152,7 @@ def handle_client_request( def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: """Handler called right after receiving raw response from upstream server. - For HTTPS connections, chunk will be encrypted unless - TLS interception is also enabled.""" + For HTTPS connections, chunk will be encrypted unless TLS interception is also enabled.""" return chunk # pragma: no cover # No longer abstract since 2.4.0 diff --git a/proxy/http/websocket/frame.py b/proxy/http/websocket/frame.py index e17490591f..d0bbfb9781 100644 --- a/proxy/http/websocket/frame.py +++ b/proxy/http/websocket/frame.py @@ -78,18 +78,55 @@ def reset(self) -> None: self.mask = None self.data = None - def parse_fin_and_rsv(self, byte: int) -> None: - self.fin = bool(byte & 1 << 7) - self.rsv1 = bool(byte & 1 << 6) - self.rsv2 = bool(byte & 1 << 5) - self.rsv3 = bool(byte & 1 << 4) - self.opcode = byte & 0b00001111 + def parse(self, raw: bytes) -> bytes: + cur = 0 + self._parse_fin_and_rsv(raw[cur]) + cur += 1 - def parse_mask_and_payload(self, byte: int) -> None: - self.masked = bool(byte & 0b10000000) - self.payload_length = byte & 0b01111111 + self._parse_mask_and_payload(raw[cur]) + cur += 1 + + if self.payload_length == 126: + data = raw[cur: cur + 2] + self.payload_length, = struct.unpack('!H', data) + cur += 2 + elif self.payload_length == 127: + data = raw[cur: cur + 8] + self.payload_length, = struct.unpack('!Q', data) + cur += 8 + + if self.masked: + self.mask = raw[cur: cur + 4] + cur += 4 + + if self.payload_length and self.payload_length > 0: + self.data = raw[cur: cur + self.payload_length] + cur += self.payload_length + if self.masked: + assert self.mask is not None + self.data = self.apply_mask(self.data, self.mask) + + return raw[cur:] def build(self) -> bytes: + """Payload length: 7 bits, 7+16 bits, or 7+64 bits + + The length of the "Payload data", in bytes: if 0-125, that is the + payload length. If 126, the following 2 bytes interpreted as a + 16-bit unsigned integer are the payload length. If 127, the + following 8 bytes interpreted as a 64-bit unsigned integer (the + most significant bit MUST be 0) are the payload length. Multi-byte + length quantities are expressed in network byte order. Note that + in all cases, the minimal number of bytes MUST be used to encode + the length, for example, the length of a 124-byte-long string + can't be encoded as the sequence 126, 0, 124. The payload length + is the length of the "Extension data" + the length of the + "Application data". The length of the "Extension data" may be + zero, in which case the payload length is the length of the + "Application data". + + Ref https://datatracker.ietf.org/doc/html/rfc6455 + """ if self.payload_length is None and self.data: self.payload_length = len(self.data) raw = io.BytesIO() @@ -122,7 +159,7 @@ def build(self) -> bytes: elif self.payload_length < 1 << 64: raw.write( struct.pack( - '!BHQ', + '!BQ', (1 << 7 if self.masked else 0) | 127, self.payload_length, ), @@ -140,35 +177,16 @@ def build(self) -> bytes: raw.write(self.data) return raw.getvalue() - def parse(self, raw: bytes) -> bytes: - cur = 0 - self.parse_fin_and_rsv(raw[cur]) - cur += 1 - - self.parse_mask_and_payload(raw[cur]) - cur += 1 - - if self.payload_length == 126: - data = raw[cur: cur + 2] - self.payload_length, = struct.unpack('!H', data) - cur += 2 - elif self.payload_length == 127: - data = raw[cur: cur + 8] - self.payload_length, = struct.unpack('!Q', data) - cur += 8 - - if self.masked: - self.mask = raw[cur: cur + 4] - cur += 4 - - assert self.payload_length - self.data = raw[cur: cur + self.payload_length] - cur += self.payload_length - if self.masked: - assert self.mask is not None - self.data = self.apply_mask(self.data, self.mask) + def _parse_fin_and_rsv(self, byte: int) -> None: + self.fin = bool(byte & 1 << 7) + self.rsv1 = bool(byte & 1 << 6) + self.rsv2 = bool(byte & 1 << 5) + self.rsv3 = bool(byte & 1 << 4) + self.opcode = byte & 0b00001111 - return raw[cur:] + def _parse_mask_and_payload(self, byte: int) -> None: + self.masked = bool(byte & 0b10000000) + self.payload_length = byte & 0b01111111 @staticmethod def apply_mask(data: bytes, mask: bytes) -> bytes: diff --git a/proxy/plugin/__init__.py b/proxy/plugin/__init__.py index 578bc3bcc4..f56ad23c8b 100644 --- a/proxy/plugin/__init__.py +++ b/proxy/plugin/__init__.py @@ -28,6 +28,7 @@ from .custom_dns_resolver import CustomDnsResolverPlugin from .cloudflare_dns import CloudflareDnsResolverPlugin from .program_name import ProgramNamePlugin +from .modify_websocket_response import ModifyWebsocketResponsePlugin __all__ = [ 'CacheResponsesPlugin', @@ -47,4 +48,5 @@ 'CustomDnsResolverPlugin', 'CloudflareDnsResolverPlugin', 'ProgramNamePlugin', + 'ModifyWebsocketResponsePlugin', ] diff --git a/proxy/plugin/modify_websocket_response.py b/proxy/plugin/modify_websocket_response.py new file mode 100644 index 0000000000..1ed4d88dcf --- /dev/null +++ b/proxy/plugin/modify_websocket_response.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from ..http.proxy import HttpProxyBasePlugin +from ..http.websocket import WebsocketFrame + + +class ModifyWebsocketResponsePlugin(HttpProxyBasePlugin): + """Inspect/Modify/Send custom websocket responses.""" + + def handle_upstream_chunk(self, chunk: memoryview) -> memoryview: + # Parse the response. + # Note that these chunks also include headers + remaining = chunk.tobytes() + while len(remaining) > 0: + response = WebsocketFrame() + remaining = response.parse(remaining) + self.client.queue( + memoryview( + WebsocketFrame.text(b'modified websocket response'), + ), + ) + return memoryview(b'') diff --git a/proxy/plugin/web_server_route.py b/proxy/plugin/web_server_route.py index 5f881a68f7..398b66860e 100644 --- a/proxy/plugin/web_server_route.py +++ b/proxy/plugin/web_server_route.py @@ -14,6 +14,7 @@ from ..http.responses import okResponse from ..http.parser import HttpParser from ..http.server import HttpWebServerBasePlugin, httpProtocolTypes +from ..http.websocket import WebsocketFrame logger = logging.getLogger(__name__) @@ -36,3 +37,10 @@ def handle_request(self, request: HttpParser) -> None: self.client.queue(HTTP_RESPONSE) elif request.path == b'/https-route-example': self.client.queue(HTTPS_RESPONSE) + + def on_websocket_message(self, frame: WebsocketFrame) -> None: + self.client.queue( + memoryview( + WebsocketFrame.text(b'Websocket route response'), + ), + )