From 6ba7cd2f13fbf09d711484d2e73f7f4ec8909973 Mon Sep 17 00:00:00 2001 From: Nimrod Gileadi Date: Tue, 12 Dec 2023 03:47:08 -0800 Subject: [PATCH 1/6] Add a Python method for setting mocap body poses by name. Create a grpc method that can be used to group multiple changes of state - to be used by a future Python API. PiperOrigin-RevId: 590142795 Change-Id: I02aff2b1f819f7ea28b5c81e0545cf92ce63b7d3 --- mjpc/grpc/agent.proto | 24 ++++++ mjpc/grpc/agent_service.cc | 12 +++ mjpc/grpc/agent_service.h | 4 + mjpc/grpc/grpc_agent_util.cc | 130 ++++++++++++++++++++++++++++---- mjpc/grpc/grpc_agent_util.h | 6 ++ mjpc/grpc/ui_agent_service.cc | 12 +++ mjpc/grpc/ui_agent_service.h | 4 + mjpc/tasks/particle/particle.cc | 29 +++++-- mjpc/tasks/particle/particle.h | 43 +++++++++-- mjpc/tasks/tasks.cc | 1 + python/mujoco_mpc/agent.py | 4 + python/mujoco_mpc/agent_test.py | 20 ++++- 12 files changed, 262 insertions(+), 27 deletions(-) diff --git a/mjpc/grpc/agent.proto b/mjpc/grpc/agent.proto index 5e15d8fa7..e20f357f9 100644 --- a/mjpc/grpc/agent.proto +++ b/mjpc/grpc/agent.proto @@ -48,6 +48,9 @@ service Agent { rpc GetMode(GetModeRequest) returns (GetModeResponse); // Get all modes. rpc GetAllModes(GetAllModesRequest) returns (GetAllModesResponse); + + // A single method that can set many of the inputs. + rpc SetAnything(SetAnythingRequest) returns (SetAnythingResponse); } message MjModel { @@ -175,3 +178,24 @@ message GetAllModesRequest {} message GetAllModesResponse { repeated string mode_names = 1; } + +message Pose { + repeated double pos = 1 [packed = true]; + repeated double quat = 2 [packed = true]; +} + +message SetAnythingRequest { + State state = 1; + + // map from parameter name to desired value + map parameters = 2; + // cost weights by name + map cost_weights = 3; + string mode = 4; + + // set the positions of mocap bodies by name. If `state` is set too, mocap + // positions will be set after the state is set. + map mocap = 5; +} + +message SetAnythingResponse {} diff --git a/mjpc/grpc/agent_service.cc b/mjpc/grpc/agent_service.cc index 957597b60..b7d7d9c8c 100644 --- a/mjpc/grpc/agent_service.cc +++ b/mjpc/grpc/agent_service.cc @@ -46,6 +46,8 @@ using ::agent::PlannerStepRequest; using ::agent::PlannerStepResponse; using ::agent::ResetRequest; using ::agent::ResetResponse; +using ::agent::SetAnythingRequest; +using ::agent::SetAnythingResponse; using ::agent::SetCostWeightsRequest; using ::agent::SetCostWeightsResponse; using ::agent::SetModeRequest; @@ -281,4 +283,14 @@ grpc::Status AgentService::GetAllModes(grpc::ServerContext* context, } return grpc_agent_util::GetAllModes(request, &agent_, response); } + +grpc::Status AgentService::SetAnything( + grpc::ServerContext* context, const SetAnythingRequest* request, + SetAnythingResponse* response) { + if (!Initialized()) { + return {grpc::StatusCode::FAILED_PRECONDITION, "Init not called."}; + } + return grpc_agent_util::SetAnything(request, &agent_, agent_.GetModel(), + data_, response); +} } // namespace mjpc::agent_grpc diff --git a/mjpc/grpc/agent_service.h b/mjpc/grpc/agent_service.h index fa036778c..57dcb0b61 100644 --- a/mjpc/grpc/agent_service.h +++ b/mjpc/grpc/agent_service.h @@ -104,6 +104,10 @@ class AgentService final : public agent::Agent::Service { const agent::GetAllModesRequest* request, agent::GetAllModesResponse* response) override; + grpc::Status SetAnything(grpc::ServerContext* context, + const agent::SetAnythingRequest* request, + agent::SetAnythingResponse* response) override; + private: bool Initialized() const { return data_ != nullptr; } diff --git a/mjpc/grpc/grpc_agent_util.cc b/mjpc/grpc/grpc_agent_util.cc index 78eca0c46..cc0fea137 100644 --- a/mjpc/grpc/grpc_agent_util.cc +++ b/mjpc/grpc/grpc_agent_util.cc @@ -14,6 +14,7 @@ #include "mjpc/grpc/grpc_agent_util.h" +#include #include #include #include @@ -35,6 +36,7 @@ #include "mjpc/agent.h" #include "mjpc/states/state.h" #include "mjpc/task.h" +#include "mjpc/utilities.h" namespace grpc_agent_util { @@ -49,6 +51,9 @@ using ::agent::GetModeResponse; using ::agent::GetStateResponse; using ::agent::GetTaskParametersRequest; using ::agent::GetTaskParametersResponse; +using ::agent::Pose; +using ::agent::SetAnythingRequest; +using ::agent::SetAnythingResponse; using ::agent::SetCostWeightsRequest; using ::agent::SetModeRequest; using ::agent::SetStateRequest; @@ -104,10 +109,9 @@ if (!(expr).ok()) { \ } \ } -grpc::Status SetState(const SetStateRequest* request, mjpc::Agent* agent, +namespace { +grpc::Status SetState(const agent::State& state, mjpc::Agent* agent, const mjModel* model, mjData* data) { - agent::State state = request->state(); - if (state.has_time()) data->time = state.time(); if (state.qpos_size() > 0) { @@ -144,6 +148,13 @@ grpc::Status SetState(const SetStateRequest* request, mjpc::Agent* agent, return grpc::Status::OK; } +} // namespace + +grpc::Status SetState(const SetStateRequest* request, mjpc::Agent* agent, + const mjModel* model, mjData* data) { + agent::State state = request->state(); + return SetState(state, agent, model, data); +} #undef CHECK_SIZE @@ -328,12 +339,10 @@ grpc::Status GetTaskParameters(const GetTaskParametersRequest* request, return grpc::Status::OK; } -grpc::Status SetCostWeights(const SetCostWeightsRequest* request, - mjpc::Agent* agent) { - if (request->reset_to_defaults()) { - agent->ActiveTask()->Reset(agent->GetModel()); - } - for (const auto& [name, weight] : request->cost_weights()) { +grpc::Status SetCostWeights( + const ::proto2::Map& cost_weights, + mjpc::Agent* agent) { + for (const auto& [name, weight] : cost_weights) { if (agent->SetWeightByName(name, weight) == -1) { std::ostringstream error_string; error_string << "Weight '" << name @@ -346,19 +355,29 @@ grpc::Status SetCostWeights(const SetCostWeightsRequest* request, agent_model->name_sensoradr[i]); error_string << " " << sensor_name << "\n"; } - return {grpc::StatusCode::INVALID_ARGUMENT, error_string.str()}; + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + error_string.str()); } } - return grpc::Status::OK; } -grpc::Status SetMode(const SetModeRequest* request, mjpc::Agent* agent) { - int outcome = agent->SetModeByName(request->mode()); +grpc::Status SetCostWeights(const SetCostWeightsRequest* request, + mjpc::Agent* agent) { + if (request->reset_to_defaults()) { + agent->ActiveTask()->Reset(agent->GetModel()); + } + auto cost_weights = request->cost_weights(); + return SetCostWeights(cost_weights, agent); +} + + +grpc::Status SetMode(std::string_view mode, mjpc::Agent* agent) { + int outcome = agent->SetModeByName(mode); if (outcome == -1) { std::vector mode_names = agent->GetAllModeNames(); std::ostringstream error_string; - error_string << "Mode '" << request->mode() + error_string << "Mode '" << mode << "' not found in task. Available names are:\n"; for (const auto& mode_name : mode_names) { error_string << " " << mode_name << "\n"; @@ -369,6 +388,10 @@ grpc::Status SetMode(const SetModeRequest* request, mjpc::Agent* agent) { } } +grpc::Status SetMode(const SetModeRequest* request, mjpc::Agent* agent) { + return SetMode(request->mode(), agent); +} + grpc::Status GetMode(const GetModeRequest* request, mjpc::Agent* agent, GetModeResponse* response) { response->set_mode(agent->GetModeName()); @@ -384,6 +407,83 @@ grpc::Status GetAllModes(const GetAllModesRequest* request, mjpc::Agent* agent, return grpc::Status::OK; } +namespace { +grpc::Status SetMocap(const ::proto2::Map& mocap, + mjpc::Agent* agent, const mjModel* model, mjData* data) { + // Check all names and poses before applying changes. + for (const auto& [name, pose] : mocap) { + int id = mj_name2id(model, mjOBJ_BODY, name.c_str()); + if (id < 0) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + absl::StrFormat("Body '%s' not found.", name)); + } + int mocap_id = model->body_mocapid[id]; + if (mocap_id < 0) { + return grpc::Status( + grpc::StatusCode::INVALID_ARGUMENT, + absl::StrFormat("Body '%s' is not a mocap body.", name)); + } + if (pose.pos_size() != 0 && pose.pos_size() != 3) { + return grpc::Status( + grpc::StatusCode::INVALID_ARGUMENT, + absl::StrFormat("Mocap '%s' has invalid pose size %d.", name, + pose.pos_size())); + } + if (pose.quat_size() != 0 && pose.quat_size() != 4) { + return grpc::Status( + grpc::StatusCode::INVALID_ARGUMENT, + absl::StrFormat("Mocap '%s' has invalid quat size %d.", name, + pose.pos_size())); + } + } + for (const auto& [name, pose] : mocap) { + int id = mj_name2id(model, mjOBJ_BODY, name.c_str()); + int mocap_id = model->body_mocapid[id]; + for (int i = 0; i < pose.pos_size(); i++) { + data->mocap_pos[3*mocap_id + i] = pose.pos(i); + } + if (pose.quat_size() == 4) { + for (int i = 0; i < 4; i++) { + data->mocap_quat[4*mocap_id + i] = pose.quat(i); + } + mju_normalize4(data->mocap_quat + 4*mocap_id); + } + } + agent->SetState(data); + return grpc::Status::OK; +} +} // namespace + +grpc::Status SetAnything(const SetAnythingRequest* request, mjpc::Agent* agent, + const mjModel* model, mjData* data, + SetAnythingResponse* response) { + if (request->has_state()) { + grpc::Status status = SetState(request->state(), agent, model, data); + if (!status.ok()) { + return status; + } + } + if (request->cost_weights_size() > 0) { + grpc::Status status = SetCostWeights(request->cost_weights(), agent); + if (!status.ok()) { + return status; + } + } + if (!request->mode().empty()) { + grpc::Status status = SetMode(request->mode(), agent); + if (!status.ok()) { + return status; + } + } + if (request->mocap_size() > 0) { + grpc::Status status = SetMocap(request->mocap(), agent, model, data); + if (!status.ok()) { + return status; + } + } + return grpc::Status::OK; +} + mjpc::UniqueMjModel LoadModelFromString(std::string_view xml, char* error, int error_size) { static constexpr char file[] = "temporary-filename.xml"; @@ -392,7 +492,7 @@ mjpc::UniqueMjModel LoadModelFromString(std::string_view xml, char* error, mj_defaultVFS(vfs.get()); mj_makeEmptyFileVFS(vfs.get(), file, xml.size()); int file_idx = mj_findFileVFS(vfs.get(), file); - memcpy(vfs->filedata[file_idx], xml.data(), xml.size()); + std::memcpy(vfs->filedata[file_idx], xml.data(), xml.size()); mjpc::UniqueMjModel m = {mj_loadXML(file, vfs.get(), error, error_size), mj_deleteModel}; mj_deleteFileVFS(vfs.get(), file); diff --git a/mjpc/grpc/grpc_agent_util.h b/mjpc/grpc/grpc_agent_util.h index 42814432f..98338ad9d 100644 --- a/mjpc/grpc/grpc_agent_util.h +++ b/mjpc/grpc/grpc_agent_util.h @@ -15,11 +15,14 @@ #ifndef MJPC_MJPC_GRPC_GRPC_AGENT_UTIL_H_ #define MJPC_MJPC_GRPC_GRPC_AGENT_UTIL_H_ +#include #include #include #include "mjpc/grpc/agent.pb.h" #include "mjpc/agent.h" +#include "mjpc/states/state.h" +#include "mjpc/utilities.h" namespace grpc_agent_util { grpc::Status GetState(const mjModel* model, const mjData* data, @@ -50,6 +53,9 @@ grpc::Status GetMode(const agent::GetModeRequest* request, mjpc::Agent* agent, grpc::Status GetAllModes(const agent::GetAllModesRequest* request, mjpc::Agent* agent, agent::GetAllModesResponse* response); +grpc::Status SetAnything(const agent::SetAnythingRequest* request, + mjpc::Agent* agent, const mjModel* model, mjData* data, + agent::SetAnythingResponse* response); mjpc::UniqueMjModel LoadModelFromString(std::string_view xml, char* error, int error_size); diff --git a/mjpc/grpc/ui_agent_service.cc b/mjpc/grpc/ui_agent_service.cc index 27f3cf6e6..125b60ed4 100644 --- a/mjpc/grpc/ui_agent_service.cc +++ b/mjpc/grpc/ui_agent_service.cc @@ -49,6 +49,8 @@ using ::agent::PlannerStepRequest; using ::agent::PlannerStepResponse; using ::agent::ResetRequest; using ::agent::ResetResponse; +using ::agent::SetAnythingRequest; +using ::agent::SetAnythingResponse; using ::agent::SetCostWeightsRequest; using ::agent::SetCostWeightsResponse; using ::agent::SetModeRequest; @@ -206,6 +208,16 @@ grpc::Status UiAgentService::GetMode(grpc::ServerContext* context, }); } +grpc::Status UiAgentService::SetAnything(grpc::ServerContext* context, + const SetAnythingRequest* request, + SetAnythingResponse* response) { + return RunBeforeStep(context, [request, response](mjpc::Agent* agent, + const mjModel* model, + mjData* data) { + return grpc_agent_util::SetAnything(request, agent, model, data, response); + }); +} + namespace { bool WaitUntilDeadline(const absl::Notification& notification, const grpc::ServerContext* context) { diff --git a/mjpc/grpc/ui_agent_service.h b/mjpc/grpc/ui_agent_service.h index ee32e9d7f..cd56cc90d 100644 --- a/mjpc/grpc/ui_agent_service.h +++ b/mjpc/grpc/ui_agent_service.h @@ -90,6 +90,10 @@ class UiAgentService final : public agent::Agent::Service { const agent::GetModeRequest* request, agent::GetModeResponse* response) override; + grpc::Status SetAnything(grpc::ServerContext* context, + const agent::SetAnythingRequest* request, + agent::SetAnythingResponse* response) override; + private: using StatusStepJob = absl::AnyInvocable; diff --git a/mjpc/tasks/particle/particle.cc b/mjpc/tasks/particle/particle.cc index e1651ac3d..58c157b13 100644 --- a/mjpc/tasks/particle/particle.cc +++ b/mjpc/tasks/particle/particle.cc @@ -17,7 +17,6 @@ #include #include -#include "mjpc/task.h" #include "mjpc/utilities.h" namespace mjpc { @@ -33,11 +32,10 @@ std::string Particle::Name() const { return "Particle"; } // Residual (1): velocity // Residual (2): control // -------------------------------------------- -void Particle::ResidualFn::Residual(const mjModel* model, const mjData* data, - double* residual) const { +namespace { +void ResidualImpl(const mjModel* model, const mjData* data, + const double goal[2], double* residual) { // ----- residual (0) ----- // - // some Lissajous curve - double goal[2] {0.25 * mju_sin(data->time), 0.25 * mju_cos(data->time/mjPI)}; double* position = SensorByName(model, data, "position"); mju_sub(residual, position, goal, model->nq); @@ -48,6 +46,14 @@ void Particle::ResidualFn::Residual(const mjModel* model, const mjData* data, // ----- residual (2) ----- // mju_copy(residual + 4, data->ctrl, model->nu); } +} // namespace + +void Particle::ResidualFn::Residual(const mjModel* model, const mjData* data, + double* residual) const { + // some Lissajous curve + double goal[2]{0.25 * mju_sin(data->time), 0.25 * mju_cos(data->time / mjPI)}; + ResidualImpl(model, data, goal, residual); +} void Particle::TransitionLocked(mjModel* model, mjData* data) { // some Lissajous curve @@ -57,4 +63,17 @@ void Particle::TransitionLocked(mjModel* model, mjData* data) { data->mocap_pos[0] = goal[0]; data->mocap_pos[1] = goal[1]; } + +std::string ParticleFixed::XmlPath() const { + return GetModelPath("particle/task_timevarying.xml"); +} +std::string ParticleFixed::Name() const { return "ParticleFixed"; } + +void ParticleFixed::ResidualFn::Residual(const mjModel* model, + const mjData* data, + double* residual) const { + double goal[2]{data->mocap_pos[0], data->mocap_pos[1]}; + ResidualImpl(model, data, goal, residual); +} + } // namespace mjpc diff --git a/mjpc/tasks/particle/particle.h b/mjpc/tasks/particle/particle.h index d59f184c0..e81a1995c 100644 --- a/mjpc/tasks/particle/particle.h +++ b/mjpc/tasks/particle/particle.h @@ -15,6 +15,7 @@ #ifndef MJPC_TASKS_PARTICLE_PARTICLE_H_ #define MJPC_TASKS_PARTICLE_PARTICLE_H_ +#include #include #include #include "mjpc/task.h" @@ -27,12 +28,12 @@ class Particle : public Task { class ResidualFn : public mjpc::BaseResidualFn { public: explicit ResidualFn(const Particle* task) : mjpc::BaseResidualFn(task) {} -// -------- Residuals for particle task ------- -// Number of residuals: 3 -// Residual (0): position - goal_position -// Residual (1): velocity -// Residual (2): control -// -------------------------------------------- + // -------- Residuals for particle task ------- + // Number of residuals: 3 + // Residual (0): position - goal_position + // Residual (1): velocity + // Residual (2): control + // -------------------------------------------- void Residual(const mjModel* model, const mjData* data, double* residual) const override; }; @@ -48,6 +49,36 @@ class Particle : public Task { private: ResidualFn residual_; }; + +// The same task, but the goal mocap body doesn't move. +class ParticleFixed : public Task { + public: + std::string Name() const override; + std::string XmlPath() const override; + class ResidualFn : public mjpc::BaseResidualFn { + public: + explicit ResidualFn(const ParticleFixed* task) + : mjpc::BaseResidualFn(task) {} + // -------- Residuals for particle task ------- + // Number of residuals: 3 + // Residual (0): position - goal_position + // Residual (1): velocity + // Residual (2): control + // -------------------------------------------- + void Residual(const mjModel* model, const mjData* data, + double* residual) const override; + }; + ParticleFixed() : residual_(this) {} + + protected: + std::unique_ptr ResidualLocked() const override { + return std::make_unique(this); + } + ResidualFn* InternalResidual() override { return &residual_; } + + private: + ResidualFn residual_; +}; } // namespace mjpc #endif // MJPC_TASKS_PARTICLE_PARTICLE_H_ diff --git a/mjpc/tasks/tasks.cc b/mjpc/tasks/tasks.cc index 9b1c773e4..cc41e4646 100644 --- a/mjpc/tasks/tasks.cc +++ b/mjpc/tasks/tasks.cc @@ -48,6 +48,7 @@ std::vector> GetTasks() { // DEEPMIND INTERNAL TASKS std::make_shared(), std::make_shared(), + std::make_shared(), std::make_shared(), std::make_shared(), std::make_shared(), diff --git a/python/mujoco_mpc/agent.py b/python/mujoco_mpc/agent.py index 2b3163e00..2a2f2c0a8 100644 --- a/python/mujoco_mpc/agent.py +++ b/python/mujoco_mpc/agent.py @@ -359,3 +359,7 @@ def set_parameters(self, parameters: mjpc_parameters.MjpcParameters): self.set_task_parameters(parameters.task_parameters) if parameters.cost_weights: self.set_cost_weights(parameters.cost_weights) + + def set_mocap(self, mocap_map: Mapping[str, agent_pb2.Pose]): + request = agent_pb2.SetAnythingRequest(mocap=mocap_map) + self.stub.SetAnything(request) diff --git a/python/mujoco_mpc/agent_test.py b/python/mujoco_mpc/agent_test.py index 93de03f8f..77f0aa140 100644 --- a/python/mujoco_mpc/agent_test.py +++ b/python/mujoco_mpc/agent_test.py @@ -22,6 +22,7 @@ import numpy as np import pathlib +from mujoco_mpc.proto import agent_pb2 def get_observation(model, data): @@ -338,7 +339,7 @@ def test_get_all_modes(self): ) @absltest.skip("asset import issue") - def test_set_mode_error(self): + def test_set_mode_error(self):# model_path = ( pathlib.Path(__file__).parent.parent.parent / "mjpc/tasks/quadruped/task_flat.xml" @@ -366,6 +367,23 @@ def test_set_task_parameters_from_another_agent(self): self.assertEqual(agent.get_task_parameters()["Goal"], 14) + def test_set_mocap(self): + model_path = ( + pathlib.Path(__file__).parent.parent.parent + / "mjpc/tasks/particle/task_timevarying.xml" + ) + model = mujoco.MjModel.from_xml_path(str(model_path)) + with agent_lib.Agent(task_id="ParticleFixed", model=model) as agent: + pose = agent_pb2.Pose(pos=[13, 14, 15], quat=[1, 1, 1, 1]) + agent.set_mocap({"goal": pose}) + final_state = agent.get_state() + self.assertEqual(final_state.mocap_pos, pose.pos) + self.assertEqual( + final_state.mocap_quat, + [0.5, 0.5, 0.5, 0.5], + "quaternions should be normalized", + ) + if __name__ == "__main__": absltest.main() From 80b816df6c8f2cb551f189e0a88e1d770454a609 Mon Sep 17 00:00:00 2001 From: Nimrod Gileadi Date: Tue, 12 Dec 2023 06:13:17 -0800 Subject: [PATCH 2/6] Use google::protobuf:: instead of proto2:: PiperOrigin-RevId: 590178986 Change-Id: I963bc1ef597923caf5bdf4792af5034991f9626c --- mjpc/grpc/grpc_agent_util.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mjpc/grpc/grpc_agent_util.cc b/mjpc/grpc/grpc_agent_util.cc index cc0fea137..9d38973d3 100644 --- a/mjpc/grpc/grpc_agent_util.cc +++ b/mjpc/grpc/grpc_agent_util.cc @@ -340,7 +340,7 @@ grpc::Status GetTaskParameters(const GetTaskParametersRequest* request, } grpc::Status SetCostWeights( - const ::proto2::Map& cost_weights, + const ::google::protobuf::Map& cost_weights, mjpc::Agent* agent) { for (const auto& [name, weight] : cost_weights) { if (agent->SetWeightByName(name, weight) == -1) { @@ -408,7 +408,7 @@ grpc::Status GetAllModes(const GetAllModesRequest* request, mjpc::Agent* agent, } namespace { -grpc::Status SetMocap(const ::proto2::Map& mocap, +grpc::Status SetMocap(const ::google::protobuf::Map& mocap, mjpc::Agent* agent, const mjModel* model, mjData* data) { // Check all names and poses before applying changes. for (const auto& [name, pose] : mocap) { From de62c1969af68408fbba3d2674707a0dea45107d Mon Sep 17 00:00:00 2001 From: Tom Erez Date: Thu, 14 Dec 2023 04:49:17 -0800 Subject: [PATCH 3/6] Use keyframe's ctrl to initialize the nominal policy. PiperOrigin-RevId: 590898044 Change-Id: I13401da5386868de7fbcb665c079ce67ab24f268 --- mjpc/agent.cc | 4 ++-- mjpc/agent.h | 2 +- mjpc/app.cc | 12 ++++++++---- mjpc/planners/gradient/planner.cc | 16 +++++++++++----- mjpc/planners/gradient/planner.h | 6 ++++-- mjpc/planners/gradient/policy.cc | 12 ++++++++++-- mjpc/planners/gradient/policy.h | 6 ++++-- mjpc/planners/ilqg/planner.cc | 17 +++++++++++------ mjpc/planners/ilqg/planner.h | 3 ++- mjpc/planners/ilqg/policy.cc | 4 ++-- mjpc/planners/ilqg/policy.h | 4 +++- mjpc/planners/ilqs/planner.cc | 6 +++--- mjpc/planners/ilqs/planner.h | 3 ++- mjpc/planners/planner.h | 3 ++- mjpc/planners/policy.h | 3 ++- mjpc/planners/robust/robust_planner.cc | 11 +++++++++-- mjpc/planners/robust/robust_planner.h | 9 ++++++++- mjpc/planners/sampling/planner.cc | 19 +++++++++++++------ mjpc/planners/sampling/planner.h | 3 ++- mjpc/planners/sampling/policy.cc | 16 +++++++++++++--- mjpc/planners/sampling/policy.h | 3 ++- .../planners/robust/robust_planner_test.cc | 17 +++++++++++++++-- mjpc/test/testdata/particle_task.xml | 3 ++- mjpc/trajectory.cc | 12 +++++++++--- mjpc/trajectory.h | 4 ++-- 25 files changed, 142 insertions(+), 56 deletions(-) diff --git a/mjpc/agent.cc b/mjpc/agent.cc index 412fca3c2..047503c3c 100644 --- a/mjpc/agent.cc +++ b/mjpc/agent.cc @@ -159,10 +159,10 @@ void Agent::Allocate() { } // reset data, settings, planners, state -void Agent::Reset() { +void Agent::Reset(const double* initial_repeated_action) { // planner for (const auto& planner : planners_) { - planner->Reset(kMaxTrajectoryHorizon); + planner->Reset(kMaxTrajectoryHorizon, initial_repeated_action); } // state diff --git a/mjpc/agent.h b/mjpc/agent.h index 55066e405..58b5ea44f 100644 --- a/mjpc/agent.h +++ b/mjpc/agent.h @@ -65,7 +65,7 @@ class Agent { void Allocate(); // reset data, settings, planners, states - void Reset(); + void Reset(const double* initial_repeated_action = nullptr); // single planner iteration void PlanIteration(ThreadPool* pool); diff --git a/mjpc/app.cc b/mjpc/app.cc index 54ec263fd..d39c92712 100644 --- a/mjpc/app.cc +++ b/mjpc/app.cc @@ -232,12 +232,16 @@ void PhysicsLoop(mj::Simulate& sim) { sim.agent->plot_enabled = absl::GetFlag(FLAGS_show_plot); sim.agent->plan_enabled = absl::GetFlag(FLAGS_planner_enabled); sim.agent->Allocate(); - sim.agent->Reset(); - sim.agent->PlotInitialize(); // set home keyframe int home_id = mj_name2id(mnew, mjOBJ_KEY, "home"); - if (home_id >= 0) mj_resetDataKeyframe(mnew, dnew, home_id); + if (home_id >= 0) { + mj_resetDataKeyframe(mnew, dnew, home_id); + sim.agent->Reset(dnew->ctrl); + } else { + sim.agent->Reset(); + } + sim.agent->PlotInitialize(); sim.Load(mnew, dnew, sim.filename, true); m = mnew; @@ -305,7 +309,7 @@ void PhysicsLoop(mj::Simulate& sim) { double slowdown = 100 / sim.percentRealTime[sim.real_time_index]; // misalignment condition: distance from target sim time is bigger - // than syncmisalign + // than maximum misalignment `syncMisalign` bool misaligned = mju_abs(Seconds(elapsedCPU).count() / slowdown - elapsedSim) > syncMisalign; diff --git a/mjpc/planners/gradient/planner.cc b/mjpc/planners/gradient/planner.cc index caa090a63..2ce00d071 100644 --- a/mjpc/planners/gradient/planner.cc +++ b/mjpc/planners/gradient/planner.cc @@ -16,15 +16,20 @@ #include #include -#include +#include +#include #include "mjpc/array_safety.h" #include "mjpc/planners/cost_derivatives.h" #include "mjpc/planners/gradient/gradient.h" #include "mjpc/planners/gradient/policy.h" #include "mjpc/planners/gradient/settings.h" +#include "mjpc/planners/gradient/spline_mapping.h" #include "mjpc/planners/model_derivatives.h" +#include "mjpc/planners/planner.h" #include "mjpc/states/state.h" +#include "mjpc/task.h" +#include "mjpc/threadpool.h" #include "mjpc/trajectory.h" #include "mjpc/utilities.h" @@ -103,7 +108,8 @@ void GradientPlanner::Allocate() { } // reset memory to zeros -void GradientPlanner::Reset(int horizon) { +void GradientPlanner::Reset(int horizon, + const double* initial_repeated_action) { // state std::fill(state.begin(), state.end(), 0.0); std::fill(mocap.begin(), mocap.end(), 0.0); @@ -122,10 +128,10 @@ void GradientPlanner::Reset(int horizon) { // policy for (int i = 0; i < kMaxTrajectory; i++) { - candidate_policy[i].Reset(horizon); + candidate_policy[i].Reset(horizon, initial_repeated_action); } - policy.Reset(horizon); - previous_policy.Reset(horizon); + policy.Reset(horizon, initial_repeated_action); + previous_policy.Reset(horizon, initial_repeated_action); // scratch std::fill(parameters_scratch.begin(), parameters_scratch.end(), 0.0); diff --git a/mjpc/planners/gradient/planner.h b/mjpc/planners/gradient/planner.h index 1b8f3a05a..f2ae52522 100644 --- a/mjpc/planners/gradient/planner.h +++ b/mjpc/planners/gradient/planner.h @@ -15,7 +15,6 @@ #ifndef MJPC_PLANNERS_GRADIENT_OPTIMIZER_H_ #define MJPC_PLANNERS_GRADIENT_OPTIMIZER_H_ -#include #include #include #include @@ -29,6 +28,8 @@ #include "mjpc/planners/model_derivatives.h" #include "mjpc/planners/planner.h" #include "mjpc/states/state.h" +#include "mjpc/task.h" +#include "mjpc/threadpool.h" #include "mjpc/trajectory.h" namespace mjpc { @@ -52,7 +53,8 @@ class GradientPlanner : public Planner { void Allocate() override; // reset memory to zeros - void Reset(int horizon) override; + void Reset(int horizon, + const double* initial_repeated_action = nullptr) override; // set state void SetState(const State& state) override; diff --git a/mjpc/planners/gradient/policy.cc b/mjpc/planners/gradient/policy.cc index e92f01ed6..fd83d8e79 100644 --- a/mjpc/planners/gradient/policy.cc +++ b/mjpc/planners/gradient/policy.cc @@ -53,11 +53,19 @@ void GradientPolicy::Allocate(const mjModel* model, const Task& task, } // reset memory to zeros -void GradientPolicy::Reset(int horizon) { +void GradientPolicy::Reset(int horizon, const double* initial_repeated_action) { std::fill(k.begin(), k.begin() + horizon * model->nu, 0.0); // parameters - std::fill(parameters.begin(), parameters.begin() + model->nu * horizon, 0.0); + if (initial_repeated_action != nullptr) { + for (int i = 0; i < horizon; ++i) { + mju_copy(parameters.data() + i * model->nu, initial_repeated_action, + model->nu); + } + } else { + std::fill(parameters.begin(), + parameters.begin() + model->nu * horizon, 0.0); + } std::fill(parameter_update.begin(), parameter_update.begin() + model->nu * horizon, 0.0); diff --git a/mjpc/planners/gradient/policy.h b/mjpc/planners/gradient/policy.h index 390169b38..93c69b295 100644 --- a/mjpc/planners/gradient/policy.h +++ b/mjpc/planners/gradient/policy.h @@ -17,8 +17,9 @@ #include +#include #include "mjpc/planners/policy.h" -#include "mjpc/trajectory.h" +#include "mjpc/task.h" namespace mjpc { @@ -37,7 +38,8 @@ class GradientPolicy : public Policy { void Allocate(const mjModel* model, const Task& task, int horizon) override; // reset memory to zeros - void Reset(int horizon) override; + void Reset(int horizon, + const double* initial_repeated_action = nullptr) override; // compute action from policy // state is not used diff --git a/mjpc/planners/ilqg/planner.cc b/mjpc/planners/ilqg/planner.cc index 5238ec510..708bba6b5 100644 --- a/mjpc/planners/ilqg/planner.cc +++ b/mjpc/planners/ilqg/planner.cc @@ -16,15 +16,20 @@ #include #include +#include +#include #include -#include +#include +#include #include "mjpc/array_safety.h" #include "mjpc/planners/ilqg/backward_pass.h" #include "mjpc/planners/ilqg/policy.h" #include "mjpc/planners/ilqg/settings.h" #include "mjpc/planners/planner.h" #include "mjpc/states/state.h" +#include "mjpc/task.h" +#include "mjpc/threadpool.h" #include "mjpc/trajectory.h" #include "mjpc/utilities.h" @@ -97,7 +102,7 @@ void iLQGPlanner::Allocate() { } // reset memory to zeros -void iLQGPlanner::Reset(int horizon) { +void iLQGPlanner::Reset(int horizon, const double* initial_repeated_action) { // state std::fill(state.begin(), state.end(), 0.0); std::fill(mocap.begin(), mocap.end(), 0.0); @@ -115,15 +120,15 @@ void iLQGPlanner::Reset(int horizon) { backward_pass.Reset(dim_state_derivative, dim_action, horizon); // policy - policy.Reset(horizon); - previous_policy.Reset(horizon); + policy.Reset(horizon, initial_repeated_action); + previous_policy.Reset(horizon, initial_repeated_action); for (int i = 0; i < kMaxTrajectory; i++) { - candidate_policy[i].Reset(horizon); + candidate_policy[i].Reset(horizon, initial_repeated_action); } // candidate trajectories for (int i = 0; i < kMaxTrajectory; i++) { - trajectory[i].Reset(horizon); + trajectory[i].Reset(horizon, initial_repeated_action); } // values diff --git a/mjpc/planners/ilqg/planner.h b/mjpc/planners/ilqg/planner.h index ed2681225..b675c1c52 100644 --- a/mjpc/planners/ilqg/planner.h +++ b/mjpc/planners/ilqg/planner.h @@ -41,7 +41,8 @@ class iLQGPlanner : public Planner { void Allocate() override; // reset memory to zeros - void Reset(int horizon) override; + void Reset(int horizon, + const double* initial_repeated_action = nullptr) override; // set state void SetState(const State& state) override; diff --git a/mjpc/planners/ilqg/policy.cc b/mjpc/planners/ilqg/policy.cc index facc5161b..349347970 100644 --- a/mjpc/planners/ilqg/policy.cc +++ b/mjpc/planners/ilqg/policy.cc @@ -57,8 +57,8 @@ void iLQGPolicy::Allocate(const mjModel* model, const Task& task, int horizon) { } // reset memory to zeros -void iLQGPolicy::Reset(int horizon) { - trajectory.Reset(horizon); +void iLQGPolicy::Reset(int horizon, const double* initial_repeated_action) { + trajectory.Reset(horizon, initial_repeated_action); std::fill( feedback_gain.begin(), feedback_gain.begin() + horizon * model->nu * (2 * model->nv + model->na), diff --git a/mjpc/planners/ilqg/policy.h b/mjpc/planners/ilqg/policy.h index 942e1dfad..4dab73d2a 100644 --- a/mjpc/planners/ilqg/policy.h +++ b/mjpc/planners/ilqg/policy.h @@ -17,6 +17,7 @@ #include +#include #include "mjpc/planners/policy.h" #include "mjpc/task.h" #include "mjpc/trajectory.h" @@ -36,7 +37,8 @@ class iLQGPolicy : public Policy { void Allocate(const mjModel* model, const Task& task, int horizon) override; // reset memory to zeros - void Reset(int horizon) override; + void Reset(int horizon, + const double* initial_repeated_action = nullptr) override; // set action from policy // if state == nullptr, return the nominal action without a feedback term diff --git a/mjpc/planners/ilqs/planner.cc b/mjpc/planners/ilqs/planner.cc index 4ed2d7198..716a0f00e 100644 --- a/mjpc/planners/ilqs/planner.cc +++ b/mjpc/planners/ilqs/planner.cc @@ -56,12 +56,12 @@ void iLQSPlanner::Allocate() { } // reset memory to zeros -void iLQSPlanner::Reset(int horizon) { +void iLQSPlanner::Reset(int horizon, const double* initial_repeated_action) { // Sampling - sampling.Reset(horizon); + sampling.Reset(horizon, initial_repeated_action); // iLQG - ilqg.Reset(horizon); + ilqg.Reset(horizon, initial_repeated_action); // active_policy active_policy = kSampling; diff --git a/mjpc/planners/ilqs/planner.h b/mjpc/planners/ilqs/planner.h index 3a60d01bd..424b942b3 100644 --- a/mjpc/planners/ilqs/planner.h +++ b/mjpc/planners/ilqs/planner.h @@ -52,7 +52,8 @@ class iLQSPlanner : public Planner { void Allocate() override; // reset memory to zeros - void Reset(int horizon) override; + void Reset(int horizon, + const double* initial_repeated_action = nullptr) override; // set state void SetState(const State& state) override; diff --git a/mjpc/planners/planner.h b/mjpc/planners/planner.h index 2a8a280d6..5755bf7d6 100644 --- a/mjpc/planners/planner.h +++ b/mjpc/planners/planner.h @@ -40,7 +40,8 @@ class Planner { virtual void Allocate() = 0; // reset memory to zeros - virtual void Reset(int horizon) = 0; + virtual void Reset(int horizon, + const double* initial_repeated_action = nullptr) = 0; // set state virtual void SetState(const State& state) = 0; diff --git a/mjpc/planners/policy.h b/mjpc/planners/policy.h index 887358ff8..ebf349a62 100644 --- a/mjpc/planners/policy.h +++ b/mjpc/planners/policy.h @@ -38,7 +38,8 @@ class Policy { int horizon) = 0; // reset memory to zeros - virtual void Reset(int horizon) = 0; + virtual void Reset(int horizon, + const double* initial_repeated_action = nullptr) = 0; // set action from policy // for policies that have a feedback term, passing nullptr for state turns diff --git a/mjpc/planners/robust/robust_planner.cc b/mjpc/planners/robust/robust_planner.cc index cdf2be44c..68e291a85 100644 --- a/mjpc/planners/robust/robust_planner.cc +++ b/mjpc/planners/robust/robust_planner.cc @@ -13,9 +13,16 @@ // limitations under the License. #include "mjpc/planners/robust/robust_planner.h" +#include +#include #include "mjpc/array_safety.h" +#include "mjpc/planners/planner.h" +#include "mjpc/states/state.h" +#include "mjpc/task.h" +#include "mjpc/threadpool.h" #include "mjpc/trajectory.h" +#include "mjpc/utilities.h" namespace mjpc { @@ -63,8 +70,8 @@ void RobustPlanner::Allocate() { ResizeTrajectories(ncandidates_ * nrepetitions_); } -void RobustPlanner::Reset(int horizon) { - delegate_->Reset(horizon); +void RobustPlanner::Reset(int horizon, const double* initial_repeated_action) { + delegate_->Reset(horizon, initial_repeated_action); // state std::fill(state_.begin(), state_.end(), 0.0); std::fill(mocap_.begin(), mocap_.end(), 0.0); diff --git a/mjpc/planners/robust/robust_planner.h b/mjpc/planners/robust/robust_planner.h index 66bc6c5b2..0880c4a44 100644 --- a/mjpc/planners/robust/robust_planner.h +++ b/mjpc/planners/robust/robust_planner.h @@ -22,9 +22,15 @@ #define MJPC_MJPC_PLANNERS_ROBUST_ROBUST_PLANNER_H_ #include +#include #include +#include #include "mjpc/planners/planner.h" +#include "mjpc/states/state.h" +#include "mjpc/task.h" +#include "mjpc/threadpool.h" +#include "mjpc/trajectory.h" namespace mjpc { @@ -36,7 +42,8 @@ class RobustPlanner : public Planner { void Initialize(mjModel* model, const Task& task) override; void Allocate() override; - void Reset(int horizon) override; + void Reset(int horizon, + const double* initial_repeated_action = nullptr) override; void SetState(const State& state) override; void OptimizePolicy(int horizon, ThreadPool& pool) override; void NominalTrajectory(int horizon, ThreadPool& pool) override; diff --git a/mjpc/planners/sampling/planner.cc b/mjpc/planners/sampling/planner.cc index 2e9c2a4d9..8a3766eb9 100644 --- a/mjpc/planners/sampling/planner.cc +++ b/mjpc/planners/sampling/planner.cc @@ -16,14 +16,16 @@ #include #include -#include #include #include #include #include "mjpc/array_safety.h" +#include "mjpc/planners/planner.h" #include "mjpc/planners/sampling/policy.h" #include "mjpc/states/state.h" +#include "mjpc/task.h" +#include "mjpc/threadpool.h" #include "mjpc/trajectory.h" #include "mjpc/utilities.h" @@ -97,7 +99,8 @@ void SamplingPlanner::Allocate() { } // reset memory to zeros -void SamplingPlanner::Reset(int horizon) { +void SamplingPlanner::Reset(int horizon, + const double* initial_repeated_action) { // state std::fill(state.begin(), state.end(), 0.0); std::fill(mocap.begin(), mocap.end(), 0.0); @@ -105,8 +108,8 @@ void SamplingPlanner::Reset(int horizon) { time = 0.0; // policy parameters - policy.Reset(horizon); - previous_policy.Reset(horizon); + policy.Reset(horizon, initial_repeated_action); + previous_policy.Reset(horizon, initial_repeated_action); // scratch std::fill(parameters_scratch.begin(), parameters_scratch.end(), 0.0); @@ -118,11 +121,15 @@ void SamplingPlanner::Reset(int horizon) { // trajectory samples for (int i = 0; i < kMaxTrajectory; i++) { trajectory[i].Reset(kMaxTrajectoryHorizon); - candidate_policy[i].Reset(horizon); + candidate_policy[i].Reset(horizon, initial_repeated_action); } for (const auto& d : data_) { - mju_zero(d->ctrl, model->nu); + if (initial_repeated_action) { + mju_copy(d->ctrl, initial_repeated_action, model->nu); + } else { + mju_zero(d->ctrl, model->nu); + } } // noise gradient diff --git a/mjpc/planners/sampling/planner.h b/mjpc/planners/sampling/planner.h index 62ee22a6b..18f6f0283 100644 --- a/mjpc/planners/sampling/planner.h +++ b/mjpc/planners/sampling/planner.h @@ -53,7 +53,8 @@ class SamplingPlanner : public RankedPlanner { void Allocate() override; // reset memory to zeros - void Reset(int horizon) override; + void Reset(int horizon, + const double* initial_repeated_action = nullptr) override; // set state void SetState(const State& state) override; diff --git a/mjpc/planners/sampling/policy.cc b/mjpc/planners/sampling/policy.cc index 294392f7f..f99cbfa36 100644 --- a/mjpc/planners/sampling/policy.cc +++ b/mjpc/planners/sampling/policy.cc @@ -15,8 +15,11 @@ #include "mjpc/planners/sampling/policy.h" #include +#include #include +#include "mjpc/planners/policy.h" +#include "mjpc/task.h" #include "mjpc/trajectory.h" #include "mjpc/utilities.h" @@ -47,10 +50,17 @@ void SamplingPolicy::Allocate(const mjModel* model, const Task& task, } // reset memory to zeros -void SamplingPolicy::Reset(int horizon) { +void SamplingPolicy::Reset(int horizon, const double* initial_repeated_action) { // parameters - std::fill(parameters.begin(), parameters.begin() + model->nu * horizon, 0.0); - + if (initial_repeated_action != nullptr) { + for (int i = 0; i < num_spline_points; ++i) { + mju_copy(parameters.data() + i * model->nu, initial_repeated_action, + model->nu); + } + } else { + std::fill(parameters.begin(), + parameters.begin() + model->nu * num_spline_points, 0.0); + } // policy parameter times std::fill(times.begin(), times.begin() + horizon, 0.0); } diff --git a/mjpc/planners/sampling/policy.h b/mjpc/planners/sampling/policy.h index e100ba981..453c7a3bc 100644 --- a/mjpc/planners/sampling/policy.h +++ b/mjpc/planners/sampling/policy.h @@ -39,7 +39,8 @@ class SamplingPolicy : public Policy { void Allocate(const mjModel* model, const Task& task, int horizon) override; // reset memory to zeros - void Reset(int horizon) override; + void Reset(int horizon, + const double* initial_repeated_action = nullptr) override; // set action from policy void Action(double* action, const double* state, double time) const override; diff --git a/mjpc/test/planners/robust/robust_planner_test.cc b/mjpc/test/planners/robust/robust_planner_test.cc index 241b16656..83aedb6c8 100644 --- a/mjpc/test/planners/robust/robust_planner_test.cc +++ b/mjpc/test/planners/robust/robust_planner_test.cc @@ -12,15 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "gtest/gtest.h" #include #include "mjpc/planners/robust/robust_planner.h" #include "mjpc/planners/sampling/planner.h" #include "mjpc/states/state.h" -#include "mjpc/task.h" #include "mjpc/test/load.h" #include "mjpc/test/testdata/particle_residual.h" #include "mjpc/threadpool.h" +#include "mjpc/trajectory.h" namespace mjpc { namespace { @@ -49,6 +51,9 @@ TEST(RobustPlannerTest, RandomSearch) { // create data mjData* data = mj_makeData(model); + // the "home" keyframe initializes the state too far from the target + int home_id = mj_name2id(model, mjOBJ_KEY, "ctrl_test"); + mj_resetDataKeyframe(model, data, home_id); // ----- state ----- // state.Initialize(model); @@ -60,7 +65,15 @@ TEST(RobustPlannerTest, RandomSearch) { RobustPlanner planner(std::make_unique()); planner.Initialize(model, task); planner.Allocate(); - planner.Reset(kMaxTrajectoryHorizon); + // If there's no keyframe, data->ctrl will be zeros, so this is always safe. + planner.Reset(kMaxTrajectoryHorizon, data->ctrl); + + double res[2]; + // look at some arbitrary, hard-coded time: + planner.ActionFromPolicy(res, state.state().data(), 2); + // expected values copied from the keyframe in the xml: + EXPECT_NEAR(res[0], 0.1, 1.0e-4); + EXPECT_NEAR(res[1], 0.2, 1.0e-4); // ----- settings ----- // int iterations = 1000; diff --git a/mjpc/test/testdata/particle_task.xml b/mjpc/test/testdata/particle_task.xml index 36e769661..bf3e81157 100644 --- a/mjpc/test/testdata/particle_task.xml +++ b/mjpc/test/testdata/particle_task.xml @@ -33,6 +33,7 @@ - + + diff --git a/mjpc/trajectory.cc b/mjpc/trajectory.cc index 9446e98eb..365353d02 100644 --- a/mjpc/trajectory.cc +++ b/mjpc/trajectory.cc @@ -62,12 +62,18 @@ void Trajectory::Allocate(int T) { } // reset memory to zeros -void Trajectory::Reset(int T) { +void Trajectory::Reset(int T, const double* initial_repeated_action) { // states std::fill(states.begin(), states.begin() + dim_state * T, 0.0); - // actions - std::fill(actions.begin(), actions.begin() + dim_action * T, 0.0); + if (initial_repeated_action != nullptr) { + for (int i = 0; i < T; ++i) { + mju_copy(actions.data() + i * dim_action, initial_repeated_action, + dim_action); + } + } else { + std::fill(actions.begin(), actions.begin() + dim_action * T, 0.0); + } // times std::fill(times.begin(), times.begin() + T, 0.0); diff --git a/mjpc/trajectory.h b/mjpc/trajectory.h index dd554505e..afc345680 100644 --- a/mjpc/trajectory.h +++ b/mjpc/trajectory.h @@ -46,8 +46,8 @@ class Trajectory { // allocate memory void Allocate(int T); - // reset memory to zeros - void Reset(int T); + // reset memory to zeros (and perhaps a non-zero action) + void Reset(int T, const double* initial_repeated_action = nullptr); // simulate model forward in time with continuous-time indexed policy void Rollout( From 81419d5d7ae8ff23a6ba228b139bc1173320b5bc Mon Sep 17 00:00:00 2001 From: Nimrod Gileadi Date: Wed, 20 Dec 2023 00:31:13 -0800 Subject: [PATCH 4/6] Add Agent API method for changing mocap body poses. Add mocap information to MjpcParameters dataclass. PiperOrigin-RevId: 592467189 Change-Id: I47c059dd43315792ffc9a397293ec61abb0a6608 --- python/mujoco_mpc/agent.py | 9 +++++++-- python/mujoco_mpc/mjpc_parameters.py | 12 +++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/python/mujoco_mpc/agent.py b/python/mujoco_mpc/agent.py index 2a2f2c0a8..9cdd310e9 100644 --- a/python/mujoco_mpc/agent.py +++ b/python/mujoco_mpc/agent.py @@ -360,6 +360,11 @@ def set_parameters(self, parameters: mjpc_parameters.MjpcParameters): if parameters.cost_weights: self.set_cost_weights(parameters.cost_weights) - def set_mocap(self, mocap_map: Mapping[str, agent_pb2.Pose]): - request = agent_pb2.SetAnythingRequest(mocap=mocap_map) + def set_mocap(self, mocap_map: Mapping[str, mjpc_parameters.Pose]): + request = agent_pb2.SetAnythingRequest() + for key, value in mocap_map.items(): + if value.pos is not None: + request.mocap[key].pos.extend(value.pos) + if value.quat is not None: + request.mocap[key].quat.extend(value.quat) self.stub.SetAnything(request) diff --git a/python/mujoco_mpc/mjpc_parameters.py b/python/mujoco_mpc/mjpc_parameters.py index 129d7df93..8f257c6ac 100644 --- a/python/mujoco_mpc/mjpc_parameters.py +++ b/python/mujoco_mpc/mjpc_parameters.py @@ -3,12 +3,22 @@ import dataclasses from typing import Optional, Union +import numpy as np + + +@dataclasses.dataclass(frozen=True) +class Pose: + pos: np.ndarray | None # 3D vector + quat: np.ndarray | None # Unit quaternion + @dataclasses.dataclass(frozen=True) class MjpcParameters: - """Dataclass to store and set task mode, task parameters and cost weights.""" + """Dataclass to store and set task settings.""" mode: Optional[str] = None task_parameters: dict[str, Union[str, float]] = dataclasses.field( default_factory=dict ) cost_weights: dict[str, float] = dataclasses.field(default_factory=dict) + # A map from body name to pose + mocap: dict[str, Pose] = dataclasses.field(default_factory=dict) From 65ce09997cb77b3f6e96a12e9e37590eaec4a1e8 Mon Sep 17 00:00:00 2001 From: Yuval Tassa Date: Fri, 5 Jan 2024 06:14:30 -0800 Subject: [PATCH 5/6] Add videos to README, fixes #234 PiperOrigin-RevId: 595976688 Change-Id: I7059d8887baef1c8e8a9d4e4e27bd941ddad119f --- README.md | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 8675de55b..9ae535df6 100644 --- a/README.md +++ b/README.md @@ -42,10 +42,30 @@ For a quick video overview of MJPC, click below. [![Video](http://img.youtube.com/vi/Bdx7DuAMB6o/hqdefault.jpg)](https://dpmd.ai/mjpc) -For a longer talk at the MIT Robotics Seminar describing our results, click +For a longer talk at the MIT Robotics Seminar in December 2022 describing our results, click below. -[![Talk](http://img.youtube.com/vi/2xVN-qY78P4/hqdefault.jpg)](https://www.youtube.com/watch?v=2xVN-qY78P4) +[![2022Talk](http://img.youtube.com/vi/2xVN-qY78P4/hqdefault.jpg)](https://www.youtube.com/watch?v=2xVN-qY78P4) + +A more recent, December 2023 talk at the IEEE Technical Committee on Model-Based Optimization +is available here: + +[![2023Talk](https://img.youtube.com/vi/J-JO-lgaKtw/hqdefault.jpg)](https://www.youtube.com/watch?v=J-JO-lgaKtw&t=0s) + +### Example tasks + +Quadruped task: + +[![Quadruped](http://img.youtube.com/vi/esLuwaWz4oE/hqdefault.jpg)](https://www.youtube.com/watch?v=esLuwaWz4oE) + +Rubik's cube 10-move unscramble: + +[![Unscramble](http://img.youtube.com/vi/ZRRvVWV-Muk/hqdefault.jpg)](https://www.youtube.com/watch?v=ZRRvVWV-Muk) + +Humanoid motion-capture tracking: + +[![Tracking](http://img.youtube.com/vi/tEBVK-MO1Sw/hqdefault.jpg)](https://www.youtube.com/watch?v=tEBVK-MO1Sw) + ## Graphical User Interface @@ -81,7 +101,7 @@ Note, we are using `clang-14`. # Python API -We provide a simple Python API for MJPC. This API is still experimental and expects some more experience from its users. For example, the correct usage requires that the model (defined in Python) and the MJPC task (i.e., the residual and transition functions defined in C++) are compatible with each other. Currently, the Python API does not provide any particular error handling for verifying this compatibilty and may be difficult to debug without more in-depth knowedge about mujoco and MJPC. +We provide a simple Python API for MJPC. This API is still experimental and expects some more experience from its users. For example, the correct usage requires that the model (defined in Python) and the MJPC task (i.e., the residual and transition functions defined in C++) are compatible with each other. Currently, the Python API does not provide any particular error handling for verifying this compatibility and may be difficult to debug without more in-depth knowledge about mujoco and MJPC. - [agent.py](python/mujoco_mpc/agent.py) for available methods for planning. From 00a19eff2771c339d1e8519f68a3b3dd713c8cfb Mon Sep 17 00:00:00 2001 From: Tom Erez Date: Mon, 8 Jan 2024 10:02:00 -0800 Subject: [PATCH 6/6] correct order of includes. PiperOrigin-RevId: 596626522 Change-Id: Ic5b617bddaeb3d553a6f8f147db8127375cdd1e3 --- mjpc/tasks/tasks.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mjpc/tasks/tasks.h b/mjpc/tasks/tasks.h index 61b72d07b..59555fa41 100644 --- a/mjpc/tasks/tasks.h +++ b/mjpc/tasks/tasks.h @@ -17,11 +17,11 @@ #ifndef MJPC_TASKS_TASKS_H_ #define MJPC_TASKS_TASKS_H_ -#include "mjpc/task.h" - #include #include +#include "mjpc/task.h" + namespace mjpc { std::vector> GetTasks(); } // namespace mjpc