Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Oct 23, 2024
1 parent 643d6aa commit 23a897d
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import (
Categorical,
Composite,
DiscreteTensorSpec,
LazyTensorStorage,
OneHotDiscreteTensorSpec,
OneHot,
ReplayBuffer,
TensorDictReplayBuffer,
)
Expand Down Expand Up @@ -122,9 +122,7 @@ def get_loss_and_updater(self, group: str) -> Tuple[LossModule, TargetNetUpdater
"""
if group not in self._losses_and_updaters.keys():
action_space = self.action_spec[group, "action"]
continuous = not isinstance(
action_space, (DiscreteTensorSpec, OneHotDiscreteTensorSpec)
)
continuous = not isinstance(action_space, (Categorical, OneHot))
loss, use_target = self._get_loss(
group=group,
policy_for_loss=self.get_policy_for_loss(group),
Expand Down Expand Up @@ -193,9 +191,7 @@ def get_policy_for_loss(self, group: str) -> TensorDictModule:
"""
if group not in self._policies_for_loss.keys():
action_space = self.action_spec[group, "action"]
continuous = not isinstance(
action_space, (DiscreteTensorSpec, OneHotDiscreteTensorSpec)
)
continuous = not isinstance(action_space, (Categorical, OneHot))
self._policies_for_loss.update(
{
group: self._get_policy_for_loss(
Expand All @@ -220,9 +216,7 @@ def get_policy_for_collection(self) -> TensorDictSequential:
if group not in self._policies_for_collection.keys():
policy_for_loss = self.get_policy_for_loss(group)
action_space = self.action_spec[group, "action"]
continuous = not isinstance(
action_space, (DiscreteTensorSpec, OneHotDiscreteTensorSpec)
)
continuous = not isinstance(action_space, (Categorical, OneHot))
policy_for_collection = self._get_policy_for_collection(
policy_for_loss,
group,
Expand Down

0 comments on commit 23a897d

Please sign in to comment.