Skip to content

Commit

Permalink
Update the RPC stuff to a cleaner API.
Browse files Browse the repository at this point in the history
  • Loading branch information
cgokmen committed Nov 18, 2023
1 parent 26ed23b commit ef2afe5
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 126 deletions.
125 changes: 62 additions & 63 deletions rl/service/environment_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,71 +9,70 @@ class EnvironmentServicer(environment_pb2_grpc.EnvironmentServicer):
def __init__(self, env) -> None:
self.env = env

def ManageEnvironment(
self, request: environment_pb2.EnvironmentRequest, unused_context
) -> environment_pb2.EnvironmentResponse:
response = environment_pb2.EnvironmentResponse()
def Step(self, request, unused_context):
action = pickle.loads(request.action)
observation, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
info["TimeLimit.truncated"] = truncated and not terminated

if request.WhichOneOf("command") == "step":
action = pickle.loads(request.step.action)
observation, reward, terminated, truncated, info = self.env.step(action)
# convert to SB3 VecEnv api
done = terminated or truncated
info["TimeLimit.truncated"] = truncated and not terminated
if done:
# save final observation where user can get it, then reset
info["terminal_observation"] = observation
observation, reset_info = self._env.reset()
subresponse = response.step_response
subresponse.observation = pickle.dumps(observation)
subresponse.reward = reward
subresponse.done = done
subresponse.info = pickle.dumps(info)
subresponse.reset_info = pickle.dumps(reset_info)
elif request.WhichOneOf("command") == "reset":
seed = request.reset.seed
maybe_options = {"options": pickle.loads(request.reset.options)} if request.reset.options else {}
observation, reset_info = self.env.reset(seed=seed, **maybe_options)
subresponse = response.reset_response
subresponse.observation = pickle.dumps(observation)
subresponse.reset_info = pickle.dumps(reset_info)
elif request.WhichOneOf("command") == "render":
image = self.env.render()
subresponse = response.render_response
subresponse.render_data = pickle.dumps(image)
elif request.WhichOneOf("command") == "close":
self._env.close()
response.close_response.SetInParent()
elif request.WhichOneOf("command") == "get_spaces":
subresponse = response.get_spaces_response
subresponse.observation_space = pickle.dumps(self.env.observation_space)
subresponse.action_space = pickle.dumps(self.env.action_space)
elif request.WhichOneOf("command") == "env_method":
method_name = request.env_method.method_name
args, kwargs = pickle.arguments(request.env_method.args)
method = getattr(self.env, method_name)
result = method(*args, **kwargs)
subresponse = response.env_method_response
subresponse.result = pickle.dumps(result)
elif request.WhichOneOf("command") == "get_attr":
attr = request.get_attr.attribute_name
result = getattr(self.env, attr)
subresponse = response.get_attr_response
subresponse.attribute_value = pickle.dumps(result)
elif request.WhichOneOf("command") == "set_attr":
attr = request.get_attr.attribute_name
val = pickle.loads(request.set_attr.attribute_value)
result = setattr(self.env, attr, val)
response.set_attr_response.SetInParent()
elif request.WhichOneOf("command") == "is_wrapped":
wrapper_type = request.is_wrapped.wrapper_type
result = is_wrapped(self.env, wrapper_type)
subresponse = response.is_wrapped_response
subresponse.is_wrapped = result
else:
raise NotImplementedError(f"Invalid request is not implemented in the worker")
if done:
info["terminal_observation"] = observation
observation, reset_info = self.env.reset()

return environment_pb2.StepResponse(
observation=pickle.dumps(observation),
reward=reward,
done=done,
info=pickle.dumps(info),
reset_info=pickle.dumps(reset_info)
)

def Reset(self, request, unused_context):
seed = request.seed
maybe_options = {"options": pickle.loads(request.options)} if request.options else {}
observation, reset_info = self.env.reset(seed=seed, **maybe_options)

return response
return environment_pb2.ResetResponse(
observation=pickle.dumps(observation),
reset_info=pickle.dumps(reset_info)
)

def Render(self, request, unused_context):
image = self.env.render()
return environment_pb2.RenderResponse(render_data=pickle.dumps(image))

def Close(self, request, unused_context):
self.env.close()
return environment_pb2.CloseResponse()

def GetSpaces(self, request, unused_context):
return environment_pb2.GetSpacesResponse(
observation_space=pickle.dumps(self.env.observation_space),
action_space=pickle.dumps(self.env.action_space)
)

def EnvMethod(self, request, unused_context):
method_name = request.method_name
args, kwargs = pickle.loads(request.arguments)
method = getattr(self.env, method_name)
result = method(*args, **kwargs)
return environment_pb2.EnvMethodResponse(result=pickle.dumps(result))

def GetAttr(self, request, unused_context):
attr = request.attribute_name
result = getattr(self.env, attr)
return environment_pb2.GetAttrResponse(attribute_value=pickle.dumps(result))

def SetAttr(self, request, unused_context):
attr = request.attribute_name
val = pickle.loads(request.attribute_value)
setattr(self.env, attr, val)
return environment_pb2.SetAttrResponse()

def IsWrapped(self, request, unused_context):
wrapper_type = request.wrapper_type
is_wrapped = hasattr(self.env, wrapper_type) # Assuming is_wrapped is implemented as hasattr
return environment_pb2.IsWrappedResponse(is_wrapped=is_wrapped)

async def serve(env):
server = grpc.aio.server()
Expand Down
100 changes: 37 additions & 63 deletions rl/service/protos/environment.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,101 +3,75 @@ syntax = "proto3";
package environment;

service EnvironmentService {
rpc ManageEnvironment(EnvironmentRequest) returns (EnvironmentResponse);
}

message EnvironmentRequest {
oneof command {
StepCommand step = 1;
ResetCommand reset = 2;
RenderCommand render = 3;
CloseCommand close = 4;
GetSpacesCommand get_spaces = 5;
EnvMethodCommand env_method = 6;
GetAttrCommand get_attr = 7;
SetAttrCommand set_attr = 8;
IsWrappedCommand is_wrapped = 9;
}
}

message EnvironmentResponse {
oneof response {
StepResponse step_response = 1;
ResetResponse reset_response = 2;
RenderResponse render_response = 3;
CloseResponse close_response = 4;
GetSpacesResponse get_spaces_response = 5;
EnvMethodResponse env_method_response = 6;
GetAttrResponse get_attr_response = 7;
SetAttrResponse set_attr_response = 8;
IsWrappedResponse is_wrapped_response = 9;
}
}

// Step command and response
message StepCommand {
bytes action = 1; // You might need to change the type based on your action space
rpc Step(StepRequest) returns (StepResponse);
rpc Reset(ResetRequest) returns (ResetResponse);
rpc Render(RenderRequest) returns (RenderResponse);
rpc Close(CloseRequest) returns (CloseResponse);
rpc GetSpaces(GetSpacesRequest) returns (GetSpacesResponse);
rpc EnvMethod(EnvMethodRequest) returns (EnvMethodResponse);
rpc GetAttr(GetAttrRequest) returns (GetAttrResponse);
rpc SetAttr(SetAttrRequest) returns (SetAttrResponse);
rpc IsWrapped(IsWrappedRequest) returns (IsWrappedResponse);
}

message StepRequest {
bytes action = 1;
}

message StepResponse {
bytes observation = 1; // Change type based on observation space
bytes observation = 1;
float reward = 2;
bool done = 3;
bytes info = 4; // Info might be a map or another structured type
bytes reset_info = 5; // Include if applicable
bytes info = 4;
bytes reset_info = 5;
}

// Reset command and response
message ResetCommand {
message ResetRequest {
int32 seed = 1;
bytes options = 2; // Options, if any
bytes options = 2;
}

message ResetResponse {
bytes observation = 1; // Change type based on observation space
bytes reset_info = 2; // Include additional reset information if applicable
bytes observation = 1;
bytes reset_info = 2;
}

// Render command and response
message RenderCommand {
// Include fields if render command needs parameters
message RenderRequest {
// Include fields if render Request needs parameters
}

message RenderResponse {
bytes render_data = 1; // This could be an image or other render output
bytes render_data = 1;
}

// Close command and response (usually empty)
message CloseCommand {}
message CloseRequest {}
message CloseResponse {}

// GetSpaces command and response
message GetSpacesCommand {}
message GetSpacesRequest {}
message GetSpacesResponse {
bytes observation_space = 1; // This could be a description of the space
bytes action_space = 2; // Description of the action space
bytes observation_space = 1;
bytes action_space = 2;
}

// EnvMethod command and response (general purpose)
message EnvMethodCommand {
message EnvMethodRequest {
string method_name = 1;
bytes arguments = 2; // Arguments for the method, if any
bytes arguments = 2;
}

message EnvMethodResponse {
bytes result = 1; // The result of the method call
bytes result = 1;
}

// GetAttr and SetAttr commands and responses
message GetAttrCommand {
// GetAttr and SetAttr Requests and responses
message GetAttrRequest {
string attribute_name = 1;
}

message GetAttrResponse {
bytes attribute_value = 1; // The value of the attribute
bytes attribute_value = 1;
}

message SetAttrCommand {
message SetAttrRequest {
string attribute_name = 1;
bytes attribute_value = 2;
}
Expand All @@ -106,8 +80,8 @@ message SetAttrResponse {
// Usually empty, can include confirmation or result
}

// IsWrapped command and response
message IsWrappedCommand {
// IsWrapped Request and response
message IsWrappedRequest {
string wrapper_type = 1;
}

Expand All @@ -126,4 +100,4 @@ message RegisterEnvironmentRequest {

message RegisterEnvironmentResponse {
bool success = 1;
}
}

0 comments on commit ef2afe5

Please sign in to comment.