-
Notifications
You must be signed in to change notification settings - Fork 678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update inverse kinematics to support multiple sites #399
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,7 +35,7 @@ | |
|
||
|
||
def qpos_from_site_pose(physics, | ||
site_name, | ||
sites_names, | ||
target_pos=None, | ||
target_quat=None, | ||
joint_names=None, | ||
|
@@ -46,12 +46,13 @@ def qpos_from_site_pose(physics, | |
max_update_norm=2.0, | ||
progress_thresh=20.0, | ||
max_steps=100, | ||
inplace=False): | ||
inplace=False, | ||
null_space_method=True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think about not using a boolean variable, and instead having an enum for regularization_method that can be NULLSPACE_METHOD or DAMPED_LEAST_SQUARES? That would allow us to add more methosd in the future without breaking the API. BTW: I think it's confusing that there's a function called nullspace_method and a variable called null_space_method. If we stick with the boolean, name the variable use_nullspace_method. Same below. |
||
"""Find joint positions that satisfy a target site position and/or rotation. | ||
|
||
Args: | ||
physics: A `mujoco.Physics` instance. | ||
site_name: A string specifying the name of the target site. | ||
sites_names: A list of strings specifying the names of the target sites. | ||
target_pos: A (3,) numpy array specifying the desired Cartesian position of | ||
the site, or None if the position should be unconstrained (default). | ||
One or both of `target_pos` or `target_quat` must be specified. | ||
|
@@ -79,7 +80,8 @@ def qpos_from_site_pose(physics, | |
max_steps: (optional) The maximum number of iterations to perform. | ||
inplace: (optional) If True, `physics.data` will be modified in place. | ||
Default value is False, i.e. a copy of `physics.data` will be made. | ||
|
||
null_space_method: (optional) If True uses the null space method to find the | ||
update norm for the joint angles, otherwise uses the damped least squares. | ||
Returns: | ||
An `IKResult` namedtuple with the following fields: | ||
qpos: An (nq,) numpy array of joint positions. | ||
|
@@ -102,8 +104,12 @@ def qpos_from_site_pose(physics, | |
jac_pos, jac_rot = jac[:3], jac[3:] | ||
err_pos, err_rot = err[:3], err[3:] | ||
else: | ||
jac = np.empty((3, physics.model.nv), dtype=dtype) | ||
err = np.empty(3, dtype=dtype) | ||
if len(sites_names) > 1: | ||
jac = np.empty((len(sites_names), 3, physics.model.nv), dtype=dtype) | ||
err = np.empty((len(sites_names), 3), dtype=dtype) | ||
else: | ||
jac = np.empty((3, physics.model.nv), dtype=dtype) | ||
err = np.empty(3, dtype=dtype) | ||
if target_pos is not None: | ||
jac_pos, jac_rot = jac, None | ||
err_pos, err_rot = err, None | ||
|
@@ -127,12 +133,13 @@ def qpos_from_site_pose(physics, | |
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr) | ||
|
||
# Convert site name to index. | ||
site_id = physics.model.name2id(site_name, 'site') | ||
site_ids = [physics.model.name2id(site_name, 'site') | ||
for site_name in sites_names] | ||
|
||
# These are views onto the underlying MuJoCo buffers. mj_fwdPosition will | ||
# update them in place, so we can avoid indexing overhead in the main loop. | ||
site_xpos = physics.named.data.site_xpos[site_name] | ||
site_xmat = physics.named.data.site_xmat[site_name] | ||
site_xpos = np.squeeze(physics.named.data.site_xpos[sites_names]) | ||
site_xmat = np.squeeze(physics.named.data.site_xmat[sites_names]) | ||
|
||
# This is an index into the rows of `update` and the columns of `jac` | ||
# that selects DOFs associated with joints that we are allowed to manipulate. | ||
|
@@ -170,24 +177,41 @@ def qpos_from_site_pose(physics, | |
mjlib.mju_quat2Vel(err_rot, err_rot_quat, 1) | ||
err_norm += np.linalg.norm(err_rot) * rot_weight | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove |
||
if err_norm < tol: | ||
logging.debug('Converged after %i steps: err_norm=%3g', steps, err_norm) | ||
success = True | ||
break | ||
else: | ||
# TODO(b/112141670): Generalize this to other entities besides sites. | ||
mjlib.mj_jacSite( | ||
physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_id) | ||
jac_joints = jac[:, dof_indices] | ||
|
||
# TODO(b/112141592): This does not take joint limits into consideration. | ||
reg_strength = ( | ||
if len(site_ids) > 1: | ||
for idx in range(len(site_ids)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switch to
|
||
site_id = site_ids[idx] | ||
mjlib.mj_jacSite( | ||
physics.model.ptr, physics.data.ptr, jac_pos[idx], jac_rot, site_id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indent 2 more spaces. Same for other line continuations below. |
||
else: | ||
mjlib.mj_jacSite( | ||
physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_ids[0]) | ||
|
||
if null_space_method: | ||
jac_joints = jac[:, dof_indices] | ||
|
||
# TODO(b/112141592): This does not take joint limits into consideration. | ||
reg_strength = ( | ||
regularization_strength if err_norm > regularization_threshold | ||
else 0.0) | ||
update_joints = nullspace_method( | ||
update_joints = nullspace_method( | ||
jac_joints, err, regularization_strength=reg_strength) | ||
|
||
update_norm = np.linalg.norm(update_joints) | ||
update_norm = np.linalg.norm(update_joints) | ||
else: | ||
update_joints = np.empty((len(site_ids), physics.model.nv)) | ||
for idx in range(len(site_ids)): | ||
update_joints[idx] = damped_least_squares(jac[idx], err[idx], | ||
regularization_strength=regularization_strength) | ||
|
||
update_joints = np.mean(update_joints, axis=0) | ||
update_norm = np.linalg.norm(update_joints) | ||
|
||
# Check whether we are still making enough progress, and halt if not. | ||
progress_criterion = err_norm / update_norm | ||
|
@@ -207,8 +231,18 @@ def qpos_from_site_pose(physics, | |
# Update `physics.qpos`, taking quaternions into account. | ||
mjlib.mj_integratePos(physics.model.ptr, physics.data.qpos, update_nv, 1) | ||
|
||
# clip joint angles to their respective limits | ||
if len(sites_names) > 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why only when there are multiple site names? |
||
joint_names = physics.named.data.qpos.axes.row.names | ||
limited_joints = joint_names[1:] # ignore root joint | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You probably want to explicitly check which joints are actually limited using |
||
lower, upper = physics.named.model.jnt_range[limited_joints].T | ||
physics.named.data.qpos[limited_joints] = np.clip(physics.named.data.qpos[limited_joints], | ||
lower, upper) | ||
|
||
# Compute the new Cartesian position of the site. | ||
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr) | ||
site_xpos = np.squeeze(physics.named.data.site_xpos[sites_names]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need the additional index here? I think |
||
site_xmat = np.squeeze(physics.named.data.site_xmat[sites_names]) | ||
|
||
logging.debug('Step %2i: err_norm=%-10.3g update_norm=%-10.3g', | ||
steps, err_norm, update_norm) | ||
|
@@ -258,3 +292,27 @@ def nullspace_method(jac_joints, delta, regularization_strength=0.0): | |
return np.linalg.solve(hess_approx, joint_delta) | ||
else: | ||
return np.linalg.lstsq(hess_approx, joint_delta, rcond=-1)[0] | ||
|
||
def damped_least_squares(jac_joints, delta, regularization_strength): | ||
"""Calculates the joint velocities to achieve a specified end effector delta. | ||
|
||
Args: | ||
jac_joints: The Jacobian of the end effector with respect to the joints. A | ||
numpy array of shape `(ndelta, nv)`, where `ndelta` is the size of `delta` | ||
and `nv` is the number of degrees of freedom. | ||
delta: The desired end-effector delta. A numpy array of shape `(3,)` or | ||
`(6,)` containing either position deltas, rotation deltas, or both. | ||
regularization_strength: (optional) Coefficient of the quadratic penalty | ||
on joint movements. Default is zero, i.e. no regularization. | ||
|
||
Returns: | ||
An `(nv,)` numpy array of joint velocities. | ||
|
||
Reference: | ||
Buss, S. R. S. (2004). Introduction to inverse kinematics with jacobian | ||
transpose, pseudoinverse and damped least squares methods. | ||
https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf | ||
""" | ||
JJ_t = jac_joints.dot(jac_joints.T) | ||
JJ_t += np.eye(JJ_t.shape[0]) * regularization_strength | ||
return jac_joints.T.dot(np.linalg.inv(JJ_t)).dot(delta) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add newline at the end of the file. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
from dm_control.mujoco.wrapper import mjbindings | ||
from dm_control.utils import inverse_kinematics as ik | ||
import numpy as np | ||
import os | ||
|
||
mjlib = mjbindings.mjlib | ||
|
||
|
@@ -80,6 +81,54 @@ def __call__(self, physics): | |
|
||
class InverseKinematicsTest(parameterized.TestCase): | ||
|
||
def testQposFromMultipleSitesPose(self): | ||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
model_dir = os.path.join(dir_path, "./testing/assets/task.xml") | ||
physics = mujoco.Physics.from_xml_path(model_dir) | ||
|
||
target_pos = physics.model.key_mpos[0] | ||
target_pos = target_pos.reshape((-1, 3)) | ||
target_quat = None | ||
|
||
_SITES_NAMES = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This naming is reserved for module-wide constants. Name it site_names. |
||
|
||
body_names = [ | ||
"pelvis", "head", "ltoe", "rtoe", "lheel", "rheel", | ||
"lknee", "rknee", "lhand", "rhand", "lelbow", "relbow", | ||
"lshoulder", "rshoulder", "lhip", "rhip", | ||
] | ||
|
||
for name in body_names: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe just use a list comprehension, e.g.:
|
||
_SITES_NAMES.append("tracking[" + name + "]") | ||
|
||
_MAX_STEPS = 5000 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. max_steps |
||
result = ik.qpos_from_site_pose( | ||
physics=physics, | ||
sites_names=_SITES_NAMES, | ||
target_pos=target_pos, | ||
target_quat=target_quat, | ||
joint_names=None, | ||
tol=1e-14, | ||
regularization_threshold=0.5, | ||
regularization_strength=1e-2, | ||
max_update_norm=2.0, | ||
progress_thresh=5000.0, | ||
max_steps=_MAX_STEPS, | ||
inplace=False, | ||
null_space_method=False | ||
) | ||
|
||
self.assertLessEqual(result.steps, _MAX_STEPS) | ||
physics.data.qpos[:] = result.qpos | ||
|
||
save_path = os.path.join(dir_path, "./testing/assets/result_qpos") | ||
np.save(save_path, result.qpos) | ||
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr) | ||
|
||
pos = physics.named.data.site_xpos[_SITES_NAMES] | ||
err_norm = np.linalg.norm(target_pos - pos) | ||
self.assertLessEqual(err_norm, 0.11) | ||
|
||
@parameterized.parameters(itertools.product(_TARGETS, _INPLACE)) | ||
def testQposFromSitePose(self, target, inplace): | ||
physics = mujoco.Physics.from_xml_string(_ARM_XML) | ||
|
@@ -90,7 +139,7 @@ def testQposFromSitePose(self, target, inplace): | |
while True: | ||
result = ik.qpos_from_site_pose( | ||
physics=physics2, | ||
site_name=_SITE_NAME, | ||
sites_names=[_SITE_NAME], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tests should prove that the existing functions continue to work. Maybe add a parameterized annotation that will run both qpos_from_site_pose and qpos_from_site_poses on the same input. |
||
target_pos=target_pos, | ||
target_quat=target_quat, | ||
joint_names=_JOINTS, | ||
|
@@ -133,7 +182,7 @@ def testNamedJointsWithMultipleDOFs(self): | |
target_pos = (0.05, 0.05, 0) | ||
result = ik.qpos_from_site_pose( | ||
physics=physics, | ||
site_name=site_name, | ||
sites_names=[site_name], | ||
target_pos=target_pos, | ||
joint_names=joint_names, | ||
tol=_TOL, | ||
|
@@ -150,7 +199,7 @@ def testNamedJointsWithMultipleDOFs(self): | |
physics.reset() | ||
result = ik.qpos_from_site_pose( | ||
physics=physics, | ||
site_name=site_name, | ||
sites_names=[site_name], | ||
target_pos=target_pos, | ||
joint_names=joint_names[:1], | ||
tol=_TOL, | ||
|
@@ -170,7 +219,7 @@ def testAllowedJointNameTypes(self, joint_names): | |
target_pos = (0.05, 0.05, 0) | ||
ik.qpos_from_site_pose( | ||
physics=physics, | ||
site_name=site_name, | ||
sites_names=[site_name], | ||
target_pos=target_pos, | ||
joint_names=joint_names, | ||
tol=_TOL, | ||
|
@@ -192,7 +241,7 @@ def testDisallowedJointNameTypes(self, joint_names): | |
with self.assertRaisesWithLiteralMatch(ValueError, expected_message): | ||
ik.qpos_from_site_pose( | ||
physics=physics, | ||
site_name=site_name, | ||
sites_names=[site_name], | ||
target_pos=target_pos, | ||
joint_names=joint_names, | ||
tol=_TOL, | ||
|
@@ -206,7 +255,7 @@ def testNoTargetPosOrQuat(self): | |
ValueError, ik._REQUIRE_TARGET_POS_OR_QUAT): | ||
ik.qpos_from_site_pose( | ||
physics=physics, | ||
site_name=site_name, | ||
sites_names=[site_name], | ||
tol=_TOL, | ||
max_steps=_MAX_STEPS, | ||
inplace=True) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
<mujoco> | ||
<keyframe> | ||
<key name="jump_1" mpos="0.0 0.0 0.9079 0.02246 -0.01716 1.48528 0.05642 0.11259 0.02555 0.05707 -0.10348 0.01632 -0.06685 0.06839 0.05535 -0.06679 -0.06662 0.05473 0.03258 0.09748 0.4423 0.05061 -0.09743 0.43688 0.07061 0.19745 0.86066 0.02813 -0.23827 0.85899 -0.07446 0.17572 1.06235 -0.11842 -0.18084 1.06022 -0.02793 0.17321 1.31954 -0.07074 -0.17717 1.31069 0.00543 0.07048 0.81714 -0.00302 -0.06647 0.81644"/> | ||
</keyframe> | ||
</mujoco> |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
<mujoco> | ||
|
||
<visual> | ||
<headlight ambient=".4 .4 .4" diffuse=".8 .8 .8" specular="0.1 0.1 0.1"/> | ||
<map znear=".01"/> | ||
<quality shadowsize="2048"/> | ||
<global elevation="-15"/> | ||
</visual> | ||
|
||
<asset> | ||
<texture name="blue_grid" type="2d" builtin="checker" rgb1=".02 .14 .44" rgb2=".27 .55 1" width="300" height="300" mark="edge" markrgb="1 1 1"/> | ||
<material name="blue_grid" texture="blue_grid" texrepeat="1 1" texuniform="true" reflectance=".2"/> | ||
|
||
<texture name="grey_grid" type="2d" builtin="checker" rgb1=".26 .26 .26" rgb2=".6 .6 .6" width="300" height="300" mark="edge" markrgb="1 1 1"/> | ||
<material name="grey_grid" texture="blue_grid" texrepeat="1 1" texuniform="true" reflectance=".2"/> | ||
<texture name="skybox" type="skybox" builtin="gradient" rgb1=".66 .79 1" rgb2=".9 .91 .93" width="800" height="800"/> | ||
|
||
<material name="self" rgba=".7 .5 .3 1"/> | ||
<material name="self_default" rgba=".7 .5 .3 1"/> | ||
<material name="self_highlight" rgba="0 .5 .3 1"/> | ||
<material name="effector" rgba=".7 .4 .2 1"/> | ||
<material name="effector_default" rgba=".7 .4 .2 1"/> | ||
<material name="effector_highlight" rgba="0 .5 .3 1"/> | ||
<material name="decoration" rgba=".2 .6 .3 1"/> | ||
<material name="eye" rgba="0 .2 1 1"/> | ||
<material name="target" rgba=".6 .3 .3 1"/> | ||
<material name="target_default" rgba=".6 .3 .3 1"/> | ||
<material name="target_highlight" rgba=".6 .3 .3 .4"/> | ||
<material name="site" rgba=".5 .5 .5 .3"/> | ||
</asset> | ||
</mujoco> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a widely used function, that we don't want to break API compatibility on.
Can you instead add a new function, qpos_from_site_poses which takes a parameter (called site_names rather than sites_names, BTW)?
The existing
qpos_from_site_pose
function could call down to the newqpos_from_site_poses
.