diff --git a/torch_uncertainty/models/wrappers/deep_ensembles.py b/torch_uncertainty/models/wrappers/deep_ensembles.py index 8affbe73..6023c55d 100644 --- a/torch_uncertainty/models/wrappers/deep_ensembles.py +++ b/torch_uncertainty/models/wrappers/deep_ensembles.py @@ -40,7 +40,7 @@ def __init__( super().__init__(models) self.probabilistic = probabilistic - def forward(self, x: torch.Tensor) -> Distribution: + def forward(self, x: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: r"""Return the logits of the ensemble. Args: