Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve accuracy of _one_side_trunc_norm_sampling #649

Merged
merged 5 commits into from
Oct 23, 2023

Conversation

not522
Copy link
Member

@not522 not522 commented Oct 6, 2023

Contributor License Agreement

This repository (optuna-dashboard) and Goptuna share common code.
This pull request may therefore be ported to Goptuna.
Make sure that you understand the consequences concerning licenses and check the box below if you accept the term before creating this pull request.

  • I agree this patch may be ported to Goptuna by other Goptuna contributors.

Reference Issues/PRs

What does this implement/fix? Explain your changes.

Improve the accuracy of _one_side_trunc_norm_sampling and add tests. The expected values were calculated using scipy.stats.truncnorm.

@not522 not522 force-pushed the test-one-side-trunc-norm-sampling branch 3 times, most recently from bac4981 to 00bfab9 Compare October 6, 2023 04:04
@not522 not522 force-pushed the test-one-side-trunc-norm-sampling branch from 00bfab9 to 926ebd3 Compare October 12, 2023 08:35
@not522 not522 changed the title Add test for _one_side_trunc_norm_sampling Improve accuracy of _one_side_trunc_norm_sampling Oct 12, 2023
@codecov
Copy link

codecov bot commented Oct 12, 2023

Codecov Report

Merging #649 (82cc858) into main (f088840) will increase coverage by 0.23%.
Report is 1 commits behind head on main.
The diff coverage is 66.66%.

@@            Coverage Diff             @@
##             main     #649      +/-   ##
==========================================
+ Coverage   62.65%   62.88%   +0.23%     
==========================================
  Files          35       35              
  Lines        2252     2250       -2     
==========================================
+ Hits         1411     1415       +4     
+ Misses        841      835       -6     
Files Coverage Δ
optuna_dashboard/preferential/samplers/gp.py 70.35% <66.66%> (+2.37%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@not522
Copy link
Member Author

not522 commented Oct 12, 2023

I checked the accuracy using the following code.

import matplotlib.pyplot as plt
import numpy as np
import scipy.special
import torch

def _one_side_trunc_norm_sampling(lower, r):
    r = torch.tensor([r], dtype=torch.float64)
    lower = torch.tensor([lower], dtype=torch.float64)
    return -torch.special.ndtri(torch.exp(torch.special.log_ndtr(-lower) + r.log()))

fig, ax= plt.subplots(3, 1)

rs = [0.5, 0.001, 0.999]
for i in range(3):
    r = rs[i]
    x = []
    y = []
    for lower in np.linspace(-10, 10, 100):
        x.append(lower)
        value = _one_side_trunc_norm_sampling(lower, r)
        expected = scipy.stats.truncnorm.ppf(1 - r, lower, float("inf"))
        y.append(float(value.numpy() - expected))
    ax[i].plot(x, y, label=f"r = {r}")
    ax[i].legend()

plt.subplots_adjust(hspace=0.6)
plt.show()

Figure_1

@not522 not522 marked this pull request as ready for review October 19, 2023 02:05
Copy link
Member

@contramundum53 contramundum53 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@contramundum53 contramundum53 merged commit c65b4cd into optuna:main Oct 23, 2023
11 checks passed
@not522 not522 deleted the test-one-side-trunc-norm-sampling branch October 23, 2023 04:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants