Skip to content

Commit

Permalink
[Version] 1.3.0 (#140)
Browse files Browse the repository at this point in the history
* Deps

* composite

* Unbounded

* Tests

* amend
  • Loading branch information
matteobettini authored Oct 24, 2024
1 parent 03d7d8a commit 277f05d
Show file tree
Hide file tree
Showing 23 changed files with 166 additions and 234 deletions.
22 changes: 8 additions & 14 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import (
CompositeSpec,
DiscreteTensorSpec,
Categorical,
Composite,
LazyTensorStorage,
OneHotDiscreteTensorSpec,
OneHot,
ReplayBuffer,
TensorDictReplayBuffer,
)
Expand Down Expand Up @@ -122,9 +122,7 @@ def get_loss_and_updater(self, group: str) -> Tuple[LossModule, TargetNetUpdater
"""
if group not in self._losses_and_updaters.keys():
action_space = self.action_spec[group, "action"]
continuous = not isinstance(
action_space, (DiscreteTensorSpec, OneHotDiscreteTensorSpec)
)
continuous = not isinstance(action_space, (Categorical, OneHot))
loss, use_target = self._get_loss(
group=group,
policy_for_loss=self.get_policy_for_loss(group),
Expand Down Expand Up @@ -193,9 +191,7 @@ def get_policy_for_loss(self, group: str) -> TensorDictModule:
"""
if group not in self._policies_for_loss.keys():
action_space = self.action_spec[group, "action"]
continuous = not isinstance(
action_space, (DiscreteTensorSpec, OneHotDiscreteTensorSpec)
)
continuous = not isinstance(action_space, (Categorical, OneHot))
self._policies_for_loss.update(
{
group: self._get_policy_for_loss(
Expand All @@ -220,9 +216,7 @@ def get_policy_for_collection(self) -> TensorDictSequential:
if group not in self._policies_for_collection.keys():
policy_for_loss = self.get_policy_for_loss(group)
action_space = self.action_spec[group, "action"]
continuous = not isinstance(
action_space, (DiscreteTensorSpec, OneHotDiscreteTensorSpec)
)
continuous = not isinstance(action_space, (Categorical, OneHot))
policy_for_collection = self._get_policy_for_collection(
policy_for_loss,
group,
Expand Down Expand Up @@ -263,9 +257,9 @@ def model_fun():
env = env_fun()

spec_actor = self.model_config.get_model_state_spec()
spec_actor = CompositeSpec(
spec_actor = Composite(
{
group: CompositeSpec(
group: Composite(
spec_actor.expand(len(agents), *spec_actor.shape),
shape=(len(agents),),
)
Expand Down
22 changes: 9 additions & 13 deletions benchmarl/algorithms/iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.data import Composite, Unbounded
from torchrl.modules import (
AdditiveGaussianWrapper,
Delta,
Expand Down Expand Up @@ -94,13 +94,13 @@ def _get_policy_for_loss(
if continuous:
n_agents = len(self.group_map[group])
logits_shape = list(self.action_spec[group, "action"].shape)
actor_input_spec = CompositeSpec(
actor_input_spec = Composite(
{group: self.observation_spec[group].clone().to(self.device)}
)
actor_output_spec = CompositeSpec(
actor_output_spec = Composite(
{
group: CompositeSpec(
{"param": UnboundedContinuousTensorSpec(shape=logits_shape)},
group: Composite(
{"param": Unbounded(shape=logits_shape)},
shape=(n_agents,),
)
}
Expand Down Expand Up @@ -190,21 +190,17 @@ def get_value_module(self, group: str) -> TensorDictModule:
n_agents = len(self.group_map[group])
modules = []

critic_input_spec = CompositeSpec(
critic_input_spec = Composite(
{
group: self.observation_spec[group]
.clone()
.update(self.action_spec[group])
}
)
critic_output_spec = CompositeSpec(
critic_output_spec = Composite(
{
group: CompositeSpec(
{
"state_action_value": UnboundedContinuousTensorSpec(
shape=(n_agents, 1)
)
},
group: Composite(
{"state_action_value": Unbounded(shape=(n_agents, 1))},
shape=(n_agents,),
)
}
Expand Down
18 changes: 9 additions & 9 deletions benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn.distributions import NormalParamExtractor
from torch.distributions import Categorical
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.data import Composite, Unbounded
from torchrl.modules import IndependentNormal, ProbabilisticActor, TanhNormal
from torchrl.modules.distributions import MaskedCategorical
from torchrl.objectives import ClipPPOLoss, LossModule, ValueEstimators
Expand Down Expand Up @@ -118,14 +118,14 @@ def _get_policy_for_loss(
self.action_spec[group, "action"].space.n,
]

actor_input_spec = CompositeSpec(
actor_input_spec = Composite(
{group: self.observation_spec[group].clone().to(self.device)}
)

actor_output_spec = CompositeSpec(
actor_output_spec = Composite(
{
group: CompositeSpec(
{"logits": UnboundedContinuousTensorSpec(shape=logits_shape)},
group: Composite(
{"logits": Unbounded(shape=logits_shape)},
shape=(n_agents,),
)
}
Expand Down Expand Up @@ -270,13 +270,13 @@ def process_loss_vals(
def get_critic(self, group: str) -> TensorDictModule:
n_agents = len(self.group_map[group])

critic_input_spec = CompositeSpec(
critic_input_spec = Composite(
{group: self.observation_spec[group].clone().to(self.device)}
)
critic_output_spec = CompositeSpec(
critic_output_spec = Composite(
{
group: CompositeSpec(
{"state_value": UnboundedContinuousTensorSpec(shape=(n_agents, 1))},
group: Composite(
{"state_value": Unbounded(shape=(n_agents, 1))},
shape=(n_agents,),
)
}
Expand Down
10 changes: 5 additions & 5 deletions benchmarl/algorithms/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.data import Composite, Unbounded
from torchrl.modules import EGreedyModule, QValueModule
from torchrl.objectives import DQNLoss, LossModule, ValueEstimators

Expand Down Expand Up @@ -77,14 +77,14 @@ def _get_policy_for_loss(
self.action_spec[group, "action"].space.n,
]

actor_input_spec = CompositeSpec(
actor_input_spec = Composite(
{group: self.observation_spec[group].clone().to(self.device)}
)

actor_output_spec = CompositeSpec(
actor_output_spec = Composite(
{
group: CompositeSpec(
{"action_value": UnboundedContinuousTensorSpec(shape=logits_shape)},
group: Composite(
{"action_value": Unbounded(shape=logits_shape)},
shape=(n_agents,),
)
}
Expand Down
34 changes: 13 additions & 21 deletions benchmarl/algorithms/isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tensordict import TensorDictBase
from tensordict.nn import NormalParamExtractor, TensorDictModule, TensorDictSequential
from torch.distributions import Categorical
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.data import Composite, Unbounded
from torchrl.modules import (
IndependentNormal,
MaskedCategorical,
Expand Down Expand Up @@ -164,14 +164,14 @@ def _get_policy_for_loss(
self.action_spec[group, "action"].space.n,
]

actor_input_spec = CompositeSpec(
actor_input_spec = Composite(
{group: self.observation_spec[group].clone().to(self.device)}
)

actor_output_spec = CompositeSpec(
actor_output_spec = Composite(
{
group: CompositeSpec(
{"logits": UnboundedContinuousTensorSpec(shape=logits_shape)},
group: Composite(
{"logits": Unbounded(shape=logits_shape)},
shape=(n_agents,),
)
}
Expand Down Expand Up @@ -283,18 +283,14 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule:
n_agents = len(self.group_map[group])
n_actions = self.action_spec[group, "action"].space.n

critic_input_spec = CompositeSpec(
critic_input_spec = Composite(
{group: self.observation_spec[group].clone().to(self.device)}
)

critic_output_spec = CompositeSpec(
critic_output_spec = Composite(
{
group: CompositeSpec(
{
"action_value": UnboundedContinuousTensorSpec(
shape=(n_agents, n_actions)
)
},
group: Composite(
{"action_value": Unbounded(shape=(n_agents, n_actions))},
shape=(n_agents,),
)
}
Expand All @@ -317,22 +313,18 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
n_agents = len(self.group_map[group])
modules = []

critic_input_spec = CompositeSpec(
critic_input_spec = Composite(
{
group: self.observation_spec[group]
.clone()
.update(self.action_spec[group])
}
)

critic_output_spec = CompositeSpec(
critic_output_spec = Composite(
{
group: CompositeSpec(
{
"state_action_value": UnboundedContinuousTensorSpec(
shape=(n_agents, 1)
)
},
group: Composite(
{"state_action_value": Unbounded(shape=(n_agents, 1))},
shape=(n_agents,),
)
}
Expand Down
28 changes: 12 additions & 16 deletions benchmarl/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.data import Composite, Unbounded
from torchrl.modules import (
AdditiveGaussianWrapper,
Delta,
Expand Down Expand Up @@ -94,13 +94,13 @@ def _get_policy_for_loss(
if continuous:
n_agents = len(self.group_map[group])
logits_shape = list(self.action_spec[group, "action"].shape)
actor_input_spec = CompositeSpec(
actor_input_spec = Composite(
{group: self.observation_spec[group].clone().to(self.device)}
)
actor_output_spec = CompositeSpec(
actor_output_spec = Composite(
{
group: CompositeSpec(
{"param": UnboundedContinuousTensorSpec(shape=logits_shape)},
group: Composite(
{"param": Unbounded(shape=logits_shape)},
shape=(n_agents,),
)
}
Expand Down Expand Up @@ -191,18 +191,14 @@ def get_value_module(self, group: str) -> TensorDictModule:
modules = []

if self.share_param_critic:
critic_output_spec = CompositeSpec(
{"state_action_value": UnboundedContinuousTensorSpec(shape=(1,))}
critic_output_spec = Composite(
{"state_action_value": Unbounded(shape=(1,))}
)
else:
critic_output_spec = CompositeSpec(
critic_output_spec = Composite(
{
group: CompositeSpec(
{
"state_action_value": UnboundedContinuousTensorSpec(
shape=(n_agents, 1)
)
},
group: Composite(
{"state_action_value": Unbounded(shape=(n_agents, 1))},
shape=(n_agents,),
)
}
Expand All @@ -219,7 +215,7 @@ def get_value_module(self, group: str) -> TensorDictModule:

critic_input_spec = self.state_spec.clone().update(
{
"global_action": UnboundedContinuousTensorSpec(
"global_action": Unbounded(
shape=(self.action_spec[group, "action"].shape[-1] * n_agents,)
)
}
Expand All @@ -240,7 +236,7 @@ def get_value_module(self, group: str) -> TensorDictModule:
)

else:
critic_input_spec = CompositeSpec(
critic_input_spec = Composite(
{
group: self.observation_spec[group]
.clone()
Expand Down
26 changes: 10 additions & 16 deletions benchmarl/algorithms/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn.distributions import NormalParamExtractor
from torch.distributions import Categorical
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.data import Composite, Unbounded
from torchrl.modules import (
IndependentNormal,
MaskedCategorical,
Expand Down Expand Up @@ -122,14 +122,14 @@ def _get_policy_for_loss(
self.action_spec[group, "action"].space.n,
]

actor_input_spec = CompositeSpec(
actor_input_spec = Composite(
{group: self.observation_spec[group].clone().to(self.device)}
)

actor_output_spec = CompositeSpec(
actor_output_spec = Composite(
{
group: CompositeSpec(
{"logits": UnboundedContinuousTensorSpec(shape=logits_shape)},
group: Composite(
{"logits": Unbounded(shape=logits_shape)},
shape=(n_agents,),
)
}
Expand Down Expand Up @@ -274,18 +274,12 @@ def process_loss_vals(
def get_critic(self, group: str) -> TensorDictModule:
n_agents = len(self.group_map[group])
if self.share_param_critic:
critic_output_spec = CompositeSpec(
{"state_value": UnboundedContinuousTensorSpec(shape=(1,))}
)
critic_output_spec = Composite({"state_value": Unbounded(shape=(1,))})
else:
critic_output_spec = CompositeSpec(
critic_output_spec = Composite(
{
group: CompositeSpec(
{
"state_value": UnboundedContinuousTensorSpec(
shape=(n_agents, 1)
)
},
group: Composite(
{"state_value": Unbounded(shape=(n_agents, 1))},
shape=(n_agents,),
)
}
Expand All @@ -305,7 +299,7 @@ def get_critic(self, group: str) -> TensorDictModule:
)

else:
critic_input_spec = CompositeSpec(
critic_input_spec = Composite(
{group: self.observation_spec[group].clone().to(self.device)}
)
value_module = self.critic_model_config.get_model(
Expand Down
Loading

0 comments on commit 277f05d

Please sign in to comment.