Skip to content

Commit

Permalink
[DOC] improve and add tide model to docs (#1762)
Browse files Browse the repository at this point in the history
### Description

Fixes #1758
  • Loading branch information
PranavBhatP authored Feb 6, 2025
1 parent 097403e commit 8ff36c2
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 23 deletions.
2 changes: 1 addition & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ and you should take into account. Here is an overview over the pros and cons of
:py:class:`~pytorch_forecasting.models.nhits.NHiTS`, "x", "x", "x", "", "", "", "", "", "", 1
:py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "x [#deepvar]_ ", "x", "", 3
:py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4

:py:class:`~pytorch_forecasting.model.tide.TiDEModel`, "x", "x", "x", "", "", "", "", "x", "", 3

.. [#deepvar] Accounting for correlations using a multivariate loss function which converts the network into a DeepVAR model.
Expand Down
2 changes: 2 additions & 0 deletions pytorch_forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
NHiTS,
RecurrentNetwork,
TemporalFusionTransformer,
TiDEModel,
get_rnn,
)
from pytorch_forecasting.utils import (
Expand All @@ -70,6 +71,7 @@
"NaNLabelEncoder",
"MultiNormalizer",
"TemporalFusionTransformer",
"TiDEModel",
"NBeats",
"NHiTS",
"Baseline",
Expand Down
2 changes: 2 additions & 0 deletions pytorch_forecasting/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytorch_forecasting.models.temporal_fusion_transformer import (
TemporalFusionTransformer,
)
from pytorch_forecasting.models.tide import TiDEModel

__all__ = [
"NBeats",
Expand All @@ -35,4 +36,5 @@
"GRU",
"MultiEmbedding",
"DecoderMLP",
"TiDEModel",
]
59 changes: 38 additions & 21 deletions pytorch_forecasting/models/tide/_tide.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
Implements the TiDE (Time-series Dense Encoder-decoder) model, which is designed for
long-term time-series forecasting.
"""

from copy import copy
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -44,30 +49,39 @@ def __init__(
):
"""An implementation of the TiDE model.
TiDE shares similarities with Transformers (implemented in :class:TransformerModel), but aims to deliver
better performance with reduced computational requirements by utilizing MLP-based encoder-decoder architectures
without attention mechanisms.
TiDE shares similarities with Transformers
(implemented in :class:TransformerModel), but aims to deliver better performance
with reduced computational requirements by utilizing MLP-based encoder-decoder
architectures without attention mechanisms.
This model supports future covariates (known for output_chunk_length points after the prediction time) and
static covariates.
This model supports future covariates (known for output_chunk_length points
after the prediction time) andstatic covariates.
The encoder and decoder are constructed using residual blocks. The number of residual blocks in the encoder and
decoder can be specified with `num_encoder_layers` and `num_decoder_layers` respectively. The layer width in the
residual blocks can be adjusted using `hidden_size`, while the layer width in the temporal decoder can be
controlled via `temporal_decoder_hidden`.
The encoder and decoder are constructed using residual blocks. The number of
residual blocks in the encoder and decoder can be specified with
`num_encoder_layers` and `num_decoder_layers` respectively. The layer width in
the residual blocks can be adjusted using `hidden_size`, while the layer width
in the temporal decoder can be controlled via `temporal_decoder_hidden`.
Parameters
----------
input_chunk_length (int): Number of past time steps to use as input for the model (per chunk).
This applies to the target series and future covariates (if supported by the model).
output_chunk_length (int): Number of time steps the internal model predicts simultaneously (per chunk).
This also determines how many future values from future covariates are used as input
input_chunk_length :int
Number of past time steps to use as input for themodel (per chunk).
This applies to the target series and future covariates
(if supported by the model).
num_encoder_layers (int): Number of residual blocks in the encoder. Defaults to 2.
num_decoder_layers (int): Number of residual blocks in the decoder. Defaults to 2.
decoder_output_dim (int): Dimensionality of the decoder's output. Defaults to 16.
hidden_size (int): Size of hidden layers in the encoder and decoder. Typically ranges from 32 to 128 when
no covariates are used. Defaults to 128.
output_chunk_length : int
Number of time steps the internal model predicts simultaneously (per chunk).
This also determines how many future values from future covariates
are used as input (if supported by the model).
num_encoder_layers : int, default=2
Number of residual blocks in the encoder
num_decoder_layers : int, default=2
Number of residual blocks in the decoder
decoder_output_dim : int, default=16
Dimensionality of the decoder's output
hidden_size : int, default=128
Size of hidden layers in the encoder and decoder.
Typically ranges from 32 to 128 when no covariates are used.
temporal_width_future (int): Width of the output layer in the residual block for future covariate projections.
If set to 0, bypasses feature projection and uses raw feature data. Defaults to 4.
temporal_hidden_size_future (int): Width of the hidden layer in the residual block for future covariate
Expand Down Expand Up @@ -98,8 +112,10 @@ def __init__(
**kwargs
Allows optional arguments to configure pytorch_lightning.Module, pytorch_lightning.Trainer, and
pytorch-forecasting's :class:BaseModelWithCovariates.
""" # noqa: E501
Note:
The model supports future covariates and static covariates.
""" # noqa: E501
if static_categoricals is None:
static_categoricals = []
if static_reals is None:
Expand Down Expand Up @@ -200,15 +216,16 @@ def static_size(self) -> int:
@classmethod
def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):
"""
Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`.
Convenience function to create network from
:py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`.
Args:
dataset (TimeSeriesDataSet): dataset where sole predictor is the target.
**kwargs: additional arguments to be passed to `__init__` method.
Returns:
TiDE
""" # noqa: E501
"""

# validate arguments
assert not isinstance(
Expand Down
3 changes: 2 additions & 1 deletion pytorch_forecasting/models/tide/sub_modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Time-series Dense Encoder (TiDE)
------
--------------------------------
"""

from typing import Optional, Tuple
Expand Down Expand Up @@ -226,6 +226,7 @@ def forward(
self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
) -> torch.Tensor:
"""TiDE model forward pass.
Parameters
----------
x_in
Expand Down

0 comments on commit 8ff36c2

Please sign in to comment.