diff --git a/src/easynetwork_asyncio/datagram/endpoint.py b/src/easynetwork_asyncio/datagram/endpoint.py index 788f51aa..d4d78ad8 100644 --- a/src/easynetwork_asyncio/datagram/endpoint.py +++ b/src/easynetwork_asyncio/datagram/endpoint.py @@ -90,8 +90,6 @@ def __init__( self.__transport: asyncio.DatagramTransport = transport self.__protocol: DatagramEndpointProtocol = protocol - _monkeypatch_transport(transport, protocol._get_loop()) - def close(self) -> None: self.__transport.close() @@ -175,7 +173,7 @@ def __init__( self.__loop: asyncio.AbstractEventLoop = loop self.__recv_queue: asyncio.Queue[tuple[bytes, tuple[Any, ...]] | None] = recv_queue self.__exception_queue: asyncio.Queue[Exception] = exception_queue - self.__transport: asyncio.BaseTransport | None = None + self.__transport: asyncio.DatagramTransport | None = None self.__closed: asyncio.Future[None] = loop.create_future() self.__drain_waiters: collections.deque[asyncio.Future[None]] = collections.deque() self.__write_paused: bool = False @@ -191,10 +189,11 @@ def __del__(self) -> None: # pragma: no cover if closed.done() and not closed.cancelled(): closed.exception() - def connection_made(self, transport: asyncio.BaseTransport) -> None: + def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore[override] assert self.__transport is None, "Transport already set" # nosec assert_used self.__transport = transport self.__connection_lost = False + _monkeypatch_transport(transport, self.__loop) def connection_lost(self, exc: Exception | None) -> None: self.__connection_lost = True diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_datagram.py b/tests/unit_test/test_async/test_asyncio_backend/test_datagram.py index 651050a2..7aa875ab 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_datagram.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_datagram.py @@ -102,10 +102,6 @@ def mock_asyncio_protocol(mocker: MockerFixture, event_loop: asyncio.AbstractEve def mock_asyncio_transport(mocker: MockerFixture) -> MagicMock: mock = mocker.NonCallableMagicMock(spec=asyncio.DatagramTransport) mock.is_closing.return_value = False - - # Tell connection_made() not to try to monkeypatch this mock object - del mock._address - return mock @staticmethod @@ -418,6 +414,10 @@ class TestDatagramEndpointProtocol: def mock_asyncio_transport(mocker: MockerFixture) -> MagicMock: mock = mocker.NonCallableMagicMock(spec=asyncio.DatagramTransport) mock.is_closing.return_value = False + + # Tell connection_made() not to try to monkeypatch this mock object + del mock._address + return mock @staticmethod