Skip to content

Commit f5091ed

Browse files
author
Jan Beitner
committed
More tests for implicit quantiles
1 parent 22f7ee0 commit f5091ed

File tree

6 files changed

+24
-13
lines changed

6 files changed

+24
-13
lines changed

pytorch_forecasting/metrics/distributions.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,8 @@ def __init__(
443443
self,
444444
quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98],
445445
input_size: Optional[int] = 16,
446-
hidden_size: Optional[int] = 64,
447-
n_loss_samples: Optional[int] = 16,
446+
hidden_size: Optional[int] = 32,
447+
n_loss_samples: Optional[int] = 64,
448448
) -> None:
449449
"""
450450
Args:
@@ -460,6 +460,14 @@ def __init__(
460460
self.distribution_arguments = list(range(int(input_size)))
461461
self.n_loss_samples = n_loss_samples
462462

463+
def sample(self, y_pred, n_samples: int) -> torch.Tensor:
464+
eps = 1e-3
465+
# for a couple of random quantiles (excl. 0 and 1 as they would lead to infinities)
466+
quantiles = torch.rand(size=(n_samples,), device=y_pred.device).clamp(eps, 1 - eps)
467+
# make prediction
468+
samples = self.to_quantiles(y_pred, quantiles=quantiles)
469+
return samples
470+
463471
def loss(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor:
464472
"""
465473
Calculate negative likelihood
@@ -478,7 +486,7 @@ def loss(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor:
478486
pred_quantiles = self.to_quantiles(y_pred, quantiles=quantiles)
479487
# and calculate quantile loss
480488
errors = y_actual[..., None] - pred_quantiles
481-
loss = torch.fmax(quantiles[None] * errors, (quantiles[None] - 1) * errors).mean(dim=-1)
489+
loss = 2 * torch.fmax(quantiles[None] * errors, (quantiles[None] - 1) * errors).mean(dim=-1)
482490
return loss
483491

484492
def rescale_parameters(
@@ -492,9 +500,7 @@ def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Ten
492500
return self.to_quantiles(y_pred, quantiles=[0.5]).squeeze(-1)
493501
else:
494502
# for a couple of random quantiles (excl. 0 and 1 as they would lead to infinities) make prediction
495-
eps = 1e-3
496-
quantiles = torch.rand(size=(n_samples,), device=y_pred.device).clamp(eps, 1 - eps)
497-
return self.to_quantiles(y_pred, quantiles=quantiles).mean(-1)
503+
return self.sample(y_pred, n_samples=n_samples).mean(-1)
498504

499505
def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None) -> torch.Tensor:
500506
"""
@@ -518,14 +524,16 @@ def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None) -> t
518524
scale = y_pred[..., -1][..., None]
519525

520526
# predict quantiles
521-
predictions = self.quantile_network(x, quantiles)
527+
if y_pred.requires_grad:
528+
predictions = self.quantile_network(x, quantiles)
529+
else:
530+
with torch.no_grad():
531+
predictions = self.quantile_network(x, quantiles)
522532
# rescale output
523533
predictions = loc + predictions * scale
524534
# transform output if required
525535
if self._transformation is not None:
526536
transform = TorchNormalizer.get_transform(self._transformation)["reverse"]
527537
predictions = transform(predictions)
528538

529-
if not y_pred.requires_grad:
530-
predictions = predictions.detach()
531539
return predictions

pytorch_forecasting/metrics/quantile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def loss(self, y_pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
3232
for i, q in enumerate(self.quantiles):
3333
errors = target - y_pred[..., i]
3434
losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
35-
losses = torch.cat(losses, dim=2)
35+
losses = 2 * torch.cat(losses, dim=2)
3636

3737
return losses
3838

pytorch_forecasting/models/base_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,6 @@ def training_step(self, batch, batch_idx):
409409
"""
410410
x, y = batch
411411
log, out = self.step(x, y, batch_idx)
412-
log.update(self.create_log(x, y, out, batch_idx))
413412
return log
414413

415414
def training_epoch_end(self, outputs):

pytorch_forecasting/models/temporal_fusion_transformer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def epoch_end(self, outputs):
535535
"""
536536
run at epoch end for training or validation
537537
"""
538-
if self.log_interval > 0:
538+
if self.log_interval > 0 and not self.training:
539539
self.log_interpretation(outputs)
540540

541541
def interpret_output(

tests/test_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,4 +234,4 @@ def test_ImplicitQuantileNetworkDistributionLoss():
234234
assert quantiles.size(1) == n_timesteps
235235

236236
point_prediction = loss.to_prediction(pred, n_samples=None)
237-
assert point_prediction.ndim == loss.to_prediction(pred, n_samples=100).ndim - 1
237+
assert point_prediction.ndim == loss.to_prediction(pred, n_samples=100).ndim

tests/test_models/test_deepar.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytorch_forecasting.data.encoders import GroupNormalizer
1212
from pytorch_forecasting.metrics import (
1313
BetaDistributionLoss,
14+
ImplicitQuantileNetworkDistributionLoss,
1415
LogNormalDistributionLoss,
1516
MultivariateNormalDistributionLoss,
1617
NegativeBinomialDistributionLoss,
@@ -121,6 +122,9 @@ def _integration(
121122
lags={"volume": [2], "discount": [2]},
122123
)
123124
),
125+
dict(
126+
loss=ImplicitQuantileNetworkDistributionLoss(hidden_size=8),
127+
),
124128
dict(
125129
loss=MultivariateNormalDistributionLoss(),
126130
),

0 commit comments

Comments
 (0)