From 7455b9661e0b736d7ec7b90b5459b83ddcb8e2c6 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Fri, 1 Nov 2024 18:13:28 +0100 Subject: [PATCH] fix: fix typing for wrapped async methods (#344) * fix typing for wrapped async methods * leftover * remove leftover * fix rebase screwups * restore test coverage --- llama_deploy/client/client.py | 18 ++-- llama_deploy/client/models/__init__.py | 4 +- llama_deploy/client/models/apiserver.py | 128 ++++++++++-------------- llama_deploy/client/models/core.py | 39 +++----- llama_deploy/client/models/model.py | 33 ++---- tests/client/models/test_apiserver.py | 77 ++++++++++---- tests/client/models/test_core.py | 32 +++--- tests/client/models/test_model.py | 35 ++++--- 8 files changed, 180 insertions(+), 186 deletions(-) diff --git a/llama_deploy/client/client.py b/llama_deploy/client/client.py index 4879a0ef..1dd53a3d 100644 --- a/llama_deploy/client/client.py +++ b/llama_deploy/client/client.py @@ -1,5 +1,7 @@ +from typing import Any + from .base import _BaseClient -from .models import ApiServer, Core +from .models import ApiServer, Core, make_sync class Client(_BaseClient): @@ -31,19 +33,19 @@ def sync(self) -> "_SyncClient": @property def apiserver(self) -> ApiServer: """Returns the ApiServer model.""" - return ApiServer.instance(client=self, id="apiserver") + return ApiServer(client=self, id="apiserver") @property def core(self) -> Core: """Returns the Core model.""" - return Core.instance(client=self, id="core") + return Core(client=self, id="core") -class _SyncClient(Client): +class _SyncClient(_BaseClient): @property - def apiserver(self) -> ApiServer: - return ApiServer.instance(make_sync=True, client=self, id="apiserver") + def apiserver(self) -> Any: + return make_sync(ApiServer)(client=self, id="apiserver") @property - def core(self) -> Core: - return Core.instance(make_sync=True, client=self, id="core") + def core(self) -> Any: + return make_sync(Core)(client=self, id="core") diff --git a/llama_deploy/client/models/__init__.py b/llama_deploy/client/models/__init__.py index db028558..8bf53ac4 100644 --- a/llama_deploy/client/models/__init__.py +++ b/llama_deploy/client/models/__init__.py @@ -1,5 +1,5 @@ from .apiserver import ApiServer from .core import Core -from .model import Collection, Model +from .model import Collection, Model, make_sync -__all__ = ["ApiServer", "Collection", "Core", "Model"] +__all__ = ["ApiServer", "Collection", "Core", "Model", "make_sync"] diff --git a/llama_deploy/client/models/apiserver.py b/llama_deploy/client/models/apiserver.py index 35d4efc7..0a4fa70c 100644 --- a/llama_deploy/client/models/apiserver.py +++ b/llama_deploy/client/models/apiserver.py @@ -53,11 +53,8 @@ async def create(self) -> Session: session_def = SessionDefinition(**r.json()) - return Session.instance( - client=self.client, - make_sync=self._instance_is_sync, - id=session_def.session_id, - ) + model_class = self._prepare(Session) + return model_class(client=self.client, id=session_def.session_id) async def list(self) -> list[Session]: # type: ignore """Returns a collection of all the sessions in the given deployment.""" @@ -70,12 +67,9 @@ async def list(self) -> list[Session]: # type: ignore verify=not self.client.disable_ssl, timeout=self.client.timeout, ) + model_class = self._prepare(Session) items = [ - Session.instance( - make_sync=self._instance_is_sync, - client=self.client, - id=session_def.session_id, - ) + model_class(client=self.client, id=session_def.session_id) for session_def in r.json() ] return items @@ -161,60 +155,56 @@ async def create(self, task: TaskDefinition) -> Task: ) response_fields = r.json() - return Task.instance( - make_sync=self._instance_is_sync, + model_class = self._prepare(Task) + return model_class( client=self.client, deployment_id=self.deployment_id, id=response_fields["task_id"], session_id=response_fields["session_id"], ) - async def list(self) -> list[Task]: # type: ignore + +class Deployment(Model): + """A model representing a deployment.""" + + async def tasks(self) -> TaskCollection: """Returns a collection of tasks from all the sessions in the given deployment.""" - tasks_url = ( - f"{self.client.api_server_url}/deployments/{self.deployment_id}/tasks" - ) + tasks_url = f"{self.client.api_server_url}/deployments/{self.id}/tasks" r = await self.client.request( "GET", tasks_url, verify=not self.client.disable_ssl, timeout=self.client.timeout, ) - items = [ - Task.instance( - make_sync=self._instance_is_sync, + task_model_class = self._prepare(Task) + items = { + "id": task_model_class( client=self.client, id=task_def.task_id, session_id=task_def.session_id, - deployment_id=self.deployment_id, + deployment_id=self.id, ) for task_def in r.json() - ] - return items - + } + model_class = self._prepare(TaskCollection) + return model_class(client=self.client, deployment_id=self.id, items=items) -class Deployment(Model): - """A model representing a deployment.""" - - @property - def tasks(self) -> TaskCollection: - """Returns a collection of tasks from all the sessions in the given deployment.""" - return TaskCollection.instance( - make_sync=self._instance_is_sync, - client=self.client, - deployment_id=self.id, - items={}, - ) - - @property - def sessions(self) -> SessionCollection: + async def sessions(self) -> SessionCollection: """Returns a collection of all the sessions in the given deployment.""" - return SessionCollection.instance( - make_sync=self._instance_is_sync, - client=self.client, - deployment_id=self.id, - items={}, + sessions_url = f"{self.client.api_server_url}/deployments/{self.id}/sessions" + r = await self.client.request( + "GET", + sessions_url, + verify=not self.client.disable_ssl, + timeout=self.client.timeout, ) + model_class = self._prepare(Session) + items = { + "id": model_class(client=self.client, id=session_def.session_id) + for session_def in r.json() + } + coll_model_class = self._prepare(SessionCollection) + return coll_model_class(client=self.client, deployment_id=self.id, items=items) class DeploymentCollection(Collection): @@ -233,15 +223,12 @@ async def create(self, config: TextIO) -> Deployment: timeout=self.client.timeout, ) - return Deployment.instance( - make_sync=self._instance_is_sync, - client=self.client, - id=r.json().get("name"), - ) + model_class = self._prepare(Deployment) + return model_class(client=self.client, id=r.json().get("name")) - async def get(self, id: str) -> Deployment: + async def get(self, deployment_id: str) -> Deployment: """Gets a deployment by id.""" - get_url = f"{self.client.api_server_url}/deployments/{id}" + get_url = f"{self.client.api_server_url}/deployments/{deployment_id}" # Current version of apiserver doesn't returns anything useful in this endpoint, let's just ignore it await self.client.request( "GET", @@ -249,27 +236,8 @@ async def get(self, id: str) -> Deployment: verify=not self.client.disable_ssl, timeout=self.client.timeout, ) - return Deployment.instance( - client=self.client, make_sync=self._instance_is_sync, id=id - ) - - async def list(self) -> list[Deployment]: # type: ignore - """Returns a collection of deployments currently active in the API Server.""" - status_url = f"{self.client.api_server_url}/deployments/" - - r = await self.client.request( - "GET", - status_url, - verify=not self.client.disable_ssl, - timeout=self.client.timeout, - ) - deployments = [ - Deployment.instance( - make_sync=self._instance_is_sync, client=self.client, id=name - ) - for name in r.json() - ] - return deployments + model_class = self._prepare(Deployment) + return model_class(client=self.client, id=deployment_id) class ApiServer(Model): @@ -312,9 +280,19 @@ async def status(self) -> Status: deployments=deployments, ) - @property - def deployments(self) -> DeploymentCollection: + async def deployments(self) -> DeploymentCollection: """Returns a collection of deployments currently active in the API Server.""" - return DeploymentCollection.instance( - make_sync=self._instance_is_sync, client=self.client, items={} + status_url = f"{self.client.api_server_url}/deployments/" + + r = await self.client.request( + "GET", + status_url, + verify=not self.client.disable_ssl, + timeout=self.client.timeout, ) + model_class = self._prepare(Deployment) + deployments = { + "id": model_class(client=self.client, id=name) for name in r.json() + } + coll_model_class = self._prepare(DeploymentCollection) + return coll_model_class(client=self.client, items=deployments) diff --git a/llama_deploy/client/models/core.py b/llama_deploy/client/models/core.py index 0b81b50e..f5a5c93f 100644 --- a/llama_deploy/client/models/core.py +++ b/llama_deploy/client/models/core.py @@ -82,12 +82,9 @@ async def list(self) -> list[Session]: # type: ignore sessions_url = f"{self.client.control_plane_url}/sessions" response = await self.client.request("GET", sessions_url) sessions = [] + model_class = self._prepare(Session) for id, session_def in response.json().items(): - sessions.append( - Session.instance( - make_sync=self._instance_is_sync, client=self.client, id=id - ) - ) + sessions.append(model_class(client=self.client, id=id)) return sessions async def create(self) -> Session: @@ -103,9 +100,8 @@ async def _create(self) -> Session: create_url = f"{self.client.control_plane_url}/sessions/create" response = await self.client.request("POST", create_url) session_id = response.json() - return Session.instance( - make_sync=self._instance_is_sync, client=self.client, id=session_id - ) + model_class = self._prepare(Session) + return model_class(client=self.client, id=session_id) async def get(self, id: str) -> Session: """Gets a session by ID. @@ -126,9 +122,8 @@ async def _get(self, id: str) -> Session: get_url = f"{self.client.control_plane_url}/sessions/{id}" await self.client.request("GET", get_url) - return Session.instance( - make_sync=self._instance_is_sync, client=self.client, id=id - ) + model_class = self._prepare(Session) + return model_class(client=self.client, id=id) async def get_or_create(self, id: str) -> Session: """Gets a session by ID, or creates a new one if it doesn't exist. @@ -167,12 +162,10 @@ async def list(self) -> list[Service]: # type: ignore services_url = f"{self.client.control_plane_url}/services" response = await self.client.request("GET", services_url) services = [] + model_class = self._prepare(Service) + for name, service in response.json().items(): - services.append( - Service.instance( - make_sync=self._instance_is_sync, client=self.client, id=name - ) - ) + services.append(model_class(client=self.client, id=name)) return services @@ -184,7 +177,8 @@ async def register(self, service: ServiceDefinition) -> Service: """ register_url = f"{self.client.control_plane_url}/services/register" await self.client.request("POST", register_url, json=service.model_dump()) - s = Service.instance(id=service.service_name, client=self.client) + model_class = self._prepare(Service) + s = model_class(id=service.service_name, client=self.client) self.items[service.service_name] = s return s @@ -210,10 +204,8 @@ def services(self) -> ServiceCollection: Returns: ServiceCollection: Collection of services registered with the control plane. """ - - return ServiceCollection.instance( - make_sync=self._instance_is_sync, client=self.client, items={} - ) + model_class = self._prepare(ServiceCollection) + return model_class(client=self.client, items={}) @property def sessions(self) -> SessionCollection: @@ -222,6 +214,5 @@ def sessions(self) -> SessionCollection: Returns: SessionCollection: Collection of sessions registered with the control plane. """ - return SessionCollection.instance( - make_sync=self._instance_is_sync, client=self.client, items={} - ) + model_class = self._prepare(SessionCollection) + return model_class(client=self.client, items={}) diff --git a/llama_deploy/client/models/model.py b/llama_deploy/client/models/model.py index b236e3eb..47151bd9 100644 --- a/llama_deploy/client/models/model.py +++ b/llama_deploy/client/models/model.py @@ -1,9 +1,8 @@ import asyncio -from typing import Any, Generic, TypeVar, cast +from typing import Any, Generic, TypeVar from asgiref.sync import async_to_sync from pydantic import BaseModel, ConfigDict, Field, PrivateAttr -from typing_extensions import Self from llama_deploy.client.base import _BaseClient @@ -16,24 +15,10 @@ class _Base(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - def __new__(cls, *args, **kwargs): # type: ignore[no-untyped-def] - """We prevent the usage of the constructor and force users to call `instance()` instead.""" - raise TypeError("Please use instance() instead of direct instantiation") - - @classmethod - def instance(cls, make_sync: bool = False, **kwargs: Any) -> Self: - """Returns an instance of the given model. - - Using the class constructor is not possible because we want to alter the class method to - accommodate sync/async usage before creating an instance, and __init__ would be too late. - """ - if make_sync: - cls = _make_sync(cls) - - inst = super(_Base, cls).__new__(cls) - inst.__init__(**kwargs) # type: ignore[misc] - inst._instance_is_sync = make_sync - return inst + def _prepare(self, _class: type) -> type: + if self._instance_is_sync: + return make_sync(_class) + return _class T = TypeVar("T", bound=_Base) @@ -57,15 +42,15 @@ def list(self) -> list[T]: return [self.get(id) for id in self.items.keys()] -def _make_sync(_class: type[T]) -> type[T]: +def make_sync(_class: type[T]) -> Any: """Wraps the methods of the given model class so that they can be called without `await`.""" class Wrapper(_class): # type: ignore - pass + _instance_is_sync: bool = True for name, method in _class.__dict__.items(): # Only wrap async public methods if asyncio.iscoroutinefunction(method) and not name.startswith("_"): setattr(Wrapper, name, async_to_sync(method)) - # Static type checkers can't assess Wrapper is indeed a type[T], let's promise it is. - return cast(type[T], Wrapper) + + return Wrapper diff --git a/tests/client/models/test_apiserver.py b/tests/client/models/test_apiserver.py index 7f586e02..2efffe00 100644 --- a/tests/client/models/test_apiserver.py +++ b/tests/client/models/test_apiserver.py @@ -19,9 +19,9 @@ @pytest.mark.asyncio async def test_session_collection_delete(client: Any) -> None: - coll = SessionCollection.instance( + coll = SessionCollection( client=client, - items={"a_session": Session.instance(id="a_session", client=client)}, + items={"a_session": Session(id="a_session", client=client)}, deployment_id="a_deployment", ) await coll.delete("a_session") @@ -39,7 +39,7 @@ async def test_session_collection_create(client: Any) -> None: client.request.return_value = mock.MagicMock( json=lambda: {"session_id": "a_session"} ) - coll = SessionCollection.instance( + coll = SessionCollection( client=client, items={}, deployment_id="a_deployment", @@ -53,12 +53,47 @@ async def test_session_collection_create(client: Any) -> None: ) +@pytest.mark.asyncio +async def test_session_collection_list(client: Any) -> None: + # Mock response containing list of sessions + client.request.return_value = mock.MagicMock( + json=lambda: [ + SessionDefinition(session_id="session1"), + SessionDefinition(session_id="session2"), + ] + ) + + # Create session collection instance + coll = SessionCollection( + client=client, + items={}, + deployment_id="a_deployment", + ) + + # Call list method + sessions = await coll.list() + + # Verify request was made correctly + client.request.assert_awaited_with( + "GET", + "http://localhost:4501/deployments/a_deployment/sessions", + verify=True, + timeout=120.0, + ) + + # Verify returned sessions + assert len(sessions) == 2 + assert all(isinstance(session, Session) for session in sessions) + assert sessions[0].id == "session1" + assert sessions[1].id == "session2" + + @pytest.mark.asyncio async def test_task_results(client: Any) -> None: res = TaskResult(task_id="a_result", history=[], result="some_text", data={}) client.request.return_value = mock.MagicMock(json=lambda: res.model_dump_json()) - t = Task.instance( + t = Task( client=client, id="a_task", deployment_id="a_deployment", @@ -78,10 +113,10 @@ async def test_task_results(client: Any) -> None: @pytest.mark.asyncio async def test_task_collection_run(client: Any) -> None: client.request.return_value = mock.MagicMock(json=lambda: "some result") - coll = TaskCollection.instance( + coll = TaskCollection( client=client, items={ - "a_session": Task.instance( + "a_session": Task( id="a_session", client=client, deployment_id="a_deployment", @@ -110,10 +145,10 @@ async def test_task_collection_create(client: Any) -> None: client.request.return_value = mock.MagicMock( json=lambda: {"session_id": "a_session", "task_id": "test_id"} ) - coll = TaskCollection.instance( + coll = TaskCollection( client=client, items={ - "a_session": Task.instance( + "a_session": Task( id="a_session", client=client, deployment_id="a_deployment", @@ -139,7 +174,7 @@ async def test_task_collection_create(client: Any) -> None: @pytest.mark.asyncio async def test_task_deployment_tasks(client: Any) -> None: - d = Deployment.instance(client=client, id="a_deployment") + d = Deployment(client=client, id="a_deployment") res: list[TaskDefinition] = [ TaskDefinition( input='{"arg": "input"}', task_id="a_task", session_id="a_session" @@ -147,7 +182,7 @@ async def test_task_deployment_tasks(client: Any) -> None: ] client.request.return_value = mock.MagicMock(json=lambda: res) - await d.tasks.list() + await d.tasks() client.request.assert_awaited_with( "GET", @@ -159,11 +194,11 @@ async def test_task_deployment_tasks(client: Any) -> None: @pytest.mark.asyncio async def test_task_deployment_sessions(client: Any) -> None: - d = Deployment.instance(client=client, id="a_deployment") + d = Deployment(client=client, id="a_deployment") res: list[SessionDefinition] = [SessionDefinition(session_id="a_session")] client.request.return_value = mock.MagicMock(json=lambda: res) - await d.sessions.list() + await d.sessions() client.request.assert_awaited_with( "GET", @@ -177,7 +212,7 @@ async def test_task_deployment_sessions(client: Any) -> None: async def test_task_deployment_collection_create(client: Any) -> None: client.request.return_value = mock.MagicMock(json=lambda: {"name": "deployment"}) - coll = DeploymentCollection.instance(client=client, items={}) + coll = DeploymentCollection(client=client, items={}) await coll.create(io.StringIO("some config")) client.request.assert_awaited_with( @@ -191,8 +226,8 @@ async def test_task_deployment_collection_create(client: Any) -> None: @pytest.mark.asyncio async def test_task_deployment_collection_get(client: Any) -> None: - d = Deployment.instance(client=client, id="a_deployment") - coll = DeploymentCollection.instance(client=client, items={"a_deployment": d}) + d = Deployment(client=client, id="a_deployment") + coll = DeploymentCollection(client=client, items={"a_deployment": d}) client.request.return_value = mock.MagicMock(json=lambda: {"a_deployment": "Up!"}) await coll.get("a_deployment") @@ -209,7 +244,7 @@ async def test_task_deployment_collection_get(client: Any) -> None: async def test_status_down(client: Any) -> None: client.request.side_effect = httpx.ConnectError(message="connection error") - apis = ApiServer.instance(client=client, id="apiserver") + apis = ApiServer(client=client, id="apiserver") res = await apis.status() client.request.assert_awaited_with( @@ -224,7 +259,7 @@ async def test_status_unhealthy(client: Any) -> None: status_code=400, text="This is a drill." ) - apis = ApiServer.instance(client=client, id="apiserver") + apis = ApiServer(client=client, id="apiserver") res = await apis.status() client.request.assert_awaited_with( @@ -240,7 +275,7 @@ async def test_status_healthy_no_deployments(client: Any) -> None: status_code=200, text="", json=lambda: {} ) - apis = ApiServer.instance(client=client, id="apiserver") + apis = ApiServer(client=client, id="apiserver") res = await apis.status() client.request.assert_awaited_with( @@ -259,7 +294,7 @@ async def test_status_healthy(client: Any) -> None: status_code=200, text="", json=lambda: {"deployments": ["foo", "bar"]} ) - apis = ApiServer.instance(client=client, id="apiserver") + apis = ApiServer(client=client, id="apiserver") res = await apis.status() client.request.assert_awaited_with( @@ -277,8 +312,8 @@ async def test_deployments(client: Any) -> None: client.request.return_value = mock.MagicMock( status_code=200, text="", json=lambda: {"deployments": ["foo", "bar"]} ) - apis = ApiServer.instance(client=client, id="apiserver") - await apis.deployments.list() + apis = ApiServer(client=client, id="apiserver") + await apis.deployments() client.request.assert_awaited_with( "GET", "http://localhost:4501/deployments/", verify=True, timeout=120.0 ) diff --git a/tests/client/models/test_core.py b/tests/client/models/test_core.py index cf5efe64..eeefa5b8 100644 --- a/tests/client/models/test_core.py +++ b/tests/client/models/test_core.py @@ -28,7 +28,7 @@ async def test_session_run(client: mock.AsyncMock) -> None: ), ] - session = Session.instance(client=client, id="test_session_id") + session = Session(client=client, id="test_session_id") result = await session.run("test_service", test_param="test_value") assert result == "test result" @@ -38,7 +38,7 @@ async def test_session_run(client: mock.AsyncMock) -> None: async def test_session_create_task(client: mock.AsyncMock) -> None: client.request.return_value = mock.MagicMock(json=lambda: "test_task_id") - session = Session.instance(client=client, id="test_session_id") + session = Session(client=client, id="test_session_id") task_def = TaskDefinition(input="test input", agent_id="test_service") task_id = await session.create_task(task_def) @@ -51,7 +51,7 @@ async def test_session_get_task_result(client: mock.AsyncMock) -> None: json=lambda: {"task_id": "test_task_id", "result": "test_result", "history": []} ) - session = Session.instance(client=client, id="test_session_id") + session = Session(client=client, id="test_session_id") result = await session.get_task_result("test_task_id") assert result.result == "test_result" if result else "" @@ -63,7 +63,7 @@ async def test_session_get_task_result(client: mock.AsyncMock) -> None: @pytest.mark.asyncio async def test_service_collection_register(client: mock.AsyncMock) -> None: - coll = ServiceCollection.instance(client=client, items={}) + coll = ServiceCollection(client=client, items={}) service = ServiceDefinition(service_name="test_service", description="some service") await coll.register(service) @@ -82,9 +82,9 @@ async def test_service_collection_register(client: mock.AsyncMock) -> None: @pytest.mark.asyncio async def test_service_collection_deregister(client: mock.AsyncMock) -> None: - coll = ServiceCollection.instance( + coll = ServiceCollection( client=client, - items={"test_service": Service.instance(client=client, id="test_service")}, + items={"test_service": Service(client=client, id="test_service")}, ) await coll.deregister("test_service") @@ -101,7 +101,7 @@ async def test_core_services(client: mock.AsyncMock) -> None: json=lambda: {"test_service": {"name": "test_service"}} ) - core = Core.instance(client=client, id="core") + core = Core(client=client, id="core") services = await core.services.list() client.request.assert_awaited_with("GET", "http://localhost:8000/services") @@ -112,7 +112,7 @@ async def test_core_services(client: mock.AsyncMock) -> None: async def test_session_collection_create(client: mock.AsyncMock) -> None: client.request.return_value = mock.MagicMock(json=lambda: "test_session_id") - coll = SessionCollection.instance(client=client, items={}) + coll = SessionCollection(client=client, items={}) session = await coll.create() client.request.assert_awaited_with("POST", "http://localhost:8000/sessions/create") @@ -122,7 +122,7 @@ async def test_session_collection_create(client: mock.AsyncMock) -> None: @pytest.mark.asyncio async def test_session_collection_get_existing(client: mock.AsyncMock) -> None: - coll = SessionCollection.instance(client=client, items={}) + coll = SessionCollection(client=client, items={}) session = await coll.get("test_session_id") client.request.assert_awaited_with( @@ -138,7 +138,7 @@ async def test_session_collection_get_nonexistent(client: mock.AsyncMock) -> Non "Not Found", request=mock.MagicMock(), response=mock.MagicMock(status_code=404) ) - coll = SessionCollection.instance(client=client, items={}) + coll = SessionCollection(client=client, items={}) with pytest.raises(httpx.HTTPStatusError, match="Not Found"): await coll.get("test_session_id") @@ -148,7 +148,7 @@ async def test_session_collection_get_nonexistent(client: mock.AsyncMock) -> Non async def test_session_collection_get_or_create_existing( client: mock.AsyncMock, ) -> None: - coll = SessionCollection.instance(client=client, items={}) + coll = SessionCollection(client=client, items={}) session = await coll.get_or_create("test_session_id") client.request.assert_awaited_with( @@ -171,7 +171,7 @@ async def test_session_collection_get_or_create_nonexistent( mock.MagicMock(json=lambda: "test_session_id"), ] - coll = SessionCollection.instance(client=client, items={}) + coll = SessionCollection(client=client, items={}) await coll.get_or_create("test_session_id") client.request.assert_awaited_with("POST", "http://localhost:8000/sessions/create") @@ -188,14 +188,14 @@ async def test_session_collection_get_or_create_error( ) ] - coll = SessionCollection.instance(client=client, items={}) + coll = SessionCollection(client=client, items={}) with pytest.raises(httpx.HTTPStatusError): await coll.get_or_create("test_session_id") @pytest.mark.asyncio async def test_session_collection_delete(client: mock.AsyncMock) -> None: - coll = SessionCollection.instance(client=client, items={}) + coll = SessionCollection(client=client, items={}) await coll.delete("test_session_id") client.request.assert_awaited_with( @@ -209,7 +209,7 @@ async def test_core_sessions(client: mock.AsyncMock) -> None: json=lambda: {"test_session": {"id": "test_session"}} ) - core = Core.instance(client=client, id="core") + core = Core(client=client, id="core") sessions = await core.sessions.list() client.request.assert_awaited_with("GET", "http://localhost:8000/sessions") @@ -233,7 +233,7 @@ async def test_session_get_tasks(client: mock.AsyncMock) -> None: ] ) - session = Session.instance(client=client, id="test_session_id") + session = Session(client=client, id="test_session_id") tasks = await session.get_tasks() client.request.assert_awaited_with( diff --git a/tests/client/models/test_model.py b/tests/client/models/test_model.py index f01739e3..7d07dcee 100644 --- a/tests/client/models/test_model.py +++ b/tests/client/models/test_model.py @@ -1,41 +1,44 @@ import asyncio -import pytest - from llama_deploy.client import Client from llama_deploy.client.models import Collection, Model -from llama_deploy.client.models.model import _make_sync +from llama_deploy.client.models.model import make_sync class SomeAsyncModel(Model): - async def method(self) -> None: - pass - - -def test_no_init(client: Client) -> None: - with pytest.raises( - TypeError, match=r"Please use instance\(\) instead of direct instantiation" - ): - SomeAsyncModel(id="foo", client=client) + async def method(self) -> int: + return 0 def test_make_sync() -> None: assert asyncio.iscoroutinefunction(getattr(SomeAsyncModel, "method")) - some_sync = _make_sync(SomeAsyncModel) + some_sync = make_sync(SomeAsyncModel) assert not asyncio.iscoroutinefunction(getattr(some_sync, "method")) +def test_make_sync_instance(client: Client) -> None: + some_sync = make_sync(SomeAsyncModel)(client=client, id="foo") + assert not asyncio.iscoroutinefunction(some_sync.method) + assert some_sync.method() + 1 == 1 + + +def test__prepare(client: Client) -> None: + some_sync = make_sync(SomeAsyncModel)(client=client, id="foo") + coll = some_sync._prepare(Collection) + assert coll._instance_is_sync + + def test_collection_get() -> None: class MyCollection(Collection): pass c = Client() models_list = [ - SomeAsyncModel.instance(client=c, id="foo"), - SomeAsyncModel.instance(client=c, id="bar"), + SomeAsyncModel(client=c, id="foo"), + SomeAsyncModel(client=c, id="bar"), ] - coll = MyCollection.instance(client=c, items={m.id: m for m in models_list}) + coll = MyCollection(client=c, items={m.id: m for m in models_list}) assert coll.get("foo").id == "foo" assert coll.get("bar").id == "bar" assert coll.list() == models_list