Skip to content

Commit

Permalink
Improve error handling for SOCKS proxy.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed Jan 28, 2025
1 parent 10175f7 commit 3a725f0
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 78 deletions.
8 changes: 5 additions & 3 deletions docs/reference/exceptions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@ also reported by :func:`~websockets.asyncio.server.serve` in logs.

.. autoexception:: SecurityError

.. autoexception:: InvalidMessage

.. autoexception:: InvalidStatus
.. autoexception:: ProxyError

.. autoexception:: InvalidProxyMessage

.. autoexception:: InvalidProxyStatus

.. autoexception:: InvalidMessage

.. autoexception:: InvalidStatus

.. autoexception:: InvalidHeader

.. autoexception:: InvalidHeaderFormat
Expand Down
3 changes: 3 additions & 0 deletions src/websockets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"NegotiationError",
"PayloadTooBig",
"ProtocolError",
"ProxyError",
"SecurityError",
"WebSocketException",
# .frames
Expand Down Expand Up @@ -112,6 +113,7 @@
NegotiationError,
PayloadTooBig,
ProtocolError,
ProxyError,
SecurityError,
WebSocketException,
)
Expand Down Expand Up @@ -173,6 +175,7 @@
"NegotiationError": ".exceptions",
"PayloadTooBig": ".exceptions",
"ProtocolError": ".exceptions",
"ProxyError": ".exceptions",
"SecurityError": ".exceptions",
"WebSocketException": ".exceptions",
# .frames
Expand Down
17 changes: 14 additions & 3 deletions src/websockets/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ..client import ClientProtocol, backoff
from ..datastructures import HeadersLike
from ..exceptions import InvalidMessage, InvalidStatus, SecurityError
from ..exceptions import InvalidMessage, InvalidStatus, ProxyError, SecurityError
from ..extensions.base import ClientExtensionFactory
from ..extensions.permessage_deflate import enable_client_permessage_deflate
from ..headers import validate_subprotocols
Expand Down Expand Up @@ -148,7 +148,9 @@ def process_exception(exc: Exception) -> Exception | None:
That exception will be raised, breaking out of the retry loop.
"""
if isinstance(exc, (OSError, asyncio.TimeoutError)):
# This catches socks_proxy's ProxyConnectionError and ProxyTimeoutError.
# Remove asyncio.TimeoutError when dropping Python < 3.11.
if isinstance(exc, (OSError, TimeoutError, asyncio.TimeoutError)):
return None
if isinstance(exc, InvalidMessage) and isinstance(exc.__cause__, EOFError):
return None
Expand Down Expand Up @@ -266,6 +268,7 @@ class connect:
Raises:
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
InvalidProxy: If ``proxy`` isn't a valid proxy.
OSError: If the TCP connection fails.
InvalidHandshake: If the opening handshake fails.
TimeoutError: If the opening handshake times out.
Expand Down Expand Up @@ -622,7 +625,15 @@ async def connect_socks_proxy(
proxy.password,
SOCKS_PROXY_RDNS[proxy.scheme],
)
return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs)
# connect() is documented to raise OSError.
# socks_proxy.connect() doesn't raise TimeoutError; it gets canceled.
# Wrap other exceptions in ProxyError, a subclass of InvalidHandshake.
try:
return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs)
except OSError:
raise
except Exception as exc:
raise ProxyError("failed to connect to SOCKS proxy") from exc

except ImportError:

Expand Down
41 changes: 25 additions & 16 deletions src/websockets/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
* :exc:`InvalidProxy`
* :exc:`InvalidHandshake`
* :exc:`SecurityError`
* :exc:`ProxyError`
* :exc:`InvalidProxyMessage`
* :exc:`InvalidProxyStatus`
* :exc:`InvalidMessage`
* :exc:`InvalidStatus`
* :exc:`InvalidStatusCode` (legacy)
* :exc:`InvalidProxyMessage`
* :exc:`InvalidProxyStatus`
* :exc:`InvalidHeader`
* :exc:`InvalidHeaderFormat`
* :exc:`InvalidHeaderValue`
Expand Down Expand Up @@ -48,10 +49,11 @@
"InvalidProxy",
"InvalidHandshake",
"SecurityError",
"InvalidMessage",
"InvalidStatus",
"ProxyError",
"InvalidProxyMessage",
"InvalidProxyStatus",
"InvalidMessage",
"InvalidStatus",
"InvalidHeader",
"InvalidHeaderFormat",
"InvalidHeaderValue",
Expand Down Expand Up @@ -206,46 +208,53 @@ class SecurityError(InvalidHandshake):
"""


