diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5a2e1bc2..74fac904 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,12 @@ notice. *In development* +New features +............ + +* Added support for regular expressions in the ``origins`` argument of + :func:`~asyncio.server.serve`. + Bug fixes ......... diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 080ea3f1..ebe45c2a 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -600,10 +600,11 @@ def handler(websocket): See :meth:`~asyncio.loop.create_server` for details. port: TCP port the server listens on. See :meth:`~asyncio.loop.create_server` for details. - origins: Acceptable values of the ``Origin`` header, including regular - expressions, for defending against Cross-Site WebSocket Hijacking - attacks. Include :obj:`None` in the list if the lack of an origin - is acceptable. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing diff --git a/src/websockets/server.py b/src/websockets/server.py index 67082ed7..90e6c992 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -50,9 +50,11 @@ class ServerProtocol(Protocol): Sans-I/O implementation of a WebSocket server connection. Args: - origins: Acceptable values of the ``Origin`` header, including regular - expressions; include :obj:`None` in the list if the lack of an origin - is acceptable. This is useful for defending against Cross-Site WebSocket + origins: Acceptable values of the ``Origin`` header. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. + This is useful for defending against Cross-Site WebSocket Hijacking attacks. extensions: List of supported extensions, in order in which they should be tried. @@ -310,17 +312,14 @@ def process_origin(self, headers: Headers) -> Origin | None: if origin is not None: origin = cast(Origin, origin) if self.origins is not None: - valid = False - for acceptable_origin_or_regex in self.origins: - if isinstance(acceptable_origin_or_regex, re.Pattern): - # `str(origin)` is needed for compatibility - # between `Pattern.match(string=...)` and `origin`. - valid = acceptable_origin_or_regex.match(str(origin)) is not None - else: - valid = acceptable_origin_or_regex == origin - if valid: + for origin_or_regex in self.origins: + if origin_or_regex == origin or ( + isinstance(origin_or_regex, re.Pattern) + and origin is not None + and origin_or_regex.fullmatch(origin) is not None + ): break - if not valid: + else: raise InvalidOrigin(origin) return origin diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index c14e558a..50a2f3c0 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -400,10 +400,11 @@ def handler(websocket): You may call :func:`socket.create_server` to create a suitable TCP socket. ssl: Configuration for enabling TLS on the connection. - origins: Acceptable values of the ``Origin`` header, including regular - expressions, for defending against Cross-Site WebSocket Hijacking - attacks. Include :obj:`None` in the list if the lack of an origin - is acceptable. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing diff --git a/tests/test_server.py b/tests/test_server.py index dd5e0d09..9f328ded 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -608,7 +608,7 @@ def test_supported_origin(self): self.assertEqual(server.origin, "https://other.example.com") def test_unsupported_origin(self): - """Handshake succeeds when checking origins and the origin is unsupported.""" + """Handshake fails when checking origins and the origin is unsupported.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) @@ -624,13 +624,10 @@ def test_unsupported_origin(self): "invalid Origin header: https://original.example.com", ) - def test_supported_origin_by_regex(self): - """ - Handshake succeeds when checking origins and the origin is supported - by a regular expression. - """ + def test_supported_origin_regex(self): + """Handshake succeeds when checking origins and the origin is supported.""" server = ServerProtocol( - origins=["https://example.com", re.compile(r"https://other.*")] + origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")] ) request = make_request() request.headers["Origin"] = "https://other.example.com" @@ -640,13 +637,10 @@ def test_supported_origin_by_regex(self): self.assertHandshakeSuccess(server) self.assertEqual(server.origin, "https://other.example.com") - def test_unsupported_origin_by_regex(self): - """ - Handshake succeeds when checking origins and the origin is unsupported - by a regular expression. - """ + def test_unsupported_origin_regex(self): + """Handshake fails when checking origins and the origin is unsupported.""" server = ServerProtocol( - origins=["https://example.com", re.compile(r"https://other.*")] + origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")] ) request = make_request() request.headers["Origin"] = "https://original.example.com" @@ -660,6 +654,23 @@ def test_unsupported_origin_by_regex(self): "invalid Origin header: https://original.example.com", ) + def test_partial_match_origin_regex(self): + """Handshake fails when checking origins and the origin a partial match.""" + server = ServerProtocol( + origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")] + ) + request = make_request() + request.headers["Origin"] = "https://other.example.com.hacked" + response = server.accept(request) + server.send_response(response) + + self.assertEqual(response.status_code, 403) + self.assertHandshakeError( + server, + InvalidOrigin, + "invalid Origin header: https://other.example.com.hacked", + ) + def test_no_origin_accepted(self): """Handshake succeeds when the lack of an origin is accepted.""" server = ServerProtocol(origins=[None])