Skip to content

Commit

Permalink
try fixing ddp issues, cannot reproduce locally
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Oct 12, 2024
1 parent a1d44ac commit 6043801
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/torchmetrics/regression/nrmse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor, tensor
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.regression.nrmse import (
Expand Down Expand Up @@ -174,7 +174,7 @@ def __init__(
self.num_outputs = num_outputs

self.add_state("sum_squared_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx=None)
self.add_state("total", default=torch.zeros(num_outputs), dist_reduce_fx=None)
self.add_state("min_val", default=float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None)
self.add_state("max_val", default=-float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None)
self.add_state("mean_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
Expand Down Expand Up @@ -228,6 +228,7 @@ def compute(self) -> Tensor:
denom = torch.sqrt(self.var_val / self.total)
else:
denom = torch.sqrt(self.target_squared)
print(self.sum_squared_error, self.total, denom)
return _normalized_root_mean_squared_error_compute(self.sum_squared_error, self.total, denom)

def plot(
Expand Down

0 comments on commit 6043801

Please sign in to comment.