From 1c56c39a60f335ad83762c9947a47a6b20a831d5 Mon Sep 17 00:00:00 2001 From: taylor howell Date: Sat, 16 Mar 2024 15:55:21 -0400 Subject: [PATCH 1/2] change path --- .../mujoco_mpc/mjx/tasks/bimanual/handover.py | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/python/mujoco_mpc/mjx/tasks/bimanual/handover.py b/python/mujoco_mpc/mjx/tasks/bimanual/handover.py index 4d1d81205..124ef042c 100644 --- a/python/mujoco_mpc/mjx/tasks/bimanual/handover.py +++ b/python/mujoco_mpc/mjx/tasks/bimanual/handover.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from etils import epath +from pathlib import Path import jax from jax import numpy as jp import mujoco @@ -51,19 +51,11 @@ def get_models_and_cost_fn() -> ( tuple[mujoco.MjModel, mujoco.MjModel, predictive_sampling.CostFn] ): """Returns a tuple of the model and the cost function.""" - path = epath.Path( - 'build/mjpc/tasks/bimanual/' + model_path = ( + Path(__file__).parent.parent.parent + / "../../../build/mjpc/tasks/bimanual/mjx_scene.xml" ) - model_file_name = 'mjx_scene.xml' - xml = (path / model_file_name).read_text() - assets = {} - for f in path.glob('*.xml'): - if f.name == model_file_name: - continue - assets[f.name] = f.read_bytes() - for f in (path / 'assets').glob('*'): - assets[f.name] = f.read_bytes() - sim_model = mujoco.MjModel.from_xml_string(xml, assets) - plan_model = mujoco.MjModel.from_xml_string(xml, assets) + sim_model = mujoco.MjModel.from_xml_path(str(model_path)) + plan_model = mujoco.MjModel.from_xml_path(str(model_path)) plan_model.opt.timestep = 0.01 # incidentally, already the case return sim_model, plan_model, bring_to_target From 034e387d146285fe9ea754d0fc7d449e7a5605d9 Mon Sep 17 00:00:00 2001 From: Taylor Date: Mon, 25 Mar 2024 21:08:04 -0400 Subject: [PATCH 2/2] minor changes to run --- python/mujoco_mpc/mjx/README.md | 10 ++++++++++ python/mujoco_mpc/mjx/tasks/bimanual/handover.py | 5 +++-- python/mujoco_mpc/mjx/visualize.py | 4 ++-- 3 files changed, 15 insertions(+), 4 deletions(-) create mode 100644 python/mujoco_mpc/mjx/README.md diff --git a/python/mujoco_mpc/mjx/README.md b/python/mujoco_mpc/mjx/README.md new file mode 100644 index 000000000..b4c82a32b --- /dev/null +++ b/python/mujoco_mpc/mjx/README.md @@ -0,0 +1,10 @@ +# MJX Predictive Sampling + +Run `handover` example: + +```sh +python visualize.py +``` + +## +Requires: mujoco, mujoco-mjx, jax[cuda], matplotlib, mediapy (Python), ffmpeg \ No newline at end of file diff --git a/python/mujoco_mpc/mjx/tasks/bimanual/handover.py b/python/mujoco_mpc/mjx/tasks/bimanual/handover.py index 124ef042c..31efb3f97 100644 --- a/python/mujoco_mpc/mjx/tasks/bimanual/handover.py +++ b/python/mujoco_mpc/mjx/tasks/bimanual/handover.py @@ -13,13 +13,14 @@ # limitations under the License. # ============================================================================== +from typing import Callable from pathlib import Path import jax from jax import numpy as jp import mujoco from mujoco import mjx -from mujoco_mpc.mjx import predictive_sampling +CostFn = Callable[[mjx.Model, mjx.Data], jax.Array] def bring_to_target(m: mjx.Model, d: mjx.Data) -> jax.Array: """Returns cost for bimanual bring to target task.""" @@ -48,7 +49,7 @@ def bring_to_target(m: mjx.Model, d: mjx.Data) -> jax.Array: def get_models_and_cost_fn() -> ( - tuple[mujoco.MjModel, mujoco.MjModel, predictive_sampling.CostFn] + tuple[mujoco.MjModel, mujoco.MjModel, CostFn] ): """Returns a tuple of the model and the cost function.""" model_path = ( diff --git a/python/mujoco_mpc/mjx/visualize.py b/python/mujoco_mpc/mjx/visualize.py index deb493d65..b3b21af87 100644 --- a/python/mujoco_mpc/mjx/visualize.py +++ b/python/mujoco_mpc/mjx/visualize.py @@ -16,8 +16,8 @@ import matplotlib.pyplot as plt import mediapy import mujoco -from mujoco_mpc.mjx import predictive_sampling -from mujoco_mpc.mjx.tasks.bimanual import handover +import predictive_sampling +from tasks.bimanual import handover # %% nsteps = 500 steps_per_plan = 4