class InvalidMessage(InvalidHandshake):
class ProxyError(InvalidHandshake):
"""
Raised when a handshake request or response is malformed.
Raised when failing to connect to a proxy.
"""


class InvalidStatus(InvalidHandshake):
class InvalidProxyMessage(ProxyError):
"""
Raised when a handshake response rejects the WebSocket upgrade.
Raised when an HTTP proxy response is malformed.
"""


class InvalidProxyStatus(ProxyError):
"""
Raised when an HTTP proxy rejects the connection.
"""

def __init__(self, response: http11.Response) -> None:
self.response = response

def __str__(self) -> str:
return (
f"server rejected WebSocket connection: HTTP {self.response.status_code:d}"
)
return f"proxy rejected connection: HTTP {self.response.status_code:d}"


class InvalidProxyMessage(InvalidHandshake):
class InvalidMessage(InvalidHandshake):
"""
Raised when a proxy response is malformed.
Raised when a handshake request or response is malformed.
"""


class InvalidProxyStatus(InvalidHandshake):
class InvalidStatus(InvalidHandshake):
"""
Raised when a proxy rejects the connection.
Raised when a handshake response rejects the WebSocket upgrade.
"""

def __init__(self, response: http11.Response) -> None:
self.response = response

def __str__(self) -> str:
return f"proxy rejected connection: HTTP {self.response.status_code:d}"
return (
f"server rejected WebSocket connection: HTTP {self.response.status_code:d}"
)


class InvalidHeader(InvalidHandshake):
Expand Down
10 changes: 9 additions & 1 deletion src/websockets/sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ..client import ClientProtocol
from ..datastructures import HeadersLike
from ..exceptions import ProxyError
from ..extensions.base import ClientExtensionFactory
from ..extensions.permessage_deflate import enable_client_permessage_deflate
from ..headers import validate_subprotocols
Expand Down Expand Up @@ -420,7 +421,14 @@ def connect_socks_proxy(
SOCKS_PROXY_RDNS[proxy.scheme],
)
kwargs.setdefault("timeout", deadline.timeout())
return socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs)
# connect() is documented to raise OSError and TimeoutError.
# Wrap other exceptions in ProxyError, a subclass of InvalidHandshake.
try:
return socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs)
except (OSError, TimeoutError, socket.timeout):
raise
except Exception as exc:
raise ProxyError("failed to connect to SOCKS proxy") from exc

except ImportError:

Expand Down
93 changes: 68 additions & 25 deletions tests/asyncio/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
InvalidProxy,
InvalidStatus,
InvalidURI,
ProxyError,
SecurityError,
)
from websockets.extensions.permessage_deflate import PerMessageDeflate
Expand Down Expand Up @@ -379,24 +380,16 @@ def remove_accept_header(self, request, response):

async def test_timeout_during_handshake(self):
"""Client times out before receiving handshake response from server."""
gate = asyncio.get_running_loop().create_future()

async def stall_connection(self, request):
await gate

# The connection will be open for the server but failed for the client.
# Use a connection handler that exits immediately to avoid an exception.
async with serve(*args, process_request=stall_connection) as server:
try:
with self.assertRaises(TimeoutError) as raised:
async with connect(get_uri(server) + "/no-op", open_timeout=2 * MS):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"timed out during handshake",
)
finally:
gate.set_result(None)
# Replace the WebSocket server with a TCP server that does't respond.
with socket.create_server(("localhost", 0)) as sock:
host, port = sock.getsockname()
with self.assertRaises(TimeoutError) as raised:
async with connect(f"ws://{host}:{port}", open_timeout=MS):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"timed out during handshake",
)

