diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py index e429b407..87e92bbd 100644 --- a/benchmarl/algorithms/common.py +++ b/benchmarl/algorithms/common.py @@ -12,10 +12,10 @@ from tensordict import TensorDictBase from tensordict.nn import TensorDictModule, TensorDictSequential from torchrl.data import ( + Categorical, Composite, - DiscreteTensorSpec, LazyTensorStorage, - OneHotDiscreteTensorSpec, + OneHot, ReplayBuffer, TensorDictReplayBuffer, ) @@ -122,9 +122,7 @@ def get_loss_and_updater(self, group: str) -> Tuple[LossModule, TargetNetUpdater """ if group not in self._losses_and_updaters.keys(): action_space = self.action_spec[group, "action"] - continuous = not isinstance( - action_space, (DiscreteTensorSpec, OneHotDiscreteTensorSpec) - ) + continuous = not isinstance(action_space, (Categorical, OneHot)) loss, use_target = self._get_loss( group=group, policy_for_loss=self.get_policy_for_loss(group), @@ -193,9 +191,7 @@ def get_policy_for_loss(self, group: str) -> TensorDictModule: """ if group not in self._policies_for_loss.keys(): action_space = self.action_spec[group, "action"] - continuous = not isinstance( - action_space, (DiscreteTensorSpec, OneHotDiscreteTensorSpec) - ) + continuous = not isinstance(action_space, (Categorical, OneHot)) self._policies_for_loss.update( { group: self._get_policy_for_loss( @@ -220,9 +216,7 @@ def get_policy_for_collection(self) -> TensorDictSequential: if group not in self._policies_for_collection.keys(): policy_for_loss = self.get_policy_for_loss(group) action_space = self.action_spec[group, "action"] - continuous = not isinstance( - action_space, (DiscreteTensorSpec, OneHotDiscreteTensorSpec) - ) + continuous = not isinstance(action_space, (Categorical, OneHot)) policy_for_collection = self._get_policy_for_collection( policy_for_loss, group,