Skip to content

Commit 4fba73d

Browse files
authored
Merge pull request #60 from jdb78/feature/mase-metric
Add MASE metric
2 parents d1a217b + 653f9b8 commit 4fba73d

File tree

7 files changed

+184
-25
lines changed

7 files changed

+184
-25
lines changed

pytorch_forecasting/data/encoders.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@ def transform(self, y: Iterable) -> Union[torch.Tensor, np.ndarray]:
9494
if self.warn:
9595
cond = ~np.isin(y, self.classes_)
9696
if cond.any():
97-
warnings.warn(f"Found {y[cond].nunique()} unknown classes which were set to NaN", UserWarning)
97+
warnings.warn(
98+
f"Found {np.unique(np.asarray(y)[cond]).size} unknown classes which were set to NaN",
99+
UserWarning,
100+
)
98101

99102
encoded = [self.classes_.get(v, 0) for v in y]
100103

pytorch_forecasting/metrics.py

Lines changed: 122 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch.nn.functional as F
1212
from torch.nn.utils import rnn
1313

14-
from pytorch_forecasting.utils import integer_histogram
14+
from pytorch_forecasting.utils import integer_histogram, unpack_sequence
1515

1616

1717
class Metric(TensorMetric):
@@ -281,13 +281,7 @@ def forward(self, y_pred: Dict[str, torch.Tensor], target: Union[torch.Tensor, r
281281
Returns:
282282
torch.Tensor: loss as a single number for backpropagation
283283
"""
284-
# unpack
285-
if isinstance(target, rnn.PackedSequence):
286-
target, lengths = rnn.pad_packed_sequence(target, batch_first=True)
287-
# batch sizes reside on the CPU by default -> we need to bring them to GPU
288-
lengths = lengths.to(target.device)
289-
else:
290-
lengths = torch.ones(target.size(0), device=target.device, dtype=torch.long) * target.size(1)
284+
target, lengths = unpack_sequence(target)
291285
assert not target.requires_grad
292286

293287
# calculate loss with "none" reduction
@@ -302,24 +296,43 @@ def forward(self, y_pred: Dict[str, torch.Tensor], target: Union[torch.Tensor, r
302296
if weight is not None:
303297
losses = losses * weight.unsqueeze(-1)
304298

299+
loss = self.reduce_loss(losses, lengths=lengths, reduction=self.reduction)
300+
return loss
301+
302+
def reduce_loss(self, losses: torch.Tensor, lengths: torch.Tensor, reduction: str = None) -> torch.Tensor:
303+
"""
304+
Reduce loss.
305+
306+
Args:
307+
losses (torch.Tensor): tensor of losses. first dimenion are samples, second timesteps
308+
lengths (torch.Tensor): tensor of lengths
309+
reduction (str, optional): type of reduction. Defaults to ``self.reduction``.
310+
311+
Returns:
312+
torch.Tensor: reduced loss
313+
"""
314+
if reduction is None:
315+
reduction = self.reduction
305316
# mask loss
306-
mask = torch.arange(target.size(1), device=target.device).unsqueeze(0) >= lengths.unsqueeze(-1)
317+
mask = torch.arange(losses.size(1), device=losses.device).unsqueeze(0) >= lengths.unsqueeze(-1)
307318
if losses.ndim > 2:
308319
mask = mask.unsqueeze(-1)
309320
dim_normalizer = losses.size(-1)
310321
else:
311322
dim_normalizer = 1.0
312323
# reduce to one number
313-
if self.reduction == "none":
324+
if reduction == "none":
314325
loss = losses.masked_fill(mask, float("nan"))
315326
else:
316-
if self.reduction == "mean":
327+
if reduction == "mean":
317328
losses = losses.masked_fill(mask, 0.0)
318329
loss = losses.sum() / lengths.sum() / dim_normalizer
319-
elif self.reduction == "sqrt-mean":
330+
elif reduction == "sqrt-mean":
320331
losses = losses.masked_fill(mask, 0.0)
321332
loss = losses.sum() / lengths.sum() / dim_normalizer
322333
loss = loss.sqrt()
334+
else:
335+
raise ValueError(f"reduction {reduction} unknown")
323336
assert not torch.isnan(loss), (
324337
"Loss should not be nan - i.e. something went wrong "
325338
"in calculating the loss (e.g. log of a negative number)"
@@ -449,3 +462,100 @@ def __init__(self, name: str = "RMSE", reduction="sqrt-mean", *args, **kwargs):
449462
def loss(self, y_pred: Dict[str, torch.Tensor], target):
450463
loss = torch.pow(self.to_prediction(y_pred) - target, 2)
451464
return loss
465+
466+
467+
class MASE(MultiHorizonMetric):
468+
"""
469+
Mean absolute scaled error
470+
471+
Defined as ``(y_pred - target).abs() / all_targets[:, :-1] - all_targets[:, 1:]).mean(1)``.
472+
``all_targets`` are here the concatenated encoder and decoder targets
473+
"""
474+
475+
def __init__(self, name: str = "MASE", *args, **kwargs):
476+
super().__init__(name, *args, **kwargs)
477+
478+
def forward(
479+
self,
480+
y_pred: Dict[str, torch.Tensor],
481+
target: Union[torch.Tensor, rnn.PackedSequence],
482+
encoder_target: Union[torch.Tensor, rnn.PackedSequence],
483+
encoder_lengths: torch.Tensor = None,
484+
) -> torch.Tensor:
485+
"""
486+
Forward method of metric that handles masking of values.
487+
488+
Args:
489+
y_pred (Dict[str, torch.Tensor]): network output
490+
target (Union[torch.Tensor, rnn.PackedSequence]): actual values
491+
encoder_target (Union[torch.Tensor, rnn.PackedSequence]): historic actual values
492+
encoder_lengths (torch.Tensor): optional encoder lengths, not necessary if encoder_target
493+
is rnn.PackedSequence. Assumed encoder_target is torch.Tensor
494+
495+
Returns:
496+
torch.Tensor: loss as a single number for backpropagation
497+
"""
498+
target, lengths = unpack_sequence(target)
499+
if encoder_lengths is None:
500+
encoder_target, encoder_lengths = unpack_sequence(target)
501+
else:
502+
assert isinstance(encoder_target, torch.Tensor)
503+
assert not target.requires_grad
504+
505+
# calculate loss with "none" reduction
506+
if target.ndim == 3:
507+
weight = target[..., 1]
508+
target = target[..., 0]
509+
else:
510+
weight = None
511+
512+
scaling = self.calculate_scaling(target, lengths, encoder_target, encoder_lengths)
513+
losses = self.loss(y_pred, target, scaling)
514+
# weight samples
515+
if weight is not None:
516+
losses = losses * weight.unsqueeze(-1)
517+
518+
loss = self.reduce_loss(losses, lengths=lengths, reduction=self.reduction)
519+
return loss
520+
521+
def loss(self, y_pred, target, scaling):
522+
return (y_pred - target).abs() / scaling.unsqueeze(-1)
523+
524+
def calculate_scaling(self, target, lengths, encoder_target, encoder_lengths):
525+
# calcualte mean(abs(diff(targets)))
526+
eps = 1e-6
527+
batch_size = target.size(0)
528+
total_lengths = lengths + encoder_lengths
529+
assert (total_lengths > 1).all(), "Need at least 2 target values to be able to calculate MASE"
530+
max_length = target.size(1) + encoder_target.size(1)
531+
if (total_lengths != max_length).any(): # if decoder or encoder targets have sequences of different lengths
532+
targets = torch.cat(
533+
[
534+
encoder_target,
535+
torch.zeros(batch_size, target.size(1), device=target.device, dtype=encoder_target.dtype),
536+
],
537+
dim=1,
538+
)
539+
target_index = torch.arange(target.size(1), device=target.device, dtype=torch.long).unsqueeze(0).expand(
540+
batch_size, -1
541+
) + encoder_lengths.unsqueeze(-1)
542+
targets.scatter_(dim=1, src=target, index=target_index)
543+
else:
544+
targets = torch.cat([encoder_target, target], dim=1)
545+
546+
# take absolute difference
547+
diffs = (targets[:, :-1] - targets[:, 1:]).abs()
548+
549+
# set last difference to 0
550+
not_maximum_length = total_lengths != max_length
551+
zero_correction_indices = total_lengths[not_maximum_length] - 1
552+
if len(zero_correction_indices) > 0:
553+
diffs[
554+
torch.arange(batch_size, dtype=torch.long, device=diffs.device)[not_maximum_length],
555+
zero_correction_indices,
556+
] = 0.0
557+
558+
# calculate mean over differences
559+
scaling = diffs.sum(1) / total_lengths + eps
560+
561+
return scaling

pytorch_forecasting/models/base_model.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pytorch_forecasting.data import TimeSeriesDataSet
2222
from pytorch_forecasting.data.encoders import GroupNormalizer
23-
from pytorch_forecasting.metrics import SMAPE
23+
from pytorch_forecasting.metrics import MASE, SMAPE
2424
from pytorch_forecasting.optim import Ranger
2525
from pytorch_forecasting.utils import groupby_apply
2626

@@ -182,23 +182,40 @@ def step(self, x: Dict[str, torch.Tensor], y: torch.Tensor, batch_idx: int, labe
182182
# multiply monotinicity loss by large number to ensure relevance and take to the power of 2
183183
# for smoothness of loss function
184184
monotinicity_loss = 10 * torch.pow(monotinicity_loss, 2)
185+
if isinstance(self.loss, MASE):
186+
loss = self.loss(
187+
prediction, y, encoder_target=x["encoder_target"], encoder_lengths=x["encoder_lengths"]
188+
)
189+
else:
190+
loss = self.loss(prediction, y)
185191

186-
loss = self.loss(prediction, y) * (1 + monotinicity_loss)
192+
loss = loss * (1 + monotinicity_loss)
187193
else:
188194
out = self(x)
189195
out["prediction"] = self.transform_output(out)
190196

191197
# calculate loss
192198
prediction = out["prediction"]
193-
loss = self.loss(prediction, y)
199+
if isinstance(self.loss, MASE):
200+
loss = self.loss(
201+
prediction, y, encoder_target=x["encoder_target"], encoder_lengths=x["encoder_lengths"]
202+
)
203+
else:
204+
loss = self.loss(prediction, y)
194205

195206
# log loss
196207
tensorboard_logs = {f"{label}_loss": loss}
197208
# logging losses
198209
y_hat_detached = prediction.detach()
199210
y_hat_point_detached = self.loss.to_prediction(y_hat_detached)
200211
for metric in self.logging_metrics:
201-
tensorboard_logs[f"{label}_{metric.name}"] = metric(y_hat_point_detached, y)
212+
if isinstance(metric, MASE):
213+
loss_value = metric(
214+
y_hat_point_detached, y, encoder_target=x["encoder_target"], encoder_lengths=x["encoder_lengths"]
215+
)
216+
else:
217+
loss_value = metric(y_hat_point_detached, y)
218+
tensorboard_logs[f"{label}_{metric.name}"] = loss_value
202219
log = {f"{label}_loss": loss, "log": tensorboard_logs, "n_samples": x["decoder_lengths"].size(0)}
203220
if label == "train":
204221
log["loss"] = loss
@@ -354,7 +371,11 @@ def plot_prediction(
354371
else:
355372
loss = add_loss_to_title
356373
loss.quantiles = self.loss.quantiles
357-
ax.set_title(f"Loss {loss(y_hat[None], y[-n_pred:][None]):.3g}")
374+
if isinstance(loss, MASE):
375+
loss_value = loss(y_hat[None], y[-n_pred:][None], y[:n_pred][None])
376+
else:
377+
loss_value = loss(y_hat[None], y[-n_pred:][None])
378+
ax.set_title(f"Loss {loss_value:.3g}")
358379
ax.set_xlabel("Time index")
359380
fig.legend()
360381
return fig

pytorch_forecasting/models/nbeats/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch import nn
99

1010
from pytorch_forecasting.data import TimeSeriesDataSet
11-
from pytorch_forecasting.metrics import MAE, MAPE, RMSE, SMAPE
11+
from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE
1212
from pytorch_forecasting.models.base_model import BaseModel
1313
from pytorch_forecasting.models.nbeats.sub_modules import NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock
1414

@@ -78,7 +78,7 @@ def __init__(
7878
reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10
7979
"""
8080
self.save_hyperparameters()
81-
self.logging_metrics = [SMAPE(), MAE(), RMSE(), MAPE()]
81+
self.logging_metrics = [SMAPE(), MAE(), RMSE(), MAPE(), MASE()]
8282
super().__init__(**kwargs)
8383
self.loss = loss
8484

@@ -218,7 +218,10 @@ def step(self, x, y, batch_idx, label) -> Dict[str, torch.Tensor]:
218218
)
219219
backcast_weight = backcast_weight / (backcast_weight + 1) # normalize
220220
forecast_weight = 1 - backcast_weight
221-
backcast_loss = self.loss(backcast, x["encoder_target"]) * backcast_weight
221+
if isinstance(self.loss, MASE):
222+
backcast_loss = self.loss(backcast, x["encoder_target"], x["decoder_target"]) * backcast_weight
223+
else:
224+
backcast_loss = self.loss(backcast, x["encoder_target"]) * backcast_weight
222225
if label == "train":
223226
log["loss"] = log["loss"] * forecast_weight + backcast_loss
224227
log["log"]["train_loss"] = log["log"]["train_loss"] * forecast_weight + backcast_loss

