1111import torch .nn .functional as F
1212from torch .nn .utils import rnn
1313
14- from pytorch_forecasting .utils import integer_histogram
14+ from pytorch_forecasting .utils import integer_histogram , unpack_sequence
1515
1616
1717class Metric (TensorMetric ):
@@ -281,13 +281,7 @@ def forward(self, y_pred: Dict[str, torch.Tensor], target: Union[torch.Tensor, r
281281 Returns:
282282 torch.Tensor: loss as a single number for backpropagation
283283 """
284- # unpack
285- if isinstance (target , rnn .PackedSequence ):
286- target , lengths = rnn .pad_packed_sequence (target , batch_first = True )
287- # batch sizes reside on the CPU by default -> we need to bring them to GPU
288- lengths = lengths .to (target .device )
289- else :
290- lengths = torch .ones (target .size (0 ), device = target .device , dtype = torch .long ) * target .size (1 )
284+ target , lengths = unpack_sequence (target )
291285 assert not target .requires_grad
292286
293287 # calculate loss with "none" reduction
@@ -302,24 +296,43 @@ def forward(self, y_pred: Dict[str, torch.Tensor], target: Union[torch.Tensor, r
302296 if weight is not None :
303297 losses = losses * weight .unsqueeze (- 1 )
304298
299+ loss = self .reduce_loss (losses , lengths = lengths , reduction = self .reduction )
300+ return loss
301+
302+ def reduce_loss (self , losses : torch .Tensor , lengths : torch .Tensor , reduction : str = None ) -> torch .Tensor :
303+ """
304+ Reduce loss.
305+
306+ Args:
307+ losses (torch.Tensor): tensor of losses. first dimenion are samples, second timesteps
308+ lengths (torch.Tensor): tensor of lengths
309+ reduction (str, optional): type of reduction. Defaults to ``self.reduction``.
310+
311+ Returns:
312+ torch.Tensor: reduced loss
313+ """
314+ if reduction is None :
315+ reduction = self .reduction
305316 # mask loss
306- mask = torch .arange (target .size (1 ), device = target .device ).unsqueeze (0 ) >= lengths .unsqueeze (- 1 )
317+ mask = torch .arange (losses .size (1 ), device = losses .device ).unsqueeze (0 ) >= lengths .unsqueeze (- 1 )
307318 if losses .ndim > 2 :
308319 mask = mask .unsqueeze (- 1 )
309320 dim_normalizer = losses .size (- 1 )
310321 else :
311322 dim_normalizer = 1.0
312323 # reduce to one number
313- if self . reduction == "none" :
324+ if reduction == "none" :
314325 loss = losses .masked_fill (mask , float ("nan" ))
315326 else :
316- if self . reduction == "mean" :
327+ if reduction == "mean" :
317328 losses = losses .masked_fill (mask , 0.0 )
318329 loss = losses .sum () / lengths .sum () / dim_normalizer
319- elif self . reduction == "sqrt-mean" :
330+ elif reduction == "sqrt-mean" :
320331 losses = losses .masked_fill (mask , 0.0 )
321332 loss = losses .sum () / lengths .sum () / dim_normalizer
322333 loss = loss .sqrt ()
334+ else :
335+ raise ValueError (f"reduction { reduction } unknown" )
323336 assert not torch .isnan (loss ), (
324337 "Loss should not be nan - i.e. something went wrong "
325338 "in calculating the loss (e.g. log of a negative number)"
@@ -449,3 +462,100 @@ def __init__(self, name: str = "RMSE", reduction="sqrt-mean", *args, **kwargs):
449462 def loss (self , y_pred : Dict [str , torch .Tensor ], target ):
450463 loss = torch .pow (self .to_prediction (y_pred ) - target , 2 )
451464 return loss
465+
466+
467+ class MASE (MultiHorizonMetric ):
468+ """
469+ Mean absolute scaled error
470+
471+ Defined as ``(y_pred - target).abs() / all_targets[:, :-1] - all_targets[:, 1:]).mean(1)``.
472+ ``all_targets`` are here the concatenated encoder and decoder targets
473+ """
474+
475+ def __init__ (self , name : str = "MASE" , * args , ** kwargs ):
476+ super ().__init__ (name , * args , ** kwargs )
477+
478+ def forward (
479+ self ,
480+ y_pred : Dict [str , torch .Tensor ],
481+ target : Union [torch .Tensor , rnn .PackedSequence ],
482+ encoder_target : Union [torch .Tensor , rnn .PackedSequence ],
483+ encoder_lengths : torch .Tensor = None ,
484+ ) -> torch .Tensor :
485+ """
486+ Forward method of metric that handles masking of values.
487+
488+ Args:
489+ y_pred (Dict[str, torch.Tensor]): network output
490+ target (Union[torch.Tensor, rnn.PackedSequence]): actual values
491+ encoder_target (Union[torch.Tensor, rnn.PackedSequence]): historic actual values
492+ encoder_lengths (torch.Tensor): optional encoder lengths, not necessary if encoder_target
493+ is rnn.PackedSequence. Assumed encoder_target is torch.Tensor
494+
495+ Returns:
496+ torch.Tensor: loss as a single number for backpropagation
497+ """
498+ target , lengths = unpack_sequence (target )
499+ if encoder_lengths is None :
500+ encoder_target , encoder_lengths = unpack_sequence (target )
501+ else :
502+ assert isinstance (encoder_target , torch .Tensor )
503+ assert not target .requires_grad
504+
505+ # calculate loss with "none" reduction
506+ if target .ndim == 3 :
507+ weight = target [..., 1 ]
508+ target = target [..., 0 ]
509+ else :
510+ weight = None
511+
512+ scaling = self .calculate_scaling (target , lengths , encoder_target , encoder_lengths )
513+ losses = self .loss (y_pred , target , scaling )
514+ # weight samples
515+ if weight is not None :
516+ losses = losses * weight .unsqueeze (- 1 )
517+
518+ loss = self .reduce_loss (losses , lengths = lengths , reduction = self .reduction )
519+ return loss
520+
521+ def loss (self , y_pred , target , scaling ):
522+ return (y_pred - target ).abs () / scaling .unsqueeze (- 1 )
523+
524+ def calculate_scaling (self , target , lengths , encoder_target , encoder_lengths ):
525+ # calcualte mean(abs(diff(targets)))
526+ eps = 1e-6
527+ batch_size = target .size (0 )
528+ total_lengths = lengths + encoder_lengths
529+ assert (total_lengths > 1 ).all (), "Need at least 2 target values to be able to calculate MASE"
530+ max_length = target .size (1 ) + encoder_target .size (1 )
531+ if (total_lengths != max_length ).any (): # if decoder or encoder targets have sequences of different lengths
532+ targets = torch .cat (
533+ [
534+ encoder_target ,
535+ torch .zeros (batch_size , target .size (1 ), device = target .device , dtype = encoder_target .dtype ),
536+ ],
537+ dim = 1 ,
538+ )
539+ target_index = torch .arange (target .size (1 ), device = target .device , dtype = torch .long ).unsqueeze (0 ).expand (
540+ batch_size , - 1
541+ ) + encoder_lengths .unsqueeze (- 1 )
542+ targets .scatter_ (dim = 1 , src = target , index = target_index )
543+ else :
544+ targets = torch .cat ([encoder_target , target ], dim = 1 )
545+
546+ # take absolute difference
547+ diffs = (targets [:, :- 1 ] - targets [:, 1 :]).abs ()
548+
549+ # set last difference to 0
550+ not_maximum_length = total_lengths != max_length
551+ zero_correction_indices = total_lengths [not_maximum_length ] - 1
552+ if len (zero_correction_indices ) > 0 :
553+ diffs [
554+ torch .arange (batch_size , dtype = torch .long , device = diffs .device )[not_maximum_length ],
555+ zero_correction_indices ,
556+ ] = 0.0
557+
558+ # calculate mean over differences
559+ scaling = diffs .sum (1 ) / total_lengths + eps
560+
561+ return scaling
0 commit comments