Skip to content

Commit

Permalink
Merge branch 'main' into improve_docs
Browse files Browse the repository at this point in the history
  • Loading branch information
erez-tom authored Jan 12, 2024
2 parents 05d292f + fd9c479 commit 4572f3d
Show file tree
Hide file tree
Showing 42 changed files with 472 additions and 108 deletions.
26 changes: 23 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,30 @@ For a quick video overview of MJPC, click below.

[![Video](http://img.youtube.com/vi/Bdx7DuAMB6o/hqdefault.jpg)](https://dpmd.ai/mjpc)

For a longer talk at the MIT Robotics Seminar describing our results, click
For a longer talk at the MIT Robotics Seminar in December 2022 describing our results, click
below.

[![Talk](http://img.youtube.com/vi/2xVN-qY78P4/hqdefault.jpg)](https://www.youtube.com/watch?v=2xVN-qY78P4)
[![2022Talk](http://img.youtube.com/vi/2xVN-qY78P4/hqdefault.jpg)](https://www.youtube.com/watch?v=2xVN-qY78P4)

A more recent, December 2023 talk at the IEEE Technical Committee on Model-Based Optimization
is available here:

[![2023Talk](https://img.youtube.com/vi/J-JO-lgaKtw/hqdefault.jpg)](https://www.youtube.com/watch?v=J-JO-lgaKtw&t=0s)

### Example tasks

Quadruped task:

[![Quadruped](http://img.youtube.com/vi/esLuwaWz4oE/hqdefault.jpg)](https://www.youtube.com/watch?v=esLuwaWz4oE)

Rubik's cube 10-move unscramble:

[![Unscramble](http://img.youtube.com/vi/ZRRvVWV-Muk/hqdefault.jpg)](https://www.youtube.com/watch?v=ZRRvVWV-Muk)

Humanoid motion-capture tracking:

[![Tracking](http://img.youtube.com/vi/tEBVK-MO1Sw/hqdefault.jpg)](https://www.youtube.com/watch?v=tEBVK-MO1Sw)


## Graphical User Interface

Expand Down Expand Up @@ -81,7 +101,7 @@ Note, we are using `clang-14`.

# Python API

We provide a simple Python API for MJPC. This API is still experimental and expects some more experience from its users. For example, the correct usage requires that the model (defined in Python) and the MJPC task (i.e., the residual and transition functions defined in C++) are compatible with each other. Currently, the Python API does not provide any particular error handling for verifying this compatibilty and may be difficult to debug without more in-depth knowedge about MuJoCo and MJPC.
We provide a simple Python API for MJPC. This API is still experimental and expects some more experience from its users. For example, the correct usage requires that the model (defined in Python) and the MJPC task (i.e., the residual and transition functions defined in C++) are compatible with each other. Currently, the Python API does not provide any particular error handling for verifying this compatibility and may be difficult to debug without more in-depth knowledge about mujoco and MJPC.

- [agent.py](python/mujoco_mpc/agent.py) for available methods for planning.

Expand Down
2 changes: 1 addition & 1 deletion docs/OVERVIEW.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ As an example, consider these snippets specifying the Swimmer task:
</sensor>
```
```c++
void Swimmer::ResidualFn::Residual(const mjModel* model,
void Swimmer::ResidualFn::Residual(const mjModel* model,
const mjData* data,
double* residual) const {
// initialize counter
Expand Down
4 changes: 2 additions & 2 deletions mjpc/agent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ void Agent::Allocate() {
}

// reset data, settings, planners, state
void Agent::Reset() {
void Agent::Reset(const double* initial_repeated_action) {
// planner
for (const auto& planner : planners_) {
planner->Reset(kMaxTrajectoryHorizon);
planner->Reset(kMaxTrajectoryHorizon, initial_repeated_action);
}

// state
Expand Down
2 changes: 1 addition & 1 deletion mjpc/agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Agent {
void Allocate();

// reset data, settings, planners, states
void Reset();
void Reset(const double* initial_repeated_action = nullptr);

// single planner iteration
void PlanIteration(ThreadPool* pool);
Expand Down
12 changes: 8 additions & 4 deletions mjpc/app.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,16 @@ void PhysicsLoop(mj::Simulate& sim) {
sim.agent->plot_enabled = absl::GetFlag(FLAGS_show_plot);
sim.agent->plan_enabled = absl::GetFlag(FLAGS_planner_enabled);
sim.agent->Allocate();
sim.agent->Reset();
sim.agent->PlotInitialize();

// set home keyframe
int home_id = mj_name2id(mnew, mjOBJ_KEY, "home");
if (home_id >= 0) mj_resetDataKeyframe(mnew, dnew, home_id);
if (home_id >= 0) {
mj_resetDataKeyframe(mnew, dnew, home_id);
sim.agent->Reset(dnew->ctrl);
} else {
sim.agent->Reset();
}
sim.agent->PlotInitialize();

sim.Load(mnew, dnew, sim.filename, true);
m = mnew;
Expand Down Expand Up @@ -305,7 +309,7 @@ void PhysicsLoop(mj::Simulate& sim) {
double slowdown = 100 / sim.percentRealTime[sim.real_time_index];

// misalignment condition: distance from target sim time is bigger
// than syncmisalign
// than maximum misalignment `syncMisalign`
bool misaligned = mju_abs(Seconds(elapsedCPU).count() / slowdown -
elapsedSim) > syncMisalign;

Expand Down
28 changes: 27 additions & 1 deletion mjpc/grpc/agent.proto
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,13 @@ service Agent {
rpc GetMode(GetModeRequest) returns (GetModeResponse);
// Get all modes.
rpc GetAllModes(GetAllModesRequest) returns (GetAllModesResponse);

// Get best trajectory (states, actions, times).
rpc GetBestTrajectory(GetBestTrajectoryRequest) returns (GetBestTrajectoryResponse);
rpc GetBestTrajectory(GetBestTrajectoryRequest)
returns (GetBestTrajectoryResponse);

// A single method that can set many of the inputs.
rpc SetAnything(SetAnythingRequest) returns (SetAnythingResponse);
}

message MjModel {
Expand Down Expand Up @@ -186,3 +191,24 @@ message GetBestTrajectoryResponse {
repeated double times = 3 [packed = true];
int32 steps = 4;
}

message Pose {
repeated double pos = 1 [packed = true];
repeated double quat = 2 [packed = true];
}

message SetAnythingRequest {
State state = 1;

// map from parameter name to desired value
map<string, TaskParameterValue> parameters = 2;
// cost weights by name
map<string, double> cost_weights = 3;
string mode = 4;

// set the positions of mocap bodies by name. If `state` is set too, mocap
// positions will be set after the state is set.
map<string, Pose> mocap = 5;
}

message SetAnythingResponse {}
13 changes: 13 additions & 0 deletions mjpc/grpc/agent_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mjpc/grpc/agent.pb.h"
#include "mjpc/grpc/grpc_agent_util.h"
#include "mjpc/task.h"
#include "mjpc/trajectory.h"

namespace mjpc::agent_grpc {

Expand All @@ -48,6 +49,8 @@ using ::agent::PlannerStepRequest;
using ::agent::PlannerStepResponse;
using ::agent::ResetRequest;
using ::agent::ResetResponse;
using ::agent::SetAnythingRequest;
using ::agent::SetAnythingResponse;
using ::agent::SetCostWeightsRequest;
using ::agent::SetCostWeightsResponse;
using ::agent::SetModeRequest;
Expand Down Expand Up @@ -323,4 +326,14 @@ grpc::Status AgentService::GetBestTrajectory(
return grpc::Status::OK;
}


grpc::Status AgentService::SetAnything(
grpc::ServerContext* context, const SetAnythingRequest* request,
SetAnythingResponse* response) {
if (!Initialized()) {
return {grpc::StatusCode::FAILED_PRECONDITION, "Init not called."};
}
return grpc_agent_util::SetAnything(request, &agent_, agent_.GetModel(),
data_, response);
}
} // namespace mjpc::agent_grpc
4 changes: 4 additions & 0 deletions mjpc/grpc/agent_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ class AgentService final : public agent::Agent::Service {
const agent::GetBestTrajectoryRequest* request,
agent::GetBestTrajectoryResponse* response) override;

grpc::Status SetAnything(grpc::ServerContext* context,
const agent::SetAnythingRequest* request,
agent::SetAnythingResponse* response) override;

private:
bool Initialized() const { return data_ != nullptr; }

Expand Down
130 changes: 115 additions & 15 deletions mjpc/grpc/grpc_agent_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "mjpc/grpc/grpc_agent_util.h"

#include <cstring>
#include <memory>
#include <sstream>
#include <string>
Expand All @@ -35,6 +36,7 @@
#include "mjpc/agent.h"
#include "mjpc/states/state.h"
#include "mjpc/task.h"
#include "mjpc/utilities.h"

namespace grpc_agent_util {

Expand All @@ -49,6 +51,9 @@ using ::agent::GetModeResponse;
using ::agent::GetStateResponse;
using ::agent::GetTaskParametersRequest;
using ::agent::GetTaskParametersResponse;
using ::agent::Pose;
using ::agent::SetAnythingRequest;
using ::agent::SetAnythingResponse;
using ::agent::SetCostWeightsRequest;
using ::agent::SetModeRequest;
using ::agent::SetStateRequest;
Expand Down Expand Up @@ -104,10 +109,9 @@ if (!(expr).ok()) { \
} \
}

grpc::Status SetState(const SetStateRequest* request, mjpc::Agent* agent,
namespace {
grpc::Status SetState(const agent::State& state, mjpc::Agent* agent,
const mjModel* model, mjData* data) {
agent::State state = request->state();

if (state.has_time()) data->time = state.time();

if (state.qpos_size() > 0) {
Expand Down Expand Up @@ -144,6 +148,13 @@ grpc::Status SetState(const SetStateRequest* request, mjpc::Agent* agent,

return grpc::Status::OK;
}
} // namespace

grpc::Status SetState(const SetStateRequest* request, mjpc::Agent* agent,
const mjModel* model, mjData* data) {
agent::State state = request->state();
return SetState(state, agent, model, data);
}

#undef CHECK_SIZE

Expand Down Expand Up @@ -328,12 +339,10 @@ grpc::Status GetTaskParameters(const GetTaskParametersRequest* request,
return grpc::Status::OK;
}

grpc::Status SetCostWeights(const SetCostWeightsRequest* request,
mjpc::Agent* agent) {
if (request->reset_to_defaults()) {
agent->ActiveTask()->Reset(agent->GetModel());
}
for (const auto& [name, weight] : request->cost_weights()) {
grpc::Status SetCostWeights(
const ::google::protobuf::Map<std::string, double>& cost_weights,
mjpc::Agent* agent) {
for (const auto& [name, weight] : cost_weights) {
if (agent->SetWeightByName(name, weight) == -1) {
std::ostringstream error_string;
error_string << "Weight '" << name
Expand All @@ -346,19 +355,29 @@ grpc::Status SetCostWeights(const SetCostWeightsRequest* request,
agent_model->name_sensoradr[i]);
error_string << " " << sensor_name << "\n";
}
return {grpc::StatusCode::INVALID_ARGUMENT, error_string.str()};
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
error_string.str());
}
}

return grpc::Status::OK;
}

grpc::Status SetMode(const SetModeRequest* request, mjpc::Agent* agent) {
int outcome = agent->SetModeByName(request->mode());
grpc::Status SetCostWeights(const SetCostWeightsRequest* request,
mjpc::Agent* agent) {
if (request->reset_to_defaults()) {
agent->ActiveTask()->Reset(agent->GetModel());
}
auto cost_weights = request->cost_weights();
return SetCostWeights(cost_weights, agent);
}


grpc::Status SetMode(std::string_view mode, mjpc::Agent* agent) {
int outcome = agent->SetModeByName(mode);
if (outcome == -1) {
std::vector<std::string> mode_names = agent->GetAllModeNames();
std::ostringstream error_string;
error_string << "Mode '" << request->mode()
error_string << "Mode '" << mode
<< "' not found in task. Available names are:\n";
for (const auto& mode_name : mode_names) {
error_string << " " << mode_name << "\n";
Expand All @@ -369,6 +388,10 @@ grpc::Status SetMode(const SetModeRequest* request, mjpc::Agent* agent) {
}
}

grpc::Status SetMode(const SetModeRequest* request, mjpc::Agent* agent) {
return SetMode(request->mode(), agent);
}

grpc::Status GetMode(const GetModeRequest* request, mjpc::Agent* agent,
GetModeResponse* response) {
response->set_mode(agent->GetModeName());
Expand All @@ -384,6 +407,83 @@ grpc::Status GetAllModes(const GetAllModesRequest* request, mjpc::Agent* agent,
return grpc::Status::OK;
}

namespace {
grpc::Status SetMocap(const ::google::protobuf::Map<std::string, Pose>& mocap,
mjpc::Agent* agent, const mjModel* model, mjData* data) {
// Check all names and poses before applying changes.
for (const auto& [name, pose] : mocap) {
int id = mj_name2id(model, mjOBJ_BODY, name.c_str());
if (id < 0) {
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
absl::StrFormat("Body '%s' not found.", name));
}
int mocap_id = model->body_mocapid[id];
if (mocap_id < 0) {
return grpc::Status(
grpc::StatusCode::INVALID_ARGUMENT,
absl::StrFormat("Body '%s' is not a mocap body.", name));
}
if (pose.pos_size() != 0 && pose.pos_size() != 3) {
return grpc::Status(
grpc::StatusCode::INVALID_ARGUMENT,
absl::StrFormat("Mocap '%s' has invalid pose size %d.", name,
pose.pos_size()));
}
if (pose.quat_size() != 0 && pose.quat_size() != 4) {
return grpc::Status(
grpc::StatusCode::INVALID_ARGUMENT,
absl::StrFormat("Mocap '%s' has invalid quat size %d.", name,
pose.pos_size()));
}
}
for (const auto& [name, pose] : mocap) {
int id = mj_name2id(model, mjOBJ_BODY, name.c_str());
int mocap_id = model->body_mocapid[id];
for (int i = 0; i < pose.pos_size(); i++) {
data->mocap_pos[3*mocap_id + i] = pose.pos(i);
}
if (pose.quat_size() == 4) {
for (int i = 0; i < 4; i++) {
data->mocap_quat[4*mocap_id + i] = pose.quat(i);
}
mju_normalize4(data->mocap_quat + 4*mocap_id);
}
}
agent->SetState(data);
return grpc::Status::OK;
}
} // namespace

grpc::Status SetAnything(const SetAnythingRequest* request, mjpc::Agent* agent,
const mjModel* model, mjData* data,
SetAnythingResponse* response) {
if (request->has_state()) {
grpc::Status status = SetState(request->state(), agent, model, data);
if (!status.ok()) {
return status;
}
}
if (request->cost_weights_size() > 0) {
grpc::Status status = SetCostWeights(request->cost_weights(), agent);
if (!status.ok()) {
return status;
}
}
if (!request->mode().empty()) {
grpc::Status status = SetMode(request->mode(), agent);
if (!status.ok()) {
return status;
}
}
if (request->mocap_size() > 0) {
grpc::Status status = SetMocap(request->mocap(), agent, model, data);
if (!status.ok()) {
return status;
}
}
return grpc::Status::OK;
}

mjpc::UniqueMjModel LoadModelFromString(std::string_view xml, char* error,
int error_size) {
static constexpr char file[] = "temporary-filename.xml";
Expand All @@ -392,7 +492,7 @@ mjpc::UniqueMjModel LoadModelFromString(std::string_view xml, char* error,
mj_defaultVFS(vfs.get());
mj_makeEmptyFileVFS(vfs.get(), file, xml.size());
int file_idx = mj_findFileVFS(vfs.get(), file);
memcpy(vfs->filedata[file_idx], xml.data(), xml.size());
std::memcpy(vfs->filedata[file_idx], xml.data(), xml.size());
mjpc::UniqueMjModel m = {mj_loadXML(file, vfs.get(), error, error_size),
mj_deleteModel};
mj_deleteFileVFS(vfs.get(), file);
Expand Down
Loading

0 comments on commit 4572f3d

Please sign in to comment.