Skip to content

Commit

Permalink
Add missing stuff from last commit
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed Feb 21, 2025
1 parent 309d0c7 commit be14ce4
Show file tree
Hide file tree
Showing 2 changed files with 383 additions and 0 deletions.
254 changes: 254 additions & 0 deletions src/prefect/_waiters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
from __future__ import annotations

import asyncio
import atexit
import threading
import uuid
from typing import (
TYPE_CHECKING,
Callable,
Self,
)

import anyio
from cachetools import TTLCache

from prefect._internal.concurrency.api import create_call, from_async, from_sync
from prefect._internal.concurrency.threads import get_global_loop
from prefect.client.schemas.objects import (
TERMINAL_STATES,
)
from prefect.events.clients import get_events_subscriber
from prefect.events.filters import EventFilter, EventNameFilter
from prefect.logging import get_logger

if TYPE_CHECKING:
import logging


class FlowRunWaiter:
"""
A service used for waiting for a flow run to finish.
This service listens for flow run events and provides a way to wait for a specific
flow run to finish. This is useful for waiting for a flow run to finish before
continuing execution.
The service is a singleton and must be started before use. The service will
automatically start when the first instance is created. A single websocket
connection is used to listen for flow run events.
The service can be used to wait for a flow run to finish by calling
`FlowRunWaiter.wait_for_flow_run` with the flow run ID to wait for. The method
will return when the flow run has finished or the timeout has elapsed.
The service will automatically stop when the Python process exits or when the
global loop thread is stopped.
Example:
```python
import asyncio
from uuid import uuid4
from prefect import flow
from prefect.flow_engine import run_flow_async
from prefect.flow_runs import FlowRunWaiter
@flow
async def test_flow():
await asyncio.sleep(5)
print("Done!")
async def main():
flow_run_id = uuid4()
asyncio.create_flow(run_flow_async(flow=test_flow, flow_run_id=flow_run_id))
await FlowRunWaiter.wait_for_flow_run(flow_run_id)
print("Flow run finished")
if __name__ == "__main__":
asyncio.run(main())
```
"""

_instance: Self | None = None
_instance_lock = threading.Lock()

def __init__(self):
self.logger: "logging.Logger" = get_logger("FlowRunWaiter")
self._consumer_task: asyncio.Task[None] | None = None
self._observed_completed_flow_runs: TTLCache[uuid.UUID, bool] = TTLCache(
maxsize=10000, ttl=600
)
self._completion_events: dict[uuid.UUID, asyncio.Event] = {}
self._completion_callbacks: dict[uuid.UUID, Callable[[], None]] = {}
self._loop: asyncio.AbstractEventLoop | None = None
self._observed_completed_flow_runs_lock = threading.Lock()
self._completion_events_lock = threading.Lock()
self._started = False

def start(self) -> None:
"""
Start the FlowRunWaiter service.
"""
if self._started:
return
self.logger.debug("Starting FlowRunWaiter")
loop_thread = get_global_loop()

if not asyncio.get_running_loop() == loop_thread.loop:
raise RuntimeError("FlowRunWaiter must run on the global loop thread.")

self._loop = loop_thread.loop
if TYPE_CHECKING:
assert self._loop is not None

consumer_started = asyncio.Event()
self._consumer_task = self._loop.create_task(
self._consume_events(consumer_started)
)
asyncio.run_coroutine_threadsafe(consumer_started.wait(), self._loop)

loop_thread.add_shutdown_call(create_call(self.stop))
atexit.register(self.stop)
self._started = True

async def _consume_events(self, consumer_started: asyncio.Event):
async with get_events_subscriber(
filter=EventFilter(
event=EventNameFilter(
name=[
f"prefect.flow-run.{state.name.title()}"
for state in TERMINAL_STATES
],
)
)
) as subscriber:
consumer_started.set()
async for event in subscriber:
try:
self.logger.debug(
f"Received event: {event.resource['prefect.resource.id']}"
)
flow_run_id = uuid.UUID(
event.resource["prefect.resource.id"].replace(
"prefect.flow-run.", ""
)
)

with self._observed_completed_flow_runs_lock:
# Cache the flow run ID for a short period of time to avoid
# unnecessary waits
self._observed_completed_flow_runs[flow_run_id] = True
with self._completion_events_lock:
# Set the event for the flow run ID if it is in the cache
# so the waiter can wake up the waiting coroutine
if flow_run_id in self._completion_events:
self._completion_events[flow_run_id].set()
if flow_run_id in self._completion_callbacks:
self._completion_callbacks[flow_run_id]()
except Exception as exc:
self.logger.error(f"Error processing event: {exc}")

def stop(self) -> None:
"""
Stop the FlowRunWaiter service.
"""
self.logger.debug("Stopping FlowRunWaiter")
if self._consumer_task:
self._consumer_task.cancel()
self._consumer_task = None
self.__class__._instance = None
self._started = False

@classmethod
async def wait_for_flow_run(
cls, flow_run_id: uuid.UUID, timeout: float | None = None
) -> None:
"""
Wait for a flow run to finish.
Note this relies on a websocket connection to receive events from the server
and will not work with an ephemeral server.
Args:
flow_run_id: The ID of the flow run to wait for.
timeout: The maximum time to wait for the flow run to
finish. Defaults to None.
"""
instance = cls.instance()
with instance._observed_completed_flow_runs_lock:
if flow_run_id in instance._observed_completed_flow_runs:
return

# Need to create event in loop thread to ensure it can be set
# from the loop thread
finished_event = await from_async.wait_for_call_in_loop_thread(
create_call(asyncio.Event)
)
with instance._completion_events_lock:
# Cache the event for the flow run ID so the consumer can set it
# when the event is received
instance._completion_events[flow_run_id] = finished_event

try:
# Now check one more time whether the flow run arrived before we start to
# wait on it, in case it came in while we were setting up the event above.
with instance._observed_completed_flow_runs_lock:
if flow_run_id in instance._observed_completed_flow_runs:
return

with anyio.move_on_after(delay=timeout):
await from_async.wait_for_call_in_loop_thread(
create_call(finished_event.wait)
)
finally:
with instance._completion_events_lock:
# Remove the event from the cache after it has been waited on
instance._completion_events.pop(flow_run_id, None)

@classmethod
def add_done_callback(
cls, flow_run_id: uuid.UUID, callback: Callable[[], None]
) -> None:
"""
Add a callback to be called when a flow run finishes.
Args:
flow_run_id: The ID of the flow run to wait for.
callback: The callback to call when the flow run finishes.
"""
instance = cls.instance()
with instance._observed_completed_flow_runs_lock:
if flow_run_id in instance._observed_completed_flow_runs:
callback()
return

with instance._completion_events_lock:
# Cache the event for the flow run ID so the consumer can set it
# when the event is received
instance._completion_callbacks[flow_run_id] = callback

@classmethod
def instance(cls) -> Self:
"""
Get the singleton instance of FlowRunWaiter.
"""
with cls._instance_lock:
if cls._instance is None:
cls._instance = cls._new_instance()
return cls._instance

@classmethod
def _new_instance(cls):
instance = cls()

if threading.get_ident() == get_global_loop().thread.ident:
instance.start()
else:
from_sync.call_soon_in_loop_thread(create_call(instance.start)).result()

return instance
129 changes: 129 additions & 0 deletions tests/test_waiters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import asyncio

import pytest

from prefect import flow
from prefect._waiters import FlowRunWaiter
from prefect.client.orchestration import PrefectClient
from prefect.flow_engine import run_flow_async
from prefect.server.events.pipeline import EventsPipeline
from prefect.states import Pending


class TestFlowRunWaiter:
@pytest.fixture(autouse=True)
def teardown(self):
yield

