Skip to content

Commit

Permalink
feat: add apiserver support to Python SDK (#327)
Browse files Browse the repository at this point in the history
* Client refactoring

try

checkpoint

add asgiref to support async-to-sync:

fix unit tests

remove pydantic warning

add unit tests for client

fix mock path

test model

added tests and fix discovered bugs

fix connection error handling

fix bugs surfaced in end-to-end

use explicity properties

fix awaitable checks

add instance method

revert to return sync class

extract base model

use instance() method on models

add e2e tests

working state

fix unit tests

more fixes

* make change backward compat

* remove useless param

* add docstrings

* add sessions.create()

* merge setting into base client class

* assert conditionin e2e

* export Client at the top level
  • Loading branch information
masci authored Oct 24, 2024
1 parent 3c8410b commit 2accf9d
Show file tree
Hide file tree
Showing 30 changed files with 1,144 additions and 40 deletions.
Empty file added e2e_tests/apiserver/__init__.py
Empty file.
27 changes: 27 additions & 0 deletions e2e_tests/apiserver/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import multiprocessing
import time

import pytest
import uvicorn

from llama_deploy.client import Client


def run_async_apiserver():
uvicorn.run("llama_deploy.apiserver:app", host="127.0.0.1", port=4501)


@pytest.fixture(scope="module")
def apiserver():
p = multiprocessing.Process(target=run_async_apiserver)
p.start()
time.sleep(3)

yield

p.kill()


@pytest.fixture
def client():
return Client(api_server_url="http://localhost:4501")
15 changes: 15 additions & 0 deletions e2e_tests/apiserver/deployments/deployment1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: TestDeployment1

control-plane: {}

default-service: dummy_workflow

services:
test-workflow:
name: Test Workflow
port: 8002
host: localhost
source:
type: git
name: https://github.com/run-llama/llama_deploy.git
path: tests/apiserver/data/workflow:my_workflow
15 changes: 15 additions & 0 deletions e2e_tests/apiserver/deployments/deployment2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: TestDeployment2

control-plane: {}

default-service: dummy_workflow

services:
test-workflow:
name: Test Workflow
port: 8002
host: localhost
source:
type: git
name: https://github.com/run-llama/llama_deploy.git
path: tests/apiserver/data/workflow:my_workflow
14 changes: 14 additions & 0 deletions e2e_tests/apiserver/deployments/deployment_streaming.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: Streaming

control-plane:
port: 8000

default-service: streaming_workflow

services:
streaming_workflow:
name: Streaming Workflow
source:
type: local
name: ./e2e_tests/apiserver/deployments/src
path: workflow:streaming_workflow
41 changes: 41 additions & 0 deletions e2e_tests/apiserver/deployments/src/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import asyncio

from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
)


class Message(Event):
text: str


class EchoWorkflow(Workflow):
"""A dummy workflow streaming three events."""

@step()
async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
for i in range(3):
ctx.write_event_to_stream(Message(text=f"message number {i+1}"))
await asyncio.sleep(0.5)

return StopEvent(result="Done.")


streaming_workflow = EchoWorkflow()


async def main():
h = streaming_workflow.run(message="Hello!")
async for ev in h.stream_events():
if type(ev) is Message:
print(ev.text)
print(await h)


if __name__ == "__main__":
asyncio.run(main())
23 changes: 23 additions & 0 deletions e2e_tests/apiserver/test_deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path

import pytest


@pytest.mark.asyncio
async def test_deploy(apiserver, client):
here = Path(__file__).parent
deployments = await client.apiserver.deployments()
with open(here / "deployments" / "deployment1.yml") as f:
await deployments.create(f)

status = await client.apiserver.status()
assert "TestDeployment1" in status.deployments


def test_deploy_sync(apiserver, client):
here = Path(__file__).parent
deployments = client.sync.apiserver.deployments()
with open(here / "deployments" / "deployment2.yml") as f:
deployments.create(f)

assert "TestDeployment2" in client.sync.apiserver.status().deployments
23 changes: 23 additions & 0 deletions e2e_tests/apiserver/test_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest


@pytest.mark.asyncio
async def test_status_down(client):
res = await client.apiserver.status()
assert res.status.value == "Down"


def test_status_down_sync(client):
res = client.sync.apiserver.status()
assert res.status.value == "Down"


@pytest.mark.asyncio
async def test_status_up(apiserver, client):
res = await client.sync.apiserver.status()
assert res.status.value == "Healthy"


def test_status_up_sync(apiserver, client):
res = client.sync.apiserver.status()
assert res.status.value == "Healthy"
27 changes: 27 additions & 0 deletions e2e_tests/apiserver/test_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import asyncio
from pathlib import Path

import pytest

from llama_deploy.types import TaskDefinition


@pytest.mark.asyncio
async def test_stream(apiserver, client):
here = Path(__file__).parent

with open(here / "deployments" / "deployment_streaming.yml") as f:
deployments = await client.apiserver.deployments()
deployment = await deployments.create(f)
await asyncio.sleep(5)

tasks = await deployment.tasks()
task = await tasks.create(TaskDefinition(input='{"a": "b"}'))
read_events = []
async for ev in task.events():
if "text" in ev:
read_events.append(ev)
assert len(read_events) == 3
# the workflow produces events sequentially, so here we can assume events arrived in order
for i, ev in enumerate(read_events):
assert ev["text"] == f"message number {i+1}"
27 changes: 14 additions & 13 deletions llama_deploy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from llama_deploy.client import AsyncLlamaDeployClient, LlamaDeployClient
from llama_deploy.control_plane import ControlPlaneServer, ControlPlaneConfig
# configure logger
import logging

