diff --git a/api/src/opentrons/protocol_runner/legacy_context_plugin.py b/api/src/opentrons/protocol_runner/legacy_context_plugin.py index 4e23038de4f..baf6ccbc716 100644 --- a/api/src/opentrons/protocol_runner/legacy_context_plugin.py +++ b/api/src/opentrons/protocol_runner/legacy_context_plugin.py @@ -1,9 +1,9 @@ """Customize the ProtocolEngine to monitor and control legacy (APIv2) protocols.""" from __future__ import annotations -import asyncio +from asyncio import create_task, Task from contextlib import ExitStack -from typing import Optional +from typing import List, Optional from opentrons.legacy_commands.types import CommandMessage as LegacyCommand from opentrons.legacy_broker import LegacyBroker @@ -12,6 +12,7 @@ from opentrons.util.broker import ReadOnlyBroker from .legacy_command_mapper import LegacyCommandMapper +from .thread_async_queue import ThreadAsyncQueue class LegacyContextPlugin(AbstractPlugin): @@ -20,36 +21,59 @@ class LegacyContextPlugin(AbstractPlugin): In the legacy ProtocolContext, protocol execution is accomplished by direct communication with the HardwareControlAPI, as opposed to an intermediate layer like the ProtocolEngine. This plugin wraps up - and hides this behavior, so the ProtocolEngine can monitor + and hides this behavior, so the ProtocolEngine can monitor and control the run of a legacy protocol without affecting the execution of the protocol commands themselves. - This plugin allows a ProtocolEngine to subscribe to what is being done with the - legacy ProtocolContext, and insert matching commands into ProtocolEngine state for - purely progress-tracking purposes. + This plugin allows a ProtocolEngine to: + + 1. Play/pause the protocol run using the HardwareControlAPI, as was done before + the ProtocolEngine existed. + 2. Subscribe to what is being done with the legacy ProtocolContext, + and insert matching commands into ProtocolEngine state for + purely progress-tracking purposes. """ def __init__( self, - engine_loop: asyncio.AbstractEventLoop, broker: LegacyBroker, equipment_broker: ReadOnlyBroker[LoadInfo], legacy_command_mapper: Optional[LegacyCommandMapper] = None, ) -> None: """Initialize the plugin with its dependencies.""" - self._engine_loop = engine_loop - self._broker = broker self._equipment_broker = equipment_broker self._legacy_command_mapper = legacy_command_mapper or LegacyCommandMapper() + # We use a non-blocking queue to communicate activity + # from the APIv2 protocol, which is running in its own thread, + # to the ProtocolEngine, which is running in the main thread's async event loop. + # + # The queue being non-blocking lets the protocol communicate its activity + # instantly *even if the event loop is currently occupied by something else.* + # Various things can accidentally occupy the event loop for too long. + # So if the protocol had to wait for the event loop to be free + # every time it reported some activity, + # it could visibly stall for a moment, making its motion jittery. + # + # TODO(mm, 2024-03-22): See if we can remove this non-blockingness now. + # It was one of several band-aids introduced in ~v5.0.0 to mitigate performance + # problems. v6.3.0 started running some Python protocols directly through + # Protocol Engine, without this plugin, and without any non-blocking queue. + # If performance is sufficient for those, that probably means the + # performance problems have been resolved in better ways elsewhere + # and we don't need this anymore. + self._actions_to_dispatch = ThreadAsyncQueue[List[pe_actions.Action]]() + self._action_dispatching_task: Optional[Task[None]] = None + self._subscription_exit_stack: Optional[ExitStack] = None def setup(self) -> None: """Set up the plugin. - Subscribe to the APIv2 context's message brokers to be informed - of the APIv2 protocol's activity. + * Subscribe to the APIv2 context's message brokers to be informed + of the APIv2 protocol's activity. + * Kick off a background task to inform Protocol Engine of that activity. """ # Subscribe to activity on the APIv2 context, # and arrange to unsubscribe when this plugin is torn down. @@ -73,16 +97,24 @@ def setup(self) -> None: # to clean up these subscriptions. self._subscription_exit_stack = exit_stack.pop_all() - # todo(mm, 2024-08-21): This no longer needs to be async. + # Kick off a background task to report activity to the ProtocolEngine. + self._action_dispatching_task = create_task(self._dispatch_all_actions()) + async def teardown(self) -> None: """Tear down the plugin, undoing the work done in `setup()`. Called by Protocol Engine. At this point, the APIv2 protocol script must have exited. """ - if self._subscription_exit_stack is not None: - self._subscription_exit_stack.close() - self._subscription_exit_stack = None + self._actions_to_dispatch.done_putting() + try: + if self._action_dispatching_task is not None: + await self._action_dispatching_task + self._action_dispatching_task = None + finally: + if self._subscription_exit_stack is not None: + self._subscription_exit_stack.close() + self._subscription_exit_stack = None def handle_action(self, action: pe_actions.Action) -> None: """React to a ProtocolEngine action.""" @@ -95,10 +127,7 @@ def _handle_legacy_command(self, command: LegacyCommand) -> None: Used as a broker callback, so this will run in the APIv2 protocol's thread. """ pe_actions = self._legacy_command_mapper.map_command(command=command) - future = asyncio.run_coroutine_threadsafe( - self._dispatch_action_list(pe_actions), self._engine_loop - ) - future.result() + self._actions_to_dispatch.put(pe_actions) def _handle_equipment_loaded(self, load_info: LoadInfo) -> None: """Handle an equipment load reported by the legacy APIv2 protocol. @@ -106,11 +135,26 @@ def _handle_equipment_loaded(self, load_info: LoadInfo) -> None: Used as a broker callback, so this will run in the APIv2 protocol's thread. """ pe_actions = self._legacy_command_mapper.map_equipment_load(load_info=load_info) - future = asyncio.run_coroutine_threadsafe( - self._dispatch_action_list(pe_actions), self._engine_loop - ) - future.result() - - async def _dispatch_action_list(self, actions: list[pe_actions.Action]) -> None: - for action in actions: - self.dispatch(action) + self._actions_to_dispatch.put(pe_actions) + + async def _dispatch_all_actions(self) -> None: + """Dispatch all actions to the `ProtocolEngine`. + + Exits only when `self._actions_to_dispatch` is closed + (or an unexpected exception is raised). + """ + async for action_batch in self._actions_to_dispatch.get_async_until_closed(): + # It's critical that we dispatch this batch of actions as one atomic + # sequence, without yielding to the event loop. + # Although this plugin only means to use the ProtocolEngine as a way of + # passively exposing the protocol's progress, the ProtocolEngine is still + # theoretically active, which means it's constantly watching in the + # background to execute any commands that it finds `queued`. + # + # For example, one of these action batches will often want to + # instantaneously create a running command by having a queue action + # immediately followed by a run action. We cannot let the + # ProtocolEngine's background task see the command in the `queued` state, + # or it will try to execute it, which the legacy protocol is already doing. + for action in action_batch: + self.dispatch(action) diff --git a/api/src/opentrons/protocol_runner/protocol_runner.py b/api/src/opentrons/protocol_runner/protocol_runner.py index dcf4f224811..22c809bcde5 100644 --- a/api/src/opentrons/protocol_runner/protocol_runner.py +++ b/api/src/opentrons/protocol_runner/protocol_runner.py @@ -1,5 +1,4 @@ """Protocol run control and management.""" -import asyncio from typing import List, NamedTuple, Optional, Union from abc import ABC, abstractmethod @@ -221,9 +220,7 @@ async def load( equipment_broker = Broker[LoadInfo]() self._protocol_engine.add_plugin( LegacyContextPlugin( - engine_loop=asyncio.get_running_loop(), - broker=self._broker, - equipment_broker=equipment_broker, + broker=self._broker, equipment_broker=equipment_broker ) ) self._hardware_api.should_taskify_movement_execution(taskify=True) diff --git a/api/src/opentrons/protocol_runner/thread_async_queue.py b/api/src/opentrons/protocol_runner/thread_async_queue.py new file mode 100644 index 00000000000..6b31a3f4c5c --- /dev/null +++ b/api/src/opentrons/protocol_runner/thread_async_queue.py @@ -0,0 +1,174 @@ +"""Safely pass values between threads and async tasks.""" + + +from __future__ import annotations + +from collections import deque +from threading import Condition +from typing import AsyncIterable, Deque, Generic, Iterable, TypeVar + +from anyio.to_thread import run_sync + + +_T = TypeVar("_T") + + +class ThreadAsyncQueue(Generic[_T]): + """A queue to safely pass values of type `_T` between threads and async tasks. + + All methods are safe to call concurrently from any thread or task. + + Compared to queue.Queue: + + * This class lets you close the queue to signal that no more values will be added, + which makes common producer/consumer patterns easier. + (This is like Golang channels and AnyIO memory object streams.) + * This class has built-in support for async consumers. + + Compared to asyncio.Queue and AnyIO memory object streams: + + * You can use this class to communicate between async tasks and threads + without the threads having to wait for the event loop to be free + every time they access the queue. + """ + + def __init__(self) -> None: + """Initialize the queue.""" + self._is_closed = False + self._deque: Deque[_T] = deque() + self._condition = Condition() + + def put(self, value: _T) -> None: + """Add a value to the back of the queue. + + Returns immediately, without blocking. The queue can grow without bound. + + Raises: + QueueClosed: If the queue is already closed. + """ + with self._condition: + if self._is_closed: + raise QueueClosed("Can't add more values when queue is already closed.") + else: + self._deque.append(value) + self._condition.notify() + + def get(self) -> _T: + """Remove and return the value at the front of the queue. + + If the queue is empty, this blocks until a new value is available. + If you're calling from an async task, use one of the async methods instead + to avoid blocking the event loop. + + Raises: + QueueClosed: If all values have been consumed + and the queue has been closed with `done_putting()`. + """ + with self._condition: + while True: + if len(self._deque) > 0: + return self._deque.popleft() + elif self._is_closed: + raise QueueClosed("Queue closed; no more items to get.") + else: + # We don't have anything to return. + # Wait for something to change, then check again. + self._condition.wait() + + def get_until_closed(self) -> Iterable[_T]: + """Remove and return values from the front of the queue until it's closed. + + Example: + for value in queue.get_until_closed(): + print(value) + """ + while True: + try: + yield self.get() + except QueueClosed: + break + + async def get_async(self) -> _T: + """Like `get()`, except yield to the event loop while waiting. + + Warning: + A waiting `get_async()` won't be interrupted by an async cancellation. + The proper way to interrupt a waiting `get_async()` + is to close the queue, just like you have to do with `get()`. + """ + return await run_sync( + self.get, + # We keep `cancellable` False so we don't leak this helper thread. + # If we made it True, an async cancellation here would detach us + # from the helper thread and allow the thread to "run to completion"-- + # but if no more values are ever enqueued, and the queue is never closed, + # completion would never happen and it would hang around forever. + cancellable=False, + ) + + async def get_async_until_closed(self) -> AsyncIterable[_T]: + """Like `get_until_closed()`, except yield to the event loop while waiting. + + Example: + async for value in queue.get_async_until_closed(): + print(value) + + Warning: + While the ``async for`` is waiting for a new value, + it won't be interrupted by an async cancellation. + The proper way to interrupt a waiting `get_async_until_closed()` + is to close the queue, just like you have to do with `get()`. + """ + while True: + try: + yield await self.get_async() + except QueueClosed: + break + + def done_putting(self) -> None: + """Close the queue, i.e. signal that no more values will be `put()`. + + You normally *must* close the queue eventually + to inform consumers that they can stop waiting for new values. + Forgetting to do this can leave them waiting forever, + leaking tasks or threads or causing deadlocks. + + Consider using a ``with`` block instead. See `__enter__()`. + + Raises: + QueueClosed: If the queue is already closed. + """ + with self._condition: + if self._is_closed: + raise QueueClosed("Can't close when queue is already closed.") + else: + self._is_closed = True + self._condition.notify_all() + + def __enter__(self) -> ThreadAsyncQueue[_T]: + """Use the queue as a context manager, closing the queue upon exit. + + Example: + This: + + with queue: + do_stuff() + + Is equivalent to: + + try: + do_stuff() + finally: + queue.done_putting() + """ + return self + + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: + """See `__enter__()`.""" + self.done_putting() + + +class QueueClosed(Exception): + """See `ThreadAsyncQueue.done_putting()`.""" + + pass diff --git a/api/tests/opentrons/protocol_runner/test_legacy_context_plugin.py b/api/tests/opentrons/protocol_runner/test_legacy_context_plugin.py index 1714064bfa5..620b7afa1ba 100644 --- a/api/tests/opentrons/protocol_runner/test_legacy_context_plugin.py +++ b/api/tests/opentrons/protocol_runner/test_legacy_context_plugin.py @@ -1,5 +1,4 @@ """Tests for the PythonAndLegacyRunner's LegacyContextPlugin.""" -import asyncio import pytest from anyio import to_thread from decoy import Decoy, matchers @@ -61,7 +60,7 @@ def mock_action_dispatcher(decoy: Decoy) -> pe_actions.ActionDispatcher: @pytest.fixture -async def subject( +def subject( mock_legacy_broker: LegacyBroker, mock_equipment_broker: ReadOnlyBroker[LoadInfo], mock_legacy_command_mapper: LegacyCommandMapper, @@ -70,7 +69,6 @@ async def subject( ) -> LegacyContextPlugin: """Get a configured LegacyContextPlugin with its dependencies mocked out.""" plugin = LegacyContextPlugin( - engine_loop=asyncio.get_running_loop(), broker=mock_legacy_broker, equipment_broker=mock_equipment_broker, legacy_command_mapper=mock_legacy_command_mapper, diff --git a/api/tests/opentrons/protocol_runner/test_thread_async_queue.py b/api/tests/opentrons/protocol_runner/test_thread_async_queue.py new file mode 100644 index 00000000000..2cf31939348 --- /dev/null +++ b/api/tests/opentrons/protocol_runner/test_thread_async_queue.py @@ -0,0 +1,200 @@ +"""Tests for thread_async_queue.""" + +from __future__ import annotations + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from itertools import chain +from typing import List, NamedTuple + +import pytest + +from opentrons.protocol_runner.thread_async_queue import ( + ThreadAsyncQueue, + QueueClosed, +) + + +def test_basic_single_threaded_behavior() -> None: + """Test basic queue behavior in a single thread.""" + subject = ThreadAsyncQueue[int]() + + with subject: + subject.put(1) + subject.put(2) + subject.put(3) + + # Putting isn't allowed after closing. + with pytest.raises(QueueClosed): + subject.put(4) + with pytest.raises(QueueClosed): + subject.put(5) + + # Closing isn't allowed after closing. + with pytest.raises(QueueClosed): + subject.done_putting() + + # Values are retrieved in order. + assert [subject.get(), subject.get(), subject.get()] == [1, 2, 3] + + # After retrieving all values, further retrievals raise. + with pytest.raises(QueueClosed): + subject.get() + with pytest.raises(QueueClosed): + # If closing were naively implemented as a sentinel value being inserted + # into the queue, it might be that the first get() after the close + # correctly raises but the second get() doesn't. + subject.get() + + +def test_multi_thread_producer_consumer() -> None: + """Stochastically smoke-test thread safety. + + Use the queue to pass values between threads + in a multi-producer, multi-consumer setup. + Verify that all the values make it through in the correct order. + """ + num_producers = 3 + num_consumers = 3 + + producer_ids = list(range(num_producers)) + + # The values that each producer will put into the queue. + # Anecdotally, threads interleave meaningfully with at least 10000 values. + values_per_producer = list(range(30000)) + + all_expected_values = [ + _ProducedValue(producer_id=p, value=v) + for p in producer_ids + for v in values_per_producer + ] + + subject = ThreadAsyncQueue[_ProducedValue]() + + # Run producers concurrently with consumers. + with ThreadPoolExecutor(max_workers=num_producers + num_consumers) as executor: + # `with subject` needs to be inside `with ThreadPoolExecutor` + # to avoid deadlocks in case something in here raises. + # Consumers need to see the queue closed eventually to terminate, + # and `with ThreadPoolExecutor` will wait until all threads are terminated + # before exiting. + with subject: + producers = [ + executor.submit( + _produce, + queue=subject, + values=values_per_producer, + producer_id=producer_id, + ) + for producer_id in producer_ids + ] + consumers = [ + executor.submit(_consume, queue=subject) for i in range(num_consumers) + ] + + # Ensure all producers are done before we exit the `with subject` block + # and close off the queue to further submissions. + for c in producers: + c.result() + + consumer_results = [consumer.result() for consumer in consumers] + all_values = list(chain(*consumer_results)) + + # Assert that the total set of consumed values is as expected: + # No duplicates, no extras, and nothing missing. + assert sorted(all_values) == sorted(all_expected_values) + + def assert_consumer_result_correctly_ordered( + consumer_result: List[_ProducedValue], + ) -> None: + # Assert that the consumer got values in the order the producer provided them. + # Allow values from different producers to be interleaved, + # and tolerate skipped values (assume they were given to a different consumer). + + # [[All consumed from producer 0], [All consumed from producer 1], etc.] + consumed_values_per_producer = [ + [pv for pv in consumer_result if pv.producer_id == producer_id] + for producer_id in producer_ids + ] + for values_from_single_producer in consumed_values_per_producer: + assert values_from_single_producer == sorted(values_from_single_producer) + + for consumer_result in consumer_results: + assert_consumer_result_correctly_ordered(consumer_result) + + +async def test_async() -> None: + """Smoke-test async support. + + Use the queue to pass values + from a single async producer to a single async consumer, + running concurrently in the same event loop. + + This verifies two things: + + 1. That async retrieval returns basically the expected values. + 2. That async retrieval keeps the event loop free while waiting. + If it didn't, this test would reveal the problem by deadlocking. + + We trust that more complicated multi-producer/multi-consumer interactions + are covered by the non-async tests. + """ + expected_values = list(range(1000)) + + subject = ThreadAsyncQueue[_ProducedValue]() + + consumer = asyncio.create_task(_consume_async(queue=subject)) + try: + with subject: + await _produce_async(queue=subject, values=expected_values, producer_id=0) + finally: + consumed = await consumer + + assert consumed == [_ProducedValue(producer_id=0, value=v) for v in expected_values] + + +class _ProducedValue(NamedTuple): + producer_id: int + value: int + + +def _produce( + queue: ThreadAsyncQueue[_ProducedValue], + values: List[int], + producer_id: int, +) -> None: + """Put values in the queue, tagged with an ID representing this producer.""" + for v in values: + queue.put(_ProducedValue(producer_id=producer_id, value=v)) + + +def _consume(queue: ThreadAsyncQueue[_ProducedValue]) -> List[_ProducedValue]: + """Consume values from the queue indiscriminately until it's closed. + + Return everything consumed, in the order that this function consumed it. + """ + result = [] + for value in queue.get_until_closed(): + result.append(value) + return result + + +async def _produce_async( + queue: ThreadAsyncQueue[_ProducedValue], + values: List[int], + producer_id: int, +) -> None: + """Like `_produce()`, except yield to the event loop after each insertion.""" + for value in values: + queue.put(_ProducedValue(producer_id=producer_id, value=value)) + await asyncio.sleep(0) + + +async def _consume_async( + queue: ThreadAsyncQueue[_ProducedValue], +) -> List[_ProducedValue]: + """Like _consume()`, except yield to the event loop while waiting.""" + result = [] + async for value in queue.get_async_until_closed(): + result.append(value) + return result