Skip to content

Commit

Permalink
[gym_jiminy/common] Add RBF kernel shape arg to all error-based rewards.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Feb 19, 2025
1 parent 72721a8 commit 52e6fdb
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 73 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pylint: disable=missing-module-docstring

from .mixin import (CUTOFF_ESP,
from .mixin import (KernelShape,
radial_basis_function,
AdditiveMixtureReward,
MultiplicativeMixtureReward)
Expand Down Expand Up @@ -34,7 +34,7 @@
ImpactForceTermination)

__all__ = [
"CUTOFF_ESP",
"KernelShape",
"radial_basis_function",
"AdditiveMixtureReward",
"MultiplicativeMixtureReward",
Expand Down
58 changes: 28 additions & 30 deletions python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
MultiActuatedJointKinematic, MechanicalPowerConsumption,
AverageMechanicalPowerConsumption)

from .mixin import radial_basis_function
from .mixin import KernelShape, radial_basis_function


ValueT = TypeVar('ValueT')
Expand Down Expand Up @@ -69,9 +69,9 @@ class TrackingQuantityReward(QuantityReward):
otherwise an exception will be risen. See `DatasetTrajectoryQuantity` and
`AbstractQuantity` documentations for details.
The error is transformed in a normalized reward to maximize by applying RBF
kernel on the error. The reward will be 0.0 if the error cancels out
completely and less than 'CUTOFF_ESP' above the user-specified cutoff
The error is transformed in a normalized reward to maximize by applying a
given RBF kernel on the error. The reward will be 0.0 if the error cancels
out completely and less than 'CUTOFF_ESP' above the user-specified cutoff
threshold.
"""
def __init__(self,
Expand All @@ -80,6 +80,7 @@ def __init__(self,
quantity_creator: Callable[
[QuantityEvalMode], QuantityCreator[ValueT]],
cutoff: float,
shape: KernelShape = KernelShape.SQUARED_EXPONENTIAL,
*,
op: Callable[[ValueT, ValueT], ValueT] = sub,
order: int = 2) -> None:
Expand All @@ -97,6 +98,8 @@ def __init__(self,
keyword-arguments of its constructor except
'env' and 'parent'.
:param cutoff: Cutoff threshold for the RBF kernel transform.
:param shape: Desired type of RBF kernel.
Optional: `KernelShape.SQUARED_EXPONENTIAL` by default.
:param op: Any callable taking the true and reference values of the
quantity as input argument and returning the difference
between them, considering the algebra defined by their Lie
Expand All @@ -106,18 +109,17 @@ def __init__(self,
:param order: Order of L^p-norm that will be used as distance metric.
Optional: 2 by default.
"""
# Backup some user argument(s)
self.cutoff = cutoff

# Call base implementation
super().__init__(
env,
name,
(BinaryOpQuantity, dict(
quantity_left=quantity_creator(QuantityEvalMode.TRUE),
quantity_right=quantity_creator(QuantityEvalMode.REFERENCE),
op=op)),
partial(radial_basis_function, cutoff=self.cutoff, order=order),
partial(radial_basis_function, **dict(
cutoff=cutoff,
shape=int(shape),
order=order)),
is_normalized=True,
is_terminal=False)

Expand All @@ -131,23 +133,23 @@ class TrackingActuatedJointPositionsReward(TrackingQuantityReward):
"""
def __init__(self,
env: InterfaceJiminyEnv,
cutoff: float) -> None:
cutoff: float,
shape: KernelShape = KernelShape.SQUARED_EXPONENTIAL) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param cutoff: Cutoff threshold for the RBF kernel transform.
:param shape: Desired type of RBF kernel.
Optional: `KernelShape.SQUARED_EXPONENTIAL` by default.
"""
# Backup some user argument(s)
self.cutoff = cutoff

# Call base implementation
super().__init__(
env,
"reward_actuated_joint_positions",
lambda mode: (MultiActuatedJointKinematic, dict(
kinematic_level=pin.KinematicLevel.POSITION,
is_motor_side=False,
mode=mode)),
cutoff)
cutoff,
shape)


class MinimizeMechanicalPowerConsumption(QuantityReward):
Expand All @@ -162,31 +164,33 @@ def __init__(
self,
env: InterfaceJiminyEnv,
cutoff: float,
shape: KernelShape = KernelShape.SQUARED_EXPONENTIAL,
*,
horizon: float,
generator_mode: EnergyGenerationMode = EnergyGenerationMode.CHARGE
) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param cutoff: Cutoff threshold for the RBF kernel transform.
:param shape: Desired type of RBF kernel.
Optional: `KernelShape.SQUARED_EXPONENTIAL` by default.
:param horizon: Horizon over which values of the quantity will be
stacked before computing the average.
:param generator_mode: Specify what happens to the energy generated by
motors when breaking.
Optional: `EnergyGenerationMode.CHARGE` by
default.
"""
# Backup some user argument(s)
self.cutoff = cutoff

# Call base implementation
super().__init__(
env,
"reward_power_consumption",
(AverageMechanicalPowerConsumption, dict(
horizon=horizon,
generator_mode=generator_mode)),
partial(radial_basis_function, cutoff=self.cutoff, order=2),
partial(radial_basis_function, **dict(
cutoff=cutoff,
shape=int(shape),
order=2)),
is_normalized=True,
is_terminal=False)

Expand Down Expand Up @@ -632,21 +636,16 @@ def __init__(
evaluation mode.
Optional: False by default.
"""
# Backup user argument(s)
self.max_power = max_power
self.horizon = horizon
self.generator_mode = generator_mode

# Pick the right quantity creator depending on the horizon
quantity_creator: QuantityCreator
if horizon is None:
quantity_creator = (MechanicalPowerConsumption, dict(
generator_mode=self.generator_mode,
generator_mode=generator_mode,
mode=QuantityEvalMode.TRUE))
else:
quantity_creator = (AverageMechanicalPowerConsumption, dict(
horizon=self.horizon,
generator_mode=self.generator_mode,
horizon=horizon,
generator_mode=generator_mode,
mode=QuantityEvalMode.TRUE))

# Call base implementation
Expand All @@ -655,7 +654,7 @@ def __init__(
"termination_power_consumption",
quantity_creator,
None,
self.max_power,
max_power,
grace_period,
is_truncation=False,
training_only=training_only)
Expand Down Expand Up @@ -701,7 +700,6 @@ def __init__(self,
evaluation mode.
Optional: False by default.
"""
# Call base implementation
super().__init__(
env,
"termination_tracking_motor_positions",
Expand Down
Loading

0 comments on commit 52e6fdb

Please sign in to comment.