Skip to content

Commit

Permalink
Merge branch 'main' into massi/apiserver-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
masci authored Oct 31, 2024
2 parents b3eb67c + 4d4efd7 commit eee25f5
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 63 deletions.
58 changes: 20 additions & 38 deletions e2e_tests/basic_workflow/test_run_client.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,50 @@
import pytest

from llama_deploy import AsyncLlamaDeployClient, ControlPlaneConfig, LlamaDeployClient
from llama_deploy import Client


@pytest.mark.e2e
def test_run_client(workflow):
client = LlamaDeployClient(ControlPlaneConfig(), timeout=10)
client = Client(timeout=10)

# test connections
assert (
len(client.list_services()) == 1
), f"Expected 1 service, got {client.list_services()}"
assert (
len(client.list_sessions()) == 0
), f"Expected 0 sessions, got {client.list_sessions()}"
assert len(client.sync.core.services.list()) == 1
assert len(client.sync.core.sessions.list()) == 0

# test create session
session = client.get_or_create_session("fake_session_id")
sessions = client.list_sessions()
assert len(sessions) == 1, f"Expected 1 session, got {sessions}"
assert (
sessions[0].session_id == session.session_id
), f"Expected session id to be {session.session_id}, got {sessions[0].session_id}"
session = client.sync.core.sessions.get_or_create("fake_session_id")
sessions = client.sync.core.sessions.list()
assert len(sessions) == 1
assert sessions[0].id == session.id

# test run with session
result = session.run("outer", arg1="hello_world")
assert result == "hello_world_result"

# test number of tasks
tasks = session.get_tasks()
assert len(tasks) == 1, f"Expected 1 task, got {len(tasks)} tasks"
assert (
tasks[0].agent_id == "outer"
), f"Expected id to be 'outer', got {tasks[0].agent_id}"
assert len(tasks) == 1
assert tasks[0].agent_id == "outer"

# delete everything
client.delete_session(session.session_id)
assert (
len(client.list_sessions()) == 0
), f"Expected 0 sessions, got {client.list_sessions()}"
client.sync.core.sessions.delete(session.id)
assert len(client.sync.core.sessions.list()) == 0


@pytest.mark.e2e
@pytest.mark.asyncio
async def test_run_client_async(workflow):
client = AsyncLlamaDeployClient(ControlPlaneConfig(), timeout=10)
client = Client(timeout=10)

# test connections
assert (
len(await client.list_services()) == 1
), f"Expected 1 service, got {await client.list_services()}"
assert (
len(await client.list_sessions()) == 0
), f"Expected 0 sessions, got {await client.list_sessions()}"
assert len(await client.core.services.list()) == 1
assert len(await client.core.sessions.list()) == 0

# test create session
session = await client.get_or_create_session("fake_session_id")
sessions = await client.list_sessions()
session = await client.core.sessions.get_or_create("fake_session_id")
sessions = await client.core.sessions.list()
assert len(sessions) == 1, f"Expected 1 session, got {sessions}"
assert (
sessions[0].session_id == session.session_id
), f"Expected session id to be {session.session_id}, got {sessions[0].session_id}"
assert sessions[0].id == session.id

# test run with session
result = await session.run("outer", arg1="hello_world")
Expand All @@ -74,7 +58,5 @@ async def test_run_client_async(workflow):
), f"Expected id to be 'outer', got {tasks[0].agent_id}"

# delete everything
await client.delete_session(session.session_id)
assert (
len(await client.list_sessions()) == 0
), f"Expected 0 sessions, got {await client.list_sessions()}"
await client.core.sessions.delete(session.id)
assert len(await client.core.sessions.list()) == 0
16 changes: 7 additions & 9 deletions e2e_tests/core/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
def test_services(workflow):
client = Client()

services = client.sync.core.services()
assert len(services.items) == 1
services = client.sync.core.services
assert len(services.list()) == 1

services.deregister("basic")
assert len(services.items) == 0
Expand All @@ -26,14 +26,12 @@ def test_services(workflow):
async def test_services_async(workflow):
client = Client()

services = await client.core.services()
assert len(services.items) == 1
assert len(await client.core.services.list()) == 1
await client.core.services.deregister("basic")
assert len(await client.core.services.list()) == 0

await services.deregister("basic")
assert len(services.items) == 0

new_s = await services.register(
new_s = await client.core.services.register(
ServiceDefinition(service_name="another_basic", description="none")
)
assert new_s.id == "another_basic"
assert len(services.items) == 1
assert len(await client.core.services.list()) == 1
55 changes: 43 additions & 12 deletions llama_deploy/client/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ async def _do_get_task_result(self, task_id: str) -> TaskResult | None:
data = response.json()
return TaskResult(**data) if data else None

async def get_tasks(self) -> list[TaskDefinition]:
"""Get all tasks in this session.
Returns:
list[TaskDefinition]: A list of task definitions in the session.
"""
url = f"{self.client.control_plane_url}/sessions/{self.id}/tasks"
response = await self.client.request("GET", url)
return [TaskDefinition(**task) for task in response.json()]


class SessionCollection(Collection):
async def list(self) -> list[Session]: # type: ignore
Expand All @@ -86,6 +96,10 @@ async def create(self) -> Session:
Returns:
Session: A Session object representing the newly created session.
"""
return await self._create()

async def _create(self) -> Session:
"""Async-only version of create, to be used internally from other methods."""
create_url = f"{self.client.control_plane_url}/sessions/create"
response = await self.client.request("POST", create_url)
session_id = response.json()
Expand All @@ -105,6 +119,11 @@ async def get(self, id: str) -> Session:
Raises:
ValueError: If the session does not exist.
"""
return await self._get(id)

async def _get(self, id: str) -> Session:
"""Async-only version of get, to be used internally from other methods."""

get_url = f"{self.client.control_plane_url}/sessions/{id}"
await self.client.request("GET", get_url)
return Session.instance(
Expand All @@ -118,10 +137,10 @@ async def get_or_create(self, id: str) -> Session:
Session: A Session object representing the specified session.
"""
try:
return await self.get(id)
return await self._get(id)
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
return await self.create()
return await self._create()
raise e

async def delete(self, session_id: str) -> None:
Expand All @@ -139,6 +158,24 @@ class Service(Model):


class ServiceCollection(Collection):
async def list(self) -> list[Service]: # type: ignore
"""Returns a list containing all the services registered with the control plane.
Returns:
list[Service]: List of services registered with the control plane.
"""
services_url = f"{self.client.control_plane_url}/services"
response = await self.client.request("GET", services_url)
services = []
for name, service in response.json().items():
services.append(
Service.instance(
make_sync=self._instance_is_sync, client=self.client, id=name
)
)

return services

async def register(self, service: ServiceDefinition) -> Service:
"""Registers a service with the control plane.
Expand All @@ -163,25 +200,19 @@ async def deregister(self, service_name: str) -> None:
deregister_url,
params={"service_name": service_name},
)
self.items.pop(service_name)


class Core(Model):
async def services(self) -> ServiceCollection:
@property
def services(self) -> ServiceCollection:
"""Returns a collection containing all the services registered with the control plane.
Returns:
ServiceCollection: Collection of services registered with the control plane.
"""
services_url = f"{self.client.control_plane_url}/services"
response = await self.client.request("GET", services_url)
items = {}
for name, service in response.json().items():
items[name] = Service.instance(
make_sync=self._instance_is_sync, client=self.client, id=name
)

return ServiceCollection.instance(
make_sync=self._instance_is_sync, client=self.client, items=items
make_sync=self._instance_is_sync, client=self.client, items={}
)

@property
Expand Down
39 changes: 35 additions & 4 deletions tests/client/models/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,10 @@ async def test_core_services(client: mock.AsyncMock) -> None:
)

core = Core.instance(client=client, id="core")
services = await core.services()
services = await core.services.list()

client.request.assert_awaited_with("GET", "http://localhost:8000/services")
assert isinstance(services, ServiceCollection)
assert "test_service" in services.items
assert isinstance(services.items["test_service"], Service)
assert services[0].id == "test_service"


@pytest.mark.asyncio
Expand Down Expand Up @@ -216,3 +214,36 @@ async def test_core_sessions(client: mock.AsyncMock) -> None:

client.request.assert_awaited_with("GET", "http://localhost:8000/sessions")
assert sessions[0].id == "test_session"


@pytest.mark.asyncio
async def test_session_get_tasks(client: mock.AsyncMock) -> None:
client.request.return_value = mock.MagicMock(
json=lambda: [
{
"input": "task1 input",
"agent_id": "agent1",
"session_id": "test_session_id",
},
{
"input": "task2 input",
"agent_id": "agent2",
"session_id": "test_session_id",
},
]
)

session = Session.instance(client=client, id="test_session_id")
tasks = await session.get_tasks()

client.request.assert_awaited_with(
"GET", "http://localhost:8000/sessions/test_session_id/tasks"
)
assert len(tasks) == 2
assert all(isinstance(task, TaskDefinition) for task in tasks)
assert tasks[0].input == "task1 input"
assert tasks[0].agent_id == "agent1"
assert tasks[0].session_id == "test_session_id"
assert tasks[1].input == "task2 input"
assert tasks[1].agent_id == "agent2"
assert tasks[1].session_id == "test_session_id"

0 comments on commit eee25f5

Please sign in to comment.