Skip to content

Commit

Permalink
Merge pull request #103 from GPflow/feat_print
Browse files Browse the repository at this point in the history
Add a verbose option to BOptimizer that prints progress information. …
  • Loading branch information
icouckuy authored Sep 3, 2018
2 parents ef924f5 + b65fb3d commit 4835f02
Show file tree
Hide file tree
Showing 7 changed files with 381 additions and 213 deletions.
123 changes: 85 additions & 38 deletions doc/source/notebooks/constrained_bo.ipynb

Large diffs are not rendered by default.

71 changes: 46 additions & 25 deletions doc/source/notebooks/firststeps.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@
{
"data": {
"text/html": [
"<table id='domain' width=100%><tr><td>Name</td><td>Type</td><td>Values</td></tr><tr><td>x1</td><td>Continuous</td><td>[-2. 2.]</td></tr><tr><td>x2</td><td>Continuous</td><td>[-1. 2.]</td></tr></table>"
"<table id='domain' width=100%><tr><td>Name</td><td>Type</td><td>Values</td></tr><tr><td>x1</td><td>Continuous</td><td>[-5. 10.]</td></tr><tr><td>x2</td><td>Continuous</td><td>[ 0. 15.]</td></tr></table>"
],
"text/plain": [
"<GPflowOpt.domain.Domain at 0x7f1f613eef60>"
"<gpflowopt.domain.Domain at 0x23142436c50>"
]
},
"execution_count": 1,
Expand All @@ -55,12 +55,21 @@
"import numpy as np\n",
"from gpflowopt.domain import ContinuousParameter\n",
"\n",
"def branin(x):\n",
" x = np.atleast_2d(x)\n",
" x1 = x[:, 0]\n",
" x2 = x[:, 1]\n",
" a = 1.\n",
" b = 5.1 / (4. * np.pi ** 2)\n",
" c = 5. / np.pi\n",
" r = 6.\n",
" s = 10.\n",
" t = 1. / (8. * np.pi)\n",
" ret = a * (x2 - b * x1 ** 2 + c * x1 - r) ** 2 + s * (1 - t) * np.cos(x1) + s\n",
" return ret[:, None]\n",
"\n",
"def fx(X):\n",
" X = np.atleast_2d(X)\n",
" return np.sum(np.square(X), axis=1)[:, None]\n",
"\n",
"domain = ContinuousParameter('x1', -2, 2) + ContinuousParameter('x2', -1, 2)\n",
"domain = ContinuousParameter('x1', -5, 10) + \\\n",
" ContinuousParameter('x2', 0, 15)\n",
"domain"
]
},
Expand All @@ -73,22 +82,31 @@
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"execution_count": 8,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warning: optimization restart 1/5 failed\n",
"Warning: optimization restart 2/5 failed\n",
"Warning: optimization restart 3/5 failed\n",
"Warning: optimization restart 2/5 failed\n",
" fun: array([ 0.01])\n",
" message: 'OK'\n",
" nfev: 15\n",
" success: True\n",
" x: array([[ 0. , -0.1]])\n"
"iter # 0 - MLL [-13.1] - fmin [4.42]\n",
"iter # 1 - MLL [-13.4] - fmin [4.42]\n",
"iter # 2 - MLL [-10.6] - fmin [0.723]\n",
"iter # 3 - MLL [-9.09] - fmin [0.486]\n",
"iter # 4 - MLL [-7.01] - fmin [0.486]\n",
"iter # 5 - MLL [-2.69] - fmin [0.446]\n",
"iter # 6 - MLL [1.96] - fmin [0.446]\n",
"iter # 7 - MLL [4.6] - fmin [0.446]\n",
"iter # 8 - MLL [7.37] - fmin [0.4]\n",
"iter # 9 - MLL [12.6] - fmin [0.4]\n",
" constraints: array([], dtype=float64)\n",
" fun: array([0.39970302])\n",
" message: 'OK'\n",
" nfev: 10\n",
" success: True\n",
" x: array([[9.40798299, 2.43938799]])\n"
]
}
],
Expand All @@ -97,30 +115,33 @@
"from gpflowopt.bo import BayesianOptimizer\n",
"from gpflowopt.design import LatinHyperCube\n",
"from gpflowopt.acquisition import ExpectedImprovement\n",
"from gpflowopt.optim import SciPyOptimizer\n",
"from gpflowopt.optim import SciPyOptimizer, StagedOptimizer, MCOptimizer\n",
"\n",
"# Use standard Gaussian process Regression\n",
"lhd = LatinHyperCube(21, domain)\n",
"X = lhd.generate()\n",
"Y = fx(X)\n",
"Y = branin(X)\n",
"model = gpflow.gpr.GPR(X, Y, gpflow.kernels.Matern52(2, ARD=True))\n",
"model.kern.lengthscales.transform = gpflow.transforms.Log1pe(1e-3)\n",
"\n",
"# Now create the Bayesian Optimizer\n",
"alpha = ExpectedImprovement(model)\n",
"optimizer = BayesianOptimizer(domain, alpha)\n",
"\n",
"acquisition_opt = StagedOptimizer([MCOptimizer(domain, 200),\n",
" SciPyOptimizer(domain)])\n",
"\n",
"optimizer = BayesianOptimizer(domain, alpha, optimizer=acquisition_opt, verbose=True)\n",
"\n",
"# Run the Bayesian optimization\n",
"with optimizer.silent():\n",
" r = optimizer.optimize(fx, n_iter=15)\n",
"r = optimizer.optimize(branin, n_iter=10)\n",
"print(r)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's all! Your objective function has now been optimized for 15 iterations."
"That's all! Your objective function has now been optimized for 10 iterations."
]
}
],
Expand All @@ -140,7 +161,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"version": "3.6.6"
}
},
"nbformat": 4,
Expand Down
40 changes: 21 additions & 19 deletions doc/source/notebooks/mes_benchmark.ipynb

Large diffs are not rendered by default.

198 changes: 110 additions & 88 deletions doc/source/notebooks/multiobjective.ipynb

Large diffs are not rendered by default.

94 changes: 71 additions & 23 deletions gpflowopt/bo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class BayesianOptimizer(Optimizer):
"""

def __init__(self, domain, acquisition, optimizer=None, initial=None, scaling=True, hyper_draws=None,
callback=jitchol_callback):
callback=jitchol_callback, verbose=False):
"""
:param Domain domain: The optimization space.
:param Acquisition acquisition: The acquisition function to optimize over the domain.
Expand Down Expand Up @@ -107,6 +107,7 @@ def __init__(self, domain, acquisition, optimizer=None, initial=None, scaling=Tr
self.set_initial(initial.generate())

self._model_callback = callback
self.verbose = verbose

@Optimizer.domain.setter
def domain(self, dom):
Expand Down Expand Up @@ -154,8 +155,16 @@ def _evaluate_objectives(self, X, fxs):

def _create_bo_result(self, success, message):
"""
Analyzes all data evaluated during the optimization, and return an OptimizeResult. Outputs of constraints
are used to remove all infeasible points.
Analyzes all data evaluated during the optimization, and return an `OptimizeResult`. Constraints are taken
into account. The contents of x, fun, and constraints depend on the detected scenario:
- single-objective: the best optimum of the feasible samples (if none, optimum of the infeasible samples)
- multi-objective: the Pareto set of the feasible samples
- only constraints: all the feasible samples (can be empty)
In all cases, if not one sample satisfies all the constraints a message will be given and success=False.
Do note that the feasibility check is based on the model predictions, but the constrained field contains
actual data values.
:param success: Optimization successful? (True/False)
:param message: return message
Expand All @@ -166,24 +175,31 @@ def _create_bo_result(self, success, message):
# Filter on constraints
valid = self.acquisition.feasible_data_index()

if not np.any(valid):
return OptimizeResult(success=False,
message="No evaluations satisfied the constraints")

valid_X = X[valid, :]
valid_Y = Y[valid, :]
valid_Yo = valid_Y[:, self.acquisition.objective_indices()]

# Differentiate between single- and multiobjective optimization results
if valid_Y.shape[1] > 1:
_, dom = non_dominated_sort(valid_Yo)
idx = dom == 0 # Return the non-dominated points
# Extract the samples that satisfies all constraints
if np.any(valid):
X = X[valid, :]
Y = Y[valid, :]
else:
idx = np.argmin(valid_Yo)

return OptimizeResult(x=valid_X[idx, :],
success = False
message = "No evaluations satisfied all the constraints"

# Split between objectives and constraints
Yo = Y[:, self.acquisition.objective_indices()]
Yc = Y[:, self.acquisition.constraint_indices()]

# Differentiate between different scenarios
if Yo.shape[1] == 1: # Single-objective: minimum
idx = np.argmin(Yo)
elif Yo.shape[1] > 1: # Multi-objective: Pareto set
_, dom = non_dominated_sort(Yo)
idx = dom == 0
else: # Constraint satisfaction problem: all samples satisfying the constraints
idx = np.arange(Yc.shape[0])

return OptimizeResult(x=X[idx, :],
success=success,
fun=valid_Yo[idx, :],
fun=Yo[idx, :],
constraints=Yc[idx, :],
message=message)

def optimize(self, objectivefx, n_iter=20):
Expand Down Expand Up @@ -232,10 +248,42 @@ def inverse_acquisition(x):
for i in range(n_iter):
# If a callback is specified, and acquisition has the setup flag enabled (indicating an upcoming
# compilation), run the callback.
if self._model_callback and self.acquisition._needs_setup:
self._model_callback([m.wrapped for m in self.acquisition.models])
result = self.optimizer.optimize(inverse_acquisition)
self._update_model_data(result.x, fx(result.x))
with self.silent():
if self._model_callback and self.acquisition._needs_setup:
self._model_callback([m.wrapped for m in self.acquisition.models])

result = self.optimizer.optimize(inverse_acquisition)
self._update_model_data(result.x, fx(result.x))

if self.verbose:
metrics = []

with self.silent():
bo_result = self._create_bo_result(True, 'Monitor')
metrics += ['MLL [' + ', '.join('{:.3}'.format(model.compute_log_likelihood()) for model in self.acquisition.models) + ']']

# fmin
n_points = bo_result.fun.shape[0]
if n_points > 0:
funs = np.atleast_1d(np.min(bo_result.fun, axis=0))
fmin = 'fmin [' + ', '.join('{:.3}'.format(fun) for fun in funs) + ']'
if n_points > 1:
fmin += ' (size {0})'.format(n_points)

metrics += [fmin]

# constraints
n_points = bo_result.constraints.shape[0]
if n_points > 0:
constraints = np.atleast_1d(np.min(bo_result.constraints, axis=0))
metrics += ['constraints [' + ', '.join('{:.3}'.format(constraint) for constraint in constraints) + ']']

# error messages
metrics += [r.message.decode('utf-8') if isinstance(r.message, bytes) else r.message for r in [bo_result, result] if not r.success]

print('iter #{0:>3} - {1}'.format(
i,
' - '.join(metrics)))

return self._create_bo_result(True, "OK")

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
raise RuntimeError("Unable to find version string in %s." % (VERSIONFILE,))

# Dependencies of GPflowOpt
dependencies = ['numpy>=1.9', 'scipy>=0.16', 'GPflow==0.4.0']
dependencies = ['numpy>=1.9', 'scipy>=0.16', 'GPflow==0.5.0']
min_tf_version = '1.0.0'

# Detect if TF is installed or outdated.
Expand Down Expand Up @@ -65,7 +65,7 @@
extras_require={'gpu': ['tensorflow-gpu>=1.0.0'],
'docs': ['sphinx', 'sphinx_rtd_theme', 'numpydoc', 'nbsphinx', 'jupyter'],
},
dependency_links=['https://github.com/GPflow/GPflow/archive/0.4.0.tar.gz#egg=GPflow-0.4.0'],
dependency_links=['https://github.com/GPflow/GPflow/archive/0.5.0.tar.gz#egg=GPflow-0.5.0'],
classifiers=['License :: OSI Approved :: Apache Software License',
'Natural Language :: English',
'Operating System :: POSIX :: Linux',
Expand Down
64 changes: 46 additions & 18 deletions testing/unit/test_optimizers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gpflowopt
import numpy as np
import pytest
import gpflow
import six
import sys
Expand Down Expand Up @@ -192,38 +193,65 @@ def test_set_domain(self):
class TestBayesianOptimizer(_TestOptimizer, GPflowOptTestCase):
def setUp(self):
super(TestBayesianOptimizer, self).setUp()

acquisition = gpflowopt.acquisition.ExpectedImprovement(create_parabola_model(self.domain))
self.optimizer = gpflowopt.BayesianOptimizer(self.domain, acquisition)

def test_default_initial(self):
self.assertTupleEqual(self.optimizer._initial.shape, (0, 2), msg="Invalid shape of initial points array")

def test_optimize(self):
with self.test_session():
result = self.optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=20)
self.assertTrue(result.success)
self.assertEqual(result.nfev, 20, "Only 20 evaluations permitted")
self.assertTrue(np.allclose(result.x, 0), msg="Optimizer failed to find optimum")
self.assertTrue(np.allclose(result.fun, 0), msg="Incorrect function value returned")
for verbose in [False, True]:
with self.test_session():
acquisition = gpflowopt.acquisition.ExpectedImprovement(create_parabola_model(self.domain))
optimizer = gpflowopt.BayesianOptimizer(self.domain, acquisition, verbose=verbose)
result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=20)
self.assertTrue(result.success)
self.assertEqual(result.nfev, 20, "Only 20 evaluations permitted")
self.assertTrue(np.allclose(result.x, 0), msg="Optimizer failed to find optimum")
self.assertTrue(np.allclose(result.fun, 0), msg="Incorrect function value returned")

def test_optimize_multi_objective(self):
with self.test_session():
m1, m2 = create_vlmop2_model()
acquisition = gpflowopt.acquisition.ExpectedImprovement(m1) + gpflowopt.acquisition.ExpectedImprovement(m2)
optimizer = gpflowopt.BayesianOptimizer(self.domain, acquisition)
result = optimizer.optimize(vlmop2, n_iter=2)
self.assertTrue(result.success)
self.assertEqual(result.nfev, 2, "Only 2 evaluations permitted")
self.assertTupleEqual(result.x.shape, (7, 2))
self.assertTupleEqual(result.fun.shape, (7, 2))
_, dom = gpflowopt.pareto.non_dominated_sort(result.fun)
self.assertTrue(np.all(dom==0))
for verbose in [False, True]:
with self.test_session():
m1, m2 = create_vlmop2_model()
acquisition = gpflowopt.acquisition.ExpectedImprovement(m1) + gpflowopt.acquisition.ExpectedImprovement(m2)
optimizer = gpflowopt.BayesianOptimizer(self.domain, acquisition, verbose=verbose)
result = optimizer.optimize(vlmop2, n_iter=2)
self.assertTrue(result.success)
self.assertEqual(result.nfev, 2, "Only 2 evaluations permitted")
self.assertTupleEqual(result.x.shape, (7, 2))
self.assertTupleEqual(result.fun.shape, (7, 2))
_, dom = gpflowopt.pareto.non_dominated_sort(result.fun)
self.assertTrue(np.all(dom == 0))

def test_optimize_constraint(self):
for verbose in [False, True]:
with self.test_session():
acquisition = gpflowopt.acquisition.ProbabilityOfFeasibility(create_parabola_model(self.domain), threshold=-1)
optimizer = gpflowopt.BayesianOptimizer(self.domain, acquisition, verbose=verbose)
result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=1)
self.assertFalse(result.success)
self.assertEqual(result.message, 'No evaluations satisfied all the constraints')
self.assertEqual(result.nfev, 1, "Only 1 evaluations permitted")
self.assertTupleEqual(result.x.shape, (17, 2))
self.assertTupleEqual(result.fun.shape, (17, 0))
self.assertTupleEqual(result.constraints.shape, (17, 1))

acquisition = gpflowopt.acquisition.ProbabilityOfFeasibility(create_parabola_model(self.domain), threshold=0.3)
optimizer = gpflowopt.BayesianOptimizer(self.domain, acquisition, verbose=verbose)
result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=1)
self.assertTrue(result.success)
self.assertEqual(result.nfev, 1, "Only 1 evaluation permitted")
self.assertTupleEqual(result.x.shape, (5, 2))
self.assertTupleEqual(result.fun.shape, (5, 0))
self.assertTupleEqual(result.constraints.shape, (5, 1))

def test_optimizer_interrupt(self):
with self.test_session():
result = self.optimizer.optimize(KeyboardRaiser(3, lambda X: parabola2d(X)[0]), n_iter=20)
self.assertFalse(result.success, msg="After 2 evaluations, a keyboard interrupt is raised, "
"non-succesfull result expected.")
"failed result expected.")
self.assertTrue(np.allclose(result.x, 0.0), msg="The optimum will not be identified nonetheless")

def test_failsafe(self):
Expand Down

0 comments on commit 4835f02

Please sign in to comment.