Skip to content

Commit 022147c

Browse files
author
Jan Beitner
committed
Remove syncing for MultiLoss
1 parent 1caf6fe commit 022147c

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

pytorch_forecasting/metrics.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Implementation of metrics for (mulit-horizon) timeseries forecasting.
33
"""
4-
from typing import Callable, Dict, List, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
55
import warnings
66

77
import scipy.stats
@@ -11,7 +11,6 @@
1111
import torch.nn.functional as F
1212
from torch.nn.utils import rnn
1313
from torchmetrics import Metric as LightningMetric
14-
from torchmetrics.metric import CompositionalMetric
1514

1615
from pytorch_forecasting.utils import create_mask, unpack_sequence, unsqueeze_like
1716

@@ -152,6 +151,9 @@ def _sync_dist(self, dist_sync_fn=None, process_group=None) -> None:
152151
# No syncing required here. syncing will be done in metric_a and metric_b
153152
pass
154153

154+
def _wrap_compute(self, compute: Callable) -> Callable:
155+
return compute
156+
155157
def reset(self) -> None:
156158
self.torchmetric.reset()
157159

@@ -340,6 +342,10 @@ def forward(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs):
340342
def _wrap_compute(self, compute: Callable) -> Callable:
341343
return compute
342344

345+
def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None) -> None:
346+
# No syncing required here. syncing will be done in metrics
347+
pass
348+
343349
def reset(self) -> None:
344350
for metric in self.metrics:
345351
metric.reset()

0 commit comments

Comments
 (0)