Skip to content

Commit e5c8698

Browse files
authored
Merge pull request #160 from JakeForsey/feature/beta-distribution-metric
Feature/beta distribution metric
2 parents caa4362 + 9c4d068 commit e5c8698

File tree

3 files changed

+72
-3
lines changed

3 files changed

+72
-3
lines changed

pytorch_forecasting/data/encoders.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,17 @@ def preprocess(self, y: Union[pd.Series, np.ndarray, torch.Tensor]) -> Union[np.
230230
Returns:
231231
Union[np.ndarray, torch.Tensor]: return rescaled series with type depending on input type
232232
"""
233-
y = y + self.eps
234233
if self.transformation is None:
235-
pass
236-
elif isinstance(y, torch.Tensor):
234+
return y
235+
236+
# protect against numerical instabilities
237+
if isinstance(self.transformation, str) and self.transformation == "logit":
238+
# need to apply eps slightly differently
239+
y = y / (1 + 2 * self.eps) + self.eps
240+
else:
241+
y = y + self.eps
242+
243+
if isinstance(y, torch.Tensor):
237244
y = self.TRANSFORMATIONS.get(self.transformation, self.transformation)[0](y)
238245
else:
239246
# convert first to tensor, then transform and then convert to numpy array

pytorch_forecasting/metrics.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,3 +834,39 @@ def rescale_parameters(
834834
loc = parameters[..., 0] * target_scale[..., 1].unsqueeze(-1) + target_scale[..., 0].unsqueeze(-1)
835835

836836
return torch.stack([loc, scale], dim=-1)
837+
838+
839+
class BetaDistributionLoss(DistributionLoss):
840+
"""
841+
Beta distribution loss for unit interval data.
842+
843+
Requirements for original target normalizer:
844+
* logit transformation
845+
"""
846+
847+
distribution_class = distributions.Beta
848+
distribution_arguments = ["mean", "shape"]
849+
850+
def map_x_to_distribution(self, x: torch.Tensor) -> distributions.Beta:
851+
mean = x[..., 0]
852+
shape = x[..., 1]
853+
return self.distribution_class(concentration0=(1 - mean) * shape, concentration1=mean * shape)
854+
855+
def rescale_parameters(
856+
self, parameters: torch.Tensor, target_scale: torch.Tensor, encoder: BaseEstimator
857+
) -> torch.Tensor:
858+
assert encoder.transformation in ["logit"], "Beta distribution is only compatible with logit transformation"
859+
assert encoder.center, "Beta distribution requires normalizer to center data"
860+
861+
scaled_mean = encoder(dict(prediction=parameters[..., 0], target_scale=target_scale))
862+
# need to first transform target scale standard deviation in logit space to real space
863+
# we assume a normal distribution in logit space (we used a logit transform and a standard scaler)
864+
# and know that the variance of the beta distribution is limited by `scaled_mean * (1 - scaled_mean)`
865+
mean_derivative = scaled_mean * (1 - scaled_mean)
866+
867+
# we can approximate variance as
868+
# torch.pow(torch.tanh(target_scale[..., 1].unsqueeze(1) * torch.sqrt(mean_derivative)), 2) * mean_derivative
869+
# shape is (positive) parameter * mean_derivative / var
870+
shape_scaler = torch.pow(torch.tanh(target_scale[..., 1].unsqueeze(1) * torch.sqrt(mean_derivative)), 2)
871+
scaled_shape = F.softplus(parameters[..., 1]) / shape_scaler
872+
return torch.stack([scaled_mean, scaled_shape], dim=-1)

tests/test_metrics.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
MAE,
1111
SMAPE,
1212
AggregationMetric,
13+
BetaDistributionLoss,
1314
CompositeMetric,
1415
LogNormalDistributionLoss,
1516
NegativeBinomialDistributionLoss,
@@ -156,3 +157,28 @@ def test_NegativeBinomialDistributionLoss(center, transformation):
156157
samples = loss.sample_n(rescaled_parameters, 1)
157158
assert torch.isclose(torch.as_tensor(mean), samples.mean(), atol=0.1, rtol=0.5)
158159
assert torch.isclose(torch.as_tensor(std), samples.std(), atol=0.1, rtol=0.5)
160+
161+
162+
@pytest.mark.parametrize(
163+
["center", "transformation"],
164+
itertools.product([True, False], ["log", "log1p", "softplus", "relu", "logit", None]),
165+
)
166+
def test_BetaDistributionLoss(center, transformation):
167+
initial_mean = 0.1
168+
initial_shape = 10
169+
n = 100000
170+
target = BetaDistributionLoss().map_x_to_distribution(torch.tensor([initial_mean, initial_shape])).sample_n(n)
171+
normalizer = TorchNormalizer(center=center, transformation=transformation)
172+
normalized_target = normalizer.fit_transform(target).view(1, -1)
173+
target_scale = normalizer.get_parameters().unsqueeze(0)
174+
parameters = torch.stack([normalized_target, 1.0 * torch.ones_like(normalized_target)], dim=-1)
175+
loss = BetaDistributionLoss()
176+
177+
if transformation not in ["logit"] or not center:
178+
with pytest.raises(AssertionError):
179+
loss.rescale_parameters(parameters, target_scale=target_scale, encoder=normalizer)
180+
else:
181+
rescaled_parameters = loss.rescale_parameters(parameters, target_scale=target_scale, encoder=normalizer)
182+
samples = loss.sample_n(rescaled_parameters, 1)
183+
assert torch.isclose(torch.as_tensor(initial_mean), samples.mean(), atol=0.01, rtol=0.01) # mean=0.1
184+
assert torch.isclose(target.std(), samples.std(), atol=0.02, rtol=0.3) # std=0.09

0 commit comments

Comments
 (0)