Skip to content

Commit c924993

Browse files
author
Jan Beitner
committed
Make logging_metrics a kwarg
1 parent e448f1b commit c924993

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

pytorch_forecasting/models/base_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytorch_lightning.metrics.metric import TensorMetric
1414
from pytorch_lightning.utilities.parsing import get_init_args
1515
import torch
16+
import torch.nn as nn
1617
from torch.nn.utils import rnn
1718
from torch.optim.lr_scheduler import LambdaLR, OneCycleLR, ReduceLROnPlateau
1819
from torch.utils.data import DataLoader
@@ -57,7 +58,7 @@ def __init__(
5758
learning_rate: Union[float, List[float]] = 1e-3,
5859
log_gradient_flow: bool = False,
5960
loss: TensorMetric = SMAPE(),
60-
logging_metrics: List[TensorMetric] = [],
61+
logging_metrics: nn.ModuleList = nn.ModuleList([]),
6162
reduce_on_plateau_patience: int = 1000,
6263
weight_decay: float = 0.0,
6364
monotone_constaints: Dict[str, int] = {},
@@ -76,7 +77,8 @@ def __init__(
7677
log_gradient_flow (bool): If to log gradient flow, this takes time and should be only done to diagnose
7778
training failures. Defaults to False.
7879
loss (TensorMetric, optional): metric to optimize. Defaults to SMAPE().
79-
logging_metrics: (List[TensorMetric], optional): list of metrics to log.
80+
logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training.
81+
Defaults to [].
8082
reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10. Defaults
8183
to 1000
8284
weight_decay (float): weight decay. Defaults to 0.0.
@@ -102,7 +104,7 @@ def __init__(
102104
if not hasattr(self, "loss"):
103105
self.loss = loss
104106
if not hasattr(self, "logging_metrics"):
105-
self.logging_metrics = logging_metrics
107+
self.logging_metrics = nn.ModuleList([l for l in logging_metrics])
106108
if not hasattr(self, "output_transformer"):
107109
self.output_transformer = output_transformer
108110

pytorch_forecasting/models/nbeats/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
loss=SMAPE(),
3434
reduce_on_plateau_patience: int = 1000,
3535
backcast_loss_ratio: float = 0.0,
36+
logging_metrics: nn.ModuleList = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]),
3637
**kwargs,
3738
):
3839
"""
@@ -76,11 +77,12 @@ def __init__(
7677
log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training
7778
failures
7879
reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10
80+
logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training.
81+
Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()])
82+
**kwargs: additional arguments to :py:class:`~BaseModel`.
7983
"""
8084
self.save_hyperparameters()
81-
self.logging_metrics = [SMAPE(), MAE(), RMSE(), MAPE(), MASE()]
82-
super().__init__(**kwargs)
83-
self.loss = loss
85+
super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs)
8486

8587
# setup stacks
8688
self.net_blocks = nn.ModuleList()

pytorch_forecasting/models/temporal_fusion_transformer/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from matplotlib import pyplot as plt
77
import numpy as np
8+
from pytorch_lightning.metrics.metric import TensorMetric
89
import torch
910
from torch import nn
1011
from torch.nn.utils import rnn
@@ -55,7 +56,8 @@ def __init__(
5556
reduce_on_plateau_patience: int = 1000,
5657
monotone_constaints: Dict[str, int] = {},
5758
share_single_variable_networks: bool = False,
58-
output_transformer: Callable = None,
59+
logging_metrics: nn.ModuleList = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]),
60+
**kwargs,
5961
):
6062
"""
6163
Temporal Fusion Transformer for forecasting timeseries - use its :py:meth:`~from_dataset` method if possible.
@@ -122,14 +124,14 @@ def __init__(
122124
This constraint significantly slows down training. Defaults to {}.
123125
share_single_variable_networks (bool): if to share the single variable networks between the encoder and
124126
decoder. Defaults to False.
127+
logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training.
128+
Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]).
129+
**kwargs: additional arguments to :py:class:`~BaseModel`.
125130
"""
126131
self.save_hyperparameters()
127-
super().__init__()
128132
# store loss function separately as it is a module
129133
assert isinstance(loss, MultiHorizonMetric), "Loss has to of class `MultiHorizonMetric`"
130-
self.loss = loss
131-
self.output_transformer = output_transformer
132-
self.logging_metrics = [SMAPE(), MAE(), RMSE(), MAPE(), MASE()]
134+
super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs)
133135

134136
# processing inputs
135137
# embeddings
@@ -201,7 +203,6 @@ def __init__(
201203
)
202204

203205
# create single variable grns that are shared across decoder and encoder
204-
205206
if self.hparams.share_single_variable_networks:
206207
self.shared_single_variable_grns = nn.ModuleDict()
207208
for name, input_size in encoder_input_sizes.items():

0 commit comments

Comments
 (0)