Skip to content

Commit

Permalink
chore: improve typing
Browse files Browse the repository at this point in the history
  • Loading branch information
ntamas committed Sep 4, 2024
1 parent 0e01364 commit 26b15f1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
20 changes: 14 additions & 6 deletions src/flockwave/channels/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
encodes them and writes them to a WritableConnection_.
"""

from __future__ import annotations

from collections import deque
from contextlib import asynccontextmanager
from logging import Logger
from trio import EndOfChannel
from trio.abc import Channel
from typing import Optional, Union, TYPE_CHECKING
from typing import Generic, Optional, Union, TYPE_CHECKING

from ..connections import Connection
from flockwave.connections.base import RWConnection

from .types import Encoder, MessageType, Parser, RawType, RPCRequestHandler

Expand All @@ -21,14 +23,20 @@
from tinyrpc.protocols import RPCProtocol


class MessageChannel(Channel[MessageType]):
class MessageChannel(Generic[MessageType, RawType], Channel[MessageType]):
"""Trio-style Channel_ that wraps a readable-writable connection and
uses a parser to decode the messages read from the connection and an
encoder to encode the messages to the wire format of the connection.
"""

_connection: RWConnection[RawType, RawType]
_encoder: Encoder[MessageType, RawType]
_parser: Parser[RawType, MessageType]

@classmethod
def for_rpc_protocol(cls, protocol: "RPCProtocol", connection: Connection):
def for_rpc_protocol(
cls, protocol: RPCProtocol, connection: RWConnection[RawType, RawType]
):
"""Helper method to construct a message channel that will send and
receive messages using the given RPC protocol.
Expand All @@ -44,13 +52,13 @@ def for_rpc_protocol(cls, protocol: "RPCProtocol", connection: Connection):
encoder = create_rpc_encoder(protocol=protocol)

# Wrap the parser and the encoder in a MessageChannel
result = cls(connection, parser=parser, encoder=encoder)
result = cls(connection, parser=parser, encoder=encoder) # type: ignore
result._protocol = protocol
return result

def __init__(
self,
connection: Connection,
connection: RWConnection[RawType, RawType],
parser: Parser[RawType, MessageType],
encoder: Encoder[MessageType, RawType],
):
Expand Down
16 changes: 15 additions & 1 deletion src/flockwave/channels/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,24 @@
if TYPE_CHECKING:
from tinyrpc.protocols import RPCRequest, RPCResponse

__all__ = ("Encoder", "MessageType", "Parser", "RawType", "Reader", "Writer")
__all__ = (
"Encoder",
"MessageType",
"Parser",
"RawType",
"Reader",
"Writer",
)

RawType = TypeVar("RawType")
"""Type variable that is used to indicate the raw (encoded) type of a message
on a MessageChannel.
"""

MessageType = TypeVar("MessageType")
"""Type variable that is used to indicate the decoded, object-like type of a
message on a MessageChannel.
"""

Reader = Union[Callable[[], Awaitable[RawType]], ReadableConnection[RawType]]
Writer = Union[Callable[[RawType], None], WritableConnection[RawType]]
Expand Down

0 comments on commit 26b15f1

Please sign in to comment.