Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve socks proxy #1585

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/reference/exceptions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@ also reported by :func:`~websockets.asyncio.server.serve` in logs.

.. autoexception:: SecurityError

.. autoexception:: InvalidMessage

.. autoexception:: InvalidStatus
.. autoexception:: ProxyError

.. autoexception:: InvalidProxyMessage

.. autoexception:: InvalidProxyStatus

.. autoexception:: InvalidMessage

.. autoexception:: InvalidStatus

.. autoexception:: InvalidHeader

.. autoexception:: InvalidHeaderFormat
Expand Down
3 changes: 3 additions & 0 deletions src/websockets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"NegotiationError",
"PayloadTooBig",
"ProtocolError",
"ProxyError",
"SecurityError",
"WebSocketException",
# .frames
Expand Down Expand Up @@ -112,6 +113,7 @@
NegotiationError,
PayloadTooBig,
ProtocolError,
ProxyError,
SecurityError,
WebSocketException,
)
Expand Down Expand Up @@ -173,6 +175,7 @@
"NegotiationError": ".exceptions",
"PayloadTooBig": ".exceptions",
"ProtocolError": ".exceptions",
"ProxyError": ".exceptions",
"SecurityError": ".exceptions",
"WebSocketException": ".exceptions",
# .frames
Expand Down
124 changes: 78 additions & 46 deletions src/websockets/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import logging
import os
import socket
import traceback
import urllib.parse
from collections.abc import AsyncIterator, Generator, Sequence
Expand All @@ -11,7 +12,7 @@

from ..client import ClientProtocol, backoff
from ..datastructures import HeadersLike
from ..exceptions import InvalidMessage, InvalidStatus, SecurityError
from ..exceptions import InvalidMessage, InvalidStatus, ProxyError, SecurityError
from ..extensions.base import ClientExtensionFactory
from ..extensions.permessage_deflate import enable_client_permessage_deflate
from ..headers import validate_subprotocols
Expand Down Expand Up @@ -147,7 +148,9 @@ def process_exception(exc: Exception) -> Exception | None:
That exception will be raised, breaking out of the retry loop.

"""
if isinstance(exc, (OSError, asyncio.TimeoutError)):
# This catches python-socks' ProxyConnectionError and ProxyTimeoutError.
# Remove asyncio.TimeoutError when dropping Python < 3.11.
if isinstance(exc, (OSError, TimeoutError, asyncio.TimeoutError)):
return None
if isinstance(exc, InvalidMessage) and isinstance(exc.__cause__, EOFError):
return None
Expand Down Expand Up @@ -265,6 +268,7 @@ class connect:

Raises:
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
InvalidProxy: If ``proxy`` isn't a valid proxy.
OSError: If the TCP connection fails.
InvalidHandshake: If the opening handshake fails.
TimeoutError: If the opening handshake times out.
Expand Down Expand Up @@ -357,15 +361,12 @@ async def create_connection(self) -> ClientConnection:
ws_uri = parse_uri(self.uri)

proxy = self.proxy
proxy_uri: Proxy | None = None
if kwargs.get("unix", False):
proxy = None
if kwargs.get("sock") is not None:
proxy = None
if proxy is True:
proxy = get_proxy(ws_uri)
if proxy is not None:
proxy_uri = parse_proxy(proxy)

def factory() -> ClientConnection:
return self.protocol_factory(ws_uri)
Expand All @@ -381,48 +382,14 @@ def factory() -> ClientConnection:

if kwargs.pop("unix", False):
_, connection = await loop.create_unix_connection(factory, **kwargs)
elif proxy is not None:
kwargs["sock"] = await connect_proxy(
parse_proxy(proxy),
ws_uri,
local_addr=kwargs.pop("local_addr", None),
)
_, connection = await loop.create_connection(factory, **kwargs)
else:
if proxy_uri is not None:
if proxy_uri.scheme[:5] == "socks":
try:
from python_socks import ProxyType
from python_socks.async_.asyncio import Proxy
except ImportError:
raise ImportError(
"python-socks is required to use a SOCKS proxy"
)
if proxy_uri.scheme == "socks5h":
proxy_type = ProxyType.SOCKS5
rdns = True
elif proxy_uri.scheme == "socks5":
proxy_type = ProxyType.SOCKS5
rdns = False
# We use mitmproxy for testing and it doesn't support SOCKS4.
elif proxy_uri.scheme == "socks4a": # pragma: no cover
proxy_type = ProxyType.SOCKS4
rdns = True
elif proxy_uri.scheme == "socks4": # pragma: no cover
proxy_type = ProxyType.SOCKS4
rdns = False
# Proxy types are enforced in parse_proxy().
else:
raise AssertionError("unsupported SOCKS proxy")
socks_proxy = Proxy(
proxy_type,
proxy_uri.host,
proxy_uri.port,
proxy_uri.username,
proxy_uri.password,
rdns,
)
kwargs["sock"] = await socks_proxy.connect(
ws_uri.host,
ws_uri.port,
local_addr=kwargs.pop("local_addr", None),
)
# Proxy types are enforced in parse_proxy().
else:
raise AssertionError("unsupported proxy")
if kwargs.get("sock") is None:
kwargs.setdefault("host", ws_uri.host)
kwargs.setdefault("port", ws_uri.port)
Expand Down Expand Up @@ -624,3 +591,68 @@ def unix_connect(
else:
uri = "wss://localhost/"
return connect(uri=uri, unix=True, path=path, **kwargs)


try:
from python_socks import ProxyType
from python_socks.async_.asyncio import Proxy as SocksProxy

SOCKS_PROXY_TYPES = {
"socks5h": ProxyType.SOCKS5,
"socks5": ProxyType.SOCKS5,
"socks4a": ProxyType.SOCKS4,
"socks4": ProxyType.SOCKS4,
}

SOCKS_PROXY_RDNS = {
"socks5h": True,
"socks5": False,
"socks4a": True,
"socks4": False,
}

async def connect_socks_proxy(
proxy: Proxy,
ws_uri: WebSocketURI,
**kwargs: Any,
) -> socket.socket:
"""Connect via a SOCKS proxy and return the socket."""
socks_proxy = SocksProxy(
SOCKS_PROXY_TYPES[proxy.scheme],
proxy.host,
proxy.port,
proxy.username,
proxy.password,
SOCKS_PROXY_RDNS[proxy.scheme],
)
# connect() is documented to raise OSError.
# socks_proxy.connect() doesn't raise TimeoutError; it gets canceled.
# Wrap other exceptions in ProxyError, a subclass of InvalidHandshake.
try:
return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs)
except OSError:
raise
except Exception as exc:
raise ProxyError("failed to connect to SOCKS proxy") from exc

except ImportError:

async def connect_socks_proxy(
proxy: Proxy,
ws_uri: WebSocketURI,
**kwargs: Any,
) -> socket.socket:
raise ImportError("python-socks is required to use a SOCKS proxy")


async def connect_proxy(
proxy: Proxy,
ws_uri: WebSocketURI,
**kwargs: Any,
) -> socket.socket:
"""Connect via a proxy and return the socket."""
# parse_proxy() validates proxy.scheme.
if proxy.scheme[:5] == "socks":
return await connect_socks_proxy(proxy, ws_uri, **kwargs)
else:
raise AssertionError("unsupported proxy")
41 changes: 25 additions & 16 deletions src/websockets/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
* :exc:`InvalidProxy`
* :exc:`InvalidHandshake`
* :exc:`SecurityError`
* :exc:`ProxyError`
* :exc:`InvalidProxyMessage`
* :exc:`InvalidProxyStatus`
* :exc:`InvalidMessage`
* :exc:`InvalidStatus`
* :exc:`InvalidStatusCode` (legacy)
* :exc:`InvalidProxyMessage`
* :exc:`InvalidProxyStatus`
* :exc:`InvalidHeader`
* :exc:`InvalidHeaderFormat`
* :exc:`InvalidHeaderValue`
Expand Down Expand Up @@ -48,10 +49,11 @@
"InvalidProxy",
"InvalidHandshake",
"SecurityError",
"InvalidMessage",
"InvalidStatus",
"ProxyError",
"InvalidProxyMessage",
"InvalidProxyStatus",
"InvalidMessage",
"InvalidStatus",
"InvalidHeader",
"InvalidHeaderFormat",
"InvalidHeaderValue",
Expand Down Expand Up @@ -206,46 +208,53 @@ class SecurityError(InvalidHandshake):
"""


class InvalidMessage(InvalidHandshake):
class ProxyError(InvalidHandshake):
"""
Raised when a handshake request or response is malformed.
Raised when failing to connect to a proxy.

"""


class InvalidStatus(InvalidHandshake):
class InvalidProxyMessage(ProxyError):
"""
Raised when a handshake response rejects the WebSocket upgrade.
Raised when an HTTP proxy response is malformed.

"""


class InvalidProxyStatus(ProxyError):
"""
Raised when an HTTP proxy rejects the connection.

"""

def __init__(self, response: http11.Response) -> None:
self.response = response

def __str__(self) -> str:
return (
f"server rejected WebSocket connection: HTTP {self.response.status_code:d}"
)
return f"proxy rejected connection: HTTP {self.response.status_code:d}"


class InvalidProxyMessage(InvalidHandshake):
class InvalidMessage(InvalidHandshake):
"""
Raised when a proxy response is malformed.
Raised when a handshake request or response is malformed.

"""


class InvalidProxyStatus(InvalidHandshake):
class InvalidStatus(InvalidHandshake):
"""
Raised when a proxy rejects the connection.
Raised when a handshake response rejects the WebSocket upgrade.

"""

def __init__(self, response: http11.Response) -> None:
self.response = response

def __str__(self) -> str:
return f"proxy rejected connection: HTTP {self.response.status_code:d}"
return (
f"server rejected WebSocket connection: HTTP {self.response.status_code:d}"
)


class InvalidHeader(InvalidHandshake):
Expand Down
Loading
Loading