diff --git a/optuna_dashboard/preferential/samplers/gp.py b/optuna_dashboard/preferential/samplers/gp.py index beb5b7c23..3cd980614 100644 --- a/optuna_dashboard/preferential/samplers/gp.py +++ b/optuna_dashboard/preferential/samplers/gp.py @@ -50,15 +50,15 @@ def _orthants_MVN_Gibbs_sampling(cov_inv: Tensor, cycles: int, initial_sample: T def _one_side_trunc_norm_sampling(lower: Tensor) -> Tensor: - if lower > 4.0: - r = torch.clamp_min(torch.rand(torch.Size(()), dtype=torch.float64), min=1e-300) - return (lower * lower - 2 * r.log()).sqrt() - else: - SQRT2 = math.sqrt(2) - r = torch.rand(torch.Size(()), dtype=torch.float64) * torch.erfc(lower / SQRT2) - while 1 - r == 1: - r = torch.rand(torch.Size(()), dtype=torch.float64) * torch.erfc(lower / SQRT2) - return torch.erfinv(1 - r) * SQRT2 + r = torch.rand(torch.Size(()), dtype=torch.float64) + ret = -torch.special.ndtri(torch.exp(torch.special.log_ndtr(-lower) + r.log())) + + # If sampled random number is very small, `ret` becomes inf. + while torch.isinf(ret): + r = torch.rand(torch.Size(()), dtype=torch.float64) + ret = -torch.special.ndtri(torch.exp(torch.special.log_ndtr(-lower) + r.log())) + + return ret _orthants_MVN_Gibbs_sampling_jit = torch.jit.script(_orthants_MVN_Gibbs_sampling) diff --git a/python_tests/preferential/samplers/test_gp.py b/python_tests/preferential/samplers/test_gp.py new file mode 100644 index 000000000..136959a79 --- /dev/null +++ b/python_tests/preferential/samplers/test_gp.py @@ -0,0 +1,29 @@ +import sys +from unittest.mock import patch + +import numpy as np +import pytest + + +if sys.version_info >= (3, 8): + from optuna_dashboard.preferential.samplers.gp import _one_side_trunc_norm_sampling + import torch +else: + pytest.skip("BoTorch dropped Python3.7 support", allow_module_level=True) + + +def test_one_side_trunc_norm_sampling() -> None: + for lower in np.linspace(-10, 10, 100): + assert _one_side_trunc_norm_sampling(torch.tensor([lower], dtype=torch.float64)) >= lower + + with patch.object(torch, "rand", return_value=torch.tensor([0.4], dtype=torch.float64)): + sampled_value = _one_side_trunc_norm_sampling(torch.tensor([0.1], dtype=torch.float64)) + assert np.allclose(sampled_value.numpy(), 0.899967154837563) + + with patch.object(torch, "rand", return_value=torch.tensor([0.8], dtype=torch.float64)): + sampled_value = _one_side_trunc_norm_sampling(torch.tensor([-2.3], dtype=torch.float64)) + assert np.allclose(sampled_value.numpy(), -0.8113606739551955) + + with patch.object(torch, "rand", return_value=torch.tensor([0.1], dtype=torch.float64)): + sampled_value = _one_side_trunc_norm_sampling(torch.tensor([5], dtype=torch.float64)) + assert np.allclose(sampled_value.numpy(), 5.426934003050024)