Skip to content

Commit 9e75685

Browse files
authored
Merge pull request #995 from jdb78/feature/implicit-quantile-loss
Implicit Quantiles
2 parents e10f5dc + f5091ed commit 9e75685

File tree

11 files changed

+182
-6
lines changed

11 files changed

+182
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- MQF2 loss (multivariate quantile loss) (#949)
1010
- Non-causal attention for TFT (#949)
1111
- Tweedie loss (#949)
12+
- ImplicitQuantileNetworkDistributionLoss (#995)
1213

1314
### Fixed
1415

pytorch_forecasting/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
BetaDistributionLoss,
1818
CrossEntropy,
1919
DistributionLoss,
20+
ImplicitQuantileNetworkDistributionLoss,
2021
LogNormalDistributionLoss,
2122
MQF2DistributionLoss,
2223
MultiHorizonMetric,
@@ -84,6 +85,7 @@
8485
"LogNormalDistributionLoss",
8586
"NegativeBinomialDistributionLoss",
8687
"NormalDistributionLoss",
88+
"ImplicitQuantileNetworkDistributionLoss",
8789
"MultivariateNormalDistributionLoss",
8890
"MQF2DistributionLoss",
8991
"CrossEntropy",

pytorch_forecasting/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from pytorch_forecasting.metrics.distributions import (
1414
BetaDistributionLoss,
15+
ImplicitQuantileNetworkDistributionLoss,
1516
LogNormalDistributionLoss,
1617
MQF2DistributionLoss,
1718
MultivariateNormalDistributionLoss,
@@ -41,6 +42,7 @@
4142
"NormalDistributionLoss",
4243
"LogNormalDistributionLoss",
4344
"MultivariateNormalDistributionLoss",
45+
"ImplicitQuantileNetworkDistributionLoss",
4446
"QuantileLoss",
4547
"MQF2DistributionLoss",
4648
]

pytorch_forecasting/metrics/distributions.py

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
from sklearn.base import BaseEstimator
66
import torch
7-
from torch import distributions
7+
from torch import distributions, nn
88
import torch.nn.functional as F
99

1010
from pytorch_forecasting.data.encoders import TorchNormalizer, softplus_inv
@@ -405,3 +405,135 @@ def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None) -> t
405405
) # (batch_size, prediction_length, quantile_size)
406406

407407
return result
408+
409+
410+
class ImplicitQuantileNetwork(nn.Module):
411+
def __init__(self, input_size: int, hidden_size: int):
412+
super().__init__()
413+
self.quantile_layer = nn.Sequential(
414+
nn.Linear(hidden_size, hidden_size), nn.PReLU(), nn.Linear(hidden_size, input_size)
415+
)
416+
self.output_layer = nn.Sequential(
417+
nn.Linear(input_size, input_size),
418+
nn.PReLU(),
419+
nn.Linear(input_size, 1),
420+
)
421+
self.register_buffer("cos_multipliers", torch.arange(0, hidden_size) * torch.pi)
422+
423+
def forward(self, x: torch.Tensor, quantiles: torch.Tensor) -> torch.Tensor:
424+
# embed quantiles
425+
cos_emb_tau = torch.cos(quantiles[:, None] * self.cos_multipliers[None]) # n_quantiles x hidden_size
426+
# modulates input depending on quantile
427+
cos_emb_tau = self.quantile_layer(cos_emb_tau) # n_quantiles x input_size
428+
429+
emb_inputs = x.unsqueeze(-2) * (1.0 + cos_emb_tau) # ... x n_quantiles x input_size
430+
emb_outputs = self.output_layer(emb_inputs).squeeze(-1) # ... x n_quantiles
431+
return emb_outputs
432+
433+
434+
class ImplicitQuantileNetworkDistributionLoss(DistributionLoss):
435+
"""Implicit Quantile Network Distribution Loss.
436+
437+
Based on `Probabilistic Time Series Forecasting with Implicit Quantile Networks
438+
<https://arxiv.org/pdf/2107.03743.pdf>`_.
439+
A network is used to directly map network outputs to a quantile.
440+
"""
441+
442+
def __init__(
443+
self,
444+
quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98],
445+
input_size: Optional[int] = 16,
446+
hidden_size: Optional[int] = 32,
447+
n_loss_samples: Optional[int] = 64,
448+
) -> None:
449+
"""
450+
Args:
451+
prediction_length (int): maximum prediction length.
452+
quantiles (List[float], optional): default quantiles to output.
453+
Defaults to [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98].
454+
input_size (int, optional): input size per prediction length. Defaults to 16.
455+
hidden_size (int, optional): hidden size per prediction length. Defaults to 64.
456+
n_loss_samples (int, optional): number of quantiles to sample to calculate loss.
457+
"""
458+
super().__init__(quantiles=quantiles)
459+
self.quantile_network = ImplicitQuantileNetwork(input_size=input_size, hidden_size=hidden_size)
460+
self.distribution_arguments = list(range(int(input_size)))
461+
self.n_loss_samples = n_loss_samples
462+
463+
def sample(self, y_pred, n_samples: int) -> torch.Tensor:
464+
eps = 1e-3
465+
# for a couple of random quantiles (excl. 0 and 1 as they would lead to infinities)
466+
quantiles = torch.rand(size=(n_samples,), device=y_pred.device).clamp(eps, 1 - eps)
467+
# make prediction
468+
samples = self.to_quantiles(y_pred, quantiles=quantiles)
469+
return samples
470+
471+
def loss(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor:
472+
"""
473+
Calculate negative likelihood
474+
475+
Args:
476+
y_pred: network output
477+
y_actual: actual values
478+
479+
Returns:
480+
torch.Tensor: metric value on which backpropagation can be applied
481+
"""
482+
eps = 1e-3
483+
# for a couple of random quantiles (excl. 0 and 1 as they would lead to infinities)
484+
quantiles = torch.rand(size=(self.n_loss_samples,), device=y_pred.device).clamp(eps, 1 - eps)
485+
# make prediction
486+
pred_quantiles = self.to_quantiles(y_pred, quantiles=quantiles)
487+
# and calculate quantile loss
488+
errors = y_actual[..., None] - pred_quantiles
489+
loss = 2 * torch.fmax(quantiles[None] * errors, (quantiles[None] - 1) * errors).mean(dim=-1)
490+
return loss
491+
492+
def rescale_parameters(
493+
self, parameters: torch.Tensor, target_scale: torch.Tensor, encoder: BaseEstimator
494+
) -> torch.Tensor:
495+
self._transformation = encoder.transformation
496+
return torch.concat([parameters, target_scale.unsqueeze(1).expand(-1, parameters.size(1), -1)], dim=-1)
497+
498+
def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Tensor:
499+
if n_samples is None:
500+
return self.to_quantiles(y_pred, quantiles=[0.5]).squeeze(-1)
501+
else:
502+
# for a couple of random quantiles (excl. 0 and 1 as they would lead to infinities) make prediction
503+
return self.sample(y_pred, n_samples=n_samples).mean(-1)
504+
505+
def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None) -> torch.Tensor:
506+
"""
507+
Convert network prediction into a quantile prediction.
508+
509+
Args:
510+
y_pred: prediction output of network
511+
quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as
512+
as defined in the class initialization.
513+
514+
Returns:
515+
torch.Tensor: prediction quantiles (last dimension)
516+
"""
517+
if quantiles is None:
518+
quantiles = self.quantiles
519+
quantiles = torch.as_tensor(quantiles, device=y_pred.device)
520+
521+
# extract parameters
522+
x = y_pred[..., :-2]
523+
loc = y_pred[..., -2][..., None]
524+
scale = y_pred[..., -1][..., None]
525+
526+
# predict quantiles
527+
if y_pred.requires_grad:
528+
predictions = self.quantile_network(x, quantiles)
529+
else:
530+
with torch.no_grad():
531+
predictions = self.quantile_network(x, quantiles)
532+
# rescale output
533+
predictions = loc + predictions * scale
534+
# transform output if required
535+
if self._transformation is not None:
536+
transform = TorchNormalizer.get_transform(self._transformation)["reverse"]
537+
predictions = transform(predictions)
538+
539+
return predictions

pytorch_forecasting/metrics/quantile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def loss(self, y_pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
3232
for i, q in enumerate(self.quantiles):
3333
errors = target - y_pred[..., i]
3434
losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
35-
losses = torch.cat(losses, dim=2)
35+
losses = 2 * torch.cat(losses, dim=2)
3636

3737
return losses
3838

pytorch_forecasting/models/base_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,6 @@ def training_step(self, batch, batch_idx):
409409
"""
410410
x, y = batch
411411
log, out = self.step(x, y, batch_idx)
412-
log.update(self.create_log(x, y, out, batch_idx))
413412
return log
414413

415414
def training_epoch_end(self, outputs):

pytorch_forecasting/models/nhits/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pytorch_forecasting.models.base_model import BaseModelWithCovariates
1616
from pytorch_forecasting.models.nhits.sub_modules import NHiTS as NHiTSModule
1717
from pytorch_forecasting.models.nn.embeddings import MultiEmbedding
18-
from pytorch_forecasting.utils import create_mask, to_list
18+
from pytorch_forecasting.utils import create_mask, detach, to_list
1919

2020

2121
class NHiTS(BaseModelWithCovariates):

pytorch_forecasting/models/temporal_fusion_transformer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def epoch_end(self, outputs):
535535
"""
536536
run at epoch end for training or validation
537537
"""
538-
if self.log_interval > 0:
538+
if self.log_interval > 0 and not self.training:
539539
self.log_interpretation(outputs)
540540

541541
def interpret_output(

tests/test_metrics.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
MAE,
1111
SMAPE,
1212
BetaDistributionLoss,
13+
ImplicitQuantileNetworkDistributionLoss,
1314
LogNormalDistributionLoss,
1415
MultivariateNormalDistributionLoss,
1516
NegativeBinomialDistributionLoss,
@@ -210,3 +211,27 @@ def test_MultivariateNormalDistributionLoss(center, transformation):
210211
assert torch.isclose(target.mean(), samples.mean(), atol=3.0, rtol=0.5)
211212
if center: # if not centered, softplus distorts std too much for testing
212213
assert torch.isclose(target.std(), samples.std(), atol=0.1, rtol=0.7)
214+
215+
216+
def test_ImplicitQuantileNetworkDistributionLoss():
217+
batch_size = 3
218+
n_timesteps = 2
219+
output_size = 5
220+
221+
target = torch.rand((batch_size, n_timesteps))
222+
223+
normalizer = TorchNormalizer(center=True, transformation="softplus")
224+
normalizer.fit(target.reshape(-1))
225+
226+
loss = ImplicitQuantileNetworkDistributionLoss(input_size=output_size)
227+
x = torch.rand((batch_size, n_timesteps, output_size))
228+
target_scale = torch.rand((batch_size, 2))
229+
pred = loss.rescale_parameters(x, target_scale=target_scale, encoder=normalizer)
230+
assert loss.loss(pred, target).shape == target.shape
231+
quantiles = loss.to_quantiles(pred)
232+
assert quantiles.size(-1) == len(loss.quantiles)
233+
assert quantiles.size(0) == batch_size
234+
assert quantiles.size(1) == n_timesteps
235+
236+
point_prediction = loss.to_prediction(pred, n_samples=None)
237+
assert point_prediction.ndim == loss.to_prediction(pred, n_samples=100).ndim

tests/test_models/test_deepar.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytorch_forecasting.data.encoders import GroupNormalizer
1212
from pytorch_forecasting.metrics import (
1313
BetaDistributionLoss,
14+
ImplicitQuantileNetworkDistributionLoss,
1415
LogNormalDistributionLoss,
1516
MultivariateNormalDistributionLoss,
1617
NegativeBinomialDistributionLoss,
@@ -121,6 +122,9 @@ def _integration(
121122
lags={"volume": [2], "discount": [2]},
122123
)
123124
),
125+
dict(
126+
loss=ImplicitQuantileNetworkDistributionLoss(hidden_size=8),
127+
),
124128
dict(
125129
loss=MultivariateNormalDistributionLoss(),
126130
),

0 commit comments

Comments
 (0)