|
4 | 4 | import numpy as np |
5 | 5 | from sklearn.base import BaseEstimator |
6 | 6 | import torch |
7 | | -from torch import distributions |
| 7 | +from torch import distributions, nn |
8 | 8 | import torch.nn.functional as F |
9 | 9 |
|
10 | 10 | from pytorch_forecasting.data.encoders import TorchNormalizer, softplus_inv |
@@ -405,3 +405,135 @@ def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None) -> t |
405 | 405 | ) # (batch_size, prediction_length, quantile_size) |
406 | 406 |
|
407 | 407 | return result |
| 408 | + |
| 409 | + |
| 410 | +class ImplicitQuantileNetwork(nn.Module): |
| 411 | + def __init__(self, input_size: int, hidden_size: int): |
| 412 | + super().__init__() |
| 413 | + self.quantile_layer = nn.Sequential( |
| 414 | + nn.Linear(hidden_size, hidden_size), nn.PReLU(), nn.Linear(hidden_size, input_size) |
| 415 | + ) |
| 416 | + self.output_layer = nn.Sequential( |
| 417 | + nn.Linear(input_size, input_size), |
| 418 | + nn.PReLU(), |
| 419 | + nn.Linear(input_size, 1), |
| 420 | + ) |
| 421 | + self.register_buffer("cos_multipliers", torch.arange(0, hidden_size) * torch.pi) |
| 422 | + |
| 423 | + def forward(self, x: torch.Tensor, quantiles: torch.Tensor) -> torch.Tensor: |
| 424 | + # embed quantiles |
| 425 | + cos_emb_tau = torch.cos(quantiles[:, None] * self.cos_multipliers[None]) # n_quantiles x hidden_size |
| 426 | + # modulates input depending on quantile |
| 427 | + cos_emb_tau = self.quantile_layer(cos_emb_tau) # n_quantiles x input_size |
| 428 | + |
| 429 | + emb_inputs = x.unsqueeze(-2) * (1.0 + cos_emb_tau) # ... x n_quantiles x input_size |
| 430 | + emb_outputs = self.output_layer(emb_inputs).squeeze(-1) # ... x n_quantiles |
| 431 | + return emb_outputs |
| 432 | + |
| 433 | + |
| 434 | +class ImplicitQuantileNetworkDistributionLoss(DistributionLoss): |
| 435 | + """Implicit Quantile Network Distribution Loss. |
| 436 | +
|
| 437 | + Based on `Probabilistic Time Series Forecasting with Implicit Quantile Networks |
| 438 | + <https://arxiv.org/pdf/2107.03743.pdf>`_. |
| 439 | + A network is used to directly map network outputs to a quantile. |
| 440 | + """ |
| 441 | + |
| 442 | + def __init__( |
| 443 | + self, |
| 444 | + quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98], |
| 445 | + input_size: Optional[int] = 16, |
| 446 | + hidden_size: Optional[int] = 32, |
| 447 | + n_loss_samples: Optional[int] = 64, |
| 448 | + ) -> None: |
| 449 | + """ |
| 450 | + Args: |
| 451 | + prediction_length (int): maximum prediction length. |
| 452 | + quantiles (List[float], optional): default quantiles to output. |
| 453 | + Defaults to [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]. |
| 454 | + input_size (int, optional): input size per prediction length. Defaults to 16. |
| 455 | + hidden_size (int, optional): hidden size per prediction length. Defaults to 64. |
| 456 | + n_loss_samples (int, optional): number of quantiles to sample to calculate loss. |
| 457 | + """ |
| 458 | + super().__init__(quantiles=quantiles) |
| 459 | + self.quantile_network = ImplicitQuantileNetwork(input_size=input_size, hidden_size=hidden_size) |
| 460 | + self.distribution_arguments = list(range(int(input_size))) |
| 461 | + self.n_loss_samples = n_loss_samples |
| 462 | + |
| 463 | + def sample(self, y_pred, n_samples: int) -> torch.Tensor: |
| 464 | + eps = 1e-3 |
| 465 | + # for a couple of random quantiles (excl. 0 and 1 as they would lead to infinities) |
| 466 | + quantiles = torch.rand(size=(n_samples,), device=y_pred.device).clamp(eps, 1 - eps) |
| 467 | + # make prediction |
| 468 | + samples = self.to_quantiles(y_pred, quantiles=quantiles) |
| 469 | + return samples |
| 470 | + |
| 471 | + def loss(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor: |
| 472 | + """ |
| 473 | + Calculate negative likelihood |
| 474 | +
|
| 475 | + Args: |
| 476 | + y_pred: network output |
| 477 | + y_actual: actual values |
| 478 | +
|
| 479 | + Returns: |
| 480 | + torch.Tensor: metric value on which backpropagation can be applied |
| 481 | + """ |
| 482 | + eps = 1e-3 |
| 483 | + # for a couple of random quantiles (excl. 0 and 1 as they would lead to infinities) |
| 484 | + quantiles = torch.rand(size=(self.n_loss_samples,), device=y_pred.device).clamp(eps, 1 - eps) |
| 485 | + # make prediction |
| 486 | + pred_quantiles = self.to_quantiles(y_pred, quantiles=quantiles) |
| 487 | + # and calculate quantile loss |
| 488 | + errors = y_actual[..., None] - pred_quantiles |
| 489 | + loss = 2 * torch.fmax(quantiles[None] * errors, (quantiles[None] - 1) * errors).mean(dim=-1) |
| 490 | + return loss |
| 491 | + |
| 492 | + def rescale_parameters( |
| 493 | + self, parameters: torch.Tensor, target_scale: torch.Tensor, encoder: BaseEstimator |
| 494 | + ) -> torch.Tensor: |
| 495 | + self._transformation = encoder.transformation |
| 496 | + return torch.concat([parameters, target_scale.unsqueeze(1).expand(-1, parameters.size(1), -1)], dim=-1) |
| 497 | + |
| 498 | + def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Tensor: |
| 499 | + if n_samples is None: |
| 500 | + return self.to_quantiles(y_pred, quantiles=[0.5]).squeeze(-1) |
| 501 | + else: |
| 502 | + # for a couple of random quantiles (excl. 0 and 1 as they would lead to infinities) make prediction |
| 503 | + return self.sample(y_pred, n_samples=n_samples).mean(-1) |
| 504 | + |
| 505 | + def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None) -> torch.Tensor: |
| 506 | + """ |
| 507 | + Convert network prediction into a quantile prediction. |
| 508 | +
|
| 509 | + Args: |
| 510 | + y_pred: prediction output of network |
| 511 | + quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as |
| 512 | + as defined in the class initialization. |
| 513 | +
|
| 514 | + Returns: |
| 515 | + torch.Tensor: prediction quantiles (last dimension) |
| 516 | + """ |
| 517 | + if quantiles is None: |
| 518 | + quantiles = self.quantiles |
| 519 | + quantiles = torch.as_tensor(quantiles, device=y_pred.device) |
| 520 | + |
| 521 | + # extract parameters |
| 522 | + x = y_pred[..., :-2] |
| 523 | + loc = y_pred[..., -2][..., None] |
| 524 | + scale = y_pred[..., -1][..., None] |
| 525 | + |
| 526 | + # predict quantiles |
| 527 | + if y_pred.requires_grad: |
| 528 | + predictions = self.quantile_network(x, quantiles) |
| 529 | + else: |
| 530 | + with torch.no_grad(): |
| 531 | + predictions = self.quantile_network(x, quantiles) |
| 532 | + # rescale output |
| 533 | + predictions = loc + predictions * scale |
| 534 | + # transform output if required |
| 535 | + if self._transformation is not None: |
| 536 | + transform = TorchNormalizer.get_transform(self._transformation)["reverse"] |
| 537 | + predictions = transform(predictions) |
| 538 | + |
| 539 | + return predictions |
0 commit comments