Skip to content

Commit

Permalink
[Fix] Signatures of losses
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 4, 2023
1 parent f46aaeb commit 132e1b8
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 18 deletions.
4 changes: 2 additions & 2 deletions benchmarl/algorithms/iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.modules import AdditiveGaussianWrapper, ProbabilisticActor, TanhDelta
from torchrl.objectives import ClipPPOLoss, DDPGLoss, LossModule, ValueEstimators
from torchrl.objectives import DDPGLoss, LossModule, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
Expand Down Expand Up @@ -74,7 +74,7 @@ def _get_loss(
"Iddpg is not compatible with discrete actions yet"
)

def _get_parameters(self, group: str, loss: ClipPPOLoss) -> Dict[str, Iterable]:
def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
return {
"loss_actor": list(loss.actor_network_params.flatten_keys().values()),
"loss_value": list(loss.value_network_params.flatten_keys().values()),
Expand Down
4 changes: 2 additions & 2 deletions benchmarl/algorithms/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.modules import EGreedyModule, QValueModule
from torchrl.objectives import ClipPPOLoss, DQNLoss, LossModule, ValueEstimators
from torchrl.objectives import DQNLoss, LossModule, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
Expand Down Expand Up @@ -70,7 +70,7 @@ def _get_loss(

return loss_module, True

def _get_parameters(self, group: str, loss: ClipPPOLoss) -> Dict[str, Iterable]:
def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
return {"loss": loss.parameters()}

def _get_policy_for_loss(
Expand Down
4 changes: 2 additions & 2 deletions benchmarl/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.modules import AdditiveGaussianWrapper, ProbabilisticActor, TanhDelta
from torchrl.objectives import ClipPPOLoss, DDPGLoss, LossModule, ValueEstimators
from torchrl.objectives import DDPGLoss, LossModule, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
Expand Down Expand Up @@ -75,7 +75,7 @@ def _get_loss(
"MADDPG is not compatible with discrete actions yet"
)

def _get_parameters(self, group: str, loss: ClipPPOLoss) -> Dict[str, Iterable]:
def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:

return {
"loss_actor": list(loss.actor_network_params.flatten_keys().values()),
Expand Down
10 changes: 2 additions & 8 deletions benchmarl/algorithms/masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,7 @@
)
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.modules import MaskedCategorical, ProbabilisticActor, TanhNormal
from torchrl.objectives import (
ClipPPOLoss,
DiscreteSACLoss,
LossModule,
SACLoss,
ValueEstimators,
)
from torchrl.objectives import DiscreteSACLoss, LossModule, SACLoss, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
Expand Down Expand Up @@ -126,7 +120,7 @@ def _get_loss(

return loss_module, True

def _get_parameters(self, group: str, loss: ClipPPOLoss) -> Dict[str, Iterable]:
def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
return {
"loss_actor": list(loss.actor_network_params.flatten_keys().values()),
"loss_qvalue": list(loss.qvalue_network_params.flatten_keys().values()),
Expand Down
4 changes: 2 additions & 2 deletions benchmarl/algorithms/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.modules import EGreedyModule, QMixer, QValueModule
from torchrl.objectives import ClipPPOLoss, LossModule, QMixerLoss, ValueEstimators
from torchrl.objectives import LossModule, QMixerLoss, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
Expand Down Expand Up @@ -75,7 +75,7 @@ def _get_loss(

return loss_module, True

def _get_parameters(self, group: str, loss: ClipPPOLoss) -> Dict[str, Iterable]:
def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
return {"loss": loss.parameters()}

def _get_policy_for_loss(
Expand Down
4 changes: 2 additions & 2 deletions benchmarl/algorithms/vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.modules import EGreedyModule, QValueModule, VDNMixer
from torchrl.objectives import ClipPPOLoss, LossModule, QMixerLoss, ValueEstimators
from torchrl.objectives import LossModule, QMixerLoss, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.models.common import ModelConfig
Expand Down Expand Up @@ -72,7 +72,7 @@ def _get_loss(

return loss_module, True

def _get_parameters(self, group: str, loss: ClipPPOLoss) -> Dict[str, Iterable]:
def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:

return {
"loss": loss.parameters(),
Expand Down

0 comments on commit 132e1b8

Please sign in to comment.