Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Preferential] Simplify EP implementation #633

Merged
merged 2 commits into from
Oct 6, 2023

Conversation

contramundum53
Copy link
Member

@contramundum53 contramundum53 commented Sep 29, 2023

Contributor License Agreement

This repository (optuna-dashboard) and Goptuna share common code.
This pull request may therefore be ported to Goptuna.
Make sure that you understand the consequences concerning licenses and check the box below if you accept the term before creating this pull request.

  • I agree this patch may be ported to Goptuna by other Goptuna contributors.

What does this implement/fix? Explain your changes.

This PR simplifies EP implementation. This PR also improves numerical stability. (It does away the need for EPS value, except when noise_var == 0.)

Here's a script to confirm that the results do not change and the performance improves a little.

import torch
from torch import Tensor
import math


def _truncnorm_mean_var_logz(alpha: Tensor) -> tuple[Tensor, Tensor, Tensor]:
    SQRT_HALF = math.sqrt(0.5)
    SQRT_HALF_PI = math.sqrt(0.5 * math.pi)
    logz = torch.special.log_ndtr(-alpha)
    mean = 1 / (SQRT_HALF_PI * torch.special.erfcx(alpha * SQRT_HALF))
    var = 1 - mean * (mean - alpha)
    return mean, var, logz

def _orthants_MVN_EP_old(
    cov0: Tensor, preferences: Tensor, noise_var: Tensor, cycles: int
) -> tuple[Tensor, Tensor, Tensor]:
    N = cov0.shape[0]
    M = preferences.shape[0]
    mu = torch.zeros(N, dtype=cov0.dtype)
    cov = cov0.clone()
    virtual_obs_a = [torch.tensor(0.0, dtype=cov0.dtype) for _ in range(M)]
    virtual_obs_b = [torch.tensor(0.0, dtype=cov0.dtype) for _ in range(M)]
    log_zs = torch.zeros(M, dtype=cov0.dtype)

    for _ in range(cycles):
        for i in range(M):
            pref_i = preferences[i, :]
            mean1 = mu[pref_i[0]] - mu[pref_i[1]]
            Sxy = cov[pref_i[0]] - cov[pref_i[1]]
            var1 = Sxy[pref_i[0]] - Sxy[pref_i[1]]

            r0 = (1 - var1 * virtual_obs_a[i]).reciprocal()
            var0 = var1 * r0
            mean0 = (mean1 + var1 * virtual_obs_b[i]) * r0

            obs_var = var0 + noise_var
            obs_sigma = torch.sqrt(obs_var)
            alpha = -mean0 / torch.clamp_min(obs_sigma, min=1e-20)
            mean_norm, var_norm, logz = _truncnorm_mean_var_logz(alpha)

            kalman_factor = var0 / torch.clamp_min(obs_var, min=1e-20)
            mean2 = mean0 + obs_sigma * mean_norm * kalman_factor
            var2 = kalman_factor * (noise_var + var_norm * var0)

            var1_var2_inv = torch.clamp_min(var1 * var2, min=1e-20).reciprocal()
            db = (mean1 * var2 - mean2 * var1) * var1_var2_inv
            da = (var1 - var2) * var1_var2_inv
            virtual_obs_b[i] = virtual_obs_b[i] + db
            virtual_obs_a[i] = virtual_obs_a[i] + da

            dr = (1 + var1 * da).reciprocal()
            mu = mu - Sxy * ((db + mean1 * da) * dr)
            cov = cov - (Sxy[:, None] * (da * dr)) @ Sxy[None, :]
            log_zs[i] = logz
    return mu, cov, torch.sum(log_zs)



def _observation(var0: Tensor, mean0: Tensor, noise_var: Tensor) -> tuple[Tensor, Tensor, Tensor]:
    obs_var = var0 + noise_var
    obs_sigma = torch.sqrt(obs_var)
    alpha = -mean0 / torch.clamp_min(obs_sigma, min=1e-20)
    mean_norm, var_norm, logz = _truncnorm_mean_var_logz(alpha)

    denom_factor = 1 / torch.clamp_min(noise_var + var_norm * var0, min=1e-20)
    da = (1 - var_norm) * denom_factor
    db = (mean0 * (1 - var_norm) + obs_sigma * mean_norm) * denom_factor
    return (da, db, logz)


