diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py index a6e24cf9414..1bcc2dbd10e 100644 --- a/sota-implementations/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -110,7 +110,7 @@ def train(cfg: "DictConfig"): # noqa: F821 if cfg.loss.mixer_type == "qmix": mixer = TensorDictModule( module=QMixer( - state_shape=env.unbatched_observation_spec[ + state_shape=env.observation_spec_unbatched[ "agents", "observation" ].shape, mixing_embed_dim=32, diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 8d2e3387e3c..140fb191cae 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -5,6 +5,7 @@ from __future__ import annotations import importlib.util +import warnings from typing import Dict, List, Optional, Union @@ -328,9 +329,9 @@ def _make_specs( self.group_map = self.group_map.get_group_map(self.agent_names) check_marl_grouping(self.group_map, self.agent_names) - self.unbatched_action_spec = Composite(device=self.device) - self.unbatched_observation_spec = Composite(device=self.device) - self.unbatched_reward_spec = Composite(device=self.device) + full_action_spec_unbatched = Composite(device=self.device) + full_observation_spec_unbatched = Composite(device=self.device) + full_reward_spec_unbatched = Composite(device=self.device) self.het_specs = False self.het_specs_map = {} @@ -341,18 +342,18 @@ def _make_specs( group_reward_spec, group_info_spec, ) = self._make_unbatched_group_specs(group) - self.unbatched_action_spec[group] = group_action_spec - self.unbatched_observation_spec[group] = group_observation_spec - self.unbatched_reward_spec[group] = group_reward_spec + full_action_spec_unbatched[group] = group_action_spec + full_observation_spec_unbatched[group] = group_observation_spec + full_reward_spec_unbatched[group] = group_reward_spec if group_info_spec is not None: - self.unbatched_observation_spec[(group, "info")] = group_info_spec + full_observation_spec_unbatched[(group, "info")] = group_info_spec group_het_specs = isinstance( group_observation_spec, StackedComposite ) or isinstance(group_action_spec, StackedComposite) self.het_specs_map[group] = group_het_specs self.het_specs = self.het_specs or group_het_specs - self.unbatched_done_spec = Composite( + full_done_spec_unbatched = Composite( { "done": Categorical( n=2, @@ -363,18 +364,42 @@ def _make_specs( }, ) - self.action_spec = self.unbatched_action_spec.expand( - *self.batch_size, *self.unbatched_action_spec.shape + self.full_action_spec_unbatched = full_action_spec_unbatched + self.full_observation_spec_unbatched = full_observation_spec_unbatched + self.full_reward_spec_unbatched = full_reward_spec_unbatched + self.full_done_spec_unbatched = full_done_spec_unbatched + + @property + def unbatched_action_spec(self): + warnings.warn( + "unbatched_action_spec is deprecated and will be removed in v0.9. " + "Please use full_action_spec_unbatched instead." ) - self.observation_spec = self.unbatched_observation_spec.expand( - *self.batch_size, *self.unbatched_observation_spec.shape + return self.full_action_spec_unbatched + + @property + def unbatched_observation_spec(self): + warnings.warn( + "unbatched_observation_spec is deprecated and will be removed in v0.9. " + "Please use full_observation_spec_unbatched instead." ) - self.reward_spec = self.unbatched_reward_spec.expand( - *self.batch_size, *self.unbatched_reward_spec.shape + return self.full_observation_spec_unbatched + + @property + def unbatched_reward_spec(self): + warnings.warn( + "unbatched_reward_spec is deprecated and will be removed in v0.9. " + "Please use full_reward_spec_unbatched instead." ) - self.done_spec = self.unbatched_done_spec.expand( - *self.batch_size, *self.unbatched_done_spec.shape + return self.full_reward_spec_unbatched + + @property + def unbatched_done_spec(self): + warnings.warn( + "unbatched_done_spec is deprecated and will be removed in v0.9. " + "Please use full_done_spec_unbatched instead." ) + return self.full_done_spec_unbatched def _make_unbatched_group_specs(self, group: str): # Agent specs @@ -618,7 +643,9 @@ def read_reward(self, rewards): def read_action(self, action, group: str = "agents"): if not self.continuous_actions and not self.categorical_actions: - action = self.unbatched_action_spec[group, "action"].to_categorical(action) + action = self.full_action_spec_unbatched[group, "action"].to_categorical( + action + ) agent_actions = action.unbind(dim=1) return agent_actions