Skip to content

Commit c07d5ac

Browse files
authored
Merge pull request #96 from jdb78/feature/minimal_categorical_prediction
Minimal categorical prediction
2 parents a64ed20 + 5696ff8 commit c07d5ac

File tree

6 files changed

+46
-9
lines changed

6 files changed

+46
-9
lines changed

pytorch_forecasting/data/encoders.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,20 @@ def inverse_transform(self, y: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
130130
decoded = self.classes_vector_[y]
131131
return decoded
132132

133+
def __call__(self, data: (Dict[str, torch.Tensor])) -> torch.Tensor:
134+
"""
135+
Extract prediction from network output. Does not map back to input
136+
categories as this would require a numpy tensor without grad-abilities.
137+
138+
Args:
139+
data (Dict[str, torch.Tensor]): Dictionary with entries
140+
* prediction: data to de-scale
141+
142+
Returns:
143+
torch.Tensor: prediction
144+
"""
145+
return data["prediction"]
146+
133147

134148
class TorchNormalizer(BaseEstimator, TransformerMixin):
135149
"""

pytorch_forecasting/data/timeseries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(
127127
Args:
128128
data: dataframe with sequence data - each row can be identified with ``time_idx`` and the ``group_ids``
129129
time_idx: integer column denoting the time index
130-
target: column denoting the target or list of columns denoting the target
130+
target: column denoting the target or list of columns denoting the target - categorical or continous.
131131
group_ids: list of column names identifying a timeseries
132132
weight: column name for weights or list of column names corresponding to each target
133133
max_encoder_length: maximum length to encode

pytorch_forecasting/metrics.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,19 @@ def loss(self, y_pred, target):
493493
return loss
494494

495495

496+
class CrossEntropy(MultiHorizonMetric):
497+
"""
498+
Cross entropy loss for classification.
499+
"""
500+
501+
def loss(self, y_pred, target):
502+
503+
loss = F.cross_entropy(y_pred.view(-1, y_pred.size(-1)), target.view(-1), reduction="none").view(
504+
-1, target.size(-1)
505+
)
506+
return loss
507+
508+
496509
class RMSE(MultiHorizonMetric):
497510
"""
498511
Root mean square error

pytorch_forecasting/models/base_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
import torch
1515
import torch.nn as nn
1616
from torch.nn.utils import rnn
17-
from torch.optim.lr_scheduler import LambdaLR, OneCycleLR, ReduceLROnPlateau
17+
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
1818
from torch.utils.data import DataLoader
1919
from tqdm.notebook import tqdm
2020

2121
from pytorch_forecasting.data import TimeSeriesDataSet
22-
from pytorch_forecasting.data.encoders import GroupNormalizer
22+
from pytorch_forecasting.data.encoders import EncoderNormalizer, GroupNormalizer
2323
from pytorch_forecasting.metrics import MASE, SMAPE, Metric
2424
from pytorch_forecasting.optim import Ranger
2525
from pytorch_forecasting.utils import groupby_apply
@@ -908,7 +908,7 @@ def plot_prediction_actual_by_variable(
908908
scaler = self.dataset_parameters["scalers"][name]
909909
x = np.linspace(-data["std"], data["std"], bins)
910910
# reversing normalization for group normalizer is not possible without sample level information
911-
if not isinstance(scaler, GroupNormalizer):
911+
if not isinstance(scaler, (GroupNormalizer, EncoderNormalizer)):
912912
x = scaler.inverse_transform(x)
913913
ax.set_xlabel(f"Normalized {name}")
914914

tests/test_models/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,23 +79,26 @@ def data_with_covariates():
7979
dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2),
8080
dict(target_normalizer=GroupNormalizer(log_scale=True)),
8181
dict(target_normalizer=GroupNormalizer(groups=["agency", "sku"], coerce_positive=1.0)),
82+
dict(target="agency"),
8283
]
8384
)
8485
def multiple_dataloaders_with_coveratiates(data_with_covariates, request):
8586
training_cutoff = "2016-09-01"
8687
max_encoder_length = 36
8788
max_prediction_length = 6
8889

90+
params = request.param
91+
params.setdefault("target", "volume")
92+
8993
training = TimeSeriesDataSet(
9094
data_with_covariates[lambda x: x.date < training_cutoff],
9195
time_idx="time_idx",
92-
target="volume",
9396
# weight="weight",
9497
group_ids=["agency", "sku"],
9598
max_encoder_length=max_encoder_length,
9699
max_prediction_length=max_prediction_length,
97100
add_relative_time_idx=True,
98-
**request.param # fixture parametrization
101+
**params # fixture parametrization
99102
)
100103

101104
validation = TimeSeriesDataSet.from_dataset(

tests/test_models/test_temporal_fusion_transformer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from torch.utils.data import dataloader
1010

1111
from pytorch_forecasting import TimeSeriesDataSet
12-
from pytorch_forecasting.metrics import PoissonLoss, QuantileLoss
12+
from pytorch_forecasting.data import NaNLabelEncoder
13+
from pytorch_forecasting.metrics import CrossEntropy, PoissonLoss, QuantileLoss
1314
from pytorch_forecasting.models import TemporalFusionTransformer
1415
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
1516

@@ -52,18 +53,24 @@ def test_integration(multiple_dataloaders_with_coveratiates, tmp_path, gpus):
5253
cuda_context = nullcontext()
5354

5455
with cuda_context:
56+
if isinstance(train_dataloader.dataset.target_normalizer, NaNLabelEncoder):
57+
output_size = len(train_dataloader.dataset.target_normalizer.classes_)
58+
loss = CrossEntropy()
59+
else:
60+
output_size = 7
61+
loss = QuantileLoss()
5562
net = TemporalFusionTransformer.from_dataset(
5663
train_dataloader.dataset,
5764
learning_rate=0.15,
5865
hidden_size=4,
5966
attention_head_size=1,
6067
dropout=0.2,
6168
hidden_continuous_size=2,
62-
loss=QuantileLoss(),
63-
output_size=7,
69+
loss=loss,
6470
log_interval=5,
6571
log_val_interval=1,
6672
log_gradient_flow=True,
73+
output_size=output_size,
6774
monotone_constaints=monotone_constaints,
6875
)
6976
net.size()

0 commit comments

Comments
 (0)