Skip to content

Commit

Permalink
minor fix in tendons energies
Browse files Browse the repository at this point in the history
  • Loading branch information
simeon-ned committed Jan 16, 2025
1 parent cbc81d9 commit 2920710
Showing 1 changed file with 33 additions and 20 deletions.
53 changes: 33 additions & 20 deletions mjx/mujoco/mjx/_src/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from mujoco.mjx._src import ray
from mujoco.mjx._src import smooth
from mujoco.mjx._src import support
from mujoco.mjx._src import scan
from mujoco.mjx._src.types import Data
from mujoco.mjx._src.types import DisableBit
from mujoco.mjx._src.types import Model
Expand Down Expand Up @@ -616,37 +617,49 @@ def energy_pos(m: Model, d: Data) -> Data:

# Add gravitational potential energy for each body
if not m.opt.disableflags & DisableBit.GRAVITY:
energy = -jp.sum(m.body_mass[1:] * jp.dot(m.opt.gravity, d.xipos[1:]))
energy = -jp.sum(m.body_mass[1:] * jp.dot(d.xipos[1:,:], m.opt.gravity))

# Add joint spring potential energy using scan.flat
if not m.opt.disableflags & DisableBit.PASSIVE:
for i in range(m.njnt):
stiffness = m.jnt_stiffness[i]
padr = m.jnt_qposadr[i]

if m.jnt_type[i] == JointType.FREE:
def spring_energy(jnt_type, stiffness, qpos, qpos_spring, padr):

if jnt_type == JointType.FREE:
# Position springs
quat = d.qpos[padr:padr+4]
quat = qpos[padr:padr+4]
quat = math.normalize(quat)
dif = quat - m.qpos_spring[padr:padr+4]
energy += 0.5 * stiffness * jp.dot(dif[:3], dif[:3])
dif = quat - qpos_spring[padr:padr+4]
energy = 0.5 * stiffness * jp.dot(dif[:3], dif[:3])

# Handle rotations
padr += 3

if m.jnt_type[i] in (JointType.FREE, JointType.BALL):
elif jnt_type in (JointType.FREE, JointType.BALL):
# Convert quaternion difference to angular displacement
quat = d.qpos[padr:padr+4]
quat = qpos[padr:padr+4]
quat = math.normalize(quat)
dif = math.quat_sub(quat, m.qpos_spring[padr:padr+4])
energy += 0.5 * stiffness * jp.dot(dif, dif)
dif = math.quat_sub(quat, qpos_spring[padr:padr+4])
energy = 0.5 * stiffness * jp.dot(dif, dif)

elif jnt_type in (JointType.SLIDE, JointType.HINGE):
dif = qpos[padr] - qpos_spring[padr]
energy = 0.5 * stiffness * dif * dif

elif m.jnt_type[i] in (JointType.SLIDE, JointType.HINGE):
dif = d.qpos[padr] - m.qpos_spring[padr]
energy += 0.5 * stiffness * dif * dif
return energy

spring_energy = scan.flat(
m,
spring_energy,
'jjqqj', # input types: jnt_type, stiffness, qpos, qpos_spring, padr
'j', # output type: energy per joint
m.jnt_type,
m.jnt_stiffness,
d.qpos,
m.qpos_spring,
jp.array(m.jnt_qposadr),
group_by='j'
)

energy += jp.sum(spring_energy)

# Add tendon spring potential energy using vectorized operations
if not m.opt.disableflags & DisableBit.PASSIVE:
if not m.opt.disableflags & DisableBit.PASSIVE & m.tendon_lengthspring.size > 0:
# Get lower/upper bounds and current lengths
lower = m.tendon_lengthspring[::2] # Even indices
upper = m.tendon_lengthspring[1::2] # Odd indices
Expand Down

0 comments on commit 2920710

Please sign in to comment.