from llama_deploy.client import AsyncLlamaDeployClient, Client, LlamaDeployClient
from llama_deploy.control_plane import ControlPlaneConfig, ControlPlaneServer
from llama_deploy.deploy import deploy_core, deploy_workflow
from llama_deploy.message_consumers import CallableMessageConsumer
from llama_deploy.message_queues import SimpleMessageQueue, SimpleMessageQueueConfig
from llama_deploy.messages import QueueMessage
from llama_deploy.orchestrators import SimpleOrchestrator, SimpleOrchestratorConfig
from llama_deploy.services import (
AgentService,
ComponentService,
HumanService,
ToolService,
WorkflowService,
WorkflowServiceConfig,
)
from llama_deploy.tools import (
AgentServiceTool,
MetaServiceTool,
ServiceAsTool,
ServiceComponent,
ServiceTool,
)
from llama_deploy.services import (
AgentService,
ToolService,
HumanService,
ComponentService,
WorkflowService,
WorkflowServiceConfig,
)

# configure logger
import logging

root_logger = logging.getLogger("llama_deploy")

Expand All @@ -39,6 +39,7 @@
# clients
"LlamaDeployClient",
"AsyncLlamaDeployClient",
"Client",
# services
"AgentService",
"HumanService",
Expand Down
16 changes: 5 additions & 11 deletions llama_deploy/apiserver/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,26 @@
from typing import Any

from llama_deploy import (
AsyncLlamaDeployClient,
ControlPlaneServer,
SimpleMessageQueue,
SimpleOrchestratorConfig,
SimpleOrchestrator,
SimpleOrchestratorConfig,
WorkflowService,
WorkflowServiceConfig,
AsyncLlamaDeployClient,
)
from llama_deploy.message_queues import (
BaseMessageQueue,
SimpleMessageQueueConfig,
AWSMessageQueue,
BaseMessageQueue,
KafkaMessageQueue,
RabbitMQMessageQueue,
RedisMessageQueue,
SimpleMessageQueueConfig,
)

from .config_parser import (
Config,
SourceType,
Service,
MessageQueueConfig,
)
from .config_parser import Config, MessageQueueConfig, Service, SourceType
from .source_managers import GitSourceManager, LocalSourceManager, SourceManager


SOURCE_MANAGERS: dict[SourceType, SourceManager] = {
SourceType.git: GitSourceManager(),
SourceType.local: LocalSourceManager(),
Expand Down
24 changes: 20 additions & 4 deletions llama_deploy/apiserver/routers/deployments.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import json
from typing import AsyncGenerator

from fastapi import APIRouter, File, UploadFile, HTTPException
from fastapi import APIRouter, File, HTTPException, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse
from typing import AsyncGenerator

from llama_deploy.apiserver.server import manager
from llama_deploy.apiserver.config_parser import Config
from llama_deploy.apiserver.server import manager
from llama_deploy.types import TaskDefinition


deployments_router = APIRouter(
prefix="/deployments",
)
Expand Down Expand Up @@ -144,6 +143,23 @@ async def get_task_result(
return JSONResponse(result.result if result else "")


@deployments_router.get("/{deployment_name}/tasks")
async def get_tasks(
deployment_name: str,
) -> JSONResponse:
"""Get all the tasks from all the sessions in a given deployment."""
deployment = manager.get_deployment(deployment_name)
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")

tasks: list[TaskDefinition] = []
for session_def in await deployment.client.list_sessions():
session = await deployment.client.get_session(session_id=session_def.session_id)
for task_def in await session.get_tasks():
tasks.append(task_def)
return JSONResponse(tasks)


@deployments_router.get("/{deployment_name}/sessions")
async def get_sessions(
deployment_name: str,
Expand Down
7 changes: 4 additions & 3 deletions llama_deploy/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from llama_deploy.client.async_client import AsyncLlamaDeployClient
from llama_deploy.client.sync_client import LlamaDeployClient
from .async_client import AsyncLlamaDeployClient
from .sync_client import LlamaDeployClient
from .client import Client

__all__ = ["AsyncLlamaDeployClient", "LlamaDeployClient"]
__all__ = ["AsyncLlamaDeployClient", "Client", "LlamaDeployClient"]
29 changes: 29 additions & 0 deletions llama_deploy/client/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any

import httpx
from pydantic_settings import BaseSettings, SettingsConfigDict


class _BaseClient(BaseSettings):
"""Base type for clients, to be used in Pydantic models to avoid circular imports.
Settings can be passed to the Client constructor when creating an instance, or defined with environment variables
having names prefixed with the string `LLAMA_DEPLOY_`, e.g. `LLAMA_DEPLOY_DISABLE_SSL`.
"""

model_config = SettingsConfigDict(env_prefix="LLAMA_DEPLOY_")

api_server_url: str = "http://localhost:4501"
disable_ssl: bool = False
timeout: float = 120.0
poll_interval: float = 0.5

async def request(
self, method: str, url: str | httpx.URL, *args: Any, **kwargs: Any
) -> httpx.Response:
"""Performs an async HTTP request using httpx."""
verify = kwargs.pop("verify", True)
async with httpx.AsyncClient(verify=verify) as client:
response = await client.request(method, url, *args, **kwargs)
response.raise_for_status()
return response
Loading

0 comments on commit 2accf9d

Please sign in to comment.