From 2accf9d4428a8e8c2bb83ede148d280fe0061d0d Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Thu, 24 Oct 2024 18:39:29 +0200 Subject: [PATCH] feat: add apiserver support to Python SDK (#327) * Client refactoring try checkpoint add asgiref to support async-to-sync: fix unit tests remove pydantic warning add unit tests for client fix mock path test model added tests and fix discovered bugs fix connection error handling fix bugs surfaced in end-to-end use explicity properties fix awaitable checks add instance method revert to return sync class extract base model use instance() method on models add e2e tests working state fix unit tests more fixes * make change backward compat * remove useless param * add docstrings * add sessions.create() * merge setting into base client class * assert conditionin e2e * export Client at the top level --- e2e_tests/apiserver/__init__.py | 0 e2e_tests/apiserver/conftest.py | 27 ++ .../apiserver/deployments/deployment1.yml | 15 + .../apiserver/deployments/deployment2.yml | 15 + .../deployments/deployment_streaming.yml | 14 + .../apiserver/deployments/src/workflow.py | 41 +++ e2e_tests/apiserver/test_deploy.py | 23 ++ e2e_tests/apiserver/test_status.py | 23 ++ e2e_tests/apiserver/test_streaming.py | 27 ++ llama_deploy/__init__.py | 27 +- llama_deploy/apiserver/deployment.py | 16 +- llama_deploy/apiserver/routers/deployments.py | 24 +- llama_deploy/client/__init__.py | 7 +- llama_deploy/client/base.py | 29 ++ llama_deploy/client/client.py | 40 +++ llama_deploy/client/models/__init__.py | 4 + llama_deploy/client/models/apiserver.py | 303 ++++++++++++++++++ llama_deploy/client/models/model.py | 71 ++++ llama_deploy/sdk/__init__.py | 0 llama_deploy/types/__init__.py | 37 +++ llama_deploy/types/apiserver.py | 16 + llama_deploy/{types.py => types/core.py} | 0 poetry.lock | 19 +- pyproject.toml | 1 + tests/client/models/__init__.py | 0 tests/client/models/conftest.py | 12 + tests/client/models/test_apiserver.py | 284 ++++++++++++++++ tests/client/models/test_model.py | 41 +++ tests/client/test_client.py | 50 +++ tests/services/test_human_service.py | 18 +- 30 files changed, 1144 insertions(+), 40 deletions(-) create mode 100644 e2e_tests/apiserver/__init__.py create mode 100644 e2e_tests/apiserver/conftest.py create mode 100644 e2e_tests/apiserver/deployments/deployment1.yml create mode 100644 e2e_tests/apiserver/deployments/deployment2.yml create mode 100644 e2e_tests/apiserver/deployments/deployment_streaming.yml create mode 100644 e2e_tests/apiserver/deployments/src/workflow.py create mode 100644 e2e_tests/apiserver/test_deploy.py create mode 100644 e2e_tests/apiserver/test_status.py create mode 100644 e2e_tests/apiserver/test_streaming.py create mode 100644 llama_deploy/client/base.py create mode 100644 llama_deploy/client/client.py create mode 100644 llama_deploy/client/models/__init__.py create mode 100644 llama_deploy/client/models/apiserver.py create mode 100644 llama_deploy/client/models/model.py create mode 100644 llama_deploy/sdk/__init__.py create mode 100644 llama_deploy/types/__init__.py create mode 100644 llama_deploy/types/apiserver.py rename llama_deploy/{types.py => types/core.py} (100%) create mode 100644 tests/client/models/__init__.py create mode 100644 tests/client/models/conftest.py create mode 100644 tests/client/models/test_apiserver.py create mode 100644 tests/client/models/test_model.py create mode 100644 tests/client/test_client.py diff --git a/e2e_tests/apiserver/__init__.py b/e2e_tests/apiserver/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/e2e_tests/apiserver/conftest.py b/e2e_tests/apiserver/conftest.py new file mode 100644 index 00000000..93802fa0 --- /dev/null +++ b/e2e_tests/apiserver/conftest.py @@ -0,0 +1,27 @@ +import multiprocessing +import time + +import pytest +import uvicorn + +from llama_deploy.client import Client + + +def run_async_apiserver(): + uvicorn.run("llama_deploy.apiserver:app", host="127.0.0.1", port=4501) + + +@pytest.fixture(scope="module") +def apiserver(): + p = multiprocessing.Process(target=run_async_apiserver) + p.start() + time.sleep(3) + + yield + + p.kill() + + +@pytest.fixture +def client(): + return Client(api_server_url="http://localhost:4501") diff --git a/e2e_tests/apiserver/deployments/deployment1.yml b/e2e_tests/apiserver/deployments/deployment1.yml new file mode 100644 index 00000000..e63acffc --- /dev/null +++ b/e2e_tests/apiserver/deployments/deployment1.yml @@ -0,0 +1,15 @@ +name: TestDeployment1 + +control-plane: {} + +default-service: dummy_workflow + +services: + test-workflow: + name: Test Workflow + port: 8002 + host: localhost + source: + type: git + name: https://github.com/run-llama/llama_deploy.git + path: tests/apiserver/data/workflow:my_workflow diff --git a/e2e_tests/apiserver/deployments/deployment2.yml b/e2e_tests/apiserver/deployments/deployment2.yml new file mode 100644 index 00000000..1699d78f --- /dev/null +++ b/e2e_tests/apiserver/deployments/deployment2.yml @@ -0,0 +1,15 @@ +name: TestDeployment2 + +control-plane: {} + +default-service: dummy_workflow + +services: + test-workflow: + name: Test Workflow + port: 8002 + host: localhost + source: + type: git + name: https://github.com/run-llama/llama_deploy.git + path: tests/apiserver/data/workflow:my_workflow diff --git a/e2e_tests/apiserver/deployments/deployment_streaming.yml b/e2e_tests/apiserver/deployments/deployment_streaming.yml new file mode 100644 index 00000000..4d0c6ecf --- /dev/null +++ b/e2e_tests/apiserver/deployments/deployment_streaming.yml @@ -0,0 +1,14 @@ +name: Streaming + +control-plane: + port: 8000 + +default-service: streaming_workflow + +services: + streaming_workflow: + name: Streaming Workflow + source: + type: local + name: ./e2e_tests/apiserver/deployments/src + path: workflow:streaming_workflow diff --git a/e2e_tests/apiserver/deployments/src/workflow.py b/e2e_tests/apiserver/deployments/src/workflow.py new file mode 100644 index 00000000..ac3f47ad --- /dev/null +++ b/e2e_tests/apiserver/deployments/src/workflow.py @@ -0,0 +1,41 @@ +import asyncio + +from llama_index.core.workflow import ( + Context, + Event, + StartEvent, + StopEvent, + Workflow, + step, +) + + +class Message(Event): + text: str + + +class EchoWorkflow(Workflow): + """A dummy workflow streaming three events.""" + + @step() + async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent: + for i in range(3): + ctx.write_event_to_stream(Message(text=f"message number {i+1}")) + await asyncio.sleep(0.5) + + return StopEvent(result="Done.") + + +streaming_workflow = EchoWorkflow() + + +async def main(): + h = streaming_workflow.run(message="Hello!") + async for ev in h.stream_events(): + if type(ev) is Message: + print(ev.text) + print(await h) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/e2e_tests/apiserver/test_deploy.py b/e2e_tests/apiserver/test_deploy.py new file mode 100644 index 00000000..fc836068 --- /dev/null +++ b/e2e_tests/apiserver/test_deploy.py @@ -0,0 +1,23 @@ +from pathlib import Path + +import pytest + + +@pytest.mark.asyncio +async def test_deploy(apiserver, client): + here = Path(__file__).parent + deployments = await client.apiserver.deployments() + with open(here / "deployments" / "deployment1.yml") as f: + await deployments.create(f) + + status = await client.apiserver.status() + assert "TestDeployment1" in status.deployments + + +def test_deploy_sync(apiserver, client): + here = Path(__file__).parent + deployments = client.sync.apiserver.deployments() + with open(here / "deployments" / "deployment2.yml") as f: + deployments.create(f) + + assert "TestDeployment2" in client.sync.apiserver.status().deployments diff --git a/e2e_tests/apiserver/test_status.py b/e2e_tests/apiserver/test_status.py new file mode 100644 index 00000000..8665e0bb --- /dev/null +++ b/e2e_tests/apiserver/test_status.py @@ -0,0 +1,23 @@ +import pytest + + +@pytest.mark.asyncio +async def test_status_down(client): + res = await client.apiserver.status() + assert res.status.value == "Down" + + +def test_status_down_sync(client): + res = client.sync.apiserver.status() + assert res.status.value == "Down" + + +@pytest.mark.asyncio +async def test_status_up(apiserver, client): + res = await client.sync.apiserver.status() + assert res.status.value == "Healthy" + + +def test_status_up_sync(apiserver, client): + res = client.sync.apiserver.status() + assert res.status.value == "Healthy" diff --git a/e2e_tests/apiserver/test_streaming.py b/e2e_tests/apiserver/test_streaming.py new file mode 100644 index 00000000..89748ebc --- /dev/null +++ b/e2e_tests/apiserver/test_streaming.py @@ -0,0 +1,27 @@ +import asyncio +from pathlib import Path + +import pytest + +from llama_deploy.types import TaskDefinition + + +@pytest.mark.asyncio +async def test_stream(apiserver, client): + here = Path(__file__).parent + + with open(here / "deployments" / "deployment_streaming.yml") as f: + deployments = await client.apiserver.deployments() + deployment = await deployments.create(f) + await asyncio.sleep(5) + + tasks = await deployment.tasks() + task = await tasks.create(TaskDefinition(input='{"a": "b"}')) + read_events = [] + async for ev in task.events(): + if "text" in ev: + read_events.append(ev) + assert len(read_events) == 3 + # the workflow produces events sequentially, so here we can assume events arrived in order + for i, ev in enumerate(read_events): + assert ev["text"] == f"message number {i+1}" diff --git a/llama_deploy/__init__.py b/llama_deploy/__init__.py index 8dde7c74..fe9b3714 100644 --- a/llama_deploy/__init__.py +++ b/llama_deploy/__init__.py @@ -1,10 +1,21 @@ -from llama_deploy.client import AsyncLlamaDeployClient, LlamaDeployClient -from llama_deploy.control_plane import ControlPlaneServer, ControlPlaneConfig +# configure logger +import logging + +from llama_deploy.client import AsyncLlamaDeployClient, Client, LlamaDeployClient +from llama_deploy.control_plane import ControlPlaneConfig, ControlPlaneServer from llama_deploy.deploy import deploy_core, deploy_workflow from llama_deploy.message_consumers import CallableMessageConsumer from llama_deploy.message_queues import SimpleMessageQueue, SimpleMessageQueueConfig from llama_deploy.messages import QueueMessage from llama_deploy.orchestrators import SimpleOrchestrator, SimpleOrchestratorConfig +from llama_deploy.services import ( + AgentService, + ComponentService, + HumanService, + ToolService, + WorkflowService, + WorkflowServiceConfig, +) from llama_deploy.tools import ( AgentServiceTool, MetaServiceTool, @@ -12,17 +23,6 @@ ServiceComponent, ServiceTool, ) -from llama_deploy.services import ( - AgentService, - ToolService, - HumanService, - ComponentService, - WorkflowService, - WorkflowServiceConfig, -) - -# configure logger -import logging root_logger = logging.getLogger("llama_deploy") @@ -39,6 +39,7 @@ # clients "LlamaDeployClient", "AsyncLlamaDeployClient", + "Client", # services "AgentService", "HumanService", diff --git a/llama_deploy/apiserver/deployment.py b/llama_deploy/apiserver/deployment.py index fdf70027..f4b20223 100644 --- a/llama_deploy/apiserver/deployment.py +++ b/llama_deploy/apiserver/deployment.py @@ -7,32 +7,26 @@ from typing import Any from llama_deploy import ( + AsyncLlamaDeployClient, ControlPlaneServer, SimpleMessageQueue, - SimpleOrchestratorConfig, SimpleOrchestrator, + SimpleOrchestratorConfig, WorkflowService, WorkflowServiceConfig, - AsyncLlamaDeployClient, ) from llama_deploy.message_queues import ( - BaseMessageQueue, - SimpleMessageQueueConfig, AWSMessageQueue, + BaseMessageQueue, KafkaMessageQueue, RabbitMQMessageQueue, RedisMessageQueue, + SimpleMessageQueueConfig, ) -from .config_parser import ( - Config, - SourceType, - Service, - MessageQueueConfig, -) +from .config_parser import Config, MessageQueueConfig, Service, SourceType from .source_managers import GitSourceManager, LocalSourceManager, SourceManager - SOURCE_MANAGERS: dict[SourceType, SourceManager] = { SourceType.git: GitSourceManager(), SourceType.local: LocalSourceManager(), diff --git a/llama_deploy/apiserver/routers/deployments.py b/llama_deploy/apiserver/routers/deployments.py index 2406047a..eec538f2 100644 --- a/llama_deploy/apiserver/routers/deployments.py +++ b/llama_deploy/apiserver/routers/deployments.py @@ -1,14 +1,13 @@ import json +from typing import AsyncGenerator -from fastapi import APIRouter, File, UploadFile, HTTPException +from fastapi import APIRouter, File, HTTPException, UploadFile from fastapi.responses import JSONResponse, StreamingResponse -from typing import AsyncGenerator -from llama_deploy.apiserver.server import manager from llama_deploy.apiserver.config_parser import Config +from llama_deploy.apiserver.server import manager from llama_deploy.types import TaskDefinition - deployments_router = APIRouter( prefix="/deployments", ) @@ -144,6 +143,23 @@ async def get_task_result( return JSONResponse(result.result if result else "") +@deployments_router.get("/{deployment_name}/tasks") +async def get_tasks( + deployment_name: str, +) -> JSONResponse: + """Get all the tasks from all the sessions in a given deployment.""" + deployment = manager.get_deployment(deployment_name) + if deployment is None: + raise HTTPException(status_code=404, detail="Deployment not found") + + tasks: list[TaskDefinition] = [] + for session_def in await deployment.client.list_sessions(): + session = await deployment.client.get_session(session_id=session_def.session_id) + for task_def in await session.get_tasks(): + tasks.append(task_def) + return JSONResponse(tasks) + + @deployments_router.get("/{deployment_name}/sessions") async def get_sessions( deployment_name: str, diff --git a/llama_deploy/client/__init__.py b/llama_deploy/client/__init__.py index 9679b7a7..de8fbf04 100644 --- a/llama_deploy/client/__init__.py +++ b/llama_deploy/client/__init__.py @@ -1,4 +1,5 @@ -from llama_deploy.client.async_client import AsyncLlamaDeployClient -from llama_deploy.client.sync_client import LlamaDeployClient +from .async_client import AsyncLlamaDeployClient +from .sync_client import LlamaDeployClient +from .client import Client -__all__ = ["AsyncLlamaDeployClient", "LlamaDeployClient"] +__all__ = ["AsyncLlamaDeployClient", "Client", "LlamaDeployClient"] diff --git a/llama_deploy/client/base.py b/llama_deploy/client/base.py new file mode 100644 index 00000000..0215a6e7 --- /dev/null +++ b/llama_deploy/client/base.py @@ -0,0 +1,29 @@ +from typing import Any + +import httpx +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class _BaseClient(BaseSettings): + """Base type for clients, to be used in Pydantic models to avoid circular imports. + + Settings can be passed to the Client constructor when creating an instance, or defined with environment variables + having names prefixed with the string `LLAMA_DEPLOY_`, e.g. `LLAMA_DEPLOY_DISABLE_SSL`. + """ + + model_config = SettingsConfigDict(env_prefix="LLAMA_DEPLOY_") + + api_server_url: str = "http://localhost:4501" + disable_ssl: bool = False + timeout: float = 120.0 + poll_interval: float = 0.5 + + async def request( + self, method: str, url: str | httpx.URL, *args: Any, **kwargs: Any + ) -> httpx.Response: + """Performs an async HTTP request using httpx.""" + verify = kwargs.pop("verify", True) + async with httpx.AsyncClient(verify=verify) as client: + response = await client.request(method, url, *args, **kwargs) + response.raise_for_status() + return response diff --git a/llama_deploy/client/client.py b/llama_deploy/client/client.py new file mode 100644 index 00000000..a42b1038 --- /dev/null +++ b/llama_deploy/client/client.py @@ -0,0 +1,40 @@ +from .base import _BaseClient +from .models import ApiServer + + +class Client(_BaseClient): + """The Llama Deploy Python client. + + The client is gives access to both the asyncio and non-asyncio APIs. To access the sync + API just use methods of `client.sync`. + + Example usage: + ```py + from llama_deploy.client import Client + + # Use the same client instance + c = Client() + + async def an_async_function(): + status = await client.apiserver.status() + + def normal_function(): + status = client.sync.apiserver.status() + ``` + """ + + @property + def sync(self) -> "Client": + """Returns the sync version of the client API.""" + return _SyncClient(**self.model_dump()) + + @property + def apiserver(self) -> ApiServer: + """Returns the ApiServer model.""" + return ApiServer.instance(client=self, id="apiserver") + + +class _SyncClient(Client): + @property + def apiserver(self) -> ApiServer: + return ApiServer.instance(make_sync=True, client=self, id="apiserver") diff --git a/llama_deploy/client/models/__init__.py b/llama_deploy/client/models/__init__.py new file mode 100644 index 00000000..ac7104c9 --- /dev/null +++ b/llama_deploy/client/models/__init__.py @@ -0,0 +1,4 @@ +from .apiserver import ApiServer +from .model import Collection, Model + +__all__ = ["ApiServer", "Collection", "Model"] diff --git a/llama_deploy/client/models/apiserver.py b/llama_deploy/client/models/apiserver.py new file mode 100644 index 00000000..a2c6508f --- /dev/null +++ b/llama_deploy/client/models/apiserver.py @@ -0,0 +1,303 @@ +import asyncio +import json +from typing import Any, AsyncGenerator, TextIO + +import httpx + +from llama_deploy.types.apiserver import Status, StatusEnum +from llama_deploy.types.core import SessionDefinition, TaskDefinition, TaskResult + +from .model import Collection, Model + +DEFAULT_POLL_INTERVAL = 0.5 + + +class Session(Model): + """A model representing a session.""" + + pass + + +class SessionCollection(Collection): + """A model representing a collection of session for a given deployment.""" + + deployment_id: str + + async def delete(self, session_id: str) -> None: + """Deletes the session with the provided `session_id`. + + Args: + session_id: The id of the session that will be removed + + Raises: + HTTPException: If the session couldn't be found with the id provided. + """ + delete_url = f"{self.client.api_server_url}/deployments/{self.deployment_id}/sessions/delete" + + await self.client.request( + "POST", + delete_url, + params={"session_id": session_id}, + verify=not self.client.disable_ssl, + timeout=self.client.timeout, + ) + + async def create(self) -> Session: + """""" + create_url = f"{self.client.api_server_url}/deployments/{self.deployment_id}/sessions/create" + + r = await self.client.request( + "POST", + create_url, + verify=not self.client.disable_ssl, + timeout=self.client.timeout, + ) + + session_def = SessionDefinition(**r.json()) + + return Session.instance( + client=self.client, + make_sync=self._instance_is_sync, + id=session_def.session_id, + ) + + +class Task(Model): + """A model representing a task belonging to a given session in the given deployment.""" + + deployment_id: str + session_id: str + + async def results(self) -> TaskResult: + """Returns the result of a given task.""" + results_url = f"{self.client.api_server_url}/deployments/{self.deployment_id}/tasks/{self.id}/results" + + r = await self.client.request( + "GET", + results_url, + verify=not self.client.disable_ssl, + params={"session_id": self.session_id}, + timeout=self.client.timeout, + ) + return TaskResult.model_validate_json(r.json()) + + async def events(self) -> AsyncGenerator[dict[str, Any], None]: # pragma: no cover + """Returns a generator object to consume the events streamed from a service.""" + events_url = f"{self.client.api_server_url}/deployments/{self.deployment_id}/tasks/{self.id}/events" + + while True: + try: + async with httpx.AsyncClient( + verify=not self.client.disable_ssl + ) as client: + async with client.stream( + "GET", events_url, params={"session_id": self.session_id} + ) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + json_line = json.loads(line) + yield json_line + break # Exit the function if successful + except httpx.HTTPStatusError as e: + if e.response.status_code != 404: + raise # Re-raise if it's not a 404 error + await asyncio.sleep(DEFAULT_POLL_INTERVAL) + + +class TaskCollection(Collection): + """A model representing a collection of tasks for a given deployment.""" + + deployment_id: str + + async def run(self, task: TaskDefinition) -> Any: + """Runs a task and returns the results once it's done. + + Args: + task: The definition of the task we want to run. + """ + run_url = ( + f"{self.client.api_server_url}/deployments/{self.deployment_id}/tasks/run" + ) + + r = await self.client.request( + "POST", + run_url, + verify=not self.client.disable_ssl, + json=task.model_dump(), + timeout=self.client.timeout, + ) + + return r.json() + + async def create(self, task: TaskDefinition) -> Task: + """Runs a task returns it immediately, without waiting for the results.""" + create_url = f"{self.client.api_server_url}/deployments/{self.deployment_id}/tasks/create" + + r = await self.client.request( + "POST", + create_url, + verify=not self.client.disable_ssl, + json=task.model_dump(), + timeout=self.client.timeout, + ) + response_fields = r.json() + + return Task.instance( + make_sync=self._instance_is_sync, + client=self.client, + deployment_id=self.deployment_id, + id=response_fields["task_id"], + session_id=response_fields["session_id"], + ) + + +class Deployment(Model): + """A model representing a deployment.""" + + async def tasks(self) -> TaskCollection: + """Returns a collection of tasks from all the sessions in the given deployment.""" + tasks_url = f"{self.client.api_server_url}/deployments/{self.id}/tasks" + r = await self.client.request( + "GET", + tasks_url, + verify=not self.client.disable_ssl, + timeout=self.client.timeout, + ) + items = { + "id": Task.instance( + make_sync=self._instance_is_sync, + client=self.client, + id=task_def.task_id, + session_id=task_def.session_id, + deployment_id=self.id, + ) + for task_def in r.json() + } + return TaskCollection.instance( + make_sync=self._instance_is_sync, + client=self.client, + deployment_id=self.id, + items=items, + ) + + async def sessions(self) -> SessionCollection: + """Returns a collection of all the sessions in the given deployment.""" + sessions_url = f"{self.client.api_server_url}/deployments/{self.id}/sessions" + r = await self.client.request( + "GET", + sessions_url, + verify=not self.client.disable_ssl, + timeout=self.client.timeout, + ) + items = { + "id": Session.instance( + make_sync=self._instance_is_sync, + client=self.client, + id=session_def.session_id, + ) + for session_def in r.json() + } + return SessionCollection.instance( + make_sync=self._instance_is_sync, + client=self.client, + deployment_id=self.id, + items=items, + ) + + +class DeploymentCollection(Collection): + """A model representing a collection of deployments currently active.""" + + async def create(self, config: TextIO) -> Deployment: + """Creates a new deployment from a deployment file.""" + create_url = f"{self.client.api_server_url}/deployments/create" + + files = {"config_file": config.read()} + r = await self.client.request( + "POST", + create_url, + files=files, + verify=not self.client.disable_ssl, + timeout=self.client.timeout, + ) + + return Deployment.instance( + make_sync=self._instance_is_sync, + client=self.client, + id=r.json().get("name"), + ) + + async def get(self, deployment_id: str) -> Deployment: + """Gets a deployment by id.""" + get_url = f"{self.client.api_server_url}/deployments/{deployment_id}" + # Current version of apiserver doesn't returns anything useful in this endpoint, let's just ignore it + await self.client.request( + "GET", + get_url, + verify=not self.client.disable_ssl, + timeout=self.client.timeout, + ) + return Deployment.instance( + client=self.client, make_sync=self._instance_is_sync, id=deployment_id + ) + + +class ApiServer(Model): + """A model representing the API Server instance.""" + + async def status(self) -> Status: + """Returns the status of the API Server.""" + status_url = f"{self.client.api_server_url}/status/" + + try: + r = await self.client.request( + "GET", + status_url, + verify=not self.client.disable_ssl, + timeout=self.client.timeout, + ) + except httpx.ConnectError: + return Status( + status=StatusEnum.DOWN, + status_message="API Server is down", + ) + + if r.status_code >= 400: + body = r.json() + return Status(status=StatusEnum.UNHEALTHY, status_message=r.text) + + description = "Llama Deploy is up and running." + body = r.json() + deployments = body.get("deployments") or [] + if deployments: + description += "\nActive deployments:" + for d in deployments: + description += f"\n- {d}" + else: + description += "\nCurrently there are no active deployments" + + return Status( + status=StatusEnum.HEALTHY, + status_message=description, + deployments=deployments, + ) + + async def deployments(self) -> DeploymentCollection: + """Returns a collection of deployments currently active in the API Server.""" + status_url = f"{self.client.api_server_url}/deployments/" + + r = await self.client.request( + "GET", + status_url, + verify=not self.client.disable_ssl, + timeout=self.client.timeout, + ) + deployments = { + "id": Deployment.instance( + make_sync=self._instance_is_sync, client=self.client, id=name + ) + for name in r.json() + } + return DeploymentCollection.instance( + make_sync=self._instance_is_sync, client=self.client, items=deployments + ) diff --git a/llama_deploy/client/models/model.py b/llama_deploy/client/models/model.py new file mode 100644 index 00000000..b236e3eb --- /dev/null +++ b/llama_deploy/client/models/model.py @@ -0,0 +1,71 @@ +import asyncio +from typing import Any, Generic, TypeVar, cast + +from asgiref.sync import async_to_sync +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from typing_extensions import Self + +from llama_deploy.client.base import _BaseClient + + +class _Base(BaseModel): + """The base model provides fields and functionalities common to derived models and collections.""" + + client: _BaseClient = Field(exclude=True) + _instance_is_sync: bool = PrivateAttr(default=False) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def __new__(cls, *args, **kwargs): # type: ignore[no-untyped-def] + """We prevent the usage of the constructor and force users to call `instance()` instead.""" + raise TypeError("Please use instance() instead of direct instantiation") + + @classmethod + def instance(cls, make_sync: bool = False, **kwargs: Any) -> Self: + """Returns an instance of the given model. + + Using the class constructor is not possible because we want to alter the class method to + accommodate sync/async usage before creating an instance, and __init__ would be too late. + """ + if make_sync: + cls = _make_sync(cls) + + inst = super(_Base, cls).__new__(cls) + inst.__init__(**kwargs) # type: ignore[misc] + inst._instance_is_sync = make_sync + return inst + + +T = TypeVar("T", bound=_Base) + + +class Model(_Base): + id: str + + +class Collection(_Base, Generic[T]): + """A generic container of items of the same model type.""" + + items: dict[str, T] + + def get(self, id: str) -> T: + """Returns an item from the collection.""" + return self.items[id] + + def list(self) -> list[T]: + """Returns a list of all the items in the collection.""" + return [self.get(id) for id in self.items.keys()] + + +def _make_sync(_class: type[T]) -> type[T]: + """Wraps the methods of the given model class so that they can be called without `await`.""" + + class Wrapper(_class): # type: ignore + pass + + for name, method in _class.__dict__.items(): + # Only wrap async public methods + if asyncio.iscoroutinefunction(method) and not name.startswith("_"): + setattr(Wrapper, name, async_to_sync(method)) + # Static type checkers can't assess Wrapper is indeed a type[T], let's promise it is. + return cast(type[T], Wrapper) diff --git a/llama_deploy/sdk/__init__.py b/llama_deploy/sdk/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/llama_deploy/types/__init__.py b/llama_deploy/types/__init__.py new file mode 100644 index 00000000..e25a5ea7 --- /dev/null +++ b/llama_deploy/types/__init__.py @@ -0,0 +1,37 @@ +from .core import ( + CONTROL_PLANE_NAME, + ActionTypes, + ChatMessage, + EventDefinition, + HumanResponse, + MessageRole, + PydanticValidatedUrl, + ServiceDefinition, + SessionDefinition, + TaskDefinition, + TaskResult, + TaskStream, + ToolCall, + ToolCallBundle, + ToolCallResult, + generate_id, +) + +__all__ = [ + "CONTROL_PLANE_NAME", + "ActionTypes", + "ChatMessage", + "EventDefinition", + "HumanResponse", + "MessageRole", + "PydanticValidatedUrl", + "ServiceDefinition", + "SessionDefinition", + "TaskDefinition", + "TaskResult", + "TaskStream", + "ToolCall", + "ToolCallBundle", + "ToolCallResult", + "generate_id", +] diff --git a/llama_deploy/types/apiserver.py b/llama_deploy/types/apiserver.py new file mode 100644 index 00000000..ea977d69 --- /dev/null +++ b/llama_deploy/types/apiserver.py @@ -0,0 +1,16 @@ +from enum import Enum + +from pydantic import BaseModel + + +class StatusEnum(Enum): + HEALTHY = "Healthy" + UNHEALTHY = "Unhealthy" + DOWN = "Down" + + +class Status(BaseModel): + status: StatusEnum + status_message: str + max_deployments: int | None = None + deployments: list[str] | None = None diff --git a/llama_deploy/types.py b/llama_deploy/types/core.py similarity index 100% rename from llama_deploy/types.py rename to llama_deploy/types/core.py diff --git a/poetry.lock b/poetry.lock index 1948f10e..66f32e75 100644 --- a/poetry.lock +++ b/poetry.lock @@ -288,6 +288,23 @@ doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] trio = ["trio (>=0.26.1)"] +[[package]] +name = "asgiref" +version = "3.8.1" +description = "ASGI specs, helper code, and adapters" +optional = false +python-versions = ">=3.8" +files = [ + {file = "asgiref-3.8.1-py3-none-any.whl", hash = "sha256:3e1e3ecc849832fe52ccf2cb6686b7a55f82bb1d6aee72a58826471390335e47"}, + {file = "asgiref-3.8.1.tar.gz", hash = "sha256:c343bd80a0bec947a9860adb4c432ffa7db769836c64238fc34bdc3fec84d590"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4", markers = "python_version < \"3.11\""} + +[package.extras] +tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"] + [[package]] name = "async-timeout" version = "4.0.3" @@ -3004,4 +3021,4 @@ redis = ["redis"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "c223e362ffa6f52cfd0c7292961f7f0d38538e362eff496da63f0d6c8fa3dea9" +content-hash = "1063f9f5a1883f755d70e7af6480c831382c661b28518d2c70fa249ab52a09e9" diff --git a/pyproject.toml b/pyproject.toml index 29b186e9..d08ec64c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ types-aiobotocore = {version = "^2.14.0", optional = true, extras = ["sqs", "sns gitpython = "^3.1.43" python-multipart = "^0.0.10" typing_extensions = "^4.0.0" +asgiref = "^3.8.1" [tool.poetry.extras] kafka = ["aiokafka", "kafka-python-ng"] diff --git a/tests/client/models/__init__.py b/tests/client/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/client/models/conftest.py b/tests/client/models/conftest.py new file mode 100644 index 00000000..46f20931 --- /dev/null +++ b/tests/client/models/conftest.py @@ -0,0 +1,12 @@ +from typing import Any, Iterator +from unittest import mock + +import pytest + +from llama_deploy.client import Client + + +@pytest.fixture +def client(monkeypatch: Any) -> Iterator[Client]: + monkeypatch.setattr(Client, "request", mock.AsyncMock()) + yield Client() diff --git a/tests/client/models/test_apiserver.py b/tests/client/models/test_apiserver.py new file mode 100644 index 00000000..42fd9240 --- /dev/null +++ b/tests/client/models/test_apiserver.py @@ -0,0 +1,284 @@ +import io +from typing import Any +from unittest import mock + +import httpx +import pytest + +from llama_deploy.client.models.apiserver import ( + ApiServer, + Deployment, + DeploymentCollection, + Session, + SessionCollection, + Task, + TaskCollection, +) +from llama_deploy.types import SessionDefinition, TaskDefinition, TaskResult + + +@pytest.mark.asyncio +async def test_session_collection_delete(client: Any) -> None: + coll = SessionCollection.instance( + client=client, + items={"a_session": Session.instance(id="a_session", client=client)}, + deployment_id="a_deployment", + ) + await coll.delete("a_session") + client.request.assert_awaited_with( + "POST", + "http://localhost:4501/deployments/a_deployment/sessions/delete", + params={"session_id": "a_session"}, + timeout=120.0, + verify=True, + ) + + +@pytest.mark.asyncio +async def test_session_collection_create(client: Any) -> None: + client.request.return_value = mock.MagicMock( + json=lambda: {"session_id": "a_session"} + ) + coll = SessionCollection.instance( + client=client, + items={}, + deployment_id="a_deployment", + ) + await coll.create() + client.request.assert_awaited_with( + "POST", + "http://localhost:4501/deployments/a_deployment/sessions/create", + verify=True, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_task_results(client: Any) -> None: + res = TaskResult(task_id="a_result", history=[], result="some_text", data={}) + client.request.return_value = mock.MagicMock(json=lambda: res.model_dump_json()) + + t = Task.instance( + client=client, + id="a_task", + deployment_id="a_deployment", + session_id="a_session", + ) + await t.results() + + client.request.assert_awaited_with( + "GET", + "http://localhost:4501/deployments/a_deployment/tasks/a_task/results", + verify=True, + params={"session_id": "a_session"}, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_task_collection_run(client: Any) -> None: + client.request.return_value = mock.MagicMock(json=lambda: "some result") + coll = TaskCollection.instance( + client=client, + items={ + "a_session": Task.instance( + id="a_session", + client=client, + deployment_id="a_deployment", + session_id="a_session", + ) + }, + deployment_id="a_deployment", + ) + await coll.run(TaskDefinition(input="some input", task_id="test_id")) + client.request.assert_awaited_with( + "POST", + "http://localhost:4501/deployments/a_deployment/tasks/run", + verify=True, + json={ + "input": "some input", + "task_id": "test_id", + "session_id": None, + "agent_id": None, + }, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_task_collection_create(client: Any) -> None: + client.request.return_value = mock.MagicMock( + json=lambda: {"session_id": "a_session", "task_id": "test_id"} + ) + coll = TaskCollection.instance( + client=client, + items={ + "a_session": Task.instance( + id="a_session", + client=client, + deployment_id="a_deployment", + session_id="a_session", + ) + }, + deployment_id="a_deployment", + ) + await coll.create(TaskDefinition(input='{"arg": "test_input"}', task_id="test_id")) + client.request.assert_awaited_with( + "POST", + "http://localhost:4501/deployments/a_deployment/tasks/create", + verify=True, + json={ + "input": '{"arg": "test_input"}', + "task_id": "test_id", + "session_id": None, + "agent_id": None, + }, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_task_deployment_tasks(client: Any) -> None: + d = Deployment.instance(client=client, id="a_deployment") + res: list[TaskDefinition] = [ + TaskDefinition( + input='{"arg": "input"}', task_id="a_task", session_id="a_session" + ) + ] + client.request.return_value = mock.MagicMock(json=lambda: res) + + await d.tasks() + + client.request.assert_awaited_with( + "GET", + "http://localhost:4501/deployments/a_deployment/tasks", + verify=True, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_task_deployment_sessions(client: Any) -> None: + d = Deployment.instance(client=client, id="a_deployment") + res: list[SessionDefinition] = [SessionDefinition(session_id="a_session")] + client.request.return_value = mock.MagicMock(json=lambda: res) + + await d.sessions() + + client.request.assert_awaited_with( + "GET", + "http://localhost:4501/deployments/a_deployment/sessions", + verify=True, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_task_deployment_collection_create(client: Any) -> None: + client.request.return_value = mock.MagicMock(json=lambda: {"name": "deployment"}) + + coll = DeploymentCollection.instance(client=client, items={}) + await coll.create(io.StringIO("some config")) + + client.request.assert_awaited_with( + "POST", + "http://localhost:4501/deployments/create", + files={"config_file": "some config"}, + verify=True, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_task_deployment_collection_get(client: Any) -> None: + d = Deployment.instance(client=client, id="a_deployment") + coll = DeploymentCollection.instance(client=client, items={"a_deployment": d}) + client.request.return_value = mock.MagicMock(json=lambda: {"a_deployment": "Up!"}) + + await coll.get("a_deployment") + + client.request.assert_awaited_with( + "GET", + "http://localhost:4501/deployments/a_deployment", + verify=True, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_status_down(client: Any) -> None: + client.request.side_effect = httpx.ConnectError(message="connection error") + + apis = ApiServer.instance(client=client, id="apiserver") + res = await apis.status() + + client.request.assert_awaited_with( + "GET", "http://localhost:4501/status/", verify=True, timeout=120.0 + ) + assert res.status.value == "Down" + + +@pytest.mark.asyncio +async def test_status_unhealthy(client: Any) -> None: + client.request.return_value = mock.MagicMock( + status_code=400, text="This is a drill." + ) + + apis = ApiServer.instance(client=client, id="apiserver") + res = await apis.status() + + client.request.assert_awaited_with( + "GET", "http://localhost:4501/status/", verify=True, timeout=120.0 + ) + assert res.status.value == "Unhealthy" + assert res.status_message == "This is a drill." + + +@pytest.mark.asyncio +async def test_status_healthy_no_deployments(client: Any) -> None: + client.request.return_value = mock.MagicMock( + status_code=200, text="", json=lambda: {} + ) + + apis = ApiServer.instance(client=client, id="apiserver") + res = await apis.status() + + client.request.assert_awaited_with( + "GET", "http://localhost:4501/status/", verify=True, timeout=120.0 + ) + assert res.status.value == "Healthy" + assert ( + res.status_message + == "Llama Deploy is up and running.\nCurrently there are no active deployments" + ) + + +@pytest.mark.asyncio +async def test_status_healthy(client: Any) -> None: + client.request.return_value = mock.MagicMock( + status_code=200, text="", json=lambda: {"deployments": ["foo", "bar"]} + ) + + apis = ApiServer.instance(client=client, id="apiserver") + res = await apis.status() + + client.request.assert_awaited_with( + "GET", "http://localhost:4501/status/", verify=True, timeout=120.0 + ) + assert res.status.value == "Healthy" + assert ( + res.status_message + == "Llama Deploy is up and running.\nActive deployments:\n- foo\n- bar" + ) + + +@pytest.mark.asyncio +async def test_deployments(client: Any) -> None: + client.request.return_value = mock.MagicMock( + status_code=200, text="", json=lambda: {"deployments": ["foo", "bar"]} + ) + apis = ApiServer.instance(client=client, id="apiserver") + await apis.deployments() + client.request.assert_awaited_with( + "GET", "http://localhost:4501/deployments/", verify=True, timeout=120.0 + ) diff --git a/tests/client/models/test_model.py b/tests/client/models/test_model.py new file mode 100644 index 00000000..f01739e3 --- /dev/null +++ b/tests/client/models/test_model.py @@ -0,0 +1,41 @@ +import asyncio + +import pytest + +from llama_deploy.client import Client +from llama_deploy.client.models import Collection, Model +from llama_deploy.client.models.model import _make_sync + + +class SomeAsyncModel(Model): + async def method(self) -> None: + pass + + +def test_no_init(client: Client) -> None: + with pytest.raises( + TypeError, match=r"Please use instance\(\) instead of direct instantiation" + ): + SomeAsyncModel(id="foo", client=client) + + +def test_make_sync() -> None: + assert asyncio.iscoroutinefunction(getattr(SomeAsyncModel, "method")) + some_sync = _make_sync(SomeAsyncModel) + assert not asyncio.iscoroutinefunction(getattr(some_sync, "method")) + + +def test_collection_get() -> None: + class MyCollection(Collection): + pass + + c = Client() + models_list = [ + SomeAsyncModel.instance(client=c, id="foo"), + SomeAsyncModel.instance(client=c, id="bar"), + ] + + coll = MyCollection.instance(client=c, items={m.id: m for m in models_list}) + assert coll.get("foo").id == "foo" + assert coll.get("bar").id == "bar" + assert coll.list() == models_list diff --git a/tests/client/test_client.py b/tests/client/test_client.py new file mode 100644 index 00000000..eaed8517 --- /dev/null +++ b/tests/client/test_client.py @@ -0,0 +1,50 @@ +from unittest import mock + +import pytest + +from llama_deploy.client import Client +from llama_deploy.client.client import _SyncClient +from llama_deploy.client.models.apiserver import ApiServer + + +def test_client_init_default() -> None: + c = Client() + assert c.api_server_url == "http://localhost:4501" + assert c.disable_ssl is False + assert c.timeout == 120.0 + assert c.poll_interval == 0.5 + + +def test_client_init_settings() -> None: + c = Client(api_server_url="test") + assert c.api_server_url == "test" + + +def test_client_sync() -> None: + c = Client() + sc = c.sync + assert type(sc) is _SyncClient + assert sc.api_server_url == "http://localhost:4501" + assert sc.disable_ssl is False + assert sc.timeout == 120.0 + assert sc.poll_interval == 0.5 + + +def test_client_attributes() -> None: + c = Client() + assert type(c.apiserver) is ApiServer + assert issubclass(type(c.sync.apiserver), ApiServer) + + +@pytest.mark.asyncio +async def test_client_request() -> None: + with mock.patch("llama_deploy.client.base.httpx") as _httpx: + mocked_response = mock.MagicMock() + _httpx.AsyncClient.return_value.__aenter__.return_value.request.return_value = ( + mocked_response + ) + + c = Client() + await c.request("GET", "http://example.com", verify=False) + _httpx.AsyncClient.assert_called_with(verify=False) + mocked_response.raise_for_status.assert_called_once() diff --git a/tests/services/test_human_service.py b/tests/services/test_human_service.py index 07ca14e9..803cf8f6 100644 --- a/tests/services/test_human_service.py +++ b/tests/services/test_human_service.py @@ -1,18 +1,20 @@ import asyncio -import pytest -from pydantic import PrivateAttr, ValidationError from typing import Any, List from unittest.mock import MagicMock, patch -from llama_deploy.services import HumanService -from llama_deploy.services.human import HELP_REQUEST_TEMPLATE_STR -from llama_deploy.message_queues.simple import SimpleMessageQueue + +import pytest +from pydantic import PrivateAttr, ValidationError + from llama_deploy.message_consumers.base import BaseMessageQueueConsumer +from llama_deploy.message_queues.simple import SimpleMessageQueue from llama_deploy.messages.base import QueueMessage +from llama_deploy.services import HumanService +from llama_deploy.services.human import HELP_REQUEST_TEMPLATE_STR from llama_deploy.types import ( - TaskDefinition, - ActionTypes, CONTROL_PLANE_NAME, + ActionTypes, ChatMessage, + TaskDefinition, ) @@ -71,7 +73,7 @@ def test_invalid_human_prompt_raises_validation_error() -> None: @pytest.mark.asyncio() -@patch("llama_deploy.types.uuid") +@patch("llama_deploy.types.core.uuid") async def test_create_task(mock_uuid: MagicMock) -> None: # arrange human_service = HumanService(