From eeda4bad27e078cde5d57da3c9a4f9445266ccfb Mon Sep 17 00:00:00 2001 From: kywch Date: Mon, 15 Sep 2025 18:34:56 -0700 Subject: [PATCH 1/4] Added impl and njmax flags from rsl training --- learning/train_rsl_rl.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/learning/train_rsl_rl.py b/learning/train_rsl_rl.py index 71228eef8..3a5ce46e1 100644 --- a/learning/train_rsl_rl.py +++ b/learning/train_rsl_rl.py @@ -54,6 +54,10 @@ f"{', '.join(mujoco_playground.registry.ALL_ENVS)}" ), ) +_IMPL = flags.DEFINE_enum("impl", "jax", ["jax", "warp"], "MJX implementation") +_NJMAX = flags.DEFINE_integer( + "njmax", None, "The maximum number of constraints per world." +) _LOAD_RUN_NAME = flags.DEFINE_string( "load_run_name", None, "Run name to load from (for checkpoint restoration)." ) @@ -108,6 +112,9 @@ def main(argv): # Load default config from registry env_cfg = registry.get_default_config(_ENV_NAME.value) + env_cfg.impl = _IMPL.value + if _NJMAX.present: + env_cfg.njmax = _NJMAX.value print(f"Environment config:\n{env_cfg}") # Generate unique experiment name From 5af04735ad119a6246b9bc64bce98f0fc9aa9463 Mon Sep 17 00:00:00 2001 From: kywch Date: Tue, 16 Sep 2025 10:37:33 -0700 Subject: [PATCH 2/4] Add playground_config_overrides flag like brax learner --- learning/train_jax_ppo.py | 22 +++++++++++++++++++--- learning/train_rsl_rl.py | 21 +++++++++++++++------ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/learning/train_jax_ppo.py b/learning/train_jax_ppo.py index 32f61b636..a7f5b655d 100644 --- a/learning/train_jax_ppo.py +++ b/learning/train_jax_ppo.py @@ -68,6 +68,11 @@ f"Name of the environment. One of {', '.join(registry.ALL_ENVS)}", ) _IMPL = flags.DEFINE_enum("impl", "jax", ["jax", "warp"], "MJX implementation") +_PLAYGROUND_CONFIG_OVERRIDES = flags.DEFINE_string( + "playground_config_overrides", + None, + "Overrides for the playground env config.", +) _VISION = flags.DEFINE_boolean("vision", False, "Use vision input") _LOAD_CHECKPOINT_PATH = flags.DEFINE_string( "load_checkpoint_path", None, "Path to load checkpoint from" @@ -260,7 +265,12 @@ def main(argv): if _VISION.value: env_cfg.vision = True env_cfg.vision_config.render_batch_size = ppo_params.num_envs - env = registry.load(_ENV_NAME.value, config=env_cfg) + env_cfg_overrides = {} + if _PLAYGROUND_CONFIG_OVERRIDES.value is not None: + env_cfg_overrides = json.loads(_PLAYGROUND_CONFIG_OVERRIDES.value) + env = registry.load( + _ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides + ) if _RUN_EVALS.present: ppo_params.run_evals = _RUN_EVALS.value if _LOG_TRAINING_METRICS.present: @@ -269,6 +279,8 @@ def main(argv): ppo_params.training_metrics_steps = _TRAINING_METRICS_STEPS.value print(f"Environment Config:\n{env_cfg}") + if env_cfg_overrides: + print(f"Environment Config Overrides:\n{env_cfg_overrides}\n") print(f"PPO Training Parameters:\n{ppo_params}") # Generate unique experiment name @@ -399,7 +411,9 @@ def progress(num_steps, metrics): # Load evaluation environment. eval_env = None if not _VISION.value: - eval_env = registry.load(_ENV_NAME.value, config=env_cfg) + eval_env = registry.load( + _ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides + ) num_envs = 1 if _VISION.value: num_envs = env_cfg.vision_config.render_batch_size @@ -410,7 +424,9 @@ def progress(num_steps, metrics): from rscope import brax as rscope_utils if not _VISION.value: - rscope_env = registry.load(_ENV_NAME.value, config=env_cfg) + rscope_env = registry.load( + _ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides + ) rscope_env = wrapper.wrap_for_brax_training( rscope_env, episode_length=ppo_params.episode_length, diff --git a/learning/train_rsl_rl.py b/learning/train_rsl_rl.py index 3a5ce46e1..f7ce76c4b 100644 --- a/learning/train_rsl_rl.py +++ b/learning/train_rsl_rl.py @@ -55,8 +55,10 @@ ), ) _IMPL = flags.DEFINE_enum("impl", "jax", ["jax", "warp"], "MJX implementation") -_NJMAX = flags.DEFINE_integer( - "njmax", None, "The maximum number of constraints per world." +_PLAYGROUND_CONFIG_OVERRIDES = flags.DEFINE_string( + "playground_config_overrides", + None, + "Overrides for the playground env config.", ) _LOAD_RUN_NAME = flags.DEFINE_string( "load_run_name", None, "Run name to load from (for checkpoint restoration)." @@ -113,10 +115,13 @@ def main(argv): # Load default config from registry env_cfg = registry.get_default_config(_ENV_NAME.value) env_cfg.impl = _IMPL.value - if _NJMAX.present: - env_cfg.njmax = _NJMAX.value print(f"Environment config:\n{env_cfg}") + env_cfg_overrides = {} + if _PLAYGROUND_CONFIG_OVERRIDES.value is not None: + env_cfg_overrides = json.loads(_PLAYGROUND_CONFIG_OVERRIDES.value) + print(f"Environment config overrides:\n{env_cfg_overrides}\n") + # Generate unique experiment name now = datetime.now() timestamp = now.strftime("%Y%m%d-%H%M%S") @@ -159,7 +164,9 @@ def render_callback(_, state): render_trajectory.append(state) # Create the environment - raw_env = registry.load(_ENV_NAME.value, config=env_cfg) + raw_env = registry.load( + _ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides + ) brax_env = wrapper_torch.RSLRLBraxWrapper( raw_env, num_envs, @@ -213,7 +220,9 @@ def render_callback(_, state): policy = runner.get_inference_policy(device=device) # Example: run a single rollout - eval_env = registry.load(_ENV_NAME.value, config=env_cfg) + eval_env = registry.load( + _ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides + ) jit_reset = jax.jit(eval_env.reset) jit_step = jax.jit(eval_env.step) From 13225ccfd7726fdb9bfd69cd62271498b0561777 Mon Sep 17 00:00:00 2001 From: kywch Date: Fri, 19 Sep 2025 22:05:06 -0700 Subject: [PATCH 3/4] randomize cube size and mass --- .../_src/manipulation/leap_hand/reorient.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/mujoco_playground/_src/manipulation/leap_hand/reorient.py b/mujoco_playground/_src/manipulation/leap_hand/reorient.py index ce8e931d5..823b1ec4a 100644 --- a/mujoco_playground/_src/manipulation/leap_hand/reorient.py +++ b/mujoco_playground/_src/manipulation/leap_hand/reorient.py @@ -518,13 +518,24 @@ def rand(rng): fingertip_friction ) - # Scale cube mass: *U(0.8, 1.2). + # Scale cube size: *U(0.5, 1.5). + rng, key = jax.random.split(rng) + geom_size = model.geom_size.at[cube_geom_id].set( + model.geom_size[cube_geom_id] * jax.random.uniform(key, minval=0.5, maxval=1.5) + ) + + # Scale cube mass: *U(0.5, 1.5). rng, key1, key2 = jax.random.split(rng, 3) - dmass = jax.random.uniform(key1, minval=0.8, maxval=1.2) body_inertia = model.body_inertia.at[cube_body_id].set( - model.body_inertia[cube_body_id] * dmass + model.body_inertia[cube_body_id] * jax.random.uniform(key1, minval=0.5, maxval=1.5) ) - dpos = jax.random.uniform(key2, (3,), minval=-5e-3, maxval=5e-3) + body_mass = model.body_mass.at[cube_body_id].set( + model.body_mass[cube_body_id] * jax.random.uniform(key2, minval=0.5, maxval=1.5) + ) + + # Jitter cube qpos: +U(-0.005, 0.005). + rng, key = jax.random.split(rng) + dpos = jax.random.uniform(key, (3,), minval=-5e-3, maxval=5e-3) body_ipos = model.body_ipos.at[cube_body_id].set( model.body_ipos[cube_body_id] + dpos ) @@ -556,8 +567,8 @@ def rand(rng): dmass = jax.random.uniform( key, shape=(len(hand_body_ids),), minval=0.9, maxval=1.1 ) - body_mass = model.body_mass.at[hand_body_ids].set( - model.body_mass[hand_body_ids] * dmass + body_mass = body_mass.at[hand_body_ids].set( + body_mass[hand_body_ids] * dmass ) # Joint stiffness: *U(0.8, 1.2). @@ -577,6 +588,7 @@ def rand(rng): return ( geom_friction, + geom_size, body_mass, body_inertia, body_ipos, @@ -590,6 +602,7 @@ def rand(rng): ( geom_friction, + geom_size, body_mass, body_inertia, body_ipos, @@ -604,6 +617,7 @@ def rand(rng): in_axes = jax.tree_util.tree_map(lambda x: None, model) in_axes = in_axes.tree_replace({ "geom_friction": 0, + "geom_size": 0, "body_mass": 0, "body_inertia": 0, "body_ipos": 0, @@ -617,6 +631,7 @@ def rand(rng): model = model.tree_replace({ "geom_friction": geom_friction, + "geom_size": geom_size, "body_mass": body_mass, "body_inertia": body_inertia, "body_ipos": body_ipos, From feb969ef7c7b6848b219eb07bc447407bb5f011a Mon Sep 17 00:00:00 2001 From: kywch Date: Sat, 20 Sep 2025 09:32:05 -0700 Subject: [PATCH 4/4] Revert back dm-related changes unrelated to the PR --- .../_src/manipulation/leap_hand/reorient.py | 29 +++++-------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/mujoco_playground/_src/manipulation/leap_hand/reorient.py b/mujoco_playground/_src/manipulation/leap_hand/reorient.py index 823b1ec4a..6b2716263 100644 --- a/mujoco_playground/_src/manipulation/leap_hand/reorient.py +++ b/mujoco_playground/_src/manipulation/leap_hand/reorient.py @@ -518,24 +518,13 @@ def rand(rng): fingertip_friction ) - # Scale cube size: *U(0.5, 1.5). - rng, key = jax.random.split(rng) - geom_size = model.geom_size.at[cube_geom_id].set( - model.geom_size[cube_geom_id] * jax.random.uniform(key, minval=0.5, maxval=1.5) - ) - - # Scale cube mass: *U(0.5, 1.5). + # Scale cube mass: *U(0.8, 1.2). rng, key1, key2 = jax.random.split(rng, 3) + dmass = jax.random.uniform(key1, minval=0.8, maxval=1.2) body_inertia = model.body_inertia.at[cube_body_id].set( - model.body_inertia[cube_body_id] * jax.random.uniform(key1, minval=0.5, maxval=1.5) + model.body_inertia[cube_body_id] * dmass ) - body_mass = model.body_mass.at[cube_body_id].set( - model.body_mass[cube_body_id] * jax.random.uniform(key2, minval=0.5, maxval=1.5) - ) - - # Jitter cube qpos: +U(-0.005, 0.005). - rng, key = jax.random.split(rng) - dpos = jax.random.uniform(key, (3,), minval=-5e-3, maxval=5e-3) + dpos = jax.random.uniform(key2, (3,), minval=-5e-3, maxval=5e-3) body_ipos = model.body_ipos.at[cube_body_id].set( model.body_ipos[cube_body_id] + dpos ) @@ -567,8 +556,8 @@ def rand(rng): dmass = jax.random.uniform( key, shape=(len(hand_body_ids),), minval=0.9, maxval=1.1 ) - body_mass = body_mass.at[hand_body_ids].set( - body_mass[hand_body_ids] * dmass + body_mass = model.body_mass.at[hand_body_ids].set( + model.body_mass[hand_body_ids] * dmass ) # Joint stiffness: *U(0.8, 1.2). @@ -588,7 +577,6 @@ def rand(rng): return ( geom_friction, - geom_size, body_mass, body_inertia, body_ipos, @@ -602,7 +590,6 @@ def rand(rng): ( geom_friction, - geom_size, body_mass, body_inertia, body_ipos, @@ -617,7 +604,6 @@ def rand(rng): in_axes = jax.tree_util.tree_map(lambda x: None, model) in_axes = in_axes.tree_replace({ "geom_friction": 0, - "geom_size": 0, "body_mass": 0, "body_inertia": 0, "body_ipos": 0, @@ -631,7 +617,6 @@ def rand(rng): model = model.tree_replace({ "geom_friction": geom_friction, - "geom_size": geom_size, "body_mass": body_mass, "body_inertia": body_inertia, "body_ipos": body_ipos, @@ -643,4 +628,4 @@ def rand(rng): "actuator_biasprm": actuator_biasprm, }) - return model, in_axes + return model, in_axes \ No newline at end of file