Skip to content

Commit

Permalink
[Refactor] Use <spec>_unbatched in VMAS
Browse files Browse the repository at this point in the history
ghstack-source-id: 2190278de44ba59a3bc8d38398fddae9ecc42a84
Pull Request resolved: #2593
  • Loading branch information
vmoens committed Nov 20, 2024
1 parent d30599e commit a126a6f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 18 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/multiagent/qmix_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def train(cfg: "DictConfig"): # noqa: F821
if cfg.loss.mixer_type == "qmix":
mixer = TensorDictModule(
module=QMixer(
state_shape=env.unbatched_observation_spec[
state_shape=env.observation_spec_unbatched[
"agents", "observation"
].shape,
mixing_embed_dim=32,
Expand Down
61 changes: 44 additions & 17 deletions torchrl/envs/libs/vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import importlib.util
import warnings

from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -328,9 +329,9 @@ def _make_specs(
self.group_map = self.group_map.get_group_map(self.agent_names)
check_marl_grouping(self.group_map, self.agent_names)

self.unbatched_action_spec = Composite(device=self.device)
self.unbatched_observation_spec = Composite(device=self.device)
self.unbatched_reward_spec = Composite(device=self.device)
full_action_spec_unbatched = Composite(device=self.device)
full_observation_spec_unbatched = Composite(device=self.device)
full_reward_spec_unbatched = Composite(device=self.device)

self.het_specs = False
self.het_specs_map = {}
Expand All @@ -341,18 +342,18 @@ def _make_specs(
group_reward_spec,
group_info_spec,
) = self._make_unbatched_group_specs(group)
self.unbatched_action_spec[group] = group_action_spec
self.unbatched_observation_spec[group] = group_observation_spec
self.unbatched_reward_spec[group] = group_reward_spec
full_action_spec_unbatched[group] = group_action_spec
full_observation_spec_unbatched[group] = group_observation_spec
full_reward_spec_unbatched[group] = group_reward_spec
if group_info_spec is not None:
self.unbatched_observation_spec[(group, "info")] = group_info_spec
full_observation_spec_unbatched[(group, "info")] = group_info_spec
group_het_specs = isinstance(
group_observation_spec, StackedComposite
) or isinstance(group_action_spec, StackedComposite)
self.het_specs_map[group] = group_het_specs
self.het_specs = self.het_specs or group_het_specs

self.unbatched_done_spec = Composite(
full_done_spec_unbatched = Composite(
{
"done": Categorical(
n=2,
Expand All @@ -363,18 +364,42 @@ def _make_specs(
},
)

self.action_spec = self.unbatched_action_spec.expand(
*self.batch_size, *self.unbatched_action_spec.shape
self.full_action_spec_unbatched = full_action_spec_unbatched
self.full_observation_spec_unbatched = full_observation_spec_unbatched
self.full_reward_spec_unbatched = full_reward_spec_unbatched
self.full_done_spec_unbatched = full_done_spec_unbatched

@property
def unbatched_action_spec(self):
warnings.warn(
"unbatched_action_spec is deprecated and will be removed in v0.9. "
"Please use full_action_spec_unbatched instead."
)
self.observation_spec = self.unbatched_observation_spec.expand(
*self.batch_size, *self.unbatched_observation_spec.shape
return self.full_action_spec_unbatched

@property
def unbatched_observation_spec(self):
warnings.warn(
"unbatched_observation_spec is deprecated and will be removed in v0.9. "
"Please use full_observation_spec_unbatched instead."
)
self.reward_spec = self.unbatched_reward_spec.expand(
*self.batch_size, *self.unbatched_reward_spec.shape
return self.full_observation_spec_unbatched

@property
def unbatched_reward_spec(self):
warnings.warn(
"unbatched_reward_spec is deprecated and will be removed in v0.9. "
"Please use full_reward_spec_unbatched instead."
)
self.done_spec = self.unbatched_done_spec.expand(
*self.batch_size, *self.unbatched_done_spec.shape
return self.full_reward_spec_unbatched

@property
def unbatched_done_spec(self):
warnings.warn(
"unbatched_done_spec is deprecated and will be removed in v0.9. "
"Please use full_done_spec_unbatched instead."
)
return self.full_done_spec_unbatched

def _make_unbatched_group_specs(self, group: str):
# Agent specs
Expand Down Expand Up @@ -618,7 +643,9 @@ def read_reward(self, rewards):

def read_action(self, action, group: str = "agents"):
if not self.continuous_actions and not self.categorical_actions:
action = self.unbatched_action_spec[group, "action"].to_categorical(action)
action = self.full_action_spec_unbatched[group, "action"].to_categorical(
action
)
agent_actions = action.unbind(dim=1)
return agent_actions

Expand Down

0 comments on commit a126a6f

Please sign in to comment.