def _orthants_MVN_EP_new(
    cov0: Tensor, preferences: Tensor, noise_var: Tensor, cycles: int
) -> tuple[Tensor, Tensor, Tensor]:
    N = cov0.shape[0]
    M = preferences.shape[0]
    mu = torch.zeros(N, dtype=cov0.dtype)
    cov = cov0.clone()
    virtual_obs_a = [torch.tensor(0.0, dtype=cov0.dtype) for _ in range(M)]
    virtual_obs_b = [torch.tensor(0.0, dtype=cov0.dtype) for _ in range(M)]
    log_zs = torch.zeros(M, dtype=cov0.dtype)

    for _ in range(cycles):
        for i in range(M):
            pref_i = preferences[i, :]
            mean1 = mu[pref_i[0]] - mu[pref_i[1]]
            Sxy = cov[pref_i[0]] - cov[pref_i[1]]
            var1 = Sxy[pref_i[0]] - Sxy[pref_i[1]]

            r0 = (1 - var1 * virtual_obs_a[i]).reciprocal()
            var0 = var1 * r0
            mean0 = (mean1 - var1 * virtual_obs_b[i]) * r0

            virtual_obs_a2, virtual_obs_b2, logz = _observation(var0, mean0, noise_var)

            da = virtual_obs_a2 - virtual_obs_a[i]
            db = virtual_obs_b2 - virtual_obs_b[i]
            virtual_obs_a[i] = virtual_obs_a2
            virtual_obs_b[i] = virtual_obs_b2

            dr = (1 + var1 * da).reciprocal()
            mu = mu + Sxy * ((db - mean1 * da) * dr)
            cov = cov - (Sxy[:, None] * (da * dr)) @ Sxy[None, :]
            log_zs[i] = logz
    return mu, cov, torch.sum(log_zs)



_orthants_MVN_EP_old_jit = torch.jit.script(_orthants_MVN_EP_old)
_orthants_MVN_EP_new_jit = torch.jit.script(_orthants_MVN_EP_new)


N = 100
cov0 = torch.randn(N, N, dtype=torch.float64)
cov0 = cov0 @ cov0.T
preferences = torch.randint(0, N, (300, 2), dtype=torch.int64)
noise_var = torch.tensor(0.1, dtype = torch.float64)
cycles = 5
mu_old, cov_old, logz_old = _orthants_MVN_EP_old_jit(cov0, preferences, noise_var, cycles)
mu_new, cov_new, logz_new = _orthants_MVN_EP_new_jit(cov0, preferences, noise_var, cycles)

print(logz_old, logz_new)
assert torch.allclose(mu_old, mu_new)
assert torch.allclose(cov_old, cov_new)

%timeit _orthants_MVN_EP_old_jit(cov0, preferences, noise_var, cycles)
%timeit _orthants_MVN_EP_new_jit(cov0, preferences, noise_var, cycles)

@codecov
Copy link

codecov bot commented Sep 29, 2023

Codecov Report

Merging #633 (f2f233c) into main (e40f410) will increase coverage by 6.03%.
Report is 47 commits behind head on main.
The diff coverage is 6.25%.

@@            Coverage Diff             @@
##             main     #633      +/-   ##
==========================================
+ Coverage   56.40%   62.44%   +6.03%     
==========================================
  Files          35       35              
  Lines        2193     2226      +33     
==========================================
+ Hits         1237     1390     +153     
+ Misses        956      836     -120     
Files Coverage Δ
optuna_dashboard/preferential/samplers/gp.py 67.61% <6.25%> (+67.61%) ⬆️

... and 5 files with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Member

@not522 not522 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@not522 not522 merged commit 6f79d7d into optuna:main Oct 6, 2023
@contramundum53 contramundum53 deleted the simplify-ep branch October 6, 2023 06:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants