Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update inverse kinematics to support multiple sites #399

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 75 additions & 17 deletions dm_control/utils/inverse_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


def qpos_from_site_pose(physics,
site_name,
sites_names,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a widely used function, that we don't want to break API compatibility on.

Can you instead add a new function, qpos_from_site_poses which takes a parameter (called site_names rather than sites_names, BTW)?

The existing qpos_from_site_pose function could call down to the new qpos_from_site_poses.

target_pos=None,
target_quat=None,
joint_names=None,
Expand All @@ -46,12 +46,13 @@ def qpos_from_site_pose(physics,
max_update_norm=2.0,
progress_thresh=20.0,
max_steps=100,
inplace=False):
inplace=False,
null_space_method=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about not using a boolean variable, and instead having an enum for regularization_method that can be NULLSPACE_METHOD or DAMPED_LEAST_SQUARES? That would allow us to add more methosd in the future without breaking the API.

BTW: I think it's confusing that there's a function called nullspace_method and a variable called null_space_method. If we stick with the boolean, name the variable use_nullspace_method.

Same below.

"""Find joint positions that satisfy a target site position and/or rotation.

Args:
physics: A `mujoco.Physics` instance.
site_name: A string specifying the name of the target site.
sites_names: A list of strings specifying the names of the target sites.
target_pos: A (3,) numpy array specifying the desired Cartesian position of
the site, or None if the position should be unconstrained (default).
One or both of `target_pos` or `target_quat` must be specified.
Expand Down Expand Up @@ -79,7 +80,8 @@ def qpos_from_site_pose(physics,
max_steps: (optional) The maximum number of iterations to perform.
inplace: (optional) If True, `physics.data` will be modified in place.
Default value is False, i.e. a copy of `physics.data` will be made.

null_space_method: (optional) If True uses the null space method to find the
update norm for the joint angles, otherwise uses the damped least squares.
Returns:
An `IKResult` namedtuple with the following fields:
qpos: An (nq,) numpy array of joint positions.
Expand All @@ -102,8 +104,12 @@ def qpos_from_site_pose(physics,
jac_pos, jac_rot = jac[:3], jac[3:]
err_pos, err_rot = err[:3], err[3:]
else:
jac = np.empty((3, physics.model.nv), dtype=dtype)
err = np.empty(3, dtype=dtype)
if len(sites_names) > 1:
jac = np.empty((len(sites_names), 3, physics.model.nv), dtype=dtype)
err = np.empty((len(sites_names), 3), dtype=dtype)
else:
jac = np.empty((3, physics.model.nv), dtype=dtype)
err = np.empty(3, dtype=dtype)
if target_pos is not None:
jac_pos, jac_rot = jac, None
err_pos, err_rot = err, None
Expand All @@ -127,12 +133,13 @@ def qpos_from_site_pose(physics,
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)

# Convert site name to index.
site_id = physics.model.name2id(site_name, 'site')
site_ids = [physics.model.name2id(site_name, 'site')
for site_name in sites_names]

# These are views onto the underlying MuJoCo buffers. mj_fwdPosition will
# update them in place, so we can avoid indexing overhead in the main loop.
site_xpos = physics.named.data.site_xpos[site_name]
site_xmat = physics.named.data.site_xmat[site_name]
site_xpos = np.squeeze(physics.named.data.site_xpos[sites_names])
site_xmat = np.squeeze(physics.named.data.site_xmat[sites_names])

# This is an index into the rows of `update` and the columns of `jac`
# that selects DOFs associated with joints that we are allowed to manipulate.
Expand Down Expand Up @@ -170,24 +177,41 @@ def qpos_from_site_pose(physics,
mjlib.mju_quat2Vel(err_rot, err_rot_quat, 1)
err_norm += np.linalg.norm(err_rot) * rot_weight


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

if err_norm < tol:
logging.debug('Converged after %i steps: err_norm=%3g', steps, err_norm)
success = True
break
else:
# TODO(b/112141670): Generalize this to other entities besides sites.
mjlib.mj_jacSite(
physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_id)
jac_joints = jac[:, dof_indices]

# TODO(b/112141592): This does not take joint limits into consideration.
reg_strength = (
if len(site_ids) > 1:
for idx in range(len(site_ids)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to

for idx, site_id in enumerate(site_ids):
  ...

site_id = site_ids[idx]
mjlib.mj_jacSite(
physics.model.ptr, physics.data.ptr, jac_pos[idx], jac_rot, site_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent 2 more spaces. Same for other line continuations below.

else:
mjlib.mj_jacSite(
physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_ids[0])

if null_space_method:
jac_joints = jac[:, dof_indices]

# TODO(b/112141592): This does not take joint limits into consideration.
reg_strength = (
regularization_strength if err_norm > regularization_threshold
else 0.0)
update_joints = nullspace_method(
update_joints = nullspace_method(
jac_joints, err, regularization_strength=reg_strength)

update_norm = np.linalg.norm(update_joints)
update_norm = np.linalg.norm(update_joints)
else:
update_joints = np.empty((len(site_ids), physics.model.nv))
for idx in range(len(site_ids)):
update_joints[idx] = damped_least_squares(jac[idx], err[idx],
regularization_strength=regularization_strength)

update_joints = np.mean(update_joints, axis=0)
update_norm = np.linalg.norm(update_joints)

# Check whether we are still making enough progress, and halt if not.
progress_criterion = err_norm / update_norm
Expand All @@ -207,8 +231,18 @@ def qpos_from_site_pose(physics,
# Update `physics.qpos`, taking quaternions into account.
mjlib.mj_integratePos(physics.model.ptr, physics.data.qpos, update_nv, 1)

# clip joint angles to their respective limits
if len(sites_names) > 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only when there are multiple site names?

joint_names = physics.named.data.qpos.axes.row.names
limited_joints = joint_names[1:] # ignore root joint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably want to explicitly check which joints are actually limited using jnt_limited.

lower, upper = physics.named.model.jnt_range[limited_joints].T
physics.named.data.qpos[limited_joints] = np.clip(physics.named.data.qpos[limited_joints],
lower, upper)

# Compute the new Cartesian position of the site.
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
site_xpos = np.squeeze(physics.named.data.site_xpos[sites_names])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the additional index here? I think np.squeeze is a view.

site_xmat = np.squeeze(physics.named.data.site_xmat[sites_names])

logging.debug('Step %2i: err_norm=%-10.3g update_norm=%-10.3g',
steps, err_norm, update_norm)
Expand Down Expand Up @@ -258,3 +292,27 @@ def nullspace_method(jac_joints, delta, regularization_strength=0.0):
return np.linalg.solve(hess_approx, joint_delta)
else:
return np.linalg.lstsq(hess_approx, joint_delta, rcond=-1)[0]

def damped_least_squares(jac_joints, delta, regularization_strength):
"""Calculates the joint velocities to achieve a specified end effector delta.

Args:
jac_joints: The Jacobian of the end effector with respect to the joints. A
numpy array of shape `(ndelta, nv)`, where `ndelta` is the size of `delta`
and `nv` is the number of degrees of freedom.
delta: The desired end-effector delta. A numpy array of shape `(3,)` or
`(6,)` containing either position deltas, rotation deltas, or both.
regularization_strength: (optional) Coefficient of the quadratic penalty
on joint movements. Default is zero, i.e. no regularization.

Returns:
An `(nv,)` numpy array of joint velocities.

Reference:
Buss, S. R. S. (2004). Introduction to inverse kinematics with jacobian
transpose, pseudoinverse and damped least squares methods.
https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf
"""
JJ_t = jac_joints.dot(jac_joints.T)
JJ_t += np.eye(JJ_t.shape[0]) * regularization_strength
return jac_joints.T.dot(np.linalg.inv(JJ_t)).dot(delta)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add newline at the end of the file.

61 changes: 55 additions & 6 deletions dm_control/utils/inverse_kinematics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from dm_control.mujoco.wrapper import mjbindings
from dm_control.utils import inverse_kinematics as ik
import numpy as np
import os

mjlib = mjbindings.mjlib

Expand Down Expand Up @@ -80,6 +81,54 @@ def __call__(self, physics):

class InverseKinematicsTest(parameterized.TestCase):

def testQposFromMultipleSitesPose(self):
dir_path = os.path.dirname(os.path.realpath(__file__))
model_dir = os.path.join(dir_path, "./testing/assets/task.xml")
physics = mujoco.Physics.from_xml_path(model_dir)

target_pos = physics.model.key_mpos[0]
target_pos = target_pos.reshape((-1, 3))
target_quat = None

_SITES_NAMES = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This naming is reserved for module-wide constants. Name it site_names.


body_names = [
"pelvis", "head", "ltoe", "rtoe", "lheel", "rheel",
"lknee", "rknee", "lhand", "rhand", "lelbow", "relbow",
"lshoulder", "rshoulder", "lhip", "rhip",
]

for name in body_names:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just use a list comprehension, e.g.:

site_names = [f"tracking[{name}]" for name in body_names]

_SITES_NAMES.append("tracking[" + name + "]")

_MAX_STEPS = 5000
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_steps

result = ik.qpos_from_site_pose(
physics=physics,
sites_names=_SITES_NAMES,
target_pos=target_pos,
target_quat=target_quat,
joint_names=None,
tol=1e-14,
regularization_threshold=0.5,
regularization_strength=1e-2,
max_update_norm=2.0,
progress_thresh=5000.0,
max_steps=_MAX_STEPS,
inplace=False,
null_space_method=False
)

self.assertLessEqual(result.steps, _MAX_STEPS)
physics.data.qpos[:] = result.qpos

save_path = os.path.join(dir_path, "./testing/assets/result_qpos")
np.save(save_path, result.qpos)
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)

pos = physics.named.data.site_xpos[_SITES_NAMES]
err_norm = np.linalg.norm(target_pos - pos)
self.assertLessEqual(err_norm, 0.11)

@parameterized.parameters(itertools.product(_TARGETS, _INPLACE))
def testQposFromSitePose(self, target, inplace):
physics = mujoco.Physics.from_xml_string(_ARM_XML)
Expand All @@ -90,7 +139,7 @@ def testQposFromSitePose(self, target, inplace):
while True:
result = ik.qpos_from_site_pose(
physics=physics2,
site_name=_SITE_NAME,
sites_names=[_SITE_NAME],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests should prove that the existing functions continue to work.

Maybe add a parameterized annotation that will run both qpos_from_site_pose and qpos_from_site_poses on the same input.

target_pos=target_pos,
target_quat=target_quat,
joint_names=_JOINTS,
Expand Down Expand Up @@ -133,7 +182,7 @@ def testNamedJointsWithMultipleDOFs(self):
target_pos = (0.05, 0.05, 0)
result = ik.qpos_from_site_pose(
physics=physics,
site_name=site_name,
sites_names=[site_name],
target_pos=target_pos,
joint_names=joint_names,
tol=_TOL,
Expand All @@ -150,7 +199,7 @@ def testNamedJointsWithMultipleDOFs(self):
physics.reset()
result = ik.qpos_from_site_pose(
physics=physics,
site_name=site_name,
sites_names=[site_name],
target_pos=target_pos,
joint_names=joint_names[:1],
tol=_TOL,
Expand All @@ -170,7 +219,7 @@ def testAllowedJointNameTypes(self, joint_names):
target_pos = (0.05, 0.05, 0)
ik.qpos_from_site_pose(
physics=physics,
site_name=site_name,
sites_names=[site_name],
target_pos=target_pos,
joint_names=joint_names,
tol=_TOL,
Expand All @@ -192,7 +241,7 @@ def testDisallowedJointNameTypes(self, joint_names):
with self.assertRaisesWithLiteralMatch(ValueError, expected_message):
ik.qpos_from_site_pose(
physics=physics,
site_name=site_name,
sites_names=[site_name],
target_pos=target_pos,
joint_names=joint_names,
tol=_TOL,
Expand All @@ -206,7 +255,7 @@ def testNoTargetPosOrQuat(self):
ValueError, ik._REQUIRE_TARGET_POS_OR_QUAT):
ik.qpos_from_site_pose(
physics=physics,
site_name=site_name,
sites_names=[site_name],
tol=_TOL,
max_steps=_MAX_STEPS,
inplace=True)
Expand Down
5 changes: 5 additions & 0 deletions dm_control/utils/testing/assets/CMU-CMU-02-02_04_poses.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<mujoco>
<keyframe>
<key name="jump_1" mpos="0.0 0.0 0.9079 0.02246 -0.01716 1.48528 0.05642 0.11259 0.02555 0.05707 -0.10348 0.01632 -0.06685 0.06839 0.05535 -0.06679 -0.06662 0.05473 0.03258 0.09748 0.4423 0.05061 -0.09743 0.43688 0.07061 0.19745 0.86066 0.02813 -0.23827 0.85899 -0.07446 0.17572 1.06235 -0.11842 -0.18084 1.06022 -0.02793 0.17321 1.31954 -0.07074 -0.17717 1.31069 0.00543 0.07048 0.81714 -0.00302 -0.06647 0.81644"/>
</keyframe>
</mujoco>
31 changes: 31 additions & 0 deletions dm_control/utils/testing/assets/common.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<mujoco>

<visual>
<headlight ambient=".4 .4 .4" diffuse=".8 .8 .8" specular="0.1 0.1 0.1"/>
<map znear=".01"/>
<quality shadowsize="2048"/>
<global elevation="-15"/>
</visual>

<asset>
<texture name="blue_grid" type="2d" builtin="checker" rgb1=".02 .14 .44" rgb2=".27 .55 1" width="300" height="300" mark="edge" markrgb="1 1 1"/>
<material name="blue_grid" texture="blue_grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>

<texture name="grey_grid" type="2d" builtin="checker" rgb1=".26 .26 .26" rgb2=".6 .6 .6" width="300" height="300" mark="edge" markrgb="1 1 1"/>
<material name="grey_grid" texture="blue_grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
<texture name="skybox" type="skybox" builtin="gradient" rgb1=".66 .79 1" rgb2=".9 .91 .93" width="800" height="800"/>

<material name="self" rgba=".7 .5 .3 1"/>
<material name="self_default" rgba=".7 .5 .3 1"/>
<material name="self_highlight" rgba="0 .5 .3 1"/>
<material name="effector" rgba=".7 .4 .2 1"/>
<material name="effector_default" rgba=".7 .4 .2 1"/>
<material name="effector_highlight" rgba="0 .5 .3 1"/>
<material name="decoration" rgba=".2 .6 .3 1"/>
<material name="eye" rgba="0 .2 1 1"/>
<material name="target" rgba=".6 .3 .3 1"/>
<material name="target_default" rgba=".6 .3 .3 1"/>
<material name="target_highlight" rgba=".6 .3 .3 .4"/>
<material name="site" rgba=".5 .5 .5 .3"/>
</asset>
</mujoco>
Loading