diff --git a/src/prefect/_waiters.py b/src/prefect/_waiters.py new file mode 100644 index 000000000000..328e18b15cdc --- /dev/null +++ b/src/prefect/_waiters.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import asyncio +import atexit +import threading +import uuid +from typing import ( + TYPE_CHECKING, + Callable, + Self, +) + +import anyio +from cachetools import TTLCache + +from prefect._internal.concurrency.api import create_call, from_async, from_sync +from prefect._internal.concurrency.threads import get_global_loop +from prefect.client.schemas.objects import ( + TERMINAL_STATES, +) +from prefect.events.clients import get_events_subscriber +from prefect.events.filters import EventFilter, EventNameFilter +from prefect.logging import get_logger + +if TYPE_CHECKING: + import logging + + +class FlowRunWaiter: + """ + A service used for waiting for a flow run to finish. + + This service listens for flow run events and provides a way to wait for a specific + flow run to finish. This is useful for waiting for a flow run to finish before + continuing execution. + + The service is a singleton and must be started before use. The service will + automatically start when the first instance is created. A single websocket + connection is used to listen for flow run events. + + The service can be used to wait for a flow run to finish by calling + `FlowRunWaiter.wait_for_flow_run` with the flow run ID to wait for. The method + will return when the flow run has finished or the timeout has elapsed. + + The service will automatically stop when the Python process exits or when the + global loop thread is stopped. + + Example: + ```python + import asyncio + from uuid import uuid4 + + from prefect import flow + from prefect.flow_engine import run_flow_async + from prefect.flow_runs import FlowRunWaiter + + + @flow + async def test_flow(): + await asyncio.sleep(5) + print("Done!") + + + async def main(): + flow_run_id = uuid4() + asyncio.create_flow(run_flow_async(flow=test_flow, flow_run_id=flow_run_id)) + + await FlowRunWaiter.wait_for_flow_run(flow_run_id) + print("Flow run finished") + + + if __name__ == "__main__": + asyncio.run(main()) + ``` + """ + + _instance: Self | None = None + _instance_lock = threading.Lock() + + def __init__(self): + self.logger: "logging.Logger" = get_logger("FlowRunWaiter") + self._consumer_task: asyncio.Task[None] | None = None + self._observed_completed_flow_runs: TTLCache[uuid.UUID, bool] = TTLCache( + maxsize=10000, ttl=600 + ) + self._completion_events: dict[uuid.UUID, asyncio.Event] = {} + self._completion_callbacks: dict[uuid.UUID, Callable[[], None]] = {} + self._loop: asyncio.AbstractEventLoop | None = None + self._observed_completed_flow_runs_lock = threading.Lock() + self._completion_events_lock = threading.Lock() + self._started = False + + def start(self) -> None: + """ + Start the FlowRunWaiter service. + """ + if self._started: + return + self.logger.debug("Starting FlowRunWaiter") + loop_thread = get_global_loop() + + if not asyncio.get_running_loop() == loop_thread.loop: + raise RuntimeError("FlowRunWaiter must run on the global loop thread.") + + self._loop = loop_thread.loop + if TYPE_CHECKING: + assert self._loop is not None + + consumer_started = asyncio.Event() + self._consumer_task = self._loop.create_task( + self._consume_events(consumer_started) + ) + asyncio.run_coroutine_threadsafe(consumer_started.wait(), self._loop) + + loop_thread.add_shutdown_call(create_call(self.stop)) + atexit.register(self.stop) + self._started = True + + async def _consume_events(self, consumer_started: asyncio.Event): + async with get_events_subscriber( + filter=EventFilter( + event=EventNameFilter( + name=[ + f"prefect.flow-run.{state.name.title()}" + for state in TERMINAL_STATES + ], + ) + ) + ) as subscriber: + consumer_started.set() + async for event in subscriber: + try: + self.logger.debug( + f"Received event: {event.resource['prefect.resource.id']}" + ) + flow_run_id = uuid.UUID( + event.resource["prefect.resource.id"].replace( + "prefect.flow-run.", "" + ) + ) + + with self._observed_completed_flow_runs_lock: + # Cache the flow run ID for a short period of time to avoid + # unnecessary waits + self._observed_completed_flow_runs[flow_run_id] = True + with self._completion_events_lock: + # Set the event for the flow run ID if it is in the cache + # so the waiter can wake up the waiting coroutine + if flow_run_id in self._completion_events: + self._completion_events[flow_run_id].set() + if flow_run_id in self._completion_callbacks: + self._completion_callbacks[flow_run_id]() + except Exception as exc: + self.logger.error(f"Error processing event: {exc}") + + def stop(self) -> None: + """ + Stop the FlowRunWaiter service. + """ + self.logger.debug("Stopping FlowRunWaiter") + if self._consumer_task: + self._consumer_task.cancel() + self._consumer_task = None + self.__class__._instance = None + self._started = False + + @classmethod + async def wait_for_flow_run( + cls, flow_run_id: uuid.UUID, timeout: float | None = None + ) -> None: + """ + Wait for a flow run to finish. + + Note this relies on a websocket connection to receive events from the server + and will not work with an ephemeral server. + + Args: + flow_run_id: The ID of the flow run to wait for. + timeout: The maximum time to wait for the flow run to + finish. Defaults to None. + """ + instance = cls.instance() + with instance._observed_completed_flow_runs_lock: + if flow_run_id in instance._observed_completed_flow_runs: + return + + # Need to create event in loop thread to ensure it can be set + # from the loop thread + finished_event = await from_async.wait_for_call_in_loop_thread( + create_call(asyncio.Event) + ) + with instance._completion_events_lock: + # Cache the event for the flow run ID so the consumer can set it + # when the event is received + instance._completion_events[flow_run_id] = finished_event + + try: + # Now check one more time whether the flow run arrived before we start to + # wait on it, in case it came in while we were setting up the event above. + with instance._observed_completed_flow_runs_lock: + if flow_run_id in instance._observed_completed_flow_runs: + return + + with anyio.move_on_after(delay=timeout): + await from_async.wait_for_call_in_loop_thread( + create_call(finished_event.wait) + ) + finally: + with instance._completion_events_lock: + # Remove the event from the cache after it has been waited on + instance._completion_events.pop(flow_run_id, None) + + @classmethod + def add_done_callback( + cls, flow_run_id: uuid.UUID, callback: Callable[[], None] + ) -> None: + """ + Add a callback to be called when a flow run finishes. + + Args: + flow_run_id: The ID of the flow run to wait for. + callback: The callback to call when the flow run finishes. + """ + instance = cls.instance() + with instance._observed_completed_flow_runs_lock: + if flow_run_id in instance._observed_completed_flow_runs: + callback() + return + + with instance._completion_events_lock: + # Cache the event for the flow run ID so the consumer can set it + # when the event is received + instance._completion_callbacks[flow_run_id] = callback + + @classmethod + def instance(cls) -> Self: + """ + Get the singleton instance of FlowRunWaiter. + """ + with cls._instance_lock: + if cls._instance is None: + cls._instance = cls._new_instance() + return cls._instance + + @classmethod + def _new_instance(cls): + instance = cls() + + if threading.get_ident() == get_global_loop().thread.ident: + instance.start() + else: + from_sync.call_soon_in_loop_thread(create_call(instance.start)).result() + + return instance diff --git a/tests/test_waiters.py b/tests/test_waiters.py new file mode 100644 index 000000000000..4bcc59d0ca4b --- /dev/null +++ b/tests/test_waiters.py @@ -0,0 +1,129 @@ +import asyncio + +import pytest + +from prefect import flow +from prefect._waiters import FlowRunWaiter +from prefect.client.orchestration import PrefectClient +from prefect.flow_engine import run_flow_async +from prefect.server.events.pipeline import EventsPipeline +from prefect.states import Pending + + +class TestFlowRunWaiter: + @pytest.fixture(autouse=True) + def teardown(self): + yield + + FlowRunWaiter.instance().stop() + + def test_instance_returns_singleton(self): + assert FlowRunWaiter.instance() is FlowRunWaiter.instance() + + def test_instance_returns_instance_after_stop(self): + instance = FlowRunWaiter.instance() + instance.stop() + assert FlowRunWaiter.instance() is not instance + + @pytest.mark.timeout(20) + async def test_wait_for_flow_run( + self, prefect_client: PrefectClient, emitting_events_pipeline: EventsPipeline + ): + """This test will fail with a timeout error if waiting is not working correctly.""" + + @flow + async def test_flow(): + await asyncio.sleep(1) + + flow_run = await prefect_client.create_flow_run(test_flow, state=Pending()) + asyncio.create_task(run_flow_async(flow=test_flow, flow_run=flow_run)) + + await FlowRunWaiter.wait_for_flow_run(flow_run.id) + + await emitting_events_pipeline.process_events() + + flow_run = await prefect_client.read_flow_run(flow_run.id) + assert flow_run.state + assert flow_run.state.is_completed() + + async def test_wait_for_flow_run_with_timeout(self, prefect_client: PrefectClient): + @flow + async def test_flow(): + await asyncio.sleep(5) + + flow_run = await prefect_client.create_flow_run(test_flow, state=Pending()) + run = asyncio.create_task(run_flow_async(flow=test_flow, flow_run=flow_run)) + + await FlowRunWaiter.wait_for_flow_run(flow_run.id, timeout=1) + + # FlowRunWaiter stopped waiting before the task finished + assert not run.done() + await run + + @pytest.mark.timeout(20) + async def test_non_singleton_mode( + self, prefect_client: PrefectClient, emitting_events_pipeline: EventsPipeline + ): + waiter = FlowRunWaiter() + assert waiter is not FlowRunWaiter.instance() + + @flow + async def test_flow(): + await asyncio.sleep(1) + + flow_run = await prefect_client.create_flow_run(test_flow, state=Pending()) + asyncio.create_task(run_flow_async(flow=test_flow, flow_run=flow_run)) + + await waiter.wait_for_flow_run(flow_run.id) + + await emitting_events_pipeline.process_events() + + flow_run = await prefect_client.read_flow_run(flow_run.id) + assert flow_run.state + assert flow_run.state.is_completed() + + waiter.stop() + + @pytest.mark.timeout(20) + async def test_handles_concurrent_task_runs( + self, prefect_client: PrefectClient, emitting_events_pipeline: EventsPipeline + ): + @flow + async def fast_flow(): + await asyncio.sleep(1) + + @flow + async def slow_flow(): + await asyncio.sleep(5) + + flow_run_1 = await prefect_client.create_flow_run(fast_flow, state=Pending()) + flow_run_2 = await prefect_client.create_flow_run(slow_flow, state=Pending()) + + asyncio.create_task(run_flow_async(flow=fast_flow, flow_run=flow_run_1)) + asyncio.create_task(run_flow_async(flow=slow_flow, flow_run=flow_run_2)) + + await FlowRunWaiter.wait_for_flow_run(flow_run_1.id) + + await emitting_events_pipeline.process_events() + + flow_run_1 = await prefect_client.read_flow_run(flow_run_1.id) + flow_run_2 = await prefect_client.read_flow_run(flow_run_2.id) + + assert flow_run_1.state + assert flow_run_1.state.is_completed() + + assert flow_run_2.state + assert not flow_run_2.state.is_completed() + + await FlowRunWaiter.wait_for_flow_run(flow_run_2.id) + + await emitting_events_pipeline.process_events() + + flow_run_1 = await prefect_client.read_flow_run(flow_run_1.id) + flow_run_2 = await prefect_client.read_flow_run(flow_run_2.id) + + assert flow_run_1.state + assert flow_run_1.state.is_completed() + + assert flow_run_2.state + assert flow_run_2.state.is_completed()