Skip to content

Commit

Permalink
Merge pull request #154 from ro1205/refactor-plmbo
Browse files Browse the repository at this point in the history
Refactor to address mypy problem in PLMBO
  • Loading branch information
y0z authored Sep 10, 2024
2 parents 6b9c1bc + a953435 commit 733bce7
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 26 deletions.
11 changes: 5 additions & 6 deletions package/samplers/plmbo/example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: ignore-errors
from __future__ import annotations

import matplotlib.pyplot as plt
Expand All @@ -7,28 +6,28 @@
import optunahub


PLMBOSampler = optunahub.load_module( # type: ignore
PLMBOSampler = optunahub.load_module(
"samplers/plmbo",
).PLMBOSampler

if __name__ == "__main__":
f_sigma = 0.01

def obj_func1(x):
def obj_func1(x: np.ndarray) -> np.ndarray:
return np.sin(x[0]) + x[1]

def obj_func2(x):
def obj_func2(x: np.ndarray) -> np.ndarray:
return -np.sin(x[0]) - x[1] + 0.1

def obs_obj_func(x):
def obs_obj_func(x: np.ndarray) -> np.ndarray:
return np.array(
[
obj_func1(x) + np.random.normal(0, f_sigma),
obj_func2(x) + np.random.normal(0, f_sigma),
]
)

def objective(trial: optuna.Trial):
def objective(trial: optuna.Trial) -> tuple[float, float]:
x1 = trial.suggest_float("x1", 0, 1)
x2 = trial.suggest_float("x2", 0, 1)
values = obs_obj_func(np.array([x1, x2]))
Expand Down
54 changes: 34 additions & 20 deletions package/samplers/plmbo/plmbo.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# mypy: ignore-errors
from typing import Any

import GPy # type: ignore
import GPy
import jax
import jax.numpy as jnp
from jax.scipy.stats import norm
import numpy as np
import numpyro # type: ignore
from numpyro.infer import init_to_value # type: ignore
import numpyro
from numpyro.infer import init_to_value
import optuna
from optuna import Study
from optuna.distributions import BaseDistribution
Expand All @@ -16,10 +15,10 @@
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
import optunahub
from scipy import optimize # type: ignore
from scipy import optimize


class PLMBOSampler(optunahub.load_module("samplers/simple").SimpleBaseSampler): # type: ignore
class PLMBOSampler(optunahub.samplers.SimpleBaseSampler):
def __init__(
self,
search_space: dict[str, BaseDistribution] | None = None,
Expand All @@ -38,7 +37,7 @@ def __init__(
self.pc: np.ndarray | None = None
self.ir: list | None = None
self.w: np.ndarray | None = None
self.gp_models = None
self.gp_models: list | None = None
self.sample_w = None
self._n_startup_trials = n_startup_trials
self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
Expand Down Expand Up @@ -108,7 +107,7 @@ def sample_relative(
print(params)
return params

def __add_comparison(self):
def __add_comparison(self) -> None:
y_rnd_1 = np.random.rand(self.obj_dim)
y_rnd_2 = np.random.rand(self.obj_dim)

Expand All @@ -125,9 +124,9 @@ def __add_comparison(self):
elif winner == 2:
self.pc = np.r_[self.pc, [[y_rnd_2, y_rnd_1]]]

y_rnd_1 = np.random.rand(self.obj_dim)
y_rnd = np.random.rand(self.obj_dim)

print(y_rnd_1)
print(y_rnd)

while True:
try:
Expand All @@ -143,37 +142,46 @@ def __add_comparison(self):
except ValueError:
print("Invalid input!")

assert self.obj_dim is not None
if winner >= 0 and winner < self.obj_dim:
for i in range(self.obj_dim):
if i != winner:
self.ir.append([y_rnd_1, winner, i])
assert self.ir is not None
self.ir.append([y_rnd, winner, i])

def __fit_gp(self):
def __fit_gp(self) -> None:
self.gp_models = []
assert self.obj_dim is not None
for i in range(self.obj_dim):
kernel = GPy.kern.RBF(self.input_dim)
assert self.Y is not None
model = GPy.models.GPRegression(
self.X, self.Y[:, i].reshape(self.Y.shape[0], 1), kernel
)
model[".*Gaussian_noise.variance"].constrain_bounded(0.000001, 0.001, warning=False)
model[".*rbf.variance"].constrain_bounded(0.01, 3, warning=False)
model[".*rbf.lengthscale"].constrain_bounded(0.2, 50, warning=False)
model.optimize(messages=False, max_iters=1e5)
assert self.gp_models is not None
self.gp_models.append(model)

def __update_w(self):
def __update_w(self) -> None:
u_sigma = 0.01
# preference information
assert self.pc is not None
assert self.ir is not None
y_pc = np.ones((len(self.pc)))
y_ir = np.ones((len(self.ir)))

def mcmc_model():
def mcmc_model() -> None:
# prior
w = numpyro.sample("w", numpyro.distributions.Dirichlet(np.full(self.obj_dim, 2)))

assert self.pc is not None
u_w = self.__u_est(self.pc[:, 0], w)
u_l = self.__u_est(self.pc[:, 1], w)

assert self.ir is not None
l_f = [l_[0] for l_ in self.ir[:]]
para = (u_w - u_l) / (np.sqrt(2) * u_sigma)
para = jnp.where(para < -30, -30, para)
Expand All @@ -194,7 +202,7 @@ def mcmc_model():

if len(y_pc) == 0:

def mcmc_model():
def mcmc_model() -> None:
w = numpyro.sample("w", numpyro.distributions.Dirichlet(np.full(self.obj_dim, 2))) # noqa: F841

# sampling
Expand All @@ -214,12 +222,14 @@ def mcmc_model():

mean2 = np.mean(sample_n_2, axis=0)

def ll(w):
def ll(w: np.ndarray) -> np.ndarray:
assert self.pc is not None
assert self.ir is not None
u_w = self.__u_est(self.pc[:, 0], w)
u_l = self.__u_est(self.pc[:, 1], w)
l_f = [l_[0] for l_ in self.ir[:]]
para = (u_w - u_l) / (np.sqrt(2) * u_sigma)
para = jnp.where(para < -20, -20, para)
para = np.where(para < -20, -20, para)
para = norm.cdf(para, 0, 1)
para = np.maximum(para, 1e-14)

Expand All @@ -238,13 +248,15 @@ def ll(w):

self.w = np.mean(self.sample_w, axis=0)

def __acq(self, x):
def __acq(self, x: np.ndarray) -> np.ndarray:
# initialize acquisition
alpha = 0

# current best
assert self.sample_w is not None
n = len(self.sample_w)
ubest = np.zeros(len(self.sample_w))
assert self.Y is not None
ubest_tmp = np.zeros((self.Y.shape[0], n))
for i in range(self.Y.shape[0]):
ubest_tmp[i, :] = np.min(np.tile(self.Y[i], (n, 1)) / self.sample_w, axis=1)
Expand All @@ -253,6 +265,7 @@ def __acq(self, x):
ubest = np.max(ubest_tmp, axis=0)

normal = []
assert self.obj_dim is not None
for i in range(self.obj_dim):
normal.append(np.random.normal(0, 1, n))

Expand All @@ -275,19 +288,20 @@ def __acq(self, x):
alpha = np.mean(uaftermax)
return -alpha

def __u_est(self, x, w):
def __u_est(self, x: np.ndarray, w: np.ndarray) -> jnp.ndarray:
x = jnp.array(x)
w = jnp.array(w)
x = (x / w).T
x = jnp.where(x < -500, -500, x)
x = jnp.where(x > 500, 500, x)
re = x[0]
assert self.obj_dim is not None
for i in range(1, self.obj_dim):
re = jnp.minimum(x[i], re)
return re

# differential of U
def __dudf(self, x, f, weight):
def __dudf(self, x: np.ndarray, f: list, weight: np.ndarray) -> jnp.ndarray:
x = jnp.array(x)
weight = jnp.array(weight)
f = jnp.array(f)
Expand Down

0 comments on commit 733bce7

Please sign in to comment.