async def test_connection_closed_during_handshake(self):
"""Client reads EOF before receiving handshake response from server."""
Expand Down Expand Up @@ -570,11 +563,13 @@ class ProxyClientTests(unittest.IsolatedAsyncioTestCase):
async def socks_proxy(self, auth=None):
if auth:
proxyauth = "hello:iloveyou"
proxy_uri = "http://hello:iloveyou@localhost:1080"
proxy_uri = "http://hello:iloveyou@localhost:51080"
else:
proxyauth = None
proxy_uri = "http://localhost:1080"
async with async_proxy(mode=["socks5"], proxyauth=proxyauth) as record_flows:
proxy_uri = "http://localhost:51080"
async with async_proxy(
mode=["socks5@51080"], proxyauth=proxyauth
) as record_flows:
with patch_environ({"socks_proxy": proxy_uri}):
yield record_flows

Expand Down Expand Up @@ -602,14 +597,62 @@ async def test_authenticated_socks_proxy(self):
self.assertEqual(client.protocol.state.name, "OPEN")
self.assertEqual(len(proxy.get_flows()), 1)

async def test_socks_proxy_connection_error(self):
"""Client receives an error when connecting to the SOCKS5 proxy."""
from python_socks import ProxyError as SocksProxyError

async with self.socks_proxy(auth=True) as proxy:
with self.assertRaises(ProxyError) as raised:
async with connect(
"ws://example.com/",
proxy="socks5h://localhost:51080", # remove credentials
):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"failed to connect to SOCKS proxy",
)
self.assertIsInstance(raised.exception.__cause__, SocksProxyError)
self.assertEqual(len(proxy.get_flows()), 0)

async def test_socks_proxy_connection_fails(self):
"""Client fails to connect to the SOCKS5 proxy."""
from python_socks import ProxyConnectionError as SocksProxyConnectionError

with self.assertRaises(OSError) as raised:
async with connect(
"ws://example.com/",
proxy="socks5h://localhost:51080", # nothing at this address
):
self.fail("did not raise")
# Don't test str(raised.exception) because we don't control it.
self.assertIsInstance(raised.exception, SocksProxyConnectionError)

async def test_socks_proxy_connection_timeout(self):
"""Client times out while connecting to the SOCKS5 proxy."""
# Replace the proxy with a TCP server that does't respond.
with socket.create_server(("localhost", 0)) as sock:
host, port = sock.getsockname()
with self.assertRaises(TimeoutError) as raised:
async with connect(
"ws://example.com/",
proxy=f"socks5h://{host}:{port}/",
open_timeout=MS,
):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"timed out during handshake",
)

async def test_explicit_proxy(self):
"""Client connects to server through a proxy set explicitly."""
async with async_proxy(mode=["socks5"]) as proxy:
async with async_proxy(mode=["socks5@51080"]) as proxy:
async with serve(*args) as server:
async with connect(
get_uri(server),
# Take this opportunity to test socks5 instead of socks5h.
proxy="socks5://localhost:1080",
proxy="socks5://localhost:51080",
) as client:
self.assertEqual(client.protocol.state.name, "OPEN")
self.assertEqual(len(proxy.get_flows()), 1)
Expand All @@ -626,13 +669,13 @@ async def test_ignore_proxy_with_existing_socket(self):

async def test_unsupported_proxy(self):
"""Client connects to server through an unsupported proxy."""
with patch_environ({"ws_proxy": "other://localhost:1080"}):
with patch_environ({"ws_proxy": "other://localhost:51080"}):
with self.assertRaises(InvalidProxy) as raised:
async with connect("ws://example.com/"):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"other://localhost:1080 isn't a valid proxy: scheme other isn't supported",
"other://localhost:51080 isn't a valid proxy: scheme other isn't supported",
)


Expand Down
6 changes: 3 additions & 3 deletions tests/asyncio/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ async def test_connection_handler_raises_exception(self):
async def test_existing_socket(self):
"""Server receives connection using a pre-existing socket."""
with socket.create_server(("localhost", 0)) as sock:
async with serve(handler, sock=sock, host=None, port=None):
uri = "ws://{}:{}/".format(*sock.getsockname())
async with connect(uri) as client:
host, port = sock.getsockname()
async with serve(handler, sock=sock):
async with connect(f"ws://{host}:{port}/") as client:
await self.assertEval(client, "ws.protocol.state.name", "OPEN")

async def test_select_subprotocol(self):
Expand Down
Loading

0 comments on commit 3a725f0

Please sign in to comment.