Skip to content

Commit

Permalink
Implement GetAllModes to obtain a list of all available modes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588044219
Change-Id: Iad6c03685454d059b46da39025ee1e9e4177f896
  • Loading branch information
copybara-github committed Dec 5, 2023
1 parent 324e9eb commit 6b138f5
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 3 deletions.
8 changes: 8 additions & 0 deletions mjpc/grpc/agent.proto
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ service Agent {
rpc SetMode(SetModeRequest) returns (SetModeResponse);
// Get mode.
rpc GetMode(GetModeRequest) returns (GetModeResponse);
// Get all modes.
rpc GetAllModes(GetAllModesRequest) returns (GetAllModesResponse);
}

message MjModel {
Expand Down Expand Up @@ -167,3 +169,9 @@ message SetModeRequest {
}

message SetModeResponse {}

message GetAllModesRequest {}

message GetAllModesResponse {
repeated string mode_names = 1;
}
16 changes: 13 additions & 3 deletions mjpc/grpc/agent_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,21 @@
#include <grpcpp/server_context.h>
#include <grpcpp/support/status.h>
#include <mujoco/mujoco.h>

#include "mjpc/grpc/agent.pb.h"
#include "mjpc/grpc/agent.proto.h"
#include "mjpc/grpc/grpc_agent_util.h"
#include "mjpc/task.h"

namespace mjpc::agent_grpc {

using ::agent::GetActionRequest;
using ::agent::GetActionResponse;
using ::agent::GetAllModesRequest;
using ::agent::GetAllModesResponse;
using ::agent::GetCostValuesAndWeightsRequest;
using ::agent::GetCostValuesAndWeightsResponse;
using ::agent::GetModeRequest;
using ::agent::GetModeResponse;
using ::agent::GetStateRequest;
using ::agent::GetStateResponse;
using ::agent::GetTaskParametersRequest;
Expand All @@ -47,8 +51,6 @@ using ::agent::SetCostWeightsRequest;
using ::agent::SetCostWeightsResponse;
using ::agent::SetModeRequest;
using ::agent::SetModeResponse;
using ::agent::GetModeRequest;
using ::agent::GetModeResponse;
using ::agent::SetStateRequest;
using ::agent::SetStateResponse;
using ::agent::SetTaskParametersRequest;
Expand Down Expand Up @@ -272,4 +274,12 @@ grpc::Status AgentService::GetMode(grpc::ServerContext* context,
return grpc_agent_util::GetMode(request, &agent_, response);
}

grpc::Status AgentService::GetAllModes(grpc::ServerContext* context,
const GetAllModesRequest* request,
GetAllModesResponse* response) {
if (!Initialized()) {
return {grpc::StatusCode::FAILED_PRECONDITION, "Init not called."};
}
return grpc_agent_util::GetAllModes(request, &agent_, response);
}
} // namespace mjpc::agent_grpc
4 changes: 4 additions & 0 deletions mjpc/grpc/agent_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ class AgentService final : public agent::Agent::Service {
const agent::GetModeRequest* request,
agent::GetModeResponse* response) override;

grpc::Status GetAllModes(grpc::ServerContext* context,
const agent::GetAllModesRequest* request,
agent::GetAllModesResponse* response);

private:
bool Initialized() const { return data_ != nullptr; }

Expand Down
27 changes: 27 additions & 0 deletions mjpc/grpc/agent_service_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,4 +333,31 @@ TEST_F(AgentServiceTest, SetCostWeights_RejectsInvalidName) {
<< "Error message should contain the list of cost term names.";
}

TEST_F(AgentServiceTest, GetMode_Works) {
RunAndCheckInit("Cartpole", nullptr);

grpc::ClientContext context;

agent::GetModeRequest request;
agent::GetModeResponse response;
grpc::Status status = stub->GetMode(&context, request, &response);

EXPECT_TRUE(status.ok());
EXPECT_EQ(response.mode(), "default_mode");
}

TEST_F(AgentServiceTest, GetAllModes_Works) {
RunAndCheckInit("Cartpole", nullptr);

grpc::ClientContext context;

agent::GetAllModesRequest request;
agent::GetAllModesResponse response;
grpc::Status status = stub->GetAllModes(&context, request, &response);

EXPECT_TRUE(status.ok());
EXPECT_EQ(response.mode_names().size(), 1);
EXPECT_EQ(response.mode_names()[0], "default_mode");
}

} // namespace mjpc::agent_grpc
11 changes: 11 additions & 0 deletions mjpc/grpc/grpc_agent_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ namespace grpc_agent_util {

using ::agent::GetActionRequest;
using ::agent::GetActionResponse;
using ::agent::GetAllModesRequest;
using ::agent::GetAllModesResponse;
using ::agent::GetCostValuesAndWeightsRequest;
using ::agent::GetCostValuesAndWeightsResponse;
using ::agent::GetModeRequest;
Expand Down Expand Up @@ -373,6 +375,15 @@ grpc::Status GetMode(const GetModeRequest* request, mjpc::Agent* agent,
return grpc::Status::OK;
}

grpc::Status GetAllModes(const GetAllModesRequest* request, mjpc::Agent* agent,
GetAllModesResponse* response) {
std::vector<std::string> mode_names = agent->GetAllModeNames();
for (const auto& mode_name : mode_names) {
response->add_mode_names(mode_name);
}
return grpc::Status::OK;
}

mjpc::UniqueMjModel LoadModelFromString(std::string_view xml, char* error,
int error_size) {
static constexpr char file[] = "temporary-filename.xml";
Expand Down
4 changes: 4 additions & 0 deletions mjpc/grpc/grpc_agent_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ grpc::Status SetMode(const agent::SetModeRequest* request, mjpc::Agent* agent);
grpc::Status GetMode(const agent::GetModeRequest* request, mjpc::Agent* agent,
agent::GetModeResponse* response);

grpc::Status GetAllModes(const agent::GetAllModesRequest* request,
mjpc::Agent* agent,
agent::GetAllModesResponse* response);

mjpc::UniqueMjModel LoadModelFromString(std::string_view xml, char* error,
int error_size);
mjpc::UniqueMjModel LoadModelFromBytes(std::string_view mjb);
Expand Down
3 changes: 3 additions & 0 deletions python/mujoco_mpc/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ def set_mode(self, mode: str):
request = agent_pb2.SetModeRequest(mode=mode)
self.stub.SetMode(request)

def get_all_modes(self) -> Sequence[str]:
return self.stub.GetAllModes(agent_pb2.GetAllModesRequest()).mode_names

def set_parameters(self, parameters: mjpc_parameters.MjpcParameters):
# TODO(nimrod): Add a single RPC that does this
if parameters.mode is not None:
Expand Down
12 changes: 12 additions & 0 deletions python/mujoco_mpc/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,18 @@ def test_get_set_mode(self):
agent.set_mode("Walk")
self.assertEqual(agent.get_mode(), "Walk")

def test_get_all_modes(self):
model_path = (
pathlib.Path(__file__).parent.parent.parent
/ "mjpc/tasks/quadruped/task_flat.xml"
)
model = mujoco.MjModel.from_xml_path(str(model_path))
with agent_lib.Agent(task_id="Quadruped Flat", model=model) as agent:
self.assertEqual(
tuple(agent.get_all_modes()),
("Quadruped", "Biped", "Walk", "Scramble", "Flip"),
)

@absltest.skip("asset import issue")
def test_set_mode_error(self):
model_path = (
Expand Down

0 comments on commit 6b138f5

Please sign in to comment.