Skip to content

Commit

Permalink
Merge pull request #649 from not522/test-one-side-trunc-norm-sampling
Browse files Browse the repository at this point in the history
Improve accuracy of `_one_side_trunc_norm_sampling`
  • Loading branch information
contramundum53 authored Oct 23, 2023
2 parents c4ac592 + 82cc858 commit c65b4cd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
18 changes: 9 additions & 9 deletions optuna_dashboard/preferential/samplers/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def _orthants_MVN_Gibbs_sampling(cov_inv: Tensor, cycles: int, initial_sample: T


def _one_side_trunc_norm_sampling(lower: Tensor) -> Tensor:
if lower > 4.0:
r = torch.clamp_min(torch.rand(torch.Size(()), dtype=torch.float64), min=1e-300)
return (lower * lower - 2 * r.log()).sqrt()
else:
SQRT2 = math.sqrt(2)
r = torch.rand(torch.Size(()), dtype=torch.float64) * torch.erfc(lower / SQRT2)
while 1 - r == 1:
r = torch.rand(torch.Size(()), dtype=torch.float64) * torch.erfc(lower / SQRT2)
return torch.erfinv(1 - r) * SQRT2
r = torch.rand(torch.Size(()), dtype=torch.float64)
ret = -torch.special.ndtri(torch.exp(torch.special.log_ndtr(-lower) + r.log()))

# If sampled random number is very small, `ret` becomes inf.
while torch.isinf(ret):
r = torch.rand(torch.Size(()), dtype=torch.float64)
ret = -torch.special.ndtri(torch.exp(torch.special.log_ndtr(-lower) + r.log()))

return ret


_orthants_MVN_Gibbs_sampling_jit = torch.jit.script(_orthants_MVN_Gibbs_sampling)
Expand Down
29 changes: 29 additions & 0 deletions python_tests/preferential/samplers/test_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import sys
from unittest.mock import patch

import numpy as np
import pytest


if sys.version_info >= (3, 8):
from optuna_dashboard.preferential.samplers.gp import _one_side_trunc_norm_sampling
import torch
else:
pytest.skip("BoTorch dropped Python3.7 support", allow_module_level=True)


def test_one_side_trunc_norm_sampling() -> None:
for lower in np.linspace(-10, 10, 100):
assert _one_side_trunc_norm_sampling(torch.tensor([lower], dtype=torch.float64)) >= lower

with patch.object(torch, "rand", return_value=torch.tensor([0.4], dtype=torch.float64)):
sampled_value = _one_side_trunc_norm_sampling(torch.tensor([0.1], dtype=torch.float64))
assert np.allclose(sampled_value.numpy(), 0.899967154837563)

with patch.object(torch, "rand", return_value=torch.tensor([0.8], dtype=torch.float64)):
sampled_value = _one_side_trunc_norm_sampling(torch.tensor([-2.3], dtype=torch.float64))
assert np.allclose(sampled_value.numpy(), -0.8113606739551955)

with patch.object(torch, "rand", return_value=torch.tensor([0.1], dtype=torch.float64)):
sampled_value = _one_side_trunc_norm_sampling(torch.tensor([5], dtype=torch.float64))
assert np.allclose(sampled_value.numpy(), 5.426934003050024)

0 comments on commit c65b4cd

Please sign in to comment.