Skip to content

Commit

Permalink
[gym_jiminy/common] Add termination condition for contact slippage.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Feb 19, 2025
1 parent 92fd778 commit 72721a8
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
FallingTermination,
FootCollisionTermination,
FlyingTermination,
SlippageTermination,
ImpactForceTermination)

__all__ = [
Expand Down Expand Up @@ -60,6 +61,7 @@
"MechanicalSafetyTermination",
"MechanicalPowerConsumptionTermination",
"FlyingTermination",
"SlippageTermination",
"BaseRollPitchTermination",
"FallingTermination",
"FootCollisionTermination",
Expand Down
196 changes: 181 additions & 15 deletions python/gym_jiminy/common/gym_jiminy/common/compositions/locomotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
"""
from functools import partial
from dataclasses import dataclass
from typing import Optional, Union, Sequence, Literal, Callable, cast
from typing import (
Optional, Union, Sequence, Literal, Callable, Tuple, List, cast)

import numpy as np
import numba as nb
Expand All @@ -11,7 +12,8 @@
import pinocchio as pin

from ..bases import (
InterfaceJiminyEnv, InterfaceQuantity, QuantityEvalMode, QuantityReward)
InterfaceJiminyEnv, InterfaceQuantity, QuantityEvalMode, AbstractQuantity,
QuantityReward)
from ..bases.compositions import ArrayOrScalar, ArrayLikeOrScalar
from ..quantities import (
OrientationType, MaskedQuantity, UnaryOpQuantity, ConcatenatedQuantity,
Expand All @@ -20,7 +22,7 @@
BaseOdometryAverageVelocity, CapturePoint, MultiFramePosition,
MultiFootRelativeXYZQuat, MultiContactNormalizedSpatialForce,
MultiFootNormalizedForceVertical, MultiFootCollisionDetection,
AverageBaseMomentum)
MultiFrameSpatialAverageVelocity, AverageBaseMomentum)
from ..utils import quat_difference, quat_to_yaw

from .generic import (
Expand Down Expand Up @@ -447,11 +449,11 @@ def __init__(self,


@nb.jit(nopython=True, cache=True, fastmath=True)
def min_depth(positions: np.ndarray,
heights: np.ndarray,
normals: np.ndarray) -> float:
"""Approximate minimum distance from the ground profile among a set of the
query points.
def depth_approx(positions: np.ndarray,
heights: np.ndarray,
normals: np.ndarray) -> np.ndarray:
"""Approximate signed distance from the ground profile (positive if above
the ground, negative otherwise) of a set of the query points.
Internally, it uses a first order approximation assuming zero local
curvature around each query point.
Expand All @@ -469,12 +471,14 @@ def min_depth(positions: np.ndarray,
while the second correponds to the N individual query
points.
"""
return np.min((positions[2] - heights) * normals[2])
return (positions[2] - heights) * normals[2]


@dataclass(unsafe_hash=True)
class _MultiContactMinGroundDistance(InterfaceQuantity[float]):
"""Minimum distance from the ground profile among all the contact points.
class _MultiContactGroundDistanceAndNormal(
InterfaceQuantity[Tuple[np.ndarray, np.ndarray]]):
"""Signed distance (positive if above the ground, negative otherwise) and
normal from the ground profile of all the candidate contact points.
.. note::
Internally, it does not compute the exact shortest distance from the
Expand Down Expand Up @@ -524,7 +528,7 @@ def initialize(self) -> None:
engine_options = self.env.unwrapped.engine.get_options()
self._heightmap = engine_options["world"]["groundProfile"]

def refresh(self) -> float:
def refresh(self) -> Tuple[np.ndarray, np.ndarray]:
# Query the height and normal to the ground profile for the position in
# world plane of all the contact points.
positions = self.positions.get()
Expand All @@ -533,11 +537,13 @@ def refresh(self) -> float:
self._heights,
self._normals)

# Make sure the ground normal is normalized
# Make sure the ground normal has unit length
# self._normals /= np.linalg.norm(self._normals, axis=0)

# First-order distance estimation assuming no curvature
return min_depth(positions, self._heights, self._normals)
depth = depth_approx(positions, self._heights, self._normals)

return depth, self._normals


