Skip to content

Commit

Permalink
Use a default config for backward compatibility when runner cfg's mem…
Browse files Browse the repository at this point in the history
…ory field is not defined
  • Loading branch information
Toni-SM committed Aug 21, 2024
1 parent 8cd1210 commit 7ce5153
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
16 changes: 13 additions & 3 deletions skrl/utils/runner/jax/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ def _generate_models(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappi
# override cfg
cfg["models"]["separate"] = True # shared model is not supported in JAX

try:
agent_class = self._class(cfg["agent"]["class"])
del cfg["agent"]["class"]
except KeyError:
agent_class = self._class("PPO")
logger.warning("No 'class' field defined in 'agent' cfg. 'PPO' will be used as default")

# instantiate models
models = {}
for agent_id in possible_agents:
Expand Down Expand Up @@ -181,7 +188,7 @@ def _generate_models(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappi
logger.warning("No 'class' field defined in 'models:value' cfg. 'DeterministicMixin' will be used as default")
# instantiate model
models[agent_id]["value"] = model_class(
observation_space=(state_spaces if self._class(_cfg["agent"]["class"]) in [MAPPO] else observation_spaces)[agent_id],
observation_space=(state_spaces if agent_class in [MAPPO] else observation_spaces)[agent_id],
action_space=action_spaces[agent_id],
device=device,
**self._process_cfg(_cfg["models"]["value"]),
Expand Down Expand Up @@ -236,15 +243,18 @@ def _generate_agent(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappin
observation_spaces = env.observation_spaces if multi_agent else {"agent": env.observation_space}
action_spaces = env.action_spaces if multi_agent else {"agent": env.action_space}

# instantiate memories
memories = {}
# check for memory configuration (backward compatibility)
if not "memory" in cfg:
logger.warning("Deprecation warning: No 'memory' field defined in cfg. Using the default generated configuration")
cfg["memory"] = {"class": "RandomMemory", "memory_size": -1}
# get memory class and remove 'class' field
try:
memory_class = self._class(cfg["memory"]["class"])
del cfg["memory"]["class"]
except KeyError:
memory_class = self._class("RandomMemory")
logger.warning("No 'class' field defined in 'memory' cfg. 'RandomMemory' will be used as default")
memories = {}
# instantiate memory
if cfg["memory"]["memory_size"] < 0:
cfg["memory"]["memory_size"] = cfg["agent"]["rollouts"] # memory_size is the agent's number of rollouts
Expand Down
16 changes: 13 additions & 3 deletions skrl/utils/runner/torch/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ def _generate_models(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappi
observation_spaces = env.observation_spaces if multi_agent else {"agent": env.observation_space}
action_spaces = env.action_spaces if multi_agent else {"agent": env.action_space}

try:
agent_class = self._class(cfg["agent"]["class"])
del cfg["agent"]["class"]
except KeyError:
agent_class = self._class("PPO")
logger.warning("No 'class' field defined in 'agent' cfg. 'PPO' will be used as default")

# instantiate models
models = {}
for agent_id in possible_agents:
Expand Down Expand Up @@ -178,7 +185,7 @@ def _generate_models(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappi
logger.warning("No 'class' field defined in 'models:value' cfg. 'DeterministicMixin' will be used as default")
# instantiate model
models[agent_id]["value"] = model_class(
observation_space=(state_spaces if self._class(_cfg["agent"]["class"]) in [MAPPO] else observation_spaces)[agent_id],
observation_space=(state_spaces if agent_class in [MAPPO] else observation_spaces)[agent_id],
action_space=action_spaces[agent_id],
device=device,
**self._process_cfg(_cfg["models"]["value"]),
Expand Down Expand Up @@ -228,15 +235,18 @@ def _generate_agent(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mappin
observation_spaces = env.observation_spaces if multi_agent else {"agent": env.observation_space}
action_spaces = env.action_spaces if multi_agent else {"agent": env.action_space}

# instantiate memories
memories = {}
# check for memory configuration (backward compatibility)
if not "memory" in cfg:
logger.warning("Deprecation warning: No 'memory' field defined in cfg. Using the default generated configuration")
cfg["memory"] = {"class": "RandomMemory", "memory_size": -1}
# get memory class and remove 'class' field
try:
memory_class = self._class(cfg["memory"]["class"])
del cfg["memory"]["class"]
except KeyError:
memory_class = self._class("RandomMemory")
logger.warning("No 'class' field defined in 'memory' cfg. 'RandomMemory' will be used as default")
memories = {}
# instantiate memory
if cfg["memory"]["memory_size"] < 0:
cfg["memory"]["memory_size"] = cfg["agent"]["rollouts"] # memory_size is the agent's number of rollouts
Expand Down

0 comments on commit 7ce5153

Please sign in to comment.