Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix runner config parsing and support permute operation in model instantiators #196

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
### Changed
- Move the KL reduction from the PyTorch `KLAdaptiveLR` class to each agent that uses it in distributed runs
- Move the PyTorch distributed initialization from the agent base class to the ML framework configuration
- Implement model instantiators using dynamic execution of Python code
- Upgrade model instantiator implementations to support CNN layers and complex network definitions,
and implement them using dynamic execution of Python code
- Update Isaac Lab environment loader argument parser options to match Isaac Lab version

### Changed (breaking changes)
Expand All @@ -24,7 +25,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Fixed
- Catch TensorBoard summary iterator exceptions in `TensorboardFileIterator` postprocessing utils
- Fix automatic wrapper detection for Isaac Gym (previews), DeepMind and vectorized Gymnasium environments
- Fix automatic wrapper detection issue (introduced in previous version) for Isaac Gym (previews),
DeepMind and vectorized Gymnasium environments
- Fix vectorized/parallel environments `reset` method return values when called more than once
- IPPO and MAPPO `act` method return values when JAX-NumPy backend is enabled

Expand Down
2 changes: 2 additions & 0 deletions docs/source/api/utils/model_instantiators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ Supported operations:
- ``features_extractor + ACTIONS``
* - Concatenation
- ``concatenate([features_extractor, ACTIONS])``
* - Permute dimensions
- ``permute(STATES, (0, 3, 1, 2))``

|

Expand Down
3 changes: 3 additions & 0 deletions skrl/utils/model_instantiators/jax/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def visit_Call(self, node: ast.Call):
if node.func.id == "concatenate":
node.func = ast.Attribute(value=ast.Name("jnp"), attr="concatenate")
node.keywords = [ast.keyword(arg="axis", value=ast.Constant(value=-1))]
# operation: permute
if node.func.id == "permute":
node.func = ast.Attribute(value=ast.Name("jnp"), attr="permute_dims")
return node

# apply operations by modifying the source syntax grammar
Expand Down
3 changes: 3 additions & 0 deletions skrl/utils/model_instantiators/torch/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def visit_Call(self, node: ast.Call):
if node.func.id == "concatenate":
node.func = ast.Attribute(value=ast.Name("torch"), attr="cat")
node.keywords = [ast.keyword(arg="dim", value=ast.Constant(value=1))]
# operation: permute
if node.func.id == "permute":
node.func = ast.Attribute(value=ast.Name("torch"), attr="permute")
return node

# apply operations by modifying the source syntax grammar
Expand Down
10 changes: 5 additions & 5 deletions skrl/utils/runner/jax/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

import copy

from skrl import config, logger
from skrl import logger
from skrl.agents.jax import Agent
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.envs.wrappers.jax import MultiAgentEnvWrapper, Wrapper
from skrl.memories.jax import RandomMemory
from skrl.models.jax import Model
from skrl.multi_agents.jax.ippo import IPPO, IPPO_DEFAULT_CONFIG
from skrl.multi_agents.jax.mappo import MAPPO, MAPPO_DEFAULT_CONFIG
from skrl.resources.preprocessors.jax import RunningStandardScaler
from skrl.resources.schedulers.jax import KLAdaptiveLR
from skrl.resources.preprocessors.jax import RunningStandardScaler # noqa
from skrl.resources.schedulers.jax import KLAdaptiveLR # noqa
from skrl.trainers.jax import SequentialTrainer, Trainer
from skrl.utils import set_seed
from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model
Expand Down Expand Up @@ -121,12 +121,12 @@ def update_dict(d):
update_dict(value)
else:
if key in _direct_eval:
d[key] = eval(value)
if type(d[key]) is str:
d[key] = eval(value)
elif key.endswith("_kwargs"):
d[key] = value if value is not None else {}
elif key in ["rewards_shaper_scale"]:
d["rewards_shaper"] = reward_shaper_function(value)

return d

return update_dict(copy.deepcopy(cfg))
Expand Down
8 changes: 4 additions & 4 deletions skrl/utils/runner/torch/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from skrl.models.torch import Model
from skrl.multi_agents.torch.ippo import IPPO, IPPO_DEFAULT_CONFIG
from skrl.multi_agents.torch.mappo import MAPPO, MAPPO_DEFAULT_CONFIG
from skrl.resources.preprocessors.torch import RunningStandardScaler
from skrl.resources.schedulers.torch import KLAdaptiveLR
from skrl.resources.preprocessors.torch import RunningStandardScaler # noqa
from skrl.resources.schedulers.torch import KLAdaptiveLR # noqa
from skrl.trainers.torch import SequentialTrainer, Trainer
from skrl.utils import set_seed
from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model
Expand Down Expand Up @@ -121,12 +121,12 @@ def update_dict(d):
update_dict(value)
else:
if key in _direct_eval:
d[key] = eval(value)
if type(d[key]) is str:
d[key] = eval(value)
elif key.endswith("_kwargs"):
d[key] = value if value is not None else {}
elif key in ["rewards_shaper_scale"]:
d["rewards_shaper"] = reward_shaper_function(value)

return d

return update_dict(copy.deepcopy(cfg))
Expand Down
Loading