diff --git a/naptha_sdk/client/node.py b/naptha_sdk/client/node.py index e99aeb0b..2a8de2c9 100644 --- a/naptha_sdk/client/node.py +++ b/naptha_sdk/client/node.py @@ -18,7 +18,7 @@ from naptha_sdk.client import grpc_server_pb2 from naptha_sdk.client import grpc_server_pb2_grpc -from naptha_sdk.schemas import AgentRun, AgentRunInput, EnvironmentRun, EnvironmentRunInput, OrchestratorRun, \ +from naptha_sdk.schemas import AgentRun, AgentRunInput, ChatCompletionRequest, EnvironmentRun, EnvironmentRunInput, OrchestratorRun, \ OrchestratorRunInput, AgentDeployment, EnvironmentDeployment, OrchestratorDeployment from naptha_sdk.utils import get_logger diff --git a/naptha_sdk/environment.py b/naptha_sdk/environment.py index 9d700bdf..02395092 100644 --- a/naptha_sdk/environment.py +++ b/naptha_sdk/environment.py @@ -1,14 +1,13 @@ from naptha_sdk.client.node import Node -from naptha_sdk.schemas import AgentRun, EnvironmentRunInput, OrchestratorRun +from naptha_sdk.schemas import AgentRun, EnvironmentDeployment, EnvironmentRunInput, OrchestratorRun from typing import Any, Dict, List, Union import logging logger = logging.getLogger(__name__) class Environment: - def __init__(self, module_run: Union[OrchestratorRun, AgentRun]): - self.module_run = module_run - self.environment_deployment = module_run.environment_deployments[0] + def __init__(self, environment_deployment: EnvironmentDeployment): + self.environment_deployment = environment_deployment self.environment_node = Node(self.environment_deployment.environment_node_url) self.table_name = "multi_chat_simulations" @@ -83,16 +82,7 @@ async def get_simulation(self, run_id: str) -> List[Dict[str, Any]]: logger.error(f"Error retrieving simulation: {str(e)}") raise - async def call_environment_func(self, *args, **kwargs): + async def call_environment_func(self, environment_run_input: EnvironmentRunInput): logger.info(f"Running environment on environment node {self.environment_node.node_url}") - - environment_run_input = EnvironmentRunInput( - consumer_id=self.module_run.consumer_id, - inputs=kwargs, - agent_deployment=self.module_run.agent_deployment.model_dump(), - ) - - environment_run = await self.environment_node.run_environment_and_poll( - environment_run_input=environment_run_input - ) + environment_run = await self.environment_node.run_environment_and_poll(environment_run_input) return environment_run \ No newline at end of file