Skip to content

Commit

Permalink
feat: add sessions collections management (#341)
Browse files Browse the repository at this point in the history
* feat: add sessions collections management

* prepare e2e test

* adjust e2e tests and fix sync/async calls

* finish unit tests
  • Loading branch information
masci authored Oct 29, 2024
1 parent 042e9c6 commit cf80392
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 94 deletions.
14 changes: 0 additions & 14 deletions e2e_tests/basic_session/launch_core.py

This file was deleted.

45 changes: 0 additions & 45 deletions e2e_tests/basic_session/launch_workflow.py

This file was deleted.

25 changes: 0 additions & 25 deletions e2e_tests/basic_session/run.sh

This file was deleted.

14 changes: 7 additions & 7 deletions e2e_tests/basic_session/test_run_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import pytest

from llama_deploy import AsyncLlamaDeployClient, ControlPlaneConfig, LlamaDeployClient
from llama_deploy import Client


@pytest.mark.e2e
def test_run_client(workflow):
client = LlamaDeployClient(ControlPlaneConfig(), timeout=10)
client = Client(timeout=10)

# create session
session = client.get_or_create_session("fake_session_id")
session = client.sync.core.sessions.create()

# test run with session
result = session.run("session_workflow")
Expand All @@ -19,18 +19,18 @@ def test_run_client(workflow):
assert result == "2"

# create new session and run
session = client.get_or_create_session("fake_session_id_2")
session = client.sync.core.sessions.create()
result = session.run("session_workflow")
assert result == "1"


@pytest.mark.e2e
@pytest.mark.asyncio
async def test_run_client_async(workflow):
client = AsyncLlamaDeployClient(ControlPlaneConfig(), timeout=10)
client = Client(timeout=10)

# create session
session = await client.get_or_create_session("fake_session_id")
session = await client.core.sessions.create()

# run
result = await session.run("session_workflow")
Expand All @@ -41,6 +41,6 @@ async def test_run_client_async(workflow):
assert result == "2"

# create new session and run
session = await client.get_or_create_session("fake_session_id_2")
session = await client.core.sessions.create()
result = await session.run("session_workflow")
assert result == "1"
144 changes: 143 additions & 1 deletion llama_deploy/client/models/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,139 @@
from llama_deploy.types.core import ServiceDefinition
import asyncio
import json
from typing import Any

import httpx

from llama_deploy.types.core import ServiceDefinition, TaskDefinition, TaskResult

from .model import Collection, Model


class Session(Model):
async def run(self, service_name: str, **run_kwargs: Any) -> str:
"""Implements the workflow-based run API for a session."""
task_input = json.dumps(run_kwargs)
task_def = TaskDefinition(input=task_input, agent_id=service_name)
task_id = await self._do_create_task(task_def)

# wait for task to complete, up to timeout seconds
async def _get_result() -> str:
while True:
task_result = await self._do_get_task_result(task_id)

if isinstance(task_result, TaskResult):
return task_result.result or ""
await asyncio.sleep(self.client.poll_interval)

return await asyncio.wait_for(_get_result(), timeout=self.client.timeout)

async def create_task(self, task_def: TaskDefinition) -> str:
"""Create a new task in this session.
Args:
task_def (Union[str, TaskDefinition]): The task definition or input string.
Returns:
str: The ID of the created task.
"""
return await self._do_create_task(task_def)

async def _do_create_task(self, task_def: TaskDefinition) -> str:
"""Async-only version of create_task, to be used internally from other methods."""
task_def.session_id = self.id
url = f"{self.client.control_plane_url}/sessions/{self.id}/tasks"
response = await self.client.request("POST", url, json=task_def.model_dump())
return response.json()

async def get_task_result(self, task_id: str) -> TaskResult | None:
"""Get the result of a task in this session if it has one.
Args:
task_id (str): The ID of the task to get the result for.
Returns:
Optional[TaskResult]: The result of the task if it has one, otherwise None.
"""
return await self._do_get_task_result(task_id)

async def _do_get_task_result(self, task_id: str) -> TaskResult | None:
"""Async-only version of get_task_result, to be used internally from other methods."""
url = (
f"{self.client.control_plane_url}/sessions/{self.id}/tasks/{task_id}/result"
)
response = await self.client.request("GET", url)
data = response.json()
return TaskResult(**data) if data else None


class SessionCollection(Collection):
async def list(self) -> list[Session]: # type: ignore
"""Returns a list of all the sessions in the collection."""
sessions_url = f"{self.client.control_plane_url}/sessions"
response = await self.client.request("GET", sessions_url)
sessions = []
for id, session_def in response.json().items():
sessions.append(
Session.instance(
make_sync=self._instance_is_sync, client=self.client, id=id
)
)
return sessions

async def create(self) -> Session:
"""Creates a new session and returns a Session object.
Returns:
Session: A Session object representing the newly created session.
"""
create_url = f"{self.client.control_plane_url}/sessions/create"
response = await self.client.request("POST", create_url)
session_id = response.json()
return Session.instance(
make_sync=self._instance_is_sync, client=self.client, id=session_id
)

async def get(self, id: str) -> Session:
"""Gets a session by ID.
Args:
session_id: The ID of the session to get.
Returns:
Session: A Session object representing the specified session.
Raises:
ValueError: If the session does not exist.
"""
get_url = f"{self.client.control_plane_url}/sessions/{id}"
await self.client.request("GET", get_url)
return Session.instance(
make_sync=self._instance_is_sync, client=self.client, id=id
)

async def get_or_create(self, id: str) -> Session:
"""Gets a session by ID, or creates a new one if it doesn't exist.
Returns:
Session: A Session object representing the specified session.
"""
try:
return await self.get(id)
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
return await self.create()
raise e

async def delete(self, session_id: str) -> None:
"""Deletes a session by ID.
Args:
session_id: The ID of the session to delete.
"""
delete_url = f"{self.client.control_plane_url}/sessions/{session_id}/delete"
await self.client.request("POST", delete_url)


class Service(Model):
pass

Expand Down Expand Up @@ -52,3 +183,14 @@ async def services(self) -> ServiceCollection:
return ServiceCollection.instance(
make_sync=self._instance_is_sync, client=self.client, items=items
)

@property
def sessions(self) -> SessionCollection:
"""Returns a collection to access all the sessions registered with the control plane.
Returns:
SessionCollection: Collection of sessions registered with the control plane.
"""
return SessionCollection.instance(
make_sync=self._instance_is_sync, client=self.client, items={}
)
Loading

0 comments on commit cf80392

Please sign in to comment.