Skip to content

Commit

Permalink
Add RARL force envs (#2)
Browse files Browse the repository at this point in the history
* add monitor compatibility

* add force envs

* reduce RARL parambounds as in original paper

* after pre-commit

* remove typo

* add more params bound

* flake8 increase max-complexity
  • Loading branch information
DavidBert authored Dec 8, 2023
1 parent 64d9d2f commit 8f7d1cc
Show file tree
Hide file tree
Showing 13 changed files with 1,388 additions and 202 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,5 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
.vscode/
sandbox/
*.mp4
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
args:
- '--per-file-ignores=*/__init__.py:F401'
- --ignore=E203,W503,E741
- --max-complexity=30
- --max-complexity=45
- --max-line-length=456
- --show-source
- --statistics
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"mujoco-py<2.2,>=2.1",
"numpy>=1.21.0",
"gymnasium>=0.26",
"moviepy>=1.0.3",
]
dynamic = ["version"]

Expand Down
103 changes: 103 additions & 0 deletions rrls/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,42 @@ def register_robotics_envs():
order_enforce=False,
disable_env_checker=True,
)
register(
id="rrls/force-ant-v0",
entry_point="rrls.envs.ant:ForceAnt",
order_enforce=False,
disable_env_checker=True,
)
register(
id="rrls/force-halfcheetah-v0",
entry_point="rrls.envs.half_cheetah:ForceHalfCheetah",
order_enforce=False,
disable_env_checker=True,
)
register(
id="rrls/force-hopper-v0",
entry_point="rrls.envs.hopper:ForceHopper",
order_enforce=False,
disable_env_checker=True,
)
register(
id="rrls/force-humanoidstandup-v0",
entry_point="rrls.envs.humanoid:ForceHumanoidStandUp",
order_enforce=False,
disable_env_checker=True,
)
register(
id="rrls/force-invertedpendulum-v0",
entry_point="rrls.envs.pendulum:ForceInvertedPendulum",
order_enforce=False,
disable_env_checker=True,
)
register(
id="rrls/force-walker-v0",
entry_point="rrls.envs.walker:ForceWalker2d",
order_enforce=False,
disable_env_checker=True,
)

# Advserarial environments
# HalfCheetah
Expand Down Expand Up @@ -127,6 +163,73 @@ def register_robotics_envs():
"params_bound": envs.AntParamsBound.ONE_DIM.value,
},
)
register(
id="rrls/robust-ant-adversarial-forces-v0",
entry_point=make_wrapped_env, # type: ignore
order_enforce=False,
disable_env_checker=True,
kwargs={
"cls_env": envs.ForceAnt,
"wrapper": wrappers.DynamicAdversarial,
"params_bound": envs.AntParamsBound.RARL.value,
},
)
register(
id="rrls/robust-halfcheetah-adversarial-forces-v0",
entry_point=make_wrapped_env, # type: ignore
order_enforce=False,
disable_env_checker=True,
kwargs={
"cls_env": envs.ForceHalfCheetah,
"wrapper": wrappers.DynamicAdversarial,
"params_bound": envs.HalfCheetahParamsBound.RARL.value,
},
)
register(
id="rrls/robust-hopper-adversarial-forces-v0",
entry_point=make_wrapped_env, # type: ignore
order_enforce=False,
disable_env_checker=True,
kwargs={
"cls_env": envs.ForceHopper,
"wrapper": wrappers.DynamicAdversarial,
"params_bound": envs.HopperParamsBound.RARL.value,
},
)
register(
id="rrls/robust-humanoidstandup-adversarial-forces-v0",
entry_point=make_wrapped_env, # type: ignore
order_enforce=False,
disable_env_checker=True,
kwargs={
"cls_env": envs.ForceHumanoidStandUp,
"wrapper": wrappers.DynamicAdversarial,
"params_bound": envs.HumanoidStandupParamsBound.RARL.value,
},
)
register(
id="rrls/robust-invertedpendulum-adversarial-forces-v0",
entry_point=make_wrapped_env, # type: ignore
order_enforce=False,
disable_env_checker=True,
kwargs={
"cls_env": envs.ForceInvertedPendulum,
"wrapper": wrappers.DynamicAdversarial,
"params_bound": envs.InvertedPendulumParamsBound.RARL.value,
},
)
register(
id="rrls/robust-walker-adversarial-forces-v0",
entry_point=make_wrapped_env, # type: ignore
order_enforce=False,
disable_env_checker=True,
kwargs={
"cls_env": envs.ForceWalker2d,
"wrapper": wrappers.DynamicAdversarial,
"params_bound": envs.Walker2dParamsBound.RARL.value,
},
)

# Hopper
register(
id="rrls/robust-hopper-adversarial-3d-v0",
Expand Down
26 changes: 20 additions & 6 deletions rrls/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from __future__ import annotations

from .ant import AntParamsBound, RobustAnt
from .half_cheetah import HalfCheetahParamsBound, RobustHalfCheetah
from .hopper import HopperParamsBound, RobustHopper
from .humanoid import HumanoidStandupParamsBound, RobustHumanoidStandUp
from .pendulum import InvertedPendulumParamsBound, RobustInvertedPendulum
from .walker import RobustWalker2d, Walker2dParamsBound
from .ant import AntParamsBound, ForceAnt, RobustAnt
from .half_cheetah import ForceHalfCheetah, HalfCheetahParamsBound, RobustHalfCheetah
from .hopper import ForceHopper, HopperParamsBound, RobustHopper
from .humanoid import (
ForceHumanoidStandUp,
HumanoidStandupParamsBound,
RobustHumanoidStandUp,
)
from .pendulum import (
ForceInvertedPendulum,
InvertedPendulumParamsBound,
RobustInvertedPendulum,
)
from .walker import ForceWalker2d, RobustWalker2d, Walker2dParamsBound

__all__ = [
"AntParamsBound",
Expand All @@ -20,4 +28,10 @@
"RobustHumanoidStandUp",
"RobustInvertedPendulum",
"RobustWalker2d",
"ForceAnt",
"ForceHalfCheetah",
"ForceHopper",
"ForceHumanoidStandUp",
"ForceInvertedPendulum",
"ForceWalker2d",
]
Loading

0 comments on commit 8f7d1cc

Please sign in to comment.