Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add sessions collections management #341

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions e2e_tests/basic_session/launch_core.py

This file was deleted.

45 changes: 0 additions & 45 deletions e2e_tests/basic_session/launch_workflow.py

This file was deleted.

25 changes: 0 additions & 25 deletions e2e_tests/basic_session/run.sh

This file was deleted.

14 changes: 7 additions & 7 deletions e2e_tests/basic_session/test_run_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import pytest

from llama_deploy import AsyncLlamaDeployClient, ControlPlaneConfig, LlamaDeployClient
from llama_deploy import Client


@pytest.mark.e2e
def test_run_client(workflow):
client = LlamaDeployClient(ControlPlaneConfig(), timeout=10)
client = Client(timeout=10)

# create session
session = client.get_or_create_session("fake_session_id")
session = client.sync.core.sessions.create()

# test run with session
result = session.run("session_workflow")
Expand All @@ -19,18 +19,18 @@ def test_run_client(workflow):
assert result == "2"

# create new session and run
session = client.get_or_create_session("fake_session_id_2")
session = client.sync.core.sessions.create()
result = session.run("session_workflow")
assert result == "1"


@pytest.mark.e2e
@pytest.mark.asyncio
async def test_run_client_async(workflow):
client = AsyncLlamaDeployClient(ControlPlaneConfig(), timeout=10)
client = Client(timeout=10)

# create session
session = await client.get_or_create_session("fake_session_id")
session = await client.core.sessions.create()

# run
result = await session.run("session_workflow")
Expand All @@ -41,6 +41,6 @@ async def test_run_client_async(workflow):
assert result == "2"

# create new session and run
session = await client.get_or_create_session("fake_session_id_2")
session = await client.core.sessions.create()
result = await session.run("session_workflow")
assert result == "1"
144 changes: 143 additions & 1 deletion llama_deploy/client/models/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,139 @@
from llama_deploy.types.core import ServiceDefinition
import asyncio
import json
from typing import Any

import httpx

from llama_deploy.types.core import ServiceDefinition, TaskDefinition, TaskResult

from .model import Collection, Model


class Session(Model):
async def run(self, service_name: str, **run_kwargs: Any) -> str:
"""Implements the workflow-based run API for a session."""
task_input = json.dumps(run_kwargs)
task_def = TaskDefinition(input=task_input, agent_id=service_name)
task_id = await self._do_create_task(task_def)

# wait for task to complete, up to timeout seconds
async def _get_result() -> str:
while True:
task_result = await self._do_get_task_result(task_id)

if isinstance(task_result, TaskResult):
return task_result.result or ""
await asyncio.sleep(self.client.poll_interval)

return await asyncio.wait_for(_get_result(), timeout=self.client.timeout)

async def create_task(self, task_def: TaskDefinition) -> str:
"""Create a new task in this session.

Args:
task_def (Union[str, TaskDefinition]): The task definition or input string.

Returns:
str: The ID of the created task.
"""
return await self._do_create_task(task_def)

async def _do_create_task(self, task_def: TaskDefinition) -> str:
"""Async-only version of create_task, to be used internally from other methods."""
task_def.session_id = self.id
url = f"{self.client.control_plane_url}/sessions/{self.id}/tasks"
response = await self.client.request("POST", url, json=task_def.model_dump())
return response.json()

async def get_task_result(self, task_id: str) -> TaskResult | None:
"""Get the result of a task in this session if it has one.

Args:
task_id (str): The ID of the task to get the result for.

Returns:
Optional[TaskResult]: The result of the task if it has one, otherwise None.
"""
return await self._do_get_task_result(task_id)

async def _do_get_task_result(self, task_id: str) -> TaskResult | None:
"""Async-only version of get_task_result, to be used internally from other methods."""
url = (
f"{self.client.control_plane_url}/sessions/{self.id}/tasks/{task_id}/result"
)
response = await self.client.request("GET", url)
data = response.json()
return TaskResult(**data) if data else None


class SessionCollection(Collection):
async def list(self) -> list[Session]: # type: ignore
"""Returns a list of all the sessions in the collection."""
sessions_url = f"{self.client.control_plane_url}/sessions"
response = await self.client.request("GET", sessions_url)
sessions = []
for id, session_def in response.json().items():
sessions.append(
Session.instance(
make_sync=self._instance_is_sync, client=self.client, id=id
)
)
return sessions

async def create(self) -> Session:
"""Creates a new session and returns a Session object.

Returns:
Session: A Session object representing the newly created session.
"""
create_url = f"{self.client.control_plane_url}/sessions/create"
response = await self.client.request("POST", create_url)
session_id = response.json()
return Session.instance(
make_sync=self._instance_is_sync, client=self.client, id=session_id
)

async def get(self, id: str) -> Session:
"""Gets a session by ID.

Args:
session_id: The ID of the session to get.

Returns:
Session: A Session object representing the specified session.

Raises:
ValueError: If the session does not exist.
"""
get_url = f"{self.client.control_plane_url}/sessions/{id}"
await self.client.request("GET", get_url)
return Session.instance(
make_sync=self._instance_is_sync, client=self.client, id=id
)

async def get_or_create(self, id: str) -> Session:
"""Gets a session by ID, or creates a new one if it doesn't exist.

Returns:
Session: A Session object representing the specified session.
"""
try:
return await self.get(id)
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
return await self.create()
raise e

async def delete(self, session_id: str) -> None:
"""Deletes a session by ID.

Args:
session_id: The ID of the session to delete.
"""
delete_url = f"{self.client.control_plane_url}/sessions/{session_id}/delete"
await self.client.request("POST", delete_url)


class Service(Model):
pass

Expand Down Expand Up @@ -52,3 +183,14 @@ async def services(self) -> ServiceCollection:
return ServiceCollection.instance(
make_sync=self._instance_is_sync, client=self.client, items=items
)

@property
def sessions(self) -> SessionCollection:
"""Returns a collection to access all the sessions registered with the control plane.

Returns:
SessionCollection: Collection of sessions registered with the control plane.
"""
return SessionCollection.instance(
make_sync=self._instance_is_sync, client=self.client, items={}
)
Loading
Loading