From 6b138f58c0c1e648f72a8a56d67508de3c73110c Mon Sep 17 00:00:00 2001 From: Copybara-Service Date: Tue, 5 Dec 2023 06:02:19 -0800 Subject: [PATCH] Implement `GetAllModes` to obtain a list of all available modes. PiperOrigin-RevId: 588044219 Change-Id: Iad6c03685454d059b46da39025ee1e9e4177f896 --- mjpc/grpc/agent.proto | 8 ++++++++ mjpc/grpc/agent_service.cc | 16 +++++++++++++--- mjpc/grpc/agent_service.h | 4 ++++ mjpc/grpc/agent_service_test.cc | 27 +++++++++++++++++++++++++++ mjpc/grpc/grpc_agent_util.cc | 11 +++++++++++ mjpc/grpc/grpc_agent_util.h | 4 ++++ python/mujoco_mpc/agent.py | 3 +++ python/mujoco_mpc/agent_test.py | 12 ++++++++++++ 8 files changed, 82 insertions(+), 3 deletions(-) diff --git a/mjpc/grpc/agent.proto b/mjpc/grpc/agent.proto index 4fb9f02d2..5e15d8fa7 100644 --- a/mjpc/grpc/agent.proto +++ b/mjpc/grpc/agent.proto @@ -46,6 +46,8 @@ service Agent { rpc SetMode(SetModeRequest) returns (SetModeResponse); // Get mode. rpc GetMode(GetModeRequest) returns (GetModeResponse); + // Get all modes. + rpc GetAllModes(GetAllModesRequest) returns (GetAllModesResponse); } message MjModel { @@ -167,3 +169,9 @@ message SetModeRequest { } message SetModeResponse {} + +message GetAllModesRequest {} + +message GetAllModesResponse { + repeated string mode_names = 1; +} diff --git a/mjpc/grpc/agent_service.cc b/mjpc/grpc/agent_service.cc index 79b1713b8..7677acf2e 100644 --- a/mjpc/grpc/agent_service.cc +++ b/mjpc/grpc/agent_service.cc @@ -22,8 +22,8 @@ #include #include #include - #include "mjpc/grpc/agent.pb.h" +#include "mjpc/grpc/agent.proto.h" #include "mjpc/grpc/grpc_agent_util.h" #include "mjpc/task.h" @@ -31,8 +31,12 @@ namespace mjpc::agent_grpc { using ::agent::GetActionRequest; using ::agent::GetActionResponse; +using ::agent::GetAllModesRequest; +using ::agent::GetAllModesResponse; using ::agent::GetCostValuesAndWeightsRequest; using ::agent::GetCostValuesAndWeightsResponse; +using ::agent::GetModeRequest; +using ::agent::GetModeResponse; using ::agent::GetStateRequest; using ::agent::GetStateResponse; using ::agent::GetTaskParametersRequest; @@ -47,8 +51,6 @@ using ::agent::SetCostWeightsRequest; using ::agent::SetCostWeightsResponse; using ::agent::SetModeRequest; using ::agent::SetModeResponse; -using ::agent::GetModeRequest; -using ::agent::GetModeResponse; using ::agent::SetStateRequest; using ::agent::SetStateResponse; using ::agent::SetTaskParametersRequest; @@ -272,4 +274,12 @@ grpc::Status AgentService::GetMode(grpc::ServerContext* context, return grpc_agent_util::GetMode(request, &agent_, response); } +grpc::Status AgentService::GetAllModes(grpc::ServerContext* context, + const GetAllModesRequest* request, + GetAllModesResponse* response) { + if (!Initialized()) { + return {grpc::StatusCode::FAILED_PRECONDITION, "Init not called."}; + } + return grpc_agent_util::GetAllModes(request, &agent_, response); +} } // namespace mjpc::agent_grpc diff --git a/mjpc/grpc/agent_service.h b/mjpc/grpc/agent_service.h index 3a9283fc9..bec86d3df 100644 --- a/mjpc/grpc/agent_service.h +++ b/mjpc/grpc/agent_service.h @@ -100,6 +100,10 @@ class AgentService final : public agent::Agent::Service { const agent::GetModeRequest* request, agent::GetModeResponse* response) override; + grpc::Status GetAllModes(grpc::ServerContext* context, + const agent::GetAllModesRequest* request, + agent::GetAllModesResponse* response); + private: bool Initialized() const { return data_ != nullptr; } diff --git a/mjpc/grpc/agent_service_test.cc b/mjpc/grpc/agent_service_test.cc index 9b56107b3..bf65c7498 100644 --- a/mjpc/grpc/agent_service_test.cc +++ b/mjpc/grpc/agent_service_test.cc @@ -333,4 +333,31 @@ TEST_F(AgentServiceTest, SetCostWeights_RejectsInvalidName) { << "Error message should contain the list of cost term names."; } +TEST_F(AgentServiceTest, GetMode_Works) { + RunAndCheckInit("Cartpole", nullptr); + + grpc::ClientContext context; + + agent::GetModeRequest request; + agent::GetModeResponse response; + grpc::Status status = stub->GetMode(&context, request, &response); + + EXPECT_TRUE(status.ok()); + EXPECT_EQ(response.mode(), "default_mode"); +} + +TEST_F(AgentServiceTest, GetAllModes_Works) { + RunAndCheckInit("Cartpole", nullptr); + + grpc::ClientContext context; + + agent::GetAllModesRequest request; + agent::GetAllModesResponse response; + grpc::Status status = stub->GetAllModes(&context, request, &response); + + EXPECT_TRUE(status.ok()); + EXPECT_EQ(response.mode_names().size(), 1); + EXPECT_EQ(response.mode_names()[0], "default_mode"); +} + } // namespace mjpc::agent_grpc diff --git a/mjpc/grpc/grpc_agent_util.cc b/mjpc/grpc/grpc_agent_util.cc index 0b138c347..78eca0c46 100644 --- a/mjpc/grpc/grpc_agent_util.cc +++ b/mjpc/grpc/grpc_agent_util.cc @@ -40,6 +40,8 @@ namespace grpc_agent_util { using ::agent::GetActionRequest; using ::agent::GetActionResponse; +using ::agent::GetAllModesRequest; +using ::agent::GetAllModesResponse; using ::agent::GetCostValuesAndWeightsRequest; using ::agent::GetCostValuesAndWeightsResponse; using ::agent::GetModeRequest; @@ -373,6 +375,15 @@ grpc::Status GetMode(const GetModeRequest* request, mjpc::Agent* agent, return grpc::Status::OK; } +grpc::Status GetAllModes(const GetAllModesRequest* request, mjpc::Agent* agent, + GetAllModesResponse* response) { + std::vector mode_names = agent->GetAllModeNames(); + for (const auto& mode_name : mode_names) { + response->add_mode_names(mode_name); + } + return grpc::Status::OK; +} + mjpc::UniqueMjModel LoadModelFromString(std::string_view xml, char* error, int error_size) { static constexpr char file[] = "temporary-filename.xml"; diff --git a/mjpc/grpc/grpc_agent_util.h b/mjpc/grpc/grpc_agent_util.h index e67a16a2b..42814432f 100644 --- a/mjpc/grpc/grpc_agent_util.h +++ b/mjpc/grpc/grpc_agent_util.h @@ -47,6 +47,10 @@ grpc::Status SetMode(const agent::SetModeRequest* request, mjpc::Agent* agent); grpc::Status GetMode(const agent::GetModeRequest* request, mjpc::Agent* agent, agent::GetModeResponse* response); +grpc::Status GetAllModes(const agent::GetAllModesRequest* request, + mjpc::Agent* agent, + agent::GetAllModesResponse* response); + mjpc::UniqueMjModel LoadModelFromString(std::string_view xml, char* error, int error_size); mjpc::UniqueMjModel LoadModelFromBytes(std::string_view mjb); diff --git a/python/mujoco_mpc/agent.py b/python/mujoco_mpc/agent.py index c560d49ed..2b3163e00 100644 --- a/python/mujoco_mpc/agent.py +++ b/python/mujoco_mpc/agent.py @@ -348,6 +348,9 @@ def set_mode(self, mode: str): request = agent_pb2.SetModeRequest(mode=mode) self.stub.SetMode(request) + def get_all_modes(self) -> Sequence[str]: + return self.stub.GetAllModes(agent_pb2.GetAllModesRequest()).mode_names + def set_parameters(self, parameters: mjpc_parameters.MjpcParameters): # TODO(nimrod): Add a single RPC that does this if parameters.mode is not None: diff --git a/python/mujoco_mpc/agent_test.py b/python/mujoco_mpc/agent_test.py index 231add92a..93de03f8f 100644 --- a/python/mujoco_mpc/agent_test.py +++ b/python/mujoco_mpc/agent_test.py @@ -325,6 +325,18 @@ def test_get_set_mode(self): agent.set_mode("Walk") self.assertEqual(agent.get_mode(), "Walk") + def test_get_all_modes(self): + model_path = ( + pathlib.Path(__file__).parent.parent.parent + / "mjpc/tasks/quadruped/task_flat.xml" + ) + model = mujoco.MjModel.from_xml_path(str(model_path)) + with agent_lib.Agent(task_id="Quadruped Flat", model=model) as agent: + self.assertEqual( + tuple(agent.get_all_modes()), + ("Quadruped", "Biped", "Walk", "Scramble", "Flip"), + ) + @absltest.skip("asset import issue") def test_set_mode_error(self): model_path = (