Skip to content

Commit

Permalink
fix: fix typing for wrapped async methods (#344)
Browse files Browse the repository at this point in the history
* fix typing for wrapped async methods

* leftover

* remove leftover

* fix rebase screwups

* restore test coverage
  • Loading branch information
masci authored Nov 1, 2024
1 parent 1ee388b commit 7455b96
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 186 deletions.
18 changes: 10 additions & 8 deletions llama_deploy/client/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any

from .base import _BaseClient
from .models import ApiServer, Core
from .models import ApiServer, Core, make_sync


class Client(_BaseClient):
Expand Down Expand Up @@ -31,19 +33,19 @@ def sync(self) -> "_SyncClient":
@property
def apiserver(self) -> ApiServer:
"""Returns the ApiServer model."""
return ApiServer.instance(client=self, id="apiserver")
return ApiServer(client=self, id="apiserver")

@property
def core(self) -> Core:
"""Returns the Core model."""
return Core.instance(client=self, id="core")
return Core(client=self, id="core")


class _SyncClient(Client):
class _SyncClient(_BaseClient):
@property
def apiserver(self) -> ApiServer:
return ApiServer.instance(make_sync=True, client=self, id="apiserver")
def apiserver(self) -> Any:
return make_sync(ApiServer)(client=self, id="apiserver")

@property
def core(self) -> Core:
return Core.instance(make_sync=True, client=self, id="core")
def core(self) -> Any:
return make_sync(Core)(client=self, id="core")
4 changes: 2 additions & 2 deletions llama_deploy/client/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .apiserver import ApiServer
from .core import Core
from .model import Collection, Model
from .model import Collection, Model, make_sync

__all__ = ["ApiServer", "Collection", "Core", "Model"]
__all__ = ["ApiServer", "Collection", "Core", "Model", "make_sync"]
128 changes: 53 additions & 75 deletions llama_deploy/client/models/apiserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,8 @@ async def create(self) -> Session:

session_def = SessionDefinition(**r.json())

return Session.instance(
client=self.client,
make_sync=self._instance_is_sync,
id=session_def.session_id,
)
model_class = self._prepare(Session)
return model_class(client=self.client, id=session_def.session_id)

async def list(self) -> list[Session]: # type: ignore
"""Returns a collection of all the sessions in the given deployment."""
Expand All @@ -70,12 +67,9 @@ async def list(self) -> list[Session]: # type: ignore
verify=not self.client.disable_ssl,
timeout=self.client.timeout,
)
model_class = self._prepare(Session)
items = [
Session.instance(
make_sync=self._instance_is_sync,
client=self.client,
id=session_def.session_id,
)
model_class(client=self.client, id=session_def.session_id)
for session_def in r.json()
]
return items
Expand Down Expand Up @@ -161,60 +155,56 @@ async def create(self, task: TaskDefinition) -> Task:
)
response_fields = r.json()

