@@ -841,10 +841,9 @@ class BetaDistributionLoss(DistributionLoss):
841841 Beta distribution loss for unit interval data.
842842
843843 Requirements for original target normalizer:
844- * coerced to be positive
845- * not centered normalization (only rescaled)
846- * normalized target not in log space
844+ * logit transformation
847845 """
846+
848847 distribution_class = distributions .Beta
849848 distribution_arguments = ["mean" , "shape" ]
850849
@@ -854,17 +853,20 @@ def map_x_to_distribution(self, x: torch.Tensor) -> distributions.Beta:
854853 return self .distribution_class (concentration0 = (1 - mean ) * shape , concentration1 = mean * shape )
855854
856855 def rescale_parameters (
857- self , parameters : torch .Tensor , target_scale : torch .Tensor , transformer : BaseEstimator
856+ self , parameters : torch .Tensor , target_scale : torch .Tensor , encoder : BaseEstimator
858857 ) -> torch .Tensor :
859- assert transformer .coerce_positive , "Beta distribution is only compatible with strictly positive data"
860- assert (
861- not transformer .log_scale
862- ), "Beta distribution is not compatible with log transformation - use LogNormal"
863- assert not transformer .center , "Beta distribution is not compatible with centered data"
864-
865- scaled_mean = torch .sigmoid (parameters [..., 0 ] + target_scale [..., 0 ].unsqueeze (1 ))
866- return torch .stack ([
867- scaled_mean ,
868- F .softplus (parameters [..., 1 ]) * scaled_mean * (1 - scaled_mean )
869- / torch .pow (target_scale [..., 1 ].unsqueeze (1 ), 2 )
870- ], dim = - 1 )
858+ assert encoder .transformation in ["logit" ], "Beta distribution is only compatible with logit transformation"
859+ assert encoder .center , "Beta distribution requires normalizer to center data"
860+
861+ scaled_mean = encoder (dict (prediction = parameters [..., 0 ], target_scale = target_scale ))
862+ # need to first transform target scale standard deviation in logit space to real space
863+ # we assume a normal distribution in logit space (we used a logit transform and a standard scaler)
864+ # and know that the variance of the beta distribution is limited by `scaled_mean * (1 - scaled_mean)`
865+ mean_derivative = scaled_mean * (1 - scaled_mean )
866+
867+ # we can approximate variance as
868+ # torch.pow(torch.tanh(target_scale[..., 1].unsqueeze(1) * torch.sqrt(mean_derivative)), 2) * mean_derivative
869+ # shape is (positive) parameter * mean_derivative / var
870+ shape_scaler = torch .pow (torch .tanh (target_scale [..., 1 ].unsqueeze (1 ) * torch .sqrt (mean_derivative )), 2 )
871+ scaled_shape = F .softplus (parameters [..., 1 ]) / shape_scaler
872+ return torch .stack ([scaled_mean , scaled_shape ], dim = - 1 )
0 commit comments