From ef2afe5f4f769027cd3571df4ae30dcf70064bdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cem=20G=C3=B6kmen?= <1408354+cgokmen@users.noreply.github.com> Date: Fri, 17 Nov 2023 16:04:55 -0800 Subject: [PATCH] Update the RPC stuff to a cleaner API. --- rl/service/environment_servicer.py | 125 ++++++++++++++-------------- rl/service/protos/environment.proto | 100 ++++++++-------------- 2 files changed, 99 insertions(+), 126 deletions(-) diff --git a/rl/service/environment_servicer.py b/rl/service/environment_servicer.py index 1adb21938..bc0ff50b7 100644 --- a/rl/service/environment_servicer.py +++ b/rl/service/environment_servicer.py @@ -9,71 +9,70 @@ class EnvironmentServicer(environment_pb2_grpc.EnvironmentServicer): def __init__(self, env) -> None: self.env = env - def ManageEnvironment( - self, request: environment_pb2.EnvironmentRequest, unused_context - ) -> environment_pb2.EnvironmentResponse: - response = environment_pb2.EnvironmentResponse() + def Step(self, request, unused_context): + action = pickle.loads(request.action) + observation, reward, terminated, truncated, info = self.env.step(action) + done = terminated or truncated + info["TimeLimit.truncated"] = truncated and not terminated - if request.WhichOneOf("command") == "step": - action = pickle.loads(request.step.action) - observation, reward, terminated, truncated, info = self.env.step(action) - # convert to SB3 VecEnv api - done = terminated or truncated - info["TimeLimit.truncated"] = truncated and not terminated - if done: - # save final observation where user can get it, then reset - info["terminal_observation"] = observation - observation, reset_info = self._env.reset() - subresponse = response.step_response - subresponse.observation = pickle.dumps(observation) - subresponse.reward = reward - subresponse.done = done - subresponse.info = pickle.dumps(info) - subresponse.reset_info = pickle.dumps(reset_info) - elif request.WhichOneOf("command") == "reset": - seed = request.reset.seed - maybe_options = {"options": pickle.loads(request.reset.options)} if request.reset.options else {} - observation, reset_info = self.env.reset(seed=seed, **maybe_options) - subresponse = response.reset_response - subresponse.observation = pickle.dumps(observation) - subresponse.reset_info = pickle.dumps(reset_info) - elif request.WhichOneOf("command") == "render": - image = self.env.render() - subresponse = response.render_response - subresponse.render_data = pickle.dumps(image) - elif request.WhichOneOf("command") == "close": - self._env.close() - response.close_response.SetInParent() - elif request.WhichOneOf("command") == "get_spaces": - subresponse = response.get_spaces_response - subresponse.observation_space = pickle.dumps(self.env.observation_space) - subresponse.action_space = pickle.dumps(self.env.action_space) - elif request.WhichOneOf("command") == "env_method": - method_name = request.env_method.method_name - args, kwargs = pickle.arguments(request.env_method.args) - method = getattr(self.env, method_name) - result = method(*args, **kwargs) - subresponse = response.env_method_response - subresponse.result = pickle.dumps(result) - elif request.WhichOneOf("command") == "get_attr": - attr = request.get_attr.attribute_name - result = getattr(self.env, attr) - subresponse = response.get_attr_response - subresponse.attribute_value = pickle.dumps(result) - elif request.WhichOneOf("command") == "set_attr": - attr = request.get_attr.attribute_name - val = pickle.loads(request.set_attr.attribute_value) - result = setattr(self.env, attr, val) - response.set_attr_response.SetInParent() - elif request.WhichOneOf("command") == "is_wrapped": - wrapper_type = request.is_wrapped.wrapper_type - result = is_wrapped(self.env, wrapper_type) - subresponse = response.is_wrapped_response - subresponse.is_wrapped = result - else: - raise NotImplementedError(f"Invalid request is not implemented in the worker") + if done: + info["terminal_observation"] = observation + observation, reset_info = self.env.reset() + + return environment_pb2.StepResponse( + observation=pickle.dumps(observation), + reward=reward, + done=done, + info=pickle.dumps(info), + reset_info=pickle.dumps(reset_info) + ) + + def Reset(self, request, unused_context): + seed = request.seed + maybe_options = {"options": pickle.loads(request.options)} if request.options else {} + observation, reset_info = self.env.reset(seed=seed, **maybe_options) - return response + return environment_pb2.ResetResponse( + observation=pickle.dumps(observation), + reset_info=pickle.dumps(reset_info) + ) + + def Render(self, request, unused_context): + image = self.env.render() + return environment_pb2.RenderResponse(render_data=pickle.dumps(image)) + + def Close(self, request, unused_context): + self.env.close() + return environment_pb2.CloseResponse() + + def GetSpaces(self, request, unused_context): + return environment_pb2.GetSpacesResponse( + observation_space=pickle.dumps(self.env.observation_space), + action_space=pickle.dumps(self.env.action_space) + ) + + def EnvMethod(self, request, unused_context): + method_name = request.method_name + args, kwargs = pickle.loads(request.arguments) + method = getattr(self.env, method_name) + result = method(*args, **kwargs) + return environment_pb2.EnvMethodResponse(result=pickle.dumps(result)) + + def GetAttr(self, request, unused_context): + attr = request.attribute_name + result = getattr(self.env, attr) + return environment_pb2.GetAttrResponse(attribute_value=pickle.dumps(result)) + + def SetAttr(self, request, unused_context): + attr = request.attribute_name + val = pickle.loads(request.attribute_value) + setattr(self.env, attr, val) + return environment_pb2.SetAttrResponse() + + def IsWrapped(self, request, unused_context): + wrapper_type = request.wrapper_type + is_wrapped = hasattr(self.env, wrapper_type) # Assuming is_wrapped is implemented as hasattr + return environment_pb2.IsWrappedResponse(is_wrapped=is_wrapped) async def serve(env): server = grpc.aio.server() diff --git a/rl/service/protos/environment.proto b/rl/service/protos/environment.proto index de41334f9..70ad4ae44 100644 --- a/rl/service/protos/environment.proto +++ b/rl/service/protos/environment.proto @@ -3,101 +3,75 @@ syntax = "proto3"; package environment; service EnvironmentService { - rpc ManageEnvironment(EnvironmentRequest) returns (EnvironmentResponse); -} - -message EnvironmentRequest { - oneof command { - StepCommand step = 1; - ResetCommand reset = 2; - RenderCommand render = 3; - CloseCommand close = 4; - GetSpacesCommand get_spaces = 5; - EnvMethodCommand env_method = 6; - GetAttrCommand get_attr = 7; - SetAttrCommand set_attr = 8; - IsWrappedCommand is_wrapped = 9; - } -} - -message EnvironmentResponse { - oneof response { - StepResponse step_response = 1; - ResetResponse reset_response = 2; - RenderResponse render_response = 3; - CloseResponse close_response = 4; - GetSpacesResponse get_spaces_response = 5; - EnvMethodResponse env_method_response = 6; - GetAttrResponse get_attr_response = 7; - SetAttrResponse set_attr_response = 8; - IsWrappedResponse is_wrapped_response = 9; - } -} - -// Step command and response -message StepCommand { - bytes action = 1; // You might need to change the type based on your action space + rpc Step(StepRequest) returns (StepResponse); + rpc Reset(ResetRequest) returns (ResetResponse); + rpc Render(RenderRequest) returns (RenderResponse); + rpc Close(CloseRequest) returns (CloseResponse); + rpc GetSpaces(GetSpacesRequest) returns (GetSpacesResponse); + rpc EnvMethod(EnvMethodRequest) returns (EnvMethodResponse); + rpc GetAttr(GetAttrRequest) returns (GetAttrResponse); + rpc SetAttr(SetAttrRequest) returns (SetAttrResponse); + rpc IsWrapped(IsWrappedRequest) returns (IsWrappedResponse); +} + +message StepRequest { + bytes action = 1; } message StepResponse { - bytes observation = 1; // Change type based on observation space + bytes observation = 1; float reward = 2; bool done = 3; - bytes info = 4; // Info might be a map or another structured type - bytes reset_info = 5; // Include if applicable + bytes info = 4; + bytes reset_info = 5; } -// Reset command and response -message ResetCommand { +message ResetRequest { int32 seed = 1; - bytes options = 2; // Options, if any + bytes options = 2; } message ResetResponse { - bytes observation = 1; // Change type based on observation space - bytes reset_info = 2; // Include additional reset information if applicable + bytes observation = 1; + bytes reset_info = 2; } -// Render command and response -message RenderCommand { - // Include fields if render command needs parameters +message RenderRequest { + // Include fields if render Request needs parameters } message RenderResponse { - bytes render_data = 1; // This could be an image or other render output + bytes render_data = 1; } -// Close command and response (usually empty) -message CloseCommand {} +message CloseRequest {} message CloseResponse {} -// GetSpaces command and response -message GetSpacesCommand {} +message GetSpacesRequest {} message GetSpacesResponse { - bytes observation_space = 1; // This could be a description of the space - bytes action_space = 2; // Description of the action space + bytes observation_space = 1; + bytes action_space = 2; } -// EnvMethod command and response (general purpose) -message EnvMethodCommand { +message EnvMethodRequest { string method_name = 1; - bytes arguments = 2; // Arguments for the method, if any + bytes arguments = 2; } message EnvMethodResponse { - bytes result = 1; // The result of the method call + bytes result = 1; } -// GetAttr and SetAttr commands and responses -message GetAttrCommand { +// GetAttr and SetAttr Requests and responses +message GetAttrRequest { string attribute_name = 1; } message GetAttrResponse { - bytes attribute_value = 1; // The value of the attribute + bytes attribute_value = 1; } -message SetAttrCommand { +message SetAttrRequest { string attribute_name = 1; bytes attribute_value = 2; } @@ -106,8 +80,8 @@ message SetAttrResponse { // Usually empty, can include confirmation or result } -// IsWrapped command and response -message IsWrappedCommand { +// IsWrapped Request and response +message IsWrappedRequest { string wrapper_type = 1; } @@ -126,4 +100,4 @@ message RegisterEnvironmentRequest { message RegisterEnvironmentResponse { bool success = 1; -} \ No newline at end of file +}