return Task.instance(
make_sync=self._instance_is_sync,
model_class = self._prepare(Task)
return model_class(
client=self.client,
deployment_id=self.deployment_id,
id=response_fields["task_id"],
session_id=response_fields["session_id"],
)

async def list(self) -> list[Task]: # type: ignore

class Deployment(Model):
"""A model representing a deployment."""

async def tasks(self) -> TaskCollection:
"""Returns a collection of tasks from all the sessions in the given deployment."""
tasks_url = (
f"{self.client.api_server_url}/deployments/{self.deployment_id}/tasks"
)
tasks_url = f"{self.client.api_server_url}/deployments/{self.id}/tasks"
r = await self.client.request(
"GET",
tasks_url,
verify=not self.client.disable_ssl,
timeout=self.client.timeout,
)
items = [
Task.instance(
make_sync=self._instance_is_sync,
task_model_class = self._prepare(Task)
items = {
"id": task_model_class(
client=self.client,
id=task_def.task_id,
session_id=task_def.session_id,
deployment_id=self.deployment_id,
deployment_id=self.id,
)
for task_def in r.json()
]
return items

}
model_class = self._prepare(TaskCollection)
return model_class(client=self.client, deployment_id=self.id, items=items)

class Deployment(Model):
"""A model representing a deployment."""

@property
def tasks(self) -> TaskCollection:
"""Returns a collection of tasks from all the sessions in the given deployment."""
return TaskCollection.instance(
make_sync=self._instance_is_sync,
client=self.client,
deployment_id=self.id,
items={},
)

@property
def sessions(self) -> SessionCollection:
async def sessions(self) -> SessionCollection:
"""Returns a collection of all the sessions in the given deployment."""
return SessionCollection.instance(
make_sync=self._instance_is_sync,
client=self.client,
deployment_id=self.id,
items={},
sessions_url = f"{self.client.api_server_url}/deployments/{self.id}/sessions"
r = await self.client.request(
"GET",
sessions_url,
verify=not self.client.disable_ssl,
timeout=self.client.timeout,
)
model_class = self._prepare(Session)
items = {
"id": model_class(client=self.client, id=session_def.session_id)
for session_def in r.json()
}
coll_model_class = self._prepare(SessionCollection)
return coll_model_class(client=self.client, deployment_id=self.id, items=items)


class DeploymentCollection(Collection):
Expand All @@ -233,43 +223,21 @@ async def create(self, config: TextIO) -> Deployment:
timeout=self.client.timeout,
)

return Deployment.instance(
make_sync=self._instance_is_sync,
client=self.client,
id=r.json().get("name"),
)
model_class = self._prepare(Deployment)
return model_class(client=self.client, id=r.json().get("name"))

async def get(self, id: str) -> Deployment:
async def get(self, deployment_id: str) -> Deployment:
"""Gets a deployment by id."""
get_url = f"{self.client.api_server_url}/deployments/{id}"
get_url = f"{self.client.api_server_url}/deployments/{deployment_id}"
# Current version of apiserver doesn't returns anything useful in this endpoint, let's just ignore it
await self.client.request(
"GET",
get_url,
verify=not self.client.disable_ssl,
timeout=self.client.timeout,
)
return Deployment.instance(
client=self.client, make_sync=self._instance_is_sync, id=id
)

async def list(self) -> list[Deployment]: # type: ignore
"""Returns a collection of deployments currently active in the API Server."""
status_url = f"{self.client.api_server_url}/deployments/"

r = await self.client.request(
"GET",
status_url,
verify=not self.client.disable_ssl,
timeout=self.client.timeout,
)
deployments = [
Deployment.instance(
make_sync=self._instance_is_sync, client=self.client, id=name
)
for name in r.json()
]
return deployments
model_class = self._prepare(Deployment)
return model_class(client=self.client, id=deployment_id)


class ApiServer(Model):
Expand Down Expand Up @@ -312,9 +280,19 @@ async def status(self) -> Status:
deployments=deployments,
)

@property
def deployments(self) -> DeploymentCollection:
async def deployments(self) -> DeploymentCollection:
"""Returns a collection of deployments currently active in the API Server."""
return DeploymentCollection.instance(
make_sync=self._instance_is_sync, client=self.client, items={}
status_url = f"{self.client.api_server_url}/deployments/"

r = await self.client.request(
"GET",
status_url,
verify=not self.client.disable_ssl,
timeout=self.client.timeout,
)
model_class = self._prepare(Deployment)
deployments = {
"id": model_class(client=self.client, id=name) for name in r.json()
}
coll_model_class = self._prepare(DeploymentCollection)
return coll_model_class(client=self.client, items=deployments)
39 changes: 15 additions & 24 deletions llama_deploy/client/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,9 @@ async def list(self) -> list[Session]: # type: ignore
sessions_url = f"{self.client.control_plane_url}/sessions"
response = await self.client.request("GET", sessions_url)
sessions = []
model_class = self._prepare(Session)
for id, session_def in response.json().items():
sessions.append(
Session.instance(
make_sync=self._instance_is_sync, client=self.client, id=id
)
)
sessions.append(model_class(client=self.client, id=id))
return sessions

async def create(self) -> Session:
Expand All @@ -103,9 +100,8 @@ async def _create(self) -> Session:
create_url = f"{self.client.control_plane_url}/sessions/create"
response = await self.client.request("POST", create_url)
session_id = response.json()
return Session.instance(
make_sync=self._instance_is_sync, client=self.client, id=session_id
)
model_class = self._prepare(Session)
return model_class(client=self.client, id=session_id)

async def get(self, id: str) -> Session:
"""Gets a session by ID.
Expand All @@ -126,9 +122,8 @@ async def _get(self, id: str) -> Session:

get_url = f"{self.client.control_plane_url}/sessions/{id}"
await self.client.request("GET", get_url)
return Session.instance(
make_sync=self._instance_is_sync, client=self.client, id=id
)
model_class = self._prepare(Session)
return model_class(client=self.client, id=id)

async def get_or_create(self, id: str) -> Session:
"""Gets a session by ID, or creates a new one if it doesn't exist.
Expand Down Expand Up @@ -167,12 +162,10 @@ async def list(self) -> list[Service]: # type: ignore
services_url = f"{self.client.control_plane_url}/services"
response = await self.client.request("GET", services_url)
services = []
model_class = self._prepare(Service)

for name, service in response.json().items():
services.append(
Service.instance(
make_sync=self._instance_is_sync, client=self.client, id=name
)
)
services.append(model_class(client=self.client, id=name))

return services

Expand All @@ -184,7 +177,8 @@ async def register(self, service: ServiceDefinition) -> Service:
"""
register_url = f"{self.client.control_plane_url}/services/register"
await self.client.request("POST", register_url, json=service.model_dump())
s = Service.instance(id=service.service_name, client=self.client)
model_class = self._prepare(Service)
s = model_class(id=service.service_name, client=self.client)
self.items[service.service_name] = s
return s

Expand All @@ -210,10 +204,8 @@ def services(self) -> ServiceCollection:
Returns:
ServiceCollection: Collection of services registered with the control plane.
"""

return ServiceCollection.instance(
make_sync=self._instance_is_sync, client=self.client, items={}
)
model_class = self._prepare(ServiceCollection)
return model_class(client=self.client, items={})

@property
def sessions(self) -> SessionCollection:
Expand All @@ -222,6 +214,5 @@ def sessions(self) -> SessionCollection:
Returns:
SessionCollection: Collection of sessions registered with the control plane.
"""
return SessionCollection.instance(
make_sync=self._instance_is_sync, client=self.client, items={}
)
model_class = self._prepare(SessionCollection)
return model_class(client=self.client, items={})
33 changes: 9 additions & 24 deletions llama_deploy/client/models/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import asyncio
from typing import Any, Generic, TypeVar, cast
from typing import Any, Generic, TypeVar

from asgiref.sync import async_to_sync
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from typing_extensions import Self

from llama_deploy.client.base import _BaseClient

Expand All @@ -16,24 +15,10 @@ class _Base(BaseModel):

model_config = ConfigDict(arbitrary_types_allowed=True)

def __new__(cls, *args, **kwargs): # type: ignore[no-untyped-def]
"""We prevent the usage of the constructor and force users to call `instance()` instead."""
raise TypeError("Please use instance() instead of direct instantiation")

@classmethod
def instance(cls, make_sync: bool = False, **kwargs: Any) -> Self:
"""Returns an instance of the given model.
Using the class constructor is not possible because we want to alter the class method to
accommodate sync/async usage before creating an instance, and __init__ would be too late.
"""
if make_sync:
cls = _make_sync(cls)

inst = super(_Base, cls).__new__(cls)
inst.__init__(**kwargs) # type: ignore[misc]
inst._instance_is_sync = make_sync
return inst
def _prepare(self, _class: type) -> type:
if self._instance_is_sync:
return make_sync(_class)
return _class


T = TypeVar("T", bound=_Base)
Expand All @@ -57,15 +42,15 @@ def list(self) -> list[T]:
return [self.get(id) for id in self.items.keys()]


def _make_sync(_class: type[T]) -> type[T]:
def make_sync(_class: type[T]) -> Any:
"""Wraps the methods of the given model class so that they can be called without `await`."""

class Wrapper(_class): # type: ignore
pass
_instance_is_sync: bool = True

for name, method in _class.__dict__.items():
# Only wrap async public methods
if asyncio.iscoroutinefunction(method) and not name.startswith("_"):
setattr(Wrapper, name, async_to_sync(method))
# Static type checkers can't assess Wrapper is indeed a type[T], let's promise it is.
return cast(type[T], Wrapper)

return Wrapper
Loading

0 comments on commit 7455b96

Please sign in to comment.