|
| 1 | +""" |
| 2 | +Simple models based on fully connected networks |
| 3 | +""" |
| 4 | + |
| 5 | + |
| 6 | +from typing import Dict, List, Tuple, Union |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import torch |
| 10 | +from torch import nn |
| 11 | + |
| 12 | +from pytorch_forecasting.data import TimeSeriesDataSet |
| 13 | +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, QuantileLoss |
| 14 | +from pytorch_forecasting.models.base_model import BaseModelWithCovariates |
| 15 | +from pytorch_forecasting.models.mlp.submodules import FullyConnectedModule |
| 16 | +from pytorch_forecasting.models.nn.embeddings import MultiEmbedding |
| 17 | + |
| 18 | + |
| 19 | +class DecoderMLP(BaseModelWithCovariates): |
| 20 | + """ |
| 21 | + MLP on the decoder. |
| 22 | +
|
| 23 | + MLP that predicts output only based on information available in the decoder. |
| 24 | + """ |
| 25 | + |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + activation_class: str = "ReLU", |
| 29 | + hidden_size: int = 300, |
| 30 | + n_hidden_layers: int = 3, |
| 31 | + dropout: float = 0.1, |
| 32 | + norm: bool = True, |
| 33 | + static_categoricals: List[str] = [], |
| 34 | + static_reals: List[str] = [], |
| 35 | + time_varying_categoricals_encoder: List[str] = [], |
| 36 | + time_varying_categoricals_decoder: List[str] = [], |
| 37 | + categorical_groups: Dict[str, List[str]] = {}, |
| 38 | + time_varying_reals_encoder: List[str] = [], |
| 39 | + time_varying_reals_decoder: List[str] = [], |
| 40 | + embedding_sizes: Dict[str, Tuple[int, int]] = {}, |
| 41 | + embedding_paddings: List[str] = [], |
| 42 | + embedding_labels: Dict[str, np.ndarray] = {}, |
| 43 | + x_reals: List[str] = [], |
| 44 | + x_categoricals: List[str] = [], |
| 45 | + output_size: Union[int, List[int]] = 1, |
| 46 | + target: Union[str, List[str]] = None, |
| 47 | + loss: MultiHorizonMetric = None, |
| 48 | + logging_metrics: nn.ModuleList = None, |
| 49 | + **kwargs, |
| 50 | + ): |
| 51 | + """ |
| 52 | + Args: |
| 53 | + activation_class (str, optional): PyTorch activation class. Defaults to "ReLU". |
| 54 | + hidden_size (int, optional): hidden recurrent size - the most important hyperparameter along with |
| 55 | + ``n_hidden_layers``. Defaults to 10. |
| 56 | + n_hidden_layers (int, optional): Number of hidden layers - important hyperparameter. Defaults to 2. |
| 57 | + dropout (float, optional): Dropout. Defaults to 0.1. |
| 58 | + norm (bool, optional): if to use normalization in the MLP. Defaults to True. |
| 59 | + static_categoricals: integer of positions of static categorical variables |
| 60 | + static_reals: integer of positions of static continuous variables |
| 61 | + time_varying_categoricals_encoder: integer of positions of categorical variables for encoder |
| 62 | + time_varying_categoricals_decoder: integer of positions of categorical variables for decoder |
| 63 | + time_varying_reals_encoder: integer of positions of continuous variables for encoder |
| 64 | + time_varying_reals_decoder: integer of positions of continuous variables for decoder |
| 65 | + categorical_groups: dictionary where values |
| 66 | + are list of categorical variables that are forming together a new categorical |
| 67 | + variable which is the key in the dictionary |
| 68 | + x_reals: order of continuous variables in tensor passed to forward function |
| 69 | + x_categoricals: order of categorical variables in tensor passed to forward function |
| 70 | + embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and |
| 71 | + embedding size |
| 72 | + embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector |
| 73 | + embedding_labels: dictionary mapping (string) indices to list of categorical labels |
| 74 | + output_size (Union[int, List[int]], optional): number of outputs (e.g. number of quantiles for |
| 75 | + QuantileLoss and one target or list of output sizes). |
| 76 | + target (str, optional): Target variable or list of target variables. Defaults to None. |
| 77 | + loss (MultiHorizonMetric, optional): loss: loss function taking prediction and targets. |
| 78 | + Defaults to QuantileLoss. |
| 79 | + logging_metrics (nn.ModuleList, optional): Metrics to log during training. |
| 80 | + Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]). |
| 81 | + """ |
| 82 | + if loss is None: |
| 83 | + loss = QuantileLoss() |
| 84 | + if logging_metrics is None: |
| 85 | + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) |
| 86 | + self.save_hyperparameters() |
| 87 | + # store loss function separately as it is a module |
| 88 | + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) |
| 89 | + |
| 90 | + self.input_embeddings = MultiEmbedding( |
| 91 | + embedding_sizes={ |
| 92 | + name: val |
| 93 | + for name, val in embedding_sizes.items() |
| 94 | + if name in self.decoder_variables + self.static_variables |
| 95 | + }, |
| 96 | + embedding_paddings=embedding_paddings, |
| 97 | + categorical_groups=categorical_groups, |
| 98 | + x_categoricals=x_categoricals, |
| 99 | + ) |
| 100 | + # define network |
| 101 | + if isinstance(self.hparams.output_size, int): |
| 102 | + mlp_output_size = self.hparams.output_size |
| 103 | + else: |
| 104 | + mlp_output_size = sum(self.hparams.output_size) |
| 105 | + |
| 106 | + cont_size = len(self.decoder_reals_positions) |
| 107 | + cat_size = sum([emb.embedding_dim for emb in self.input_embeddings.values()]) |
| 108 | + input_size = cont_size + cat_size |
| 109 | + |
| 110 | + self.mlp = FullyConnectedModule( |
| 111 | + dropout=dropout, |
| 112 | + norm=self.hparams.norm, |
| 113 | + activation_class=getattr(nn, self.hparams.activation_class), |
| 114 | + input_size=input_size, |
| 115 | + output_size=mlp_output_size, |
| 116 | + hidden_size=self.hparams.hidden_size, |
| 117 | + n_hidden_layers=self.hparams.n_hidden_layers, |
| 118 | + ) |
| 119 | + |
| 120 | + @property |
| 121 | + def decoder_reals_positions(self) -> List[int]: |
| 122 | + return [ |
| 123 | + self.hparams.x_reals.index(name) |
| 124 | + for name in self.reals |
| 125 | + if name in self.decoder_variables + self.static_variables |
| 126 | + ] |
| 127 | + |
| 128 | + def forward(self, x: Dict[str, torch.Tensor], n_samples: int = None) -> Dict[str, torch.Tensor]: |
| 129 | + """ |
| 130 | + Forward network |
| 131 | + """ |
| 132 | + # x is a batch generated based on the TimeSeriesDataset |
| 133 | + batch_size = x["decoder_lengths"].size(0) |
| 134 | + embeddings = self.input_embeddings(x["decoder_cat"]) # returns dictionary with embedding tensors |
| 135 | + network_input = torch.cat( |
| 136 | + [x["decoder_cont"][..., self.decoder_reals_positions]] + list(embeddings.values()), |
| 137 | + dim=-1, |
| 138 | + ) |
| 139 | + prediction = self.mlp(network_input.view(-1, self.mlp.input_size)).view( |
| 140 | + batch_size, network_input.size(1), self.mlp.output_size |
| 141 | + ) |
| 142 | + |
| 143 | + # cut prediction into pieces for multiple targets |
| 144 | + if self.n_targets > 1: |
| 145 | + prediction = torch.split(prediction, self.hparams.output_size, dim=-1) |
| 146 | + |
| 147 | + # We need to return a dictionary that at least contains the prediction and the target_scale. |
| 148 | + # The parameter can be directly forwarded from the input. |
| 149 | + return dict(prediction=prediction, target_scale=x["target_scale"]) |
| 150 | + |
| 151 | + @classmethod |
| 152 | + def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): |
| 153 | + new_kwargs = cls.deduce_default_output_parameters(dataset, kwargs, QuantileLoss()) |
| 154 | + kwargs.update(new_kwargs) |
| 155 | + return super().from_dataset(dataset, **kwargs) |
0 commit comments