-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve MAPElites performance using torch_scatter #93
Comments
Runnable comparison: import math
import time
from typing import NamedTuple, List, Iterable
import torch
from torch_scatter import scatter_max, scatter_min
from evotorch import Problem
from evotorch.algorithms import MAPElites, SearchAlgorithm
from evotorch.algorithms.ga import ExtendedPopulationMixin
from evotorch.algorithms.searchalgorithm import SinglePopulationAlgorithmMixin
from evotorch.operators import GaussianMutation, SimulatedBinaryCrossOver
class FeatureGrid(NamedTuple):
lower_bounds: List[float]
upper_bounds: List[float]
bins: List[int]
class MAPElitesScatter(MAPElites):
def __init__(
self,
problem: Problem,
*,
operators: Iterable,
feature_grid: FeatureGrid,
):
problem.ensure_single_objective()
problem.ensure_numeric()
SearchAlgorithm.__init__(self, problem)
self._sense = self._problem.senses[0]
self._feature_grid = feature_grid
self._popsize = math.prod(feature_grid.bins)
self._population = problem.generate_batch(self._popsize)
self._filled = torch.zeros(self._popsize, dtype=torch.bool, device=self._population.device)
self._scatter_best = scatter_max if self._sense == "max" else scatter_min
ExtendedPopulationMixin.__init__(
self,
re_evaluate=True,
re_evaluate_parents_first=None,
operators=operators,
allow_empty_operators_list=False,
)
SinglePopulationAlgorithmMixin.__init__(self)
def _step(self):
# Form an extended population from the parents and from the children
extended_population = self._make_extended_population(split=False)
extended_pop_size = extended_population.eval_shape[0]
all_evals = extended_population.evals.as_subclass(torch.Tensor)
all_values = extended_population.values.as_subclass(torch.Tensor)
all_fitnesses = all_evals[:, 0]
feats = all_evals[:, 1:]
device = all_evals.device
hypervolume_index = torch.zeros(extended_pop_size, device=device, dtype=torch.long)
widths = []
for i, (lb, ub, n_bins) in enumerate(zip(*self._feature_grid)):
diff = ub - lb
const = n_bins / diff
min_ = const * lb
max_ = (const * ub) - 1
feat = feats[:, i]
feat *= const
feat = torch.clamp_min(feat, min_)
feat = torch.clamp_max(feat, max_)
feat -= min_
hypervolume_index += (feat.long() * math.prod(widths))
widths.append(n_bins)
# Find the best population members for each hypervolume
_, argbest = self._scatter_best(all_fitnesses, hypervolume_index)
# Filter hypervolumes that had no members
all_index = argbest[argbest < extended_pop_size]
index = torch.argwhere(argbest < extended_pop_size)[:, 0]
# Build empty output
values = torch.zeros((self._popsize, all_values.shape[1]), device=device, dtype=all_values.dtype)
evals = torch.zeros((self._popsize, all_evals.shape[1]), device=device, dtype=all_evals.dtype)
suitable = torch.zeros(self._popsize, device=device, dtype=torch.bool)
# Map the members from the extended population to the output
values[index] = all_values[all_index]
evals[index] = all_evals[all_index]
suitable[index] = True
# Place the most suitable decision values and evaluation results into the current population.
self._population.access_values(keep_evals=True)[:] = values
self._population.access_evals()[:] = evals
# If there was a suitable solution for the i-th cell, fill[i] is to be set as True.
self._filled[:] = suitable
def kursawe(x: torch.Tensor) -> torch.Tensor:
f1 = torch.sum(
-10 * torch.exp(
-0.2 * torch.sqrt(x[:, 0:2] ** 2.0 + x[:, 1:3] ** 2.0)
),
dim=-1,
)
f2 = torch.sum(
(torch.abs(x) ** 0.8) + (5 * torch.sin(x ** 3)),
dim=-1,
)
fitnesses = torch.stack([f1 + f2, f1, f2], dim=-1)
return fitnesses
if __name__ == "__main__":
tensor_feature_grid = MAPElites.make_feature_grid(
lower_bounds=[-20, -14],
upper_bounds=[-10, 4],
num_bins=50,
dtype="float32",
)
for clazz, feature_grid in [
(MAPElitesScatter, FeatureGrid([-20, -14], [-10, 4], [50, 50])),
(MAPElites, tensor_feature_grid),
]:
problem = Problem(
"min",
kursawe,
solution_length=3,
eval_data_length=2,
bounds=(-5.0, 5.0),
vectorized=True,
)
searcher = clazz(
problem,
feature_grid=feature_grid,
operators=[
SimulatedBinaryCrossOver(problem, tournament_size=4, cross_over_rate=1.0, eta=8),
GaussianMutation(problem, stdev=0.03),
],
)
start = time.time()
searcher.run(100)
print("Final status:\n", searcher.status)
print("Impl: ", clazz)
print("Time spent (secs): ", time.time() - start)
print("Filled hypervolumes: ", searcher.filled.sum()) out:
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Would you be interested in contributions to re-work MAPElites to use torch_scatter rather than the vmaped
extended_population
xfeature_grid
operation?The general gist is:
The text was updated successfully, but these errors were encountered: