Skip to content

Commit

Permalink
Simplify untangling of import cycles
Browse files Browse the repository at this point in the history
Three imports at the bottom, related to type annotations, is a lesser
evil compared to dozens of module-prefixed identifiers, a departure
from the coding style of this library.

Refs #989.
  • Loading branch information
aaugustin committed Aug 23, 2024
1 parent 4e17142 commit 2a9dfb5
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 78 deletions.
4 changes: 3 additions & 1 deletion src/websockets/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import typing
import warnings

from . import frames, http11
from .imports import lazy_import


Expand Down Expand Up @@ -376,3 +375,6 @@ class InvalidState(WebSocketException, AssertionError):
"WebSocketProtocolError": ".legacy.exceptions",
},
)

# At the bottom to break import cycles created by type annotations.
from . import frames, http11 # noqa: E402
11 changes: 3 additions & 8 deletions src/websockets/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Sequence

from .. import frames
from ..frames import Frame
from ..typing import ExtensionName, ExtensionParameter


Expand All @@ -18,12 +18,7 @@ class Extension:
name: ExtensionName
"""Extension identifier."""

def decode(
self,
frame: frames.Frame,
*,
max_size: int | None = None,
) -> frames.Frame:
def decode(self, frame: Frame, *, max_size: int | None = None) -> Frame:
"""
Decode an incoming frame.
Expand All @@ -40,7 +35,7 @@ def decode(
"""
raise NotImplementedError

def encode(self, frame: frames.Frame) -> frames.Frame:
def encode(self, frame: Frame) -> Frame:
"""
Encode an outgoing frame.
Expand Down
48 changes: 28 additions & 20 deletions src/websockets/extensions/permessage_deflate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
import zlib
from typing import Any, Sequence

from .. import exceptions, frames
from .. import frames
from ..exceptions import (
DuplicateParameter,
InvalidParameterName,
InvalidParameterValue,
NegotiationError,
PayloadTooBig,
ProtocolError,
)
from ..typing import ExtensionName, ExtensionParameter
from .base import ClientExtensionFactory, Extension, ServerExtensionFactory

Expand Down Expand Up @@ -129,9 +137,9 @@ def decode(
try:
data = self.decoder.decompress(data, max_length)
except zlib.error as exc:
raise exceptions.ProtocolError("decompression failed") from exc
raise ProtocolError("decompression failed") from exc
if self.decoder.unconsumed_tail:
raise exceptions.PayloadTooBig(f"over size limit (? > {max_size} bytes)")
raise PayloadTooBig(f"over size limit (? > {max_size} bytes)")

# Allow garbage collection of the decoder if it won't be reused.
if frame.fin and self.remote_no_context_takeover:
Expand Down Expand Up @@ -215,40 +223,40 @@ def _extract_parameters(
for name, value in params:
if name == "server_no_context_takeover":
if server_no_context_takeover:
raise exceptions.DuplicateParameter(name)
raise DuplicateParameter(name)
if value is None:
server_no_context_takeover = True
else:
raise exceptions.InvalidParameterValue(name, value)
raise InvalidParameterValue(name, value)

elif name == "client_no_context_takeover":
if client_no_context_takeover:
raise exceptions.DuplicateParameter(name)
raise DuplicateParameter(name)
if value is None:
client_no_context_takeover = True
else:
raise exceptions.InvalidParameterValue(name, value)
raise InvalidParameterValue(name, value)

elif name == "server_max_window_bits":
if server_max_window_bits is not None:
raise exceptions.DuplicateParameter(name)
raise DuplicateParameter(name)
if value in _MAX_WINDOW_BITS_VALUES:
server_max_window_bits = int(value)
else:
raise exceptions.InvalidParameterValue(name, value)
raise InvalidParameterValue(name, value)

elif name == "client_max_window_bits":
if client_max_window_bits is not None:
raise exceptions.DuplicateParameter(name)
raise DuplicateParameter(name)
if is_server and value is None: # only in handshake requests
client_max_window_bits = True
elif value in _MAX_WINDOW_BITS_VALUES:
client_max_window_bits = int(value)
else:
raise exceptions.InvalidParameterValue(name, value)
raise InvalidParameterValue(name, value)

else:
raise exceptions.InvalidParameterName(name)
raise InvalidParameterName(name)

return (
server_no_context_takeover,
Expand Down Expand Up @@ -340,7 +348,7 @@ def process_response_params(
"""
if any(other.name == self.name for other in accepted_extensions):
raise exceptions.NegotiationError(f"received duplicate {self.name}")
raise NegotiationError(f"received duplicate {self.name}")

# Request parameters are available in instance variables.

Expand All @@ -366,7 +374,7 @@ def process_response_params(

if self.server_no_context_takeover:
if not server_no_context_takeover:
raise exceptions.NegotiationError("expected server_no_context_takeover")
raise NegotiationError("expected server_no_context_takeover")

# client_no_context_takeover
#
Expand Down Expand Up @@ -396,9 +404,9 @@ def process_response_params(

else:
if server_max_window_bits is None:
raise exceptions.NegotiationError("expected server_max_window_bits")
raise NegotiationError("expected server_max_window_bits")
elif server_max_window_bits > self.server_max_window_bits:
raise exceptions.NegotiationError("unsupported server_max_window_bits")
raise NegotiationError("unsupported server_max_window_bits")

# client_max_window_bits

Expand All @@ -414,7 +422,7 @@ def process_response_params(

if self.client_max_window_bits is None:
if client_max_window_bits is not None:
raise exceptions.NegotiationError("unexpected client_max_window_bits")
raise NegotiationError("unexpected client_max_window_bits")

elif self.client_max_window_bits is True:
pass
Expand All @@ -423,7 +431,7 @@ def process_response_params(
if client_max_window_bits is None:
client_max_window_bits = self.client_max_window_bits
elif client_max_window_bits > self.client_max_window_bits:
raise exceptions.NegotiationError("unsupported client_max_window_bits")
raise NegotiationError("unsupported client_max_window_bits")

return PerMessageDeflate(
server_no_context_takeover, # remote_no_context_takeover
Expand Down Expand Up @@ -534,7 +542,7 @@ def process_request_params(
"""
if any(other.name == self.name for other in accepted_extensions):
raise exceptions.NegotiationError(f"skipped duplicate {self.name}")
raise NegotiationError(f"skipped duplicate {self.name}")

# Load request parameters in local variables.
(
Expand Down Expand Up @@ -613,7 +621,7 @@ def process_request_params(
else:
if client_max_window_bits is None:
if self.require_client_max_window_bits:
raise exceptions.NegotiationError("required client_max_window_bits")
raise NegotiationError("required client_max_window_bits")
elif client_max_window_bits is True:
client_max_window_bits = self.client_max_window_bits
elif self.client_max_window_bits < client_max_window_bits:
Expand Down
24 changes: 13 additions & 11 deletions src/websockets/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import struct
from typing import Callable, Generator, Sequence

from . import exceptions, extensions
from .exceptions import PayloadTooBig, ProtocolError


try:
Expand Down Expand Up @@ -239,10 +239,10 @@ def parse(
try:
opcode = Opcode(head1 & 0b00001111)
except ValueError as exc:
raise exceptions.ProtocolError("invalid opcode") from exc
raise ProtocolError("invalid opcode") from exc

if (True if head2 & 0b10000000 else False) != mask:
raise exceptions.ProtocolError("incorrect masking")
raise ProtocolError("incorrect masking")

length = head2 & 0b01111111
if length == 126:
Expand All @@ -252,9 +252,7 @@ def parse(
data = yield from read_exact(8)
(length,) = struct.unpack("!Q", data)
if max_size is not None and length > max_size:
raise exceptions.PayloadTooBig(
f"over size limit ({length} > {max_size} bytes)"
)
raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)")
if mask:
mask_bytes = yield from read_exact(4)

Expand Down Expand Up @@ -342,13 +340,13 @@ def check(self) -> None:
"""
if self.rsv1 or self.rsv2 or self.rsv3:
raise exceptions.ProtocolError("reserved bits must be 0")
raise ProtocolError("reserved bits must be 0")

if self.opcode in CTRL_OPCODES:
if len(self.data) > 125:
raise exceptions.ProtocolError("control frame too long")
raise ProtocolError("control frame too long")
if not self.fin:
raise exceptions.ProtocolError("fragmented control frame")
raise ProtocolError("fragmented control frame")


@dataclasses.dataclass
Expand Down Expand Up @@ -405,7 +403,7 @@ def parse(cls, data: bytes) -> Close:
elif len(data) == 0:
return cls(CloseCode.NO_STATUS_RCVD, "")
else:
raise exceptions.ProtocolError("close frame too short")
raise ProtocolError("close frame too short")

def serialize(self) -> bytes:
"""
Expand All @@ -424,4 +422,8 @@ def check(self) -> None:
"""
if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000):
raise exceptions.ProtocolError("invalid status code")
raise ProtocolError("invalid status code")


# At the bottom to break import cycles created by type annotations.
from . import extensions # noqa: E402
32 changes: 12 additions & 20 deletions src/websockets/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
from typing import Callable, Sequence, TypeVar, cast

from . import exceptions
from .exceptions import InvalidHeaderFormat, InvalidHeaderValue
from .typing import (
ConnectionOption,
ExtensionHeader,
Expand Down Expand Up @@ -108,7 +108,7 @@ def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]:
"""
match = _token_re.match(header, pos)
if match is None:
raise exceptions.InvalidHeaderFormat(header_name, "expected token", header, pos)
raise InvalidHeaderFormat(header_name, "expected token", header, pos)
return match.group(), match.end()


Expand All @@ -132,9 +132,7 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, i
"""
match = _quoted_string_re.match(header, pos)
if match is None:
raise exceptions.InvalidHeaderFormat(
header_name, "expected quoted string", header, pos
)
raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos)
return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end()


Expand Down Expand Up @@ -206,9 +204,7 @@ def parse_list(
if peek_ahead(header, pos) == ",":
pos = parse_OWS(header, pos + 1)
else:
raise exceptions.InvalidHeaderFormat(
header_name, "expected comma", header, pos
)
raise InvalidHeaderFormat(header_name, "expected comma", header, pos)

# Remove extra delimiters before the next item.
while peek_ahead(header, pos) == ",":
Expand Down Expand Up @@ -276,9 +272,7 @@ def parse_upgrade_protocol(
"""
match = _protocol_re.match(header, pos)
if match is None:
raise exceptions.InvalidHeaderFormat(
header_name, "expected protocol", header, pos
)
raise InvalidHeaderFormat(header_name, "expected protocol", header, pos)
return cast(UpgradeProtocol, match.group()), match.end()


Expand Down Expand Up @@ -324,7 +318,7 @@ def parse_extension_item_param(
# the value after quoted-string unescaping MUST conform to
# the 'token' ABNF.
if _token_re.fullmatch(value) is None:
raise exceptions.InvalidHeaderFormat(
raise InvalidHeaderFormat(
header_name, "invalid quoted header content", header, pos_before
)
else:
Expand Down Expand Up @@ -510,9 +504,7 @@ def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]:
"""
match = _token68_re.match(header, pos)
if match is None:
raise exceptions.InvalidHeaderFormat(
header_name, "expected token68", header, pos
)
raise InvalidHeaderFormat(header_name, "expected token68", header, pos)
return match.group(), match.end()


Expand All @@ -522,7 +514,7 @@ def parse_end(header: str, pos: int, header_name: str) -> None:
"""
if pos < len(header):
raise exceptions.InvalidHeaderFormat(header_name, "trailing data", header, pos)
raise InvalidHeaderFormat(header_name, "trailing data", header, pos)


def parse_authorization_basic(header: str) -> tuple[str, str]:
Expand All @@ -543,12 +535,12 @@ def parse_authorization_basic(header: str) -> tuple[str, str]:
# https://datatracker.ietf.org/doc/html/rfc7617#section-2
scheme, pos = parse_token(header, 0, "Authorization")
if scheme.lower() != "basic":
raise exceptions.InvalidHeaderValue(
raise InvalidHeaderValue(
"Authorization",
f"unsupported scheme: {scheme}",
)
if peek_ahead(header, pos) != " ":
raise exceptions.InvalidHeaderFormat(
raise InvalidHeaderFormat(
"Authorization", "expected space after scheme", header, pos
)
pos += 1
Expand All @@ -558,14 +550,14 @@ def parse_authorization_basic(header: str) -> tuple[str, str]:
try:
user_pass = base64.b64decode(basic_credentials.encode()).decode()
except binascii.Error:
raise exceptions.InvalidHeaderValue(
raise InvalidHeaderValue(
"Authorization",
"expected base64-encoded credentials",
) from None
try:
username, password = user_pass.split(":", 1)
except ValueError:
raise exceptions.InvalidHeaderValue(
raise InvalidHeaderValue(
"Authorization",
"expected username:password credentials",
) from None
Expand Down
Loading

0 comments on commit 2a9dfb5

Please sign in to comment.