diff --git a/mjx/mujoco/mjx/_src/sensor.py b/mjx/mujoco/mjx/_src/sensor.py index 3a027a8d9f..7e1848fa2a 100644 --- a/mjx/mujoco/mjx/_src/sensor.py +++ b/mjx/mujoco/mjx/_src/sensor.py @@ -22,6 +22,7 @@ from mujoco.mjx._src import ray from mujoco.mjx._src import smooth from mujoco.mjx._src import support +from mujoco.mjx._src import scan from mujoco.mjx._src.types import Data from mujoco.mjx._src.types import DisableBit from mujoco.mjx._src.types import Model @@ -616,37 +617,49 @@ def energy_pos(m: Model, d: Data) -> Data: # Add gravitational potential energy for each body if not m.opt.disableflags & DisableBit.GRAVITY: - energy = -jp.sum(m.body_mass[1:] * jp.dot(m.opt.gravity, d.xipos[1:])) + energy = -jp.sum(m.body_mass[1:] * jp.dot(d.xipos[1:,:], m.opt.gravity)) # Add joint spring potential energy using scan.flat if not m.opt.disableflags & DisableBit.PASSIVE: - for i in range(m.njnt): - stiffness = m.jnt_stiffness[i] - padr = m.jnt_qposadr[i] - - if m.jnt_type[i] == JointType.FREE: + def spring_energy(jnt_type, stiffness, qpos, qpos_spring, padr): + + if jnt_type == JointType.FREE: # Position springs - quat = d.qpos[padr:padr+4] + quat = qpos[padr:padr+4] quat = math.normalize(quat) - dif = quat - m.qpos_spring[padr:padr+4] - energy += 0.5 * stiffness * jp.dot(dif[:3], dif[:3]) + dif = quat - qpos_spring[padr:padr+4] + energy = 0.5 * stiffness * jp.dot(dif[:3], dif[:3]) - # Handle rotations - padr += 3 - - if m.jnt_type[i] in (JointType.FREE, JointType.BALL): + elif jnt_type in (JointType.FREE, JointType.BALL): # Convert quaternion difference to angular displacement - quat = d.qpos[padr:padr+4] + quat = qpos[padr:padr+4] quat = math.normalize(quat) - dif = math.quat_sub(quat, m.qpos_spring[padr:padr+4]) - energy += 0.5 * stiffness * jp.dot(dif, dif) + dif = math.quat_sub(quat, qpos_spring[padr:padr+4]) + energy = 0.5 * stiffness * jp.dot(dif, dif) + + elif jnt_type in (JointType.SLIDE, JointType.HINGE): + dif = qpos[padr] - qpos_spring[padr] + energy = 0.5 * stiffness * dif * dif - elif m.jnt_type[i] in (JointType.SLIDE, JointType.HINGE): - dif = d.qpos[padr] - m.qpos_spring[padr] - energy += 0.5 * stiffness * dif * dif + return energy + + spring_energy = scan.flat( + m, + spring_energy, + 'jjqqj', # input types: jnt_type, stiffness, qpos, qpos_spring, padr + 'j', # output type: energy per joint + m.jnt_type, + m.jnt_stiffness, + d.qpos, + m.qpos_spring, + jp.array(m.jnt_qposadr), + group_by='j' + ) + + energy += jp.sum(spring_energy) # Add tendon spring potential energy using vectorized operations - if not m.opt.disableflags & DisableBit.PASSIVE: + if not m.opt.disableflags & DisableBit.PASSIVE & m.tendon_lengthspring.size > 0: # Get lower/upper bounds and current lengths lower = m.tendon_lengthspring[::2] # Even indices upper = m.tendon_lengthspring[1::2] # Odd indices