Skip to content

Go2 support #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions mujoco_playground/_src/locomotion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
from mujoco_playground._src.locomotion.go1 import handstand as go1_handstand
from mujoco_playground._src.locomotion.go1 import joystick as go1_joystick
from mujoco_playground._src.locomotion.go1 import randomize as go1_randomize
from mujoco_playground._src.locomotion.go2 import getup as go2_getup
from mujoco_playground._src.locomotion.go2 import handstand as go2_handstand
from mujoco_playground._src.locomotion.go2 import joystick as go2_joystick
from mujoco_playground._src.locomotion.go2 import randomize as go2_randomize
from mujoco_playground._src.locomotion.h1 import inplace_gait_tracking as h1_inplace_gait_tracking
from mujoco_playground._src.locomotion.h1 import joystick_gait_tracking as h1_joystick_gait_tracking
from mujoco_playground._src.locomotion.op3 import joystick as op3_joystick
Expand Down Expand Up @@ -67,6 +71,15 @@
"Go1Getup": go1_getup.Getup,
"Go1Handstand": go1_handstand.Handstand,
"Go1Footstand": go1_handstand.Footstand,
"Go2JoystickFlatTerrain": functools.partial(
go2_joystick.Joystick, task="flat_terrain"
),
"Go2JoystickRoughTerrain": functools.partial(
go2_joystick.Joystick, task="rough_terrain"
),
"Go2Getup": go2_getup.Getup,
"Go2Handstand": go2_handstand.Handstand,
"Go2Footstand": go2_handstand.Footstand,
"H1InplaceGaitTracking": h1_inplace_gait_tracking.InplaceGaitTracking,
"H1JoystickGaitTracking": h1_joystick_gait_tracking.JoystickGaitTracking,
"Op3Joystick": op3_joystick.Joystick,
Expand Down Expand Up @@ -101,6 +114,11 @@
"Go1Getup": go1_getup.default_config,
"Go1Handstand": go1_handstand.default_config,
"Go1Footstand": go1_handstand.default_config,
"Go2JoystickFlatTerrain": go2_joystick.default_config,
"Go2JoystickRoughTerrain": go2_joystick.default_config,
"Go2Getup": go2_getup.default_config,
"Go2Handstand": go2_handstand.default_config,
"Go2Footstand": go2_handstand.default_config,
"H1InplaceGaitTracking": h1_inplace_gait_tracking.default_config,
"H1JoystickGaitTracking": h1_joystick_gait_tracking.default_config,
"Op3Joystick": op3_joystick.default_config,
Expand All @@ -125,6 +143,11 @@
"Go1Getup": go1_randomize.domain_randomize,
"Go1Handstand": go1_randomize.domain_randomize,
"Go1Footstand": go1_randomize.domain_randomize,
"Go2JoystickFlatTerrain": go2_randomize.domain_randomize,
"Go2JoystickRoughTerrain": go2_randomize.domain_randomize,
"Go2Getup": go2_randomize.domain_randomize,
"Go2Handstand": go2_randomize.domain_randomize,
"Go2Footstand": go2_randomize.domain_randomize,
"T1JoystickFlatTerrain": t1_randomize.domain_randomize,
"T1JoystickRoughTerrain": t1_randomize.domain_randomize,
}
Expand Down
1 change: 1 addition & 0 deletions mujoco_playground/_src/locomotion/go2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Unitree Go2 environments
14 changes: 14 additions & 0 deletions mujoco_playground/_src/locomotion/go2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
122 changes: 122 additions & 0 deletions mujoco_playground/_src/locomotion/go2/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base classes for Go2."""

from typing import Any, Dict, Optional, Union

from etils import epath
import jax
import jax.numpy as jp
from ml_collections import config_dict
import mujoco
from mujoco import mjx

from mujoco_playground._src import mjx_env
from mujoco_playground._src.locomotion.go2 import go2_constants as consts


def get_assets() -> Dict[str, bytes]:
assets = {}
mjx_env.update_assets(assets, consts.ROOT_PATH / "xmls", "*.xml")
mjx_env.update_assets(assets, consts.ROOT_PATH / "xmls" / "assets")
path = mjx_env.MENAGERIE_PATH / "unitree_go2"
mjx_env.update_assets(assets, path, "*.xml")
mjx_env.update_assets(assets, path / "assets")
return assets


class Go2Env(mjx_env.MjxEnv):
"""Base class for Go2 environments."""

def __init__(
self,
xml_path: str,
config: config_dict.ConfigDict,
config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
) -> None:
super().__init__(config, config_overrides)

self._mj_model = mujoco.MjModel.from_xml_string(
epath.Path(xml_path).read_text(), assets=get_assets()
)
self._mj_model.opt.timestep = self._config.sim_dt

# Modify PD gains.
self._mj_model.dof_damping[6:] = config.Kd
self._mj_model.actuator_gainprm[:, 0] = config.Kp
self._mj_model.actuator_biasprm[:, 1] = -config.Kp

# Increase offscreen framebuffer size to render at higher resolutions.
self._mj_model.vis.global_.offwidth = 3840
self._mj_model.vis.global_.offheight = 2160

self._mjx_model = mjx.put_model(self._mj_model)
self._xml_path = xml_path
self._imu_site_id = self._mj_model.site("imu").id

# Sensor readings.

def get_upvector(self, data: mjx.Data) -> jax.Array:
return mjx_env.get_sensor_data(self.mj_model, data, consts.UPVECTOR_SENSOR)

def get_gravity(self, data: mjx.Data) -> jax.Array:
return data.site_xmat[self._imu_site_id].T @ jp.array([0, 0, -1])

def get_global_linvel(self, data: mjx.Data) -> jax.Array:
return mjx_env.get_sensor_data(
self.mj_model, data, consts.GLOBAL_LINVEL_SENSOR
)

def get_global_angvel(self, data: mjx.Data) -> jax.Array:
return mjx_env.get_sensor_data(
self.mj_model, data, consts.GLOBAL_ANGVEL_SENSOR
)

def get_local_linvel(self, data: mjx.Data) -> jax.Array:
return mjx_env.get_sensor_data(
self.mj_model, data, consts.LOCAL_LINVEL_SENSOR
)

def get_accelerometer(self, data: mjx.Data) -> jax.Array:
return mjx_env.get_sensor_data(
self.mj_model, data, consts.ACCELEROMETER_SENSOR
)

def get_gyro(self, data: mjx.Data) -> jax.Array:
return mjx_env.get_sensor_data(self.mj_model, data, consts.GYRO_SENSOR)

def get_feet_pos(self, data: mjx.Data) -> jax.Array:
return jp.vstack([
mjx_env.get_sensor_data(self.mj_model, data, sensor_name)
for sensor_name in consts.FEET_POS_SENSOR
])

# Accessors.

@property
def xml_path(self) -> str:
return self._xml_path

@property
def action_size(self) -> int:
return self._mjx_model.nu

@property
def mj_model(self) -> mujoco.MjModel:
return self._mj_model

@property
def mjx_model(self) -> mjx.Model:
return self._mjx_model
Loading