@@ -443,8 +443,8 @@ def __init__(
443443 self ,
444444 quantiles : List [float ] = [0.02 , 0.1 , 0.25 , 0.5 , 0.75 , 0.9 , 0.98 ],
445445 input_size : Optional [int ] = 16 ,
446- hidden_size : Optional [int ] = 64 ,
447- n_loss_samples : Optional [int ] = 16 ,
446+ hidden_size : Optional [int ] = 32 ,
447+ n_loss_samples : Optional [int ] = 64 ,
448448 ) -> None :
449449 """
450450 Args:
@@ -460,6 +460,14 @@ def __init__(
460460 self .distribution_arguments = list (range (int (input_size )))
461461 self .n_loss_samples = n_loss_samples
462462
463+ def sample (self , y_pred , n_samples : int ) -> torch .Tensor :
464+ eps = 1e-3
465+ # for a couple of random quantiles (excl. 0 and 1 as they would lead to infinities)
466+ quantiles = torch .rand (size = (n_samples ,), device = y_pred .device ).clamp (eps , 1 - eps )
467+ # make prediction
468+ samples = self .to_quantiles (y_pred , quantiles = quantiles )
469+ return samples
470+
463471 def loss (self , y_pred : torch .Tensor , y_actual : torch .Tensor ) -> torch .Tensor :
464472 """
465473 Calculate negative likelihood
@@ -478,7 +486,7 @@ def loss(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor:
478486 pred_quantiles = self .to_quantiles (y_pred , quantiles = quantiles )
479487 # and calculate quantile loss
480488 errors = y_actual [..., None ] - pred_quantiles
481- loss = torch .fmax (quantiles [None ] * errors , (quantiles [None ] - 1 ) * errors ).mean (dim = - 1 )
489+ loss = 2 * torch .fmax (quantiles [None ] * errors , (quantiles [None ] - 1 ) * errors ).mean (dim = - 1 )
482490 return loss
483491
484492 def rescale_parameters (
@@ -492,9 +500,7 @@ def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Ten
492500 return self .to_quantiles (y_pred , quantiles = [0.5 ]).squeeze (- 1 )
493501 else :
494502 # for a couple of random quantiles (excl. 0 and 1 as they would lead to infinities) make prediction
495- eps = 1e-3
496- quantiles = torch .rand (size = (n_samples ,), device = y_pred .device ).clamp (eps , 1 - eps )
497- return self .to_quantiles (y_pred , quantiles = quantiles ).mean (- 1 )
503+ return self .sample (y_pred , n_samples = n_samples ).mean (- 1 )
498504
499505 def to_quantiles (self , y_pred : torch .Tensor , quantiles : List [float ] = None ) -> torch .Tensor :
500506 """
@@ -518,14 +524,16 @@ def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None) -> t
518524 scale = y_pred [..., - 1 ][..., None ]
519525
520526 # predict quantiles
521- predictions = self .quantile_network (x , quantiles )
527+ if y_pred .requires_grad :
528+ predictions = self .quantile_network (x , quantiles )
529+ else :
530+ with torch .no_grad ():
531+ predictions = self .quantile_network (x , quantiles )
522532 # rescale output
523533 predictions = loc + predictions * scale
524534 # transform output if required
525535 if self ._transformation is not None :
526536 transform = TorchNormalizer .get_transform (self ._transformation )["reverse" ]
527537 predictions = transform (predictions )
528538
529- if not y_pred .requires_grad :
530- predictions = predictions .detach ()
531539 return predictions
0 commit comments