Skip to content

Commit

Permalink
refact: make collections in apiserver models lazy (#343)
Browse files Browse the repository at this point in the history
  • Loading branch information
masci authored Oct 31, 2024
1 parent 4d4efd7 commit 1ee388b
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 58 deletions.
6 changes: 2 additions & 4 deletions e2e_tests/apiserver/test_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,16 @@
@pytest.mark.asyncio
async def test_deploy(apiserver, client):
here = Path(__file__).parent
deployments = await client.apiserver.deployments()
with open(here / "deployments" / "deployment1.yml") as f:
await deployments.create(f)
await client.apiserver.deployments.create(f)

status = await client.apiserver.status()
assert "TestDeployment1" in status.deployments


def test_deploy_sync(apiserver, client):
here = Path(__file__).parent
deployments = client.sync.apiserver.deployments()
with open(here / "deployments" / "deployment2.yml") as f:
deployments.create(f)
client.sync.apiserver.deployments.create(f)

assert "TestDeployment2" in client.sync.apiserver.status().deployments
2 changes: 1 addition & 1 deletion e2e_tests/apiserver/test_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_status_down_sync(client):

@pytest.mark.asyncio
async def test_status_up(apiserver, client):
res = await client.sync.apiserver.status()
res = await client.apiserver.status()
assert res.status.value == "Healthy"


Expand Down
117 changes: 67 additions & 50 deletions llama_deploy/client/models/apiserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

from .model import Collection, Model

DEFAULT_POLL_INTERVAL = 0.5


class Session(Model):
"""A model representing a session."""
Expand Down Expand Up @@ -61,6 +59,27 @@ async def create(self) -> Session:
id=session_def.session_id,
)

async def list(self) -> list[Session]: # type: ignore
"""Returns a collection of all the sessions in the given deployment."""
sessions_url = (
f"{self.client.api_server_url}/deployments/{self.deployment_id}/sessions"
)
r = await self.client.request(
"GET",
sessions_url,
verify=not self.client.disable_ssl,
timeout=self.client.timeout,
)
items = [
Session.instance(
make_sync=self._instance_is_sync,
client=self.client,
id=session_def.session_id,
)
for session_def in r.json()
]
return items


class Task(Model):
"""A model representing a task belonging to a given session in the given deployment."""
Expand Down Expand Up @@ -101,7 +120,7 @@ async def events(self) -> AsyncGenerator[dict[str, Any], None]: # pragma: no co
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
raise # Re-raise if it's not a 404 error
await asyncio.sleep(DEFAULT_POLL_INTERVAL)
await asyncio.sleep(self.client.poll_interval)


class TaskCollection(Collection):
Expand Down Expand Up @@ -150,58 +169,51 @@ async def create(self, task: TaskDefinition) -> Task:
session_id=response_fields["session_id"],
)


class Deployment(Model):
"""A model representing a deployment."""

async def tasks(self) -> TaskCollection:
async def list(self) -> list[Task]: # type: ignore
"""Returns a collection of tasks from all the sessions in the given deployment."""
tasks_url = f"{self.client.api_server_url}/deployments/{self.id}/tasks"
tasks_url = (
f"{self.client.api_server_url}/deployments/{self.deployment_id}/tasks"
)
r = await self.client.request(
"GET",
tasks_url,
verify=not self.client.disable_ssl,
timeout=self.client.timeout,
)
items = {
"id": Task.instance(
items = [
Task.instance(
make_sync=self._instance_is_sync,
client=self.client,
id=task_def.task_id,
session_id=task_def.session_id,
deployment_id=self.id,
deployment_id=self.deployment_id,
)
for task_def in r.json()
}
]
return items


class Deployment(Model):
"""A model representing a deployment."""

@property
def tasks(self) -> TaskCollection:
"""Returns a collection of tasks from all the sessions in the given deployment."""
return TaskCollection.instance(
make_sync=self._instance_is_sync,
client=self.client,
deployment_id=self.id,
items=items,
items={},
)

async def sessions(self) -> SessionCollection:
@property
def sessions(self) -> SessionCollection:
"""Returns a collection of all the sessions in the given deployment."""
sessions_url = f"{self.client.api_server_url}/deployments/{self.id}/sessions"
r = await self.client.request(
"GET",
sessions_url,
verify=not self.client.disable_ssl,
timeout=self.client.timeout,
)
items = {
"id": Session.instance(
make_sync=self._instance_is_sync,
client=self.client,
id=session_def.session_id,
)
for session_def in r.json()
}
return SessionCollection.instance(
make_sync=self._instance_is_sync,
client=self.client,
deployment_id=self.id,
items=items,
items={},
)


Expand All @@ -227,9 +239,9 @@ async def create(self, config: TextIO) -> Deployment:
id=r.json().get("name"),
)

