Skip to content

Commit

Permalink
update environment module class
Browse files Browse the repository at this point in the history
  • Loading branch information
richardblythman committed Dec 12, 2024
1 parent aae413b commit e3886f7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 16 deletions.
2 changes: 1 addition & 1 deletion naptha_sdk/client/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from naptha_sdk.client import grpc_server_pb2
from naptha_sdk.client import grpc_server_pb2_grpc
from naptha_sdk.schemas import AgentRun, AgentRunInput, EnvironmentRun, EnvironmentRunInput, OrchestratorRun, \
from naptha_sdk.schemas import AgentRun, AgentRunInput, ChatCompletionRequest, EnvironmentRun, EnvironmentRunInput, OrchestratorRun, \
OrchestratorRunInput, AgentDeployment, EnvironmentDeployment, OrchestratorDeployment
from naptha_sdk.utils import get_logger

Expand Down
20 changes: 5 additions & 15 deletions naptha_sdk/environment.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from naptha_sdk.client.node import Node
from naptha_sdk.schemas import AgentRun, EnvironmentRunInput, OrchestratorRun
from naptha_sdk.schemas import AgentRun, EnvironmentDeployment, EnvironmentRunInput, OrchestratorRun
from typing import Any, Dict, List, Union
import logging

logger = logging.getLogger(__name__)

class Environment:
def __init__(self, module_run: Union[OrchestratorRun, AgentRun]):
self.module_run = module_run
self.environment_deployment = module_run.environment_deployments[0]
def __init__(self, environment_deployment: EnvironmentDeployment):
self.environment_deployment = environment_deployment
self.environment_node = Node(self.environment_deployment.environment_node_url)
self.table_name = "multi_chat_simulations"

Expand Down Expand Up @@ -83,16 +82,7 @@ async def get_simulation(self, run_id: str) -> List[Dict[str, Any]]:
logger.error(f"Error retrieving simulation: {str(e)}")
raise

async def call_environment_func(self, *args, **kwargs):
async def call_environment_func(self, environment_run_input: EnvironmentRunInput):
logger.info(f"Running environment on environment node {self.environment_node.node_url}")

environment_run_input = EnvironmentRunInput(
consumer_id=self.module_run.consumer_id,
inputs=kwargs,
agent_deployment=self.module_run.agent_deployment.model_dump(),
)

environment_run = await self.environment_node.run_environment_and_poll(
environment_run_input=environment_run_input
)
environment_run = await self.environment_node.run_environment_and_poll(environment_run_input)
return environment_run

0 comments on commit e3886f7

Please sign in to comment.