diff --git a/e2e_tests/apiserver/test_deploy.py b/e2e_tests/apiserver/test_deploy.py index fc836068..df609b32 100644 --- a/e2e_tests/apiserver/test_deploy.py +++ b/e2e_tests/apiserver/test_deploy.py @@ -6,9 +6,8 @@ @pytest.mark.asyncio async def test_deploy(apiserver, client): here = Path(__file__).parent - deployments = await client.apiserver.deployments() with open(here / "deployments" / "deployment1.yml") as f: - await deployments.create(f) + await client.apiserver.deployments.create(f) status = await client.apiserver.status() assert "TestDeployment1" in status.deployments @@ -16,8 +15,7 @@ async def test_deploy(apiserver, client): def test_deploy_sync(apiserver, client): here = Path(__file__).parent - deployments = client.sync.apiserver.deployments() with open(here / "deployments" / "deployment2.yml") as f: - deployments.create(f) + client.sync.apiserver.deployments.create(f) assert "TestDeployment2" in client.sync.apiserver.status().deployments diff --git a/e2e_tests/apiserver/test_status.py b/e2e_tests/apiserver/test_status.py index 8665e0bb..d8f7759d 100644 --- a/e2e_tests/apiserver/test_status.py +++ b/e2e_tests/apiserver/test_status.py @@ -14,7 +14,7 @@ def test_status_down_sync(client): @pytest.mark.asyncio async def test_status_up(apiserver, client): - res = await client.sync.apiserver.status() + res = await client.apiserver.status() assert res.status.value == "Healthy" diff --git a/llama_deploy/client/models/apiserver.py b/llama_deploy/client/models/apiserver.py index a2c6508f..35d4efc7 100644 --- a/llama_deploy/client/models/apiserver.py +++ b/llama_deploy/client/models/apiserver.py @@ -9,8 +9,6 @@ from .model import Collection, Model -DEFAULT_POLL_INTERVAL = 0.5 - class Session(Model): """A model representing a session.""" @@ -61,6 +59,27 @@ async def create(self) -> Session: id=session_def.session_id, ) + async def list(self) -> list[Session]: # type: ignore + """Returns a collection of all the sessions in the given deployment.""" + sessions_url = ( + f"{self.client.api_server_url}/deployments/{self.deployment_id}/sessions" + ) + r = await self.client.request( + "GET", + sessions_url, + verify=not self.client.disable_ssl, + timeout=self.client.timeout, + ) + items = [ + Session.instance( + make_sync=self._instance_is_sync, + client=self.client, + id=session_def.session_id, + ) + for session_def in r.json() + ] + return items + class Task(Model): """A model representing a task belonging to a given session in the given deployment.""" @@ -101,7 +120,7 @@ async def events(self) -> AsyncGenerator[dict[str, Any], None]: # pragma: no co except httpx.HTTPStatusError as e: if e.response.status_code != 404: raise # Re-raise if it's not a 404 error - await asyncio.sleep(DEFAULT_POLL_INTERVAL) + await asyncio.sleep(self.client.poll_interval) class TaskCollection(Collection): @@ -150,58 +169,51 @@ async def create(self, task: TaskDefinition) -> Task: session_id=response_fields["session_id"], ) - -class Deployment(Model): - """A model representing a deployment.""" - - async def tasks(self) -> TaskCollection: + async def list(self) -> list[Task]: # type: ignore """Returns a collection of tasks from all the sessions in the given deployment.""" - tasks_url = f"{self.client.api_server_url}/deployments/{self.id}/tasks" + tasks_url = ( + f"{self.client.api_server_url}/deployments/{self.deployment_id}/tasks" + ) r = await self.client.request( "GET", tasks_url, verify=not self.client.disable_ssl, timeout=self.client.timeout, ) - items = { - "id": Task.instance( + items = [ + Task.instance( make_sync=self._instance_is_sync, client=self.client, id=task_def.task_id, session_id=task_def.session_id, - deployment_id=self.id, + deployment_id=self.deployment_id, ) for task_def in r.json() - } + ] + return items + + +class Deployment(Model): + """A model representing a deployment.""" + + @property + def tasks(self) -> TaskCollection: + """Returns a collection of tasks from all the sessions in the given deployment.""" return TaskCollection.instance( make_sync=self._instance_is_sync, client=self.client, deployment_id=self.id, - items=items, + items={}, ) - async def sessions(self) -> SessionCollection: + @property + def sessions(self) -> SessionCollection: """Returns a collection of all the sessions in the given deployment.""" - sessions_url = f"{self.client.api_server_url}/deployments/{self.id}/sessions" - r = await self.client.request( - "GET", - sessions_url, - verify=not self.client.disable_ssl, - timeout=self.client.timeout, - ) - items = { - "id": Session.instance( - make_sync=self._instance_is_sync, - client=self.client, - id=session_def.session_id, - ) - for session_def in r.json() - } return SessionCollection.instance( make_sync=self._instance_is_sync, client=self.client, deployment_id=self.id, - items=items, + items={}, ) @@ -227,9 +239,9 @@ async def create(self, config: TextIO) -> Deployment: id=r.json().get("name"), ) - async def get(self, deployment_id: str) -> Deployment: + async def get(self, id: str) -> Deployment: """Gets a deployment by id.""" - get_url = f"{self.client.api_server_url}/deployments/{deployment_id}" + get_url = f"{self.client.api_server_url}/deployments/{id}" # Current version of apiserver doesn't returns anything useful in this endpoint, let's just ignore it await self.client.request( "GET", @@ -238,8 +250,26 @@ async def get(self, deployment_id: str) -> Deployment: timeout=self.client.timeout, ) return Deployment.instance( - client=self.client, make_sync=self._instance_is_sync, id=deployment_id + client=self.client, make_sync=self._instance_is_sync, id=id + ) + + async def list(self) -> list[Deployment]: # type: ignore + """Returns a collection of deployments currently active in the API Server.""" + status_url = f"{self.client.api_server_url}/deployments/" + + r = await self.client.request( + "GET", + status_url, + verify=not self.client.disable_ssl, + timeout=self.client.timeout, ) + deployments = [ + Deployment.instance( + make_sync=self._instance_is_sync, client=self.client, id=name + ) + for name in r.json() + ] + return deployments class ApiServer(Model): @@ -282,22 +312,9 @@ async def status(self) -> Status: deployments=deployments, ) - async def deployments(self) -> DeploymentCollection: + @property + def deployments(self) -> DeploymentCollection: """Returns a collection of deployments currently active in the API Server.""" - status_url = f"{self.client.api_server_url}/deployments/" - - r = await self.client.request( - "GET", - status_url, - verify=not self.client.disable_ssl, - timeout=self.client.timeout, - ) - deployments = { - "id": Deployment.instance( - make_sync=self._instance_is_sync, client=self.client, id=name - ) - for name in r.json() - } return DeploymentCollection.instance( - make_sync=self._instance_is_sync, client=self.client, items=deployments + make_sync=self._instance_is_sync, client=self.client, items={} ) diff --git a/tests/client/models/test_apiserver.py b/tests/client/models/test_apiserver.py index 42fd9240..7f586e02 100644 --- a/tests/client/models/test_apiserver.py +++ b/tests/client/models/test_apiserver.py @@ -147,7 +147,7 @@ async def test_task_deployment_tasks(client: Any) -> None: ] client.request.return_value = mock.MagicMock(json=lambda: res) - await d.tasks() + await d.tasks.list() client.request.assert_awaited_with( "GET", @@ -163,7 +163,7 @@ async def test_task_deployment_sessions(client: Any) -> None: res: list[SessionDefinition] = [SessionDefinition(session_id="a_session")] client.request.return_value = mock.MagicMock(json=lambda: res) - await d.sessions() + await d.sessions.list() client.request.assert_awaited_with( "GET", @@ -278,7 +278,7 @@ async def test_deployments(client: Any) -> None: status_code=200, text="", json=lambda: {"deployments": ["foo", "bar"]} ) apis = ApiServer.instance(client=client, id="apiserver") - await apis.deployments() + await apis.deployments.list() client.request.assert_awaited_with( "GET", "http://localhost:4501/deployments/", verify=True, timeout=120.0 )