async def get(self, deployment_id: str) -> Deployment:
async def get(self, id: str) -> Deployment:
"""Gets a deployment by id."""
get_url = f"{self.client.api_server_url}/deployments/{deployment_id}"
get_url = f"{self.client.api_server_url}/deployments/{id}"
# Current version of apiserver doesn't returns anything useful in this endpoint, let's just ignore it
await self.client.request(
"GET",
Expand All @@ -238,8 +250,26 @@ async def get(self, deployment_id: str) -> Deployment:
timeout=self.client.timeout,
)
return Deployment.instance(
client=self.client, make_sync=self._instance_is_sync, id=deployment_id
client=self.client, make_sync=self._instance_is_sync, id=id
)

async def list(self) -> list[Deployment]: # type: ignore
"""Returns a collection of deployments currently active in the API Server."""
status_url = f"{self.client.api_server_url}/deployments/"

r = await self.client.request(
"GET",
status_url,
verify=not self.client.disable_ssl,
timeout=self.client.timeout,
)
deployments = [
Deployment.instance(
make_sync=self._instance_is_sync, client=self.client, id=name
)
for name in r.json()
]
return deployments


class ApiServer(Model):
Expand Down Expand Up @@ -282,22 +312,9 @@ async def status(self) -> Status:
deployments=deployments,
)

async def deployments(self) -> DeploymentCollection:
@property
def deployments(self) -> DeploymentCollection:
"""Returns a collection of deployments currently active in the API Server."""
status_url = f"{self.client.api_server_url}/deployments/"

r = await self.client.request(
"GET",
status_url,
verify=not self.client.disable_ssl,
timeout=self.client.timeout,
)
deployments = {
"id": Deployment.instance(
make_sync=self._instance_is_sync, client=self.client, id=name
)
for name in r.json()
}
return DeploymentCollection.instance(
make_sync=self._instance_is_sync, client=self.client, items=deployments
make_sync=self._instance_is_sync, client=self.client, items={}
)
6 changes: 3 additions & 3 deletions tests/client/models/test_apiserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ async def test_task_deployment_tasks(client: Any) -> None:
]
client.request.return_value = mock.MagicMock(json=lambda: res)

await d.tasks()
await d.tasks.list()

client.request.assert_awaited_with(
"GET",
Expand All @@ -163,7 +163,7 @@ async def test_task_deployment_sessions(client: Any) -> None:
res: list[SessionDefinition] = [SessionDefinition(session_id="a_session")]
client.request.return_value = mock.MagicMock(json=lambda: res)

await d.sessions()
await d.sessions.list()

client.request.assert_awaited_with(
"GET",
Expand Down Expand Up @@ -278,7 +278,7 @@ async def test_deployments(client: Any) -> None:
status_code=200, text="", json=lambda: {"deployments": ["foo", "bar"]}
)
apis = ApiServer.instance(client=client, id="apiserver")
await apis.deployments()
await apis.deployments.list()
client.request.assert_awaited_with(
"GET", "http://localhost:4501/deployments/", verify=True, timeout=120.0
)

0 comments on commit 1ee388b

Please sign in to comment.