class FlyingTermination(QuantityTermination):
Expand Down Expand Up @@ -571,14 +577,174 @@ def __init__(self,
super().__init__(
env,
"termination_flying",
(_MultiContactMinGroundDistance, {}), # type: ignore[arg-type]
(UnaryOpQuantity, dict(
quantity=(_MultiContactGroundDistanceAndNormal, {}),
op=lambda depths_and_normals: depths_and_normals[0].min()
)),
None,
max_height,
grace_period,
is_truncation=False,
training_only=training_only)


@nb.jit(nopython=True, cache=True, fastmath=True)
def compute_velocity_tangential(velocity: np.ndarray,
normal: np.ndarray) -> np.ndarray:
"""Compute the norm of the velocity projected in the plan orthogonal to a
given normal direction vector.
.. warning::
The normal direction vector used assumed to be normalized. It is up to
the pratitioner to make sure this holds true.
:param velocity: Linear velocity in world-aligned reference frame.
:param normal: Normal direction vector in world reference frame.
"""
return np.sqrt(
np.sum(np.square(velocity), 0) - np.sum(velocity * normal, 0) ** 2)


@nb.jit(nopython=True, cache=True, fastmath=True)
def _compute_max_velocity_tangential(
velocities: np.ndarray,
depths: np.ndarray,
normals: np.ndarray,
height_thr: float) -> float:
"""Compute the maximum norm of the tangential velocity wrt local curvature
of the ground profile of all the frames that are considered in contact.
:param velocities: Linear velocity of each frame in local-world-aligned
reference frame.
:param depths: Signed distance of each frames from the ground as a vector.
:param normals: Normal direction vector that fully specify the local
curvature of the ground profile at the location of each
frame as a 2D array whose first dimension gathers the
position components (X, Y, Z) and the second corresponds
to individual frames.
:param height_thr: Distance threshold below which frames are considered in
contact with the ground.
"""
# Compute the norm of tangential velocity for all frames at once
velocities_tangential = compute_velocity_tangential(velocities, normals)

# Compute the maximum tangential velocity
velocity_tangential_max = 0.0
for depth, velocity_tangential in zip(
depths, velocities_tangential.T):
# Ignore frames that are not close enough from the ground
if depth > height_thr:
continue

# Update the maximum tangential velocity
velocity_tangential_max = max(
velocity_tangential_max, velocity_tangential)

return velocity_tangential_max


@dataclass(unsafe_hash=True)
class _MultiContactMaxVelocityTangential(AbstractQuantity[float]):
"""Maximum norm of the tangential velocity wrt local curvature of the
ground profile of all the candidate contact frames that are close enough
from the ground.
.. note::
The maximum norm of the tangential velocity is considered to be 0.0 if
none of the candidate contact frames are close enough from the ground.
"""

height_thr: float
"""Height threshold above which a candidate contact point is deemed too far
from the ground and is discarded from the set of frames being considered
when looking for the maximum norm of the tangential velocity.
"""

def __init__(self,
env: InterfaceJiminyEnv,
parent: Optional[InterfaceQuantity],
height_thr: float) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param parent: Higher-level quantity from which this quantity is a
requirement if any, `None` otherwise.
:param height_thr: Height threshold above which a candidate contact
point is ignored for being too far from the ground.
"""
# Backup some user-argument(s)
self.height_thr = height_thr

# Call base implementation
super().__init__(
env,
parent,
requirements=dict(
depths_and_normals=(
_MultiContactGroundDistanceAndNormal, {}),
v_spatial=(
MultiFrameSpatialAverageVelocity, dict(
frame_names=env.robot.contact_frame_names,
reference_frame=pin.LOCAL_WORLD_ALIGNED))),
auto_refresh=False)

def refresh(self) -> float:
# Get the average linear velocity of all the contact points
v_spatial = self.v_spatial.get()
velocities = v_spatial[:3]

# Get the distance and normal of all the contact points from the ground
depths, normals = self.depths_and_normals.get()

# Compute the maximum tangential velocity
return _compute_max_velocity_tangential(
velocities, depths, normals, self.height_thr)


class SlippageTermination(QuantityTermination):
"""Discourage the agent of sliding on the ground purposedly by terminating
the episode immediately if some of the active contact points are slipping
on the ground.
This kind of behavior is usually undesirable because they are hardly
repeatable and tend to transfer poorly to reality. Moreover, it may cause
a sense of poorly controlled motion to people nearby.
"""
def __init__(self,
env: InterfaceJiminyEnv,
height_thr: float,
max_velocity: float,
grace_period: float = 0.0,
*,
training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param height_thr: Height threshold below which a candidate contact
point is closed enough from the ground for its
tangential velocity to be considered.
:param max_velocity: Maximum norm of the tangential velocity wrt ground
of the contact points that are close enough above
which termination is triggered.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Optional: 0.0 by default.
:param training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
super().__init__(
env,
"termination_slippage",
(_MultiContactMaxVelocityTangential, dict(
height_thr=height_thr)),
None,
max_velocity,
grace_period,
is_truncation=False,
training_only=training_only)


class ImpactForceTermination(QuantityTermination):
"""Terminate the episode immediately in case of violent impact on the
ground.
Expand Down

0 comments on commit 72721a8

Please sign in to comment.