-
-
Notifications
You must be signed in to change notification settings - Fork 91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve accuracy of _one_side_trunc_norm_sampling
#649
Improve accuracy of _one_side_trunc_norm_sampling
#649
Conversation
bac4981
to
00bfab9
Compare
00bfab9
to
926ebd3
Compare
_one_side_trunc_norm_sampling
_one_side_trunc_norm_sampling
Codecov Report
@@ Coverage Diff @@
## main #649 +/- ##
==========================================
+ Coverage 62.65% 62.88% +0.23%
==========================================
Files 35 35
Lines 2252 2250 -2
==========================================
+ Hits 1411 1415 +4
+ Misses 841 835 -6
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
I checked the accuracy using the following code. import matplotlib.pyplot as plt
import numpy as np
import scipy.special
import torch
def _one_side_trunc_norm_sampling(lower, r):
r = torch.tensor([r], dtype=torch.float64)
lower = torch.tensor([lower], dtype=torch.float64)
return -torch.special.ndtri(torch.exp(torch.special.log_ndtr(-lower) + r.log()))
fig, ax= plt.subplots(3, 1)
rs = [0.5, 0.001, 0.999]
for i in range(3):
r = rs[i]
x = []
y = []
for lower in np.linspace(-10, 10, 100):
x.append(lower)
value = _one_side_trunc_norm_sampling(lower, r)
expected = scipy.stats.truncnorm.ppf(1 - r, lower, float("inf"))
y.append(float(value.numpy() - expected))
ax[i].plot(x, y, label=f"r = {r}")
ax[i].legend()
plt.subplots_adjust(hspace=0.6)
plt.show() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Contributor License Agreement
This repository (
optuna-dashboard
) and Goptuna share common code.This pull request may therefore be ported to Goptuna.
Make sure that you understand the consequences concerning licenses and check the box below if you accept the term before creating this pull request.
Reference Issues/PRs
What does this implement/fix? Explain your changes.
Improve the accuracy of
_one_side_trunc_norm_sampling
and add tests. The expected values were calculated usingscipy.stats.truncnorm
.