pytorch_forecasting/models/temporal_fusion_transformer/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.nn.utils import rnn
1111

1212
from pytorch_forecasting.data import TimeSeriesDataSet
13-
from pytorch_forecasting.metrics import MAE, MAPE, RMSE, SMAPE, MultiHorizonMetric, QuantileLoss
13+
from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, QuantileLoss
1414
from pytorch_forecasting.models.base_model import BaseModel, CovariatesMixin
1515
from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import (
1616
AddNorm,
@@ -126,7 +126,7 @@ def __init__(
126126
assert isinstance(loss, MultiHorizonMetric), "Loss has to of class `MultiHorizonMetric`"
127127
self.loss = loss
128128
self.output_transformer = output_transformer
129-
self.logging_metrics = [SMAPE(), MAE(), RMSE(), MAPE()]
129+
self.logging_metrics = [SMAPE(), MAE(), RMSE(), MAPE(), MASE()]
130130

131131
# processing inputs
132132
# embeddings

pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,15 +351,15 @@ def forward(self, x: Dict[str, torch.Tensor], context: torch.Tensor = None):
351351
variable_embedding = self.prescalers[name](variable_embedding)
352352
weight_inputs.append(variable_embedding)
353353
var_outputs.append(self.single_variable_grns[name](variable_embedding))
354-
var_outputs = torch.stack(var_outputs, axis=-1)
354+
var_outputs = torch.stack(var_outputs, dim=-1)
355355

356356
# calculate variable weights
357357
flat_embedding = torch.cat(weight_inputs, dim=-1)
358358
sparse_weights = self.flattened_grn(flat_embedding, context)
359359
sparse_weights = self.softmax(sparse_weights).unsqueeze(-2)
360360

361361
outputs = var_outputs * sparse_weights
362-
outputs = outputs.sum(axis=-1)
362+
outputs = outputs.sum(dim=-1)
363363
else: # for one input, do not perform variable selection but just encoding
364364
name = next(iter(self.single_variable_grns.keys()))
365365
variable_embedding = x[name]

pytorch_forecasting/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import Callable, Tuple, Union
88

99
import torch
10+
from torch.nn.utils import rnn
11+
from torch.tensor import Tensor
1012

1113

1214
def integer_histogram(
@@ -179,3 +181,23 @@ def autocorrelation(input, dim=0):
179181
autocorr = autocorr / torch.tensor(range(N, 0, -1), dtype=input.dtype, device=input.device)
180182
autocorr = autocorr / autocorr[..., :1]
181183
return autocorr.transpose(dim, -1)
184+
185+
186+
def unpack_sequence(sequence: Union[torch.Tensor, rnn.PackedSequence]) -> Tuple[torch.Tensor, torch.Tensor]:
187+
"""
188+
Unpack RNN sequence.
189+
190+
Args:
191+
sequence (Union[torch.Tensor, rnn.PackedSequence]): RNN packed sequence or tensor of which
192+
first index are samples and second are timesteps
193+
194+
Returns:
195+
Tuple[torch.Tensor, torch.Tensor]: tuple of unpacked sequence and length of samples
196+
"""
197+
if isinstance(sequence, rnn.PackedSequence):
198+
sequence, lengths = rnn.pad_packed_sequence(sequence, batch_first=True)
199+
# batch sizes reside on the CPU by default -> we need to bring them to GPU
200+
lengths = lengths.to(sequence.device)
201+
else:
202+
lengths = torch.ones(sequence.size(0), device=sequence.device, dtype=torch.long) * sequence.size(1)
203+
return sequence, lengths

0 commit comments

Comments
 (0)