diff --git a/bellows/ash.py b/bellows/ash.py index 3c2a4b4f..66c0e516 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -3,11 +3,13 @@ import abc import asyncio import binascii +from collections.abc import Coroutine import dataclasses import enum import logging import sys import time +import typing if sys.version_info[:2] < (3, 11): from async_timeout import timeout as asyncio_timeout # pragma: no cover @@ -55,7 +57,7 @@ class Reserved(enum.IntEnum): # Maximum number of DATA frames the NCP can transmit without having received # acknowledgements -TX_K = 1 +TX_K = 1 # TODO: investigate why this cannot be raised without causing a firmware crash # Maximum number of consecutive timeouts allowed while waiting to receive an ACK before # going to the FAILED state. The value 0 prevents the NCP from entering the error state @@ -81,6 +83,23 @@ def generate_random_sequence(length: int) -> bytes: # Since the sequence is static for every frame, we only need to generate it once PSEUDO_RANDOM_DATA_SEQUENCE = generate_random_sequence(256) +if sys.version_info[:2] < (3, 12): + create_eager_task = asyncio.create_task +else: + _T = typing.TypeVar("T") + + def create_eager_task( + coro: Coroutine[typing.Any, typing.Any, _T], + *, + name: str | None = None, + loop: asyncio.AbstractEventLoop | None = None, + ) -> asyncio.Task[_T]: + """Create a task from a coroutine and schedule it to run immediately.""" + if loop is None: + loop = asyncio.get_running_loop() + + return asyncio.Task(coro, loop=loop, name=name, eager_start=True) + class NcpState(enum.Enum): CONNECTED = "connected" @@ -463,15 +482,14 @@ def data_received(self, data: bytes) -> None: def _handle_ack(self, frame: DataFrame | AckFrame) -> None: # Note that ackNum is the number of the next frame the receiver expects and it # is one greater than the last frame received. - ack_num = (frame.ack_num - 1) % 8 + for ack_num_offset in range(-TX_K, 0): + ack_num = (frame.ack_num + ack_num_offset) % 8 + fut = self._pending_data_frames.get(ack_num) - fut = self._pending_data_frames.get(ack_num) + if fut is None or fut.done(): + continue - if fut is None or fut.done(): - return - - # _LOGGER.debug("Resolving frame %d", ack_num) - self._pending_data_frames[ack_num].set_result(True) + self._pending_data_frames[ack_num].set_result(True) def frame_received(self, frame: AshFrame) -> None: _LOGGER.debug("Received frame %r", frame) @@ -537,13 +555,16 @@ def error_frame_received(self, frame: ErrorFrame) -> None: self._ncp_state = NcpState.FAILED # Cancel all pending requests - exc = NcpFailure(code=self._ncp_reset_code) + self._enter_failed_state(self._ncp_reset_code) + + def _enter_failed_state(self, reset_code: t.NcpResetCode) -> None: + exc = NcpFailure(code=reset_code) for fut in self._pending_data_frames.values(): if not fut.done(): fut.set_exception(exc) - self._ezsp_protocol.reset_received(frame.reset_code) + self._ezsp_protocol.reset_received(reset_code) def _write_frame( self, @@ -582,7 +603,7 @@ async def _send_data_frame(self, frame: AshFrame) -> None: for attempt in range(ACK_TIMEOUTS): if self._ncp_state == NcpState.FAILED: _LOGGER.debug( - "NCP is in a failed state, not re-sending: %r", frame + "NCP is in a failed state, not sending: %r", frame ) raise NcpFailure( t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT @@ -618,6 +639,9 @@ async def _send_data_frame(self, frame: AshFrame) -> None: self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) if attempt >= ACK_TIMEOUTS - 1: + self._enter_failed_state( + t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT + ) raise except NcpFailure: _LOGGER.debug( @@ -635,6 +659,9 @@ async def _send_data_frame(self, frame: AshFrame) -> None: self._change_ack_timeout(2 * self._t_rx_ack) if attempt >= ACK_TIMEOUTS - 1: + self._enter_failed_state( + t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT + ) raise else: # Whenever an acknowledgement is received, t_rx_ack is set to @@ -649,9 +676,14 @@ async def _send_data_frame(self, frame: AshFrame) -> None: self._pending_data_frames.pop(frm_num) async def send_data(self, data: bytes) -> None: - await self._send_data_frame( - # All of the other fields will be set during transmission/retries - DataFrame(frm_num=None, re_tx=None, ack_num=None, ezsp_frame=data) + # Sending data is a critical operation and cannot really be cancelled + await asyncio.shield( + create_eager_task( + self._send_data_frame( + # All of the other fields will be set during transmission/retries + DataFrame(frm_num=None, re_tx=None, ack_num=None, ezsp_frame=data) + ) + ) ) def send_reset(self) -> None: