Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aserve utility for serving multiple flows from an asynchronous context #15972

Merged
merged 14 commits into from
Dec 3, 2024
3 changes: 3 additions & 0 deletions src/prefect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Transaction,
unmapped,
serve,
aserve,
deploy,
pause_flow_run,
resume_flow_run,
Expand Down Expand Up @@ -66,6 +67,7 @@
"Transaction": (__spec__.parent, ".main"),
"unmapped": (__spec__.parent, ".main"),
"serve": (__spec__.parent, ".main"),
"aserve": (__spec__.parent, ".main"),
"deploy": (__spec__.parent, ".main"),
"pause_flow_run": (__spec__.parent, ".main"),
"resume_flow_run": (__spec__.parent, ".main"),
Expand All @@ -86,6 +88,7 @@
"Transaction",
"unmapped",
"serve",
"aserve",
"deploy",
"pause_flow_run",
"resume_flow_run",
Expand Down
147 changes: 108 additions & 39 deletions src/prefect/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from prefect.utilities.hashing import file_hash
from prefect.utilities.importtools import import_object, safe_load_namespace

from ._internal.compatibility.async_dispatch import is_in_async_context
from ._internal.pydantic.v2_schema import is_v2_type
from ._internal.pydantic.v2_validated_func import V2ValidatedFunction
from ._internal.pydantic.v2_validated_func import (
Expand Down Expand Up @@ -1812,61 +1813,129 @@ def my_other_flow(name):
serve(hello_deploy, bye_deploy)
```
"""
from rich.console import Console, Group
from rich.table import Table

from prefect.runner import Runner

if is_in_async_context():
raise RuntimeError(
"Cannot call `serve` in an asynchronous context. Use `aserve` instead."
)

runner = Runner(pause_on_shutdown=pause_on_shutdown, limit=limit, **kwargs)
for deployment in args:
runner.add_deployment(deployment)

if print_starting_message:
help_message_top = (
"[green]Your deployments are being served and polling for"
" scheduled runs!\n[/]"
)
_display_serve_start_message(*args)

table = Table(title="Deployments", show_header=False)
try:
asyncio.run(runner.start())
except (KeyboardInterrupt, TerminationSignal) as exc:
logger.info(f"Received {type(exc).__name__}, shutting down...")

table.add_column(style="blue", no_wrap=True)

for deployment in args:
table.add_row(f"{deployment.flow_name}/{deployment.name}")
async def aserve(
*args: "RunnerDeployment",
pause_on_shutdown: bool = True,
print_starting_message: bool = True,
limit: Optional[int] = None,
**kwargs,
):
"""
Asynchronously serve the provided list of deployments.

Use `serve` instead if calling from a synchronous context.

help_message_bottom = (
"\nTo trigger any of these deployments, use the"
" following command:\n[blue]\n\t$ prefect deployment run"
" [DEPLOYMENT_NAME]\n[/]"
)
if PREFECT_UI_URL:
help_message_bottom += (
"\nYou can also trigger your deployments via the Prefect UI:"
f" [blue]{PREFECT_UI_URL.value()}/deployments[/]\n"
Args:
*args: A list of deployments to serve.
pause_on_shutdown: A boolean for whether or not to automatically pause
deployment schedules on shutdown.
print_starting_message: Whether or not to print message to the console
on startup.
limit: The maximum number of runs that can be executed concurrently.
**kwargs: Additional keyword arguments to pass to the runner.

Examples:
Prepare deployment and asynchronous initialization function and serve them:

```python
import asyncio
import datetime

from prefect import flow, aserve, get_client


async def init():
await set_concurrency_limit()


async def set_concurrency_limit():
async with get_client() as client:
await client.create_concurrency_limit(tag='dev', concurrency_limit=3)


@flow
async def my_flow(name):
print(f"hello {name}")


async def main():
# Initialization function
await init()

# Run once a day
hello_deploy = await my_flow.to_deployment(
"hello", tags=["dev"], interval=datetime.timedelta(days=1)
)

console = Console()
console.print(
Group(help_message_top, table, help_message_bottom), soft_wrap=True
)
await aserve(hello_deploy)

try:
loop = asyncio.get_running_loop()
except RuntimeError as exc:
if "no running event loop" in str(exc):
loop = None
else:
raise

try:
if loop is not None:
loop.run_until_complete(runner.start())
else:
asyncio.run(runner.start())
except (KeyboardInterrupt, TerminationSignal) as exc:
logger.info(f"Received {type(exc).__name__}, shutting down...")
if loop is not None:
loop.stop()
if __name__ == "__main__":
asyncio.run(main())
"""

from prefect.runner import Runner

runner = Runner(pause_on_shutdown=pause_on_shutdown, limit=limit, **kwargs)
for deployment in args:
await runner.add_deployment(deployment)

if print_starting_message:
_display_serve_start_message(*args)

await runner.start()


def _display_serve_start_message(*args: "RunnerDeployment"):
from rich.console import Console, Group
from rich.table import Table

help_message_top = (
"[green]Your deployments are being served and polling for"
" scheduled runs!\n[/]"
)

table = Table(title="Deployments", show_header=False)

table.add_column(style="blue", no_wrap=True)

for deployment in args:
table.add_row(f"{deployment.flow_name}/{deployment.name}")

help_message_bottom = (
"\nTo trigger any of these deployments, use the"
" following command:\n[blue]\n\t$ prefect deployment run"
" [DEPLOYMENT_NAME]\n[/]"
)
if PREFECT_UI_URL:
help_message_bottom += (
"\nYou can also trigger your deployments via the Prefect UI:"
f" [blue]{PREFECT_UI_URL.value()}/deployments[/]\n"
)

console = Console()
console.print(Group(help_message_top, table, help_message_bottom), soft_wrap=True)


@client_injector
Expand Down
3 changes: 2 additions & 1 deletion src/prefect/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from prefect.deployments import deploy
from prefect.states import State
from prefect.logging import get_run_logger
from prefect.flows import flow, Flow, serve
from prefect.flows import flow, Flow, serve, aserve
from prefect.transactions import Transaction
from prefect.tasks import task, Task
from prefect.context import tags
Expand Down Expand Up @@ -84,6 +84,7 @@
"Transaction",
"unmapped",
"serve",
"aserve",
"deploy",
"pause_flow_run",
"resume_flow_run",
Expand Down
94 changes: 93 additions & 1 deletion tests/runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from starlette import status

import prefect.runner
from prefect import __version__, flow, serve, task
from prefect import __version__, aserve, flow, serve, task
from prefect.client.orchestration import PrefectClient, SyncPrefectClient
from prefect.client.schemas.actions import DeploymentScheduleCreate
from prefect.client.schemas.objects import (
Expand Down Expand Up @@ -291,6 +291,98 @@ def test_log_level_lowercasing(self, monkeypatch):
webserver_mock, host=mock.ANY, port=mock.ANY, log_level="debug"
)

def test_serve_in_async_context_raises_error(self, monkeypatch):
monkeypatch.setattr(
"asyncio.get_running_loop", lambda: asyncio.get_event_loop()
)

deployment = dummy_flow_1.to_deployment("test")

with pytest.raises(
RuntimeError,
match="Cannot call `serve` in an asynchronous context. Use `aserve` instead.",
):
serve(deployment)


class TestAServe:
@pytest.fixture(autouse=True)
async def mock_runner_start(self, monkeypatch):
mock = AsyncMock()
monkeypatch.setattr("prefect.runner.Runner.start", mock)
return mock

async def test_aserve_prints_help_message_on_startup(self, capsys):
await aserve(
await dummy_flow_1.to_deployment(__file__),
await dummy_flow_2.to_deployment(__file__),
await tired_flow.to_deployment(__file__),
)

captured = capsys.readouterr()

assert (
"Your deployments are being served and polling for scheduled runs!"
in captured.out
)
assert "dummy-flow-1/test_runner" in captured.out
assert "dummy-flow-2/test_runner" in captured.out
assert "tired-flow/test_runner" in captured.out
assert "$ prefect deployment run [DEPLOYMENT_NAME]" in captured.out

async def test_aserve_typed_container_inputs_flow(self, capsys):
@flow
def type_container_input_flow(arg1: List[str]) -> str:
print(arg1)
return ",".join(arg1)

await aserve(
await type_container_input_flow.to_deployment(__file__),
)

captured = capsys.readouterr()

assert (
"Your deployments are being served and polling for scheduled runs!"
in captured.out
)
assert "type-container-input-flow/test_runner" in captured.out
assert "$ prefect deployment run [DEPLOYMENT_NAME]" in captured.out

async def test_aserve_can_create_multiple_deployments(
self,
prefect_client: PrefectClient,
):
deployment_1 = dummy_flow_1.to_deployment(__file__, interval=3600)
deployment_2 = dummy_flow_2.to_deployment(__file__, cron="* * * * *")

await aserve(await deployment_1, await deployment_2)

deployment = await prefect_client.read_deployment_by_name(
name="dummy-flow-1/test_runner"
)

assert deployment is not None
assert deployment.schedules[0].schedule.interval == datetime.timedelta(
seconds=3600
)

deployment = await prefect_client.read_deployment_by_name(
name="dummy-flow-2/test_runner"
)

assert deployment is not None
assert deployment.schedules[0].schedule.cron == "* * * * *"

async def test_aserve_starts_a_runner(
self, prefect_client: PrefectClient, mock_runner_start: AsyncMock
):
deployment = dummy_flow_1.to_deployment("test")

await aserve(await deployment)

mock_runner_start.assert_awaited_once()


class TestRunner:
async def test_add_flows_to_runner(self, prefect_client: PrefectClient):
Expand Down
Loading