Skip to content

Commit

Permalink
render new marker offsets, add option to visualize marker error (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-zhng authored Oct 6, 2024
1 parent 57174a8 commit 527c0ee
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 27 deletions.
Binary file modified demos/demo_viz.p
Binary file not shown.
22 changes: 11 additions & 11 deletions demos/viz_usage.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion stac_mjx/compute_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def offset_optimization(

print(f"offset optimization finished in {time.time()-s}")

return mjx_model, mjx_data
return mjx_model, mjx_data, offset_opt_param


def pose_optimization(
Expand Down
64 changes: 49 additions & 15 deletions stac_mjx/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _create_body_sites(self, root: mjcf.Element):
name=key,
type="sphere",
size=[0.005],
rgba="0 0 0 1",
rgba="0 0 0 0.8",
pos=pos,
group=3,
)
Expand Down Expand Up @@ -258,7 +258,7 @@ def fit_offsets(self, kp_data):
print(f"Standard deviation: {std}")

print("starting offset optimization")
mjx_model, mjx_data = compute_stac.offset_optimization(
mjx_model, mjx_data, self._offsets = compute_stac.offset_optimization(
mjx_model,
mjx_data,
kp_data,
Expand Down Expand Up @@ -386,11 +386,11 @@ def _package_data(self, mjx_model, q, x, walker_body_sites, kp_data, batched=Fal
if batched:
# prepare batched data to be packaged
get_batch_offsets = jax.vmap(op.get_site_pos, in_axes=(0, None))
offsets = get_batch_offsets(mjx_model, self._body_site_idxs).copy()[0]
offsets = get_batch_offsets(mjx_model, self._body_site_idxs)[0]
x = x.reshape(-1, x.shape[-1])
q = q.reshape(-1, q.shape[-1])
else:
offsets = op.get_site_pos(mjx_model, self._body_site_idxs).copy()
offsets = self._offsets

kp_data = kp_data.reshape(-1, kp_data.shape[-1])

Expand Down Expand Up @@ -468,6 +468,7 @@ def render(
camera: Union[int, str] = 0,
height: int = 1200,
width: int = 1920,
show_marker_error: bool = False,
):
"""Creates rendering using the instantiated model, given the user's qposes and kp_data.
Expand All @@ -481,6 +482,7 @@ def render(
camera (Union[int, str], optional): Mujoco camera name. Defaults to 0.
height (int, optional): Height in pixels. Defaults to 1200.
width (int, optional): Width in pixels. Defaults to 1920.
show_marker_error (bool, optional): Show distance between marker and keypoint. Defaults to False.
Raises:
ValueError: qposes and kp_data must have same length (shape[0])
Expand All @@ -506,28 +508,59 @@ def render(
render_mj_model, body_site_idxs, keypoint_site_idxs = (
self._create_keypoint_sites()
)
render_mj_model.site_pos[body_site_idxs] = offsets

# Add body sites for new offsets
for (key, v), pos in zip(
self.cfg.model.KEYPOINT_MODEL_PAIRS.items(), offsets.reshape((-1, 3))
):
parent = self._root.find("body", v)
parent.add(
"site",
name=key + "_new",
type="sphere",
size=[0.005],
rgba="0 0 0 1",
pos=pos,
group=2,
)

# Tendons from new marker sites to kp
if show_marker_error:
for key, v in self.cfg.model.KEYPOINT_MODEL_PAIRS.items():
tendon = self._root.tendon.add(
"spatial",
name=key + "-" + v,
width="0.001",
rgba="255 0 0 1", # Red
limited=False,
)
tendon.add("site", site=key + "_kp")
tendon.add("site", site=key + "_new")

physics = mjcf.Physics.from_mjcf_model(self._root)
render_mj_model = deepcopy(physics.model.ptr)

scene_option = mujoco.MjvOption()
scene_option.geomgroup[1] = 0
scene_option.geomgroup[2] = 1

scene_option.sitegroup[2] = 1

scene_option.sitegroup[3] = 1
scene_option.sitegroup[3] = 0
scene_option.flags[enums.mjtVisFlag.mjVIS_TRANSPARENT] = True
scene_option.flags[enums.mjtVisFlag.mjVIS_LIGHT] = False
scene_option.flags[enums.mjtVisFlag.mjVIS_LIGHT] = True
scene_option.flags[enums.mjtVisFlag.mjVIS_CONVEXHULL] = True
scene_option.flags[enums.mjtRndFlag.mjRND_SHADOW] = False
scene_option.flags[enums.mjtRndFlag.mjRND_REFLECTION] = False
scene_option.flags[enums.mjtRndFlag.mjRND_SKYBOX] = False
scene_option.flags[enums.mjtRndFlag.mjRND_FOG] = False

scene_option.flags[enums.mjtRndFlag.mjRND_SHADOW] = True
scene_option.flags[enums.mjtRndFlag.mjRND_REFLECTION] = True
scene_option.flags[enums.mjtRndFlag.mjRND_SKYBOX] = True
scene_option.flags[enums.mjtRndFlag.mjRND_FOG] = True
mj_data = mujoco.MjData(render_mj_model)

mujoco.mj_kinematics(render_mj_model, mj_data)

renderer = mujoco.Renderer(render_mj_model, height=height, width=width)

# slice kp_data to match qposes length
# Slice kp_data to match qposes length
kp_data = kp_data[: qposes.shape[0]]

# Slice arrays to be the range that is being rendered
Expand All @@ -538,10 +571,11 @@ def render(
# render while stepping using mujoco
with imageio.get_writer(save_path, fps=self.cfg.model.RENDER_FPS) as video:
for qpos, kps in tqdm(zip(qposes, kp_data)):
# Set keypoints
# Set keypoints--they're in cartesian space, but since they're attached to the worldbody they're the same as offsets
render_mj_model.site_pos[keypoint_site_idxs] = np.reshape(kps, (-1, 3))
mj_data.qpos = qpos
mujoco.mj_forward(render_mj_model, mj_data)

mujoco.mj_fwdPosition(render_mj_model, mj_data)

renderer.update_scene(mj_data, camera=camera, scene_option=scene_option)
pixels = renderer.render()
Expand Down
2 changes: 2 additions & 0 deletions stac_mjx/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def viz_stac(
height: int = 1200,
width: int = 1920,
base_path=None,
show_marker_error=False,
):
"""Render forward kinematics from keypoint positions.
Expand Down Expand Up @@ -61,4 +62,5 @@ def viz_stac(
camera,
height,
width,
show_marker_error,
)

0 comments on commit 527c0ee

Please sign in to comment.