FlowRunWaiter.instance().stop()

def test_instance_returns_singleton(self):
assert FlowRunWaiter.instance() is FlowRunWaiter.instance()

def test_instance_returns_instance_after_stop(self):
instance = FlowRunWaiter.instance()
instance.stop()
assert FlowRunWaiter.instance() is not instance

@pytest.mark.timeout(20)
async def test_wait_for_flow_run(
self, prefect_client: PrefectClient, emitting_events_pipeline: EventsPipeline
):
"""This test will fail with a timeout error if waiting is not working correctly."""

@flow
async def test_flow():
await asyncio.sleep(1)

flow_run = await prefect_client.create_flow_run(test_flow, state=Pending())
asyncio.create_task(run_flow_async(flow=test_flow, flow_run=flow_run))

await FlowRunWaiter.wait_for_flow_run(flow_run.id)

await emitting_events_pipeline.process_events()

flow_run = await prefect_client.read_flow_run(flow_run.id)
assert flow_run.state
assert flow_run.state.is_completed()

async def test_wait_for_flow_run_with_timeout(self, prefect_client: PrefectClient):
@flow
async def test_flow():
await asyncio.sleep(5)

flow_run = await prefect_client.create_flow_run(test_flow, state=Pending())
run = asyncio.create_task(run_flow_async(flow=test_flow, flow_run=flow_run))

await FlowRunWaiter.wait_for_flow_run(flow_run.id, timeout=1)

# FlowRunWaiter stopped waiting before the task finished
assert not run.done()
await run

@pytest.mark.timeout(20)
async def test_non_singleton_mode(
self, prefect_client: PrefectClient, emitting_events_pipeline: EventsPipeline
):
waiter = FlowRunWaiter()
assert waiter is not FlowRunWaiter.instance()

@flow
async def test_flow():
await asyncio.sleep(1)

flow_run = await prefect_client.create_flow_run(test_flow, state=Pending())
asyncio.create_task(run_flow_async(flow=test_flow, flow_run=flow_run))

await waiter.wait_for_flow_run(flow_run.id)

await emitting_events_pipeline.process_events()

flow_run = await prefect_client.read_flow_run(flow_run.id)
assert flow_run.state
assert flow_run.state.is_completed()

waiter.stop()

@pytest.mark.timeout(20)
async def test_handles_concurrent_task_runs(
self, prefect_client: PrefectClient, emitting_events_pipeline: EventsPipeline
):
@flow
async def fast_flow():
await asyncio.sleep(1)

@flow
async def slow_flow():
await asyncio.sleep(5)

flow_run_1 = await prefect_client.create_flow_run(fast_flow, state=Pending())
flow_run_2 = await prefect_client.create_flow_run(slow_flow, state=Pending())

asyncio.create_task(run_flow_async(flow=fast_flow, flow_run=flow_run_1))
asyncio.create_task(run_flow_async(flow=slow_flow, flow_run=flow_run_2))

await FlowRunWaiter.wait_for_flow_run(flow_run_1.id)

await emitting_events_pipeline.process_events()

flow_run_1 = await prefect_client.read_flow_run(flow_run_1.id)
flow_run_2 = await prefect_client.read_flow_run(flow_run_2.id)

assert flow_run_1.state
assert flow_run_1.state.is_completed()

assert flow_run_2.state
assert not flow_run_2.state.is_completed()

await FlowRunWaiter.wait_for_flow_run(flow_run_2.id)

await emitting_events_pipeline.process_events()

flow_run_1 = await prefect_client.read_flow_run(flow_run_1.id)
flow_run_2 = await prefect_client.read_flow_run(flow_run_2.id)

assert flow_run_1.state
assert flow_run_1.state.is_completed()

assert flow_run_2.state
assert flow_run_2.state.is_completed()

0 comments on commit be14ce4

Please sign in to comment.