Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add default parameters to all envs for consistent get_param return #3

Merged
merged 6 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 153 additions & 69 deletions rrls/envs/ant.py

Large diffs are not rendered by default.

119 changes: 79 additions & 40 deletions rrls/envs/half_cheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@

# from gymnasium.envs.mujoco.half_cheetah_v4 import HalfCheetahEnv

DEFAULT_PARAMS = {
"worldfriction": [0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
"torsomass": 6.25020920502092,
"backthighmass": 1.5435146443514645,
"backshinmass": 1.5874476987447697,
"backfootmass": 1.0953974895397491,
"forwardthighmass": 1.4380753138075317,
"forwardshinmass": 1.200836820083682,
"forwardfootmass": 0.8845188284518829,
}


class HalfCheetahParamsBound(Enum):
ONE_DIM = {
Expand Down Expand Up @@ -47,7 +58,7 @@ class RobustHalfCheetah(Wrapper):
- forwardfootmass
"""

metadata = {
metadata = { # type: ignore
"render_modes": [
"human",
"rgb_array",
Expand All @@ -67,7 +78,7 @@ def __init__(
forwardfootmass: float | None = None,
**kwargs: dict[str, Any],
):
super().__init__(env=gym.make("HalfCheetah-v5", **kwargs))
super().__init__(env=gym.make("HalfCheetah-v5", **kwargs)) # type: ignore

self.set_params(
worldfriction=worldfriction,
Expand All @@ -93,13 +104,41 @@ def set_params(
forwardfootmass: float | None = None,
):
self.worldfriction = worldfriction
self.torsomass = torsomass
self.backthighmass = backthighmass
self.backshinmass = backshinmass
self.backfootmass = backfootmass
self.forwardthighmass = forwardthighmass
self.forwardshinmass = forwardshinmass
self.forwardfootmass = forwardfootmass
self.torsomass = (
torsomass
if torsomass is not None
else getattr(self, "torsomass", DEFAULT_PARAMS["torsomass"])
)
self.backthighmass = (
backthighmass
if backthighmass is not None
else getattr(self, "backthighmass", DEFAULT_PARAMS["backthighmass"])
)
self.backshinmass = (
backshinmass
if backshinmass is not None
else getattr(self, "backshinmass", DEFAULT_PARAMS["backshinmass"])
)
self.backfootmass = (
backfootmass
if backfootmass is not None
else getattr(self, "backfootmass", DEFAULT_PARAMS["backfootmass"])
)
self.forwardthighmass = (
forwardthighmass
if forwardthighmass is not None
else getattr(self, "forwardthighmass", DEFAULT_PARAMS["forwardthighmass"])
)
self.forwardshinmass = (
forwardshinmass
if forwardshinmass is not None
else getattr(self, "forwardshinmass", DEFAULT_PARAMS["forwardshinmass"])
)
self.forwardfootmass = (
forwardfootmass
if forwardfootmass is not None
else getattr(self, "forwardfootmass", DEFAULT_PARAMS["forwardfootmass"])
)
self._change_params()

def get_params(self):
Expand Down Expand Up @@ -128,21 +167,21 @@ def step(self, action):

def _change_params(self):
if self.worldfriction is not None:
self.unwrapped.model.geom_friction[:, 0] = self.worldfriction
self.unwrapped.model.geom_friction[:, 0] = self.worldfriction # type: ignore
if self.torsomass is not None:
self.unwrapped.model.body_mass[1] = self.torsomass
self.unwrapped.model.body_mass[1] = self.torsomass # type: ignore
if self.backthighmass is not None:
self.unwrapped.model.body_mass[2] = self.backthighmass
self.unwrapped.model.body_mass[2] = self.backthighmass # type: ignore
if self.backshinmass is not None:
self.unwrapped.model.body_mass[3] = self.backshinmass
self.unwrapped.model.body_mass[3] = self.backshinmass # type: ignore
if self.backfootmass is not None:
self.unwrapped.model.body_mass[4] = self.backfootmass
self.unwrapped.model.body_mass[4] = self.backfootmass # type: ignore
if self.forwardthighmass is not None:
self.unwrapped.model.body_mass[5] = self.forwardthighmass
self.unwrapped.model.body_mass[5] = self.forwardthighmass # type: ignore
if self.forwardshinmass is not None:
self.unwrapped.model.body_mass[6] = self.forwardshinmass
self.unwrapped.model.body_mass[6] = self.forwardshinmass # type: ignore
if self.forwardfootmass is not None:
self.unwrapped.model.body_mass[7] = self.forwardfootmass
self.unwrapped.model.body_mass[7] = self.forwardfootmass # type: ignore


class ForceHalfCheetah(Wrapper):
Expand Down Expand Up @@ -172,7 +211,7 @@ class ForceHalfCheetah(Wrapper):
- forwardfootforce_z
"""

metadata = {
metadata = { # type: ignore
"render_modes": [
"human",
"rgb_array",
Expand All @@ -181,7 +220,7 @@ class ForceHalfCheetah(Wrapper):
}

def __init__(self, **kwargs: dict[str, Any]):
super().__init__(env=gym.make("HalfCheetah-v5", **kwargs))
super().__init__(env=gym.make("HalfCheetah-v5", **kwargs)) # type: ignore
self.set_params()
self._change_params()

Expand Down Expand Up @@ -259,47 +298,47 @@ def get_params(self):

def _change_params(self):
if self.torsoforce_x is not None:
self.unwrapped.data.xfrc_applied[1, 0] = self.torsoforce_x
self.unwrapped.data.xfrc_applied[1, 0] = self.torsoforce_x # type: ignore
if self.torsoforce_y is not None:
self.unwrapped.data.xfrc_applied[1, 1] = self.torsoforce_y
self.unwrapped.data.xfrc_applied[1, 1] = self.torsoforce_y # type: ignore
if self.torsoforce_z is not None:
self.unwrapped.data.xfrc_applied[1, 2] = self.torsoforce_z
self.unwrapped.data.xfrc_applied[1, 2] = self.torsoforce_z # type: ignore
if self.backthighforce_x is not None:
self.unwrapped.data.xfrc_applied[2, 0] = self.backthighforce_x
self.unwrapped.data.xfrc_applied[2, 0] = self.backthighforce_x # type: ignore
if self.backthighforce_y is not None:
self.unwrapped.data.xfrc_applied[2, 1] = self.backthighforce_y
self.unwrapped.data.xfrc_applied[2, 1] = self.backthighforce_y # type: ignore
if self.backthighforce_z is not None:
self.unwrapped.data.xfrc_applied[2, 2] = self.backthighforce_z
self.unwrapped.data.xfrc_applied[2, 2] = self.backthighforce_z # type: ignore
if self.backshinforce_x is not None:
self.unwrapped.data.xfrc_applied[3, 0] = self.backshinforce_x
self.unwrapped.data.xfrc_applied[3, 0] = self.backshinforce_x # type: ignore
if self.backshinforce_y is not None:
self.unwrapped.data.xfrc_applied[3, 1] = self.backshinforce_y
self.unwrapped.data.xfrc_applied[3, 1] = self.backshinforce_y # type: ignore
if self.backshinforce_z is not None:
self.unwrapped.data.xfrc_applied[3, 2] = self.backshinforce_z
self.unwrapped.data.xfrc_applied[3, 2] = self.backshinforce_z # type: ignore
if self.backfootforce_x is not None:
self.unwrapped.data.xfrc_applied[4, 0] = self.backfootforce_x
self.unwrapped.data.xfrc_applied[4, 0] = self.backfootforce_x # type: ignore
if self.backfootforce_y is not None:
self.unwrapped.data.xfrc_applied[4, 1] = self.backfootforce_y
self.unwrapped.data.xfrc_applied[4, 1] = self.backfootforce_y # type: ignore
if self.backfootforce_z is not None:
self.unwrapped.data.xfrc_applied[4, 2] = self.backfootforce_z
self.unwrapped.data.xfrc_applied[4, 2] = self.backfootforce_z # type: ignore
if self.forwardthighforce_x is not None:
self.unwrapped.data.xfrc_applied[5, 0] = self.forwardthighforce_x
self.unwrapped.data.xfrc_applied[5, 0] = self.forwardthighforce_x # type: ignore
if self.forwardthighforce_y is not None:
self.unwrapped.data.xfrc_applied[5, 1] = self.forwardthighforce_y
self.unwrapped.data.xfrc_applied[5, 1] = self.forwardthighforce_y # type: ignore
if self.forwardthighforce_z is not None:
self.unwrapped.data.xfrc_applied[5, 2] = self.forwardthighforce_z
self.unwrapped.data.xfrc_applied[5, 2] = self.forwardthighforce_z # type: ignore
if self.forwardshinforce_x is not None:
self.unwrapped.data.xfrc_applied[6, 0] = self.forwardshinforce_x
self.unwrapped.data.xfrc_applied[6, 0] = self.forwardshinforce_x # type: ignore
if self.forwardshinforce_y is not None:
self.unwrapped.data.xfrc_applied[6, 1] = self.forwardshinforce_y
self.unwrapped.data.xfrc_applied[6, 1] = self.forwardshinforce_y # type: ignore
if self.forwardshinforce_z is not None:
self.unwrapped.data.xfrc_applied[6, 2] = self.forwardshinforce_z
self.unwrapped.data.xfrc_applied[6, 2] = self.forwardshinforce_z # type: ignore
if self.forwardfootforce_x is not None:
self.unwrapped.data.xfrc_applied[7, 0] = self.forwardfootforce_x
self.unwrapped.data.xfrc_applied[7, 0] = self.forwardfootforce_x # type: ignore
if self.forwardfootforce_y is not None:
self.unwrapped.data.xfrc_applied[7, 1] = self.forwardfootforce_y
self.unwrapped.data.xfrc_applied[7, 1] = self.forwardfootforce_y # type: ignore
if self.forwardfootforce_z is not None:
self.unwrapped.data.xfrc_applied[7, 2] = self.forwardfootforce_z
self.unwrapped.data.xfrc_applied[7, 2] = self.forwardfootforce_z # type: ignore

def reset(self, *, seed: int | None = None, options: dict | None = None):
if options is not None:
Expand Down
82 changes: 56 additions & 26 deletions rrls/envs/hopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ class HopperParamsBound(Enum):
}


DEFAULT_PARAMS = {
"worldfriction": 0.7,
"torsomass": 3.6651914291880923,
"thighmass": 4.057890510886818,
"legmass": 2.7813566959781637,
"footmass": 5.315574769873931,
}


class RobustHopper(Wrapper):
"""
Robust Hopper environment. You can change the parameters of the environment using options in
Expand All @@ -38,7 +47,7 @@ class RobustHopper(Wrapper):
- footmass
"""

metadata = {
metadata = { # type: ignore
"render_modes": [
"human",
"rgb_array",
Expand All @@ -55,7 +64,7 @@ def __init__(
footmass: float | None = None,
**kwargs: dict[str, Any],
):
super().__init__(env=gym.make("Hopper-v5", **kwargs))
super().__init__(env=gym.make("Hopper-v5", **kwargs)) # type: ignore

self.set_params(
worldfriction=worldfriction,
Expand All @@ -74,11 +83,32 @@ def set_params(
legmass: float | None = None,
footmass: float | None = None,
):
self.worldfriction = worldfriction
self.torsomass = torsomass
self.thighmass = thighmass
self.legmass = legmass
self.footmass = footmass
self.worldfriction = (
worldfriction
if worldfriction is not None
else getattr(self, "worldfriction", DEFAULT_PARAMS["worldfriction"])
)
self.torsomass = (
torsomass
if torsomass is not None
else getattr(self, "torsomass", DEFAULT_PARAMS["torsomass"])
)
self.thighmass = (
thighmass
if thighmass is not None
else getattr(self, "thighmass", DEFAULT_PARAMS["thighmass"])
)
self.legmass = (
legmass
if legmass is not None
else getattr(self, "legmass", DEFAULT_PARAMS["legmass"])
)
self.footmass = (
footmass
if footmass is not None
else getattr(self, "footmass", DEFAULT_PARAMS["footmass"])
)

self._change_params()

def get_params(self):
Expand All @@ -104,19 +134,19 @@ def step(self, action):

def _change_params(self):
if self.worldfriction is not None:
self.unwrapped.model.geom_friction[0, 0] = self.worldfriction
self.unwrapped.model.geom_friction[0, 0] = self.worldfriction # type: ignore

if self.torsomass is not None:
self.unwrapped.model.body_mass[1] = self.torsomass
self.unwrapped.model.body_mass[1] = self.torsomass # type: ignore

if self.thighmass is not None:
self.unwrapped.model.body_mass[2] = self.thighmass
self.unwrapped.model.body_mass[2] = self.thighmass # type: ignore

if self.legmass is not None:
self.unwrapped.model.body_mass[3] = self.legmass
self.unwrapped.model.body_mass[3] = self.legmass # type: ignore

if self.footmass is not None:
self.unwrapped.model.body_mass[4] = self.footmass
self.unwrapped.model.body_mass[4] = self.footmass # type: ignore


class ForceHopper(Wrapper):
Expand All @@ -137,7 +167,7 @@ class ForceHopper(Wrapper):
- footforce_z
"""

metadata = {
metadata = { # type: ignore
"render_modes": [
"human",
"rgb_array",
Expand All @@ -146,7 +176,7 @@ class ForceHopper(Wrapper):
}

def __init__(self, **kwargs: dict[str, Any]):
super().__init__(env=gym.make("Hopper-v5", **kwargs))
super().__init__(env=gym.make("Hopper-v5", **kwargs)) # type: ignore
self.set_params()

def set_params(
Expand Down Expand Up @@ -196,40 +226,40 @@ def get_params(self):

def _change_params(self):
if self.torsoforce_x is not None:
self.unwrapped.data.xfrc_applied[1, 0] = self.torsoforce_x
self.unwrapped.data.xfrc_applied[1, 0] = self.torsoforce_x # type: ignore

if self.torsoforce_y is not None:
self.unwrapped.data.xfrc_applied[1, 1] = self.torsoforce_y
self.unwrapped.data.xfrc_applied[1, 1] = self.torsoforce_y # type: ignore

if self.torsoforce_z is not None:
self.unwrapped.data.xfrc_applied[1, 2] = self.torsoforce_z
self.unwrapped.data.xfrc_applied[1, 2] = self.torsoforce_z # type: ignore

if self.thighforce_x is not None:
self.unwrapped.data.xfrc_applied[2, 0] = self.thighforce_x
self.unwrapped.data.xfrc_applied[2, 0] = self.thighforce_x # type: ignore

if self.thighforce_y is not None:
self.unwrapped.data.xfrc_applied[2, 1] = self.thighforce_y
self.unwrapped.data.xfrc_applied[2, 1] = self.thighforce_y # type: ignore

if self.thighforce_z is not None:
self.unwrapped.data.xfrc_applied[2, 2] = self.thighforce_z
self.unwrapped.data.xfrc_applied[2, 2] = self.thighforce_z # type: ignore

if self.legforce_x is not None:
self.unwrapped.data.xfrc_applied[3, 0] = self.legforce_x
self.unwrapped.data.xfrc_applied[3, 0] = self.legforce_x # type: ignore

if self.legforce_y is not None:
self.unwrapped.data.xfrc_applied[3, 1] = self.legforce_y
self.unwrapped.data.xfrc_applied[3, 1] = self.legforce_y # type: ignore

if self.legforce_z is not None:
self.unwrapped.data.xfrc_applied[3, 2] = self.legforce_z
self.unwrapped.data.xfrc_applied[3, 2] = self.legforce_z # type: ignore

if self.footforce_x is not None:
self.unwrapped.data.xfrc_applied[4, 0] = self.footforce_x
self.unwrapped.data.xfrc_applied[4, 0] = self.footforce_x # type: ignore

if self.footforce_y is not None:
self.unwrapped.data.xfrc_applied[4, 1] = self.footforce_y
self.unwrapped.data.xfrc_applied[4, 1] = self.footforce_y # type: ignore

if self.footforce_z is not None:
self.unwrapped.data.xfrc_applied[4, 2] = self.footforce_z
self.unwrapped.data.xfrc_applied[4, 2] = self.footforce_z # type: ignore

def reset(self, *, seed: int | None = None, options: dict | None = None):
if options is not None:
Expand Down
Loading
Loading