Skip to content

Commit

Permalink
Merge pull request #2 from Leviathan321/custom-ops
Browse files Browse the repository at this point in the history
Make possible to use custom operators.
  • Loading branch information
Leviathan321 authored Dec 11, 2024
2 parents 75859b0 + e090526 commit bb8e619
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 15 deletions.
13 changes: 7 additions & 6 deletions opensbt/algorithm/nsga2_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from opensbt.utils.operators import select_operator
from pymoo.core.problem import Problem
from pymoo.termination import get_termination
from pymoo.algorithms.moo.nsga2 import NSGA2
Expand All @@ -17,7 +18,6 @@ class NsgaIIOptimizer(Optimizer):

def __init__(self,
problem: Problem,

config: SearchConfiguration):
self.config = config
self.problem = problem
Expand All @@ -40,11 +40,12 @@ def __init__(self,
self.algorithm = NSGA2(
pop_size=config.population_size,
n_offsprings=config.num_offsprings,
sampling=FloatRandomSampling(),
crossover=SBX(prob=config.prob_crossover, eta=config.eta_crossover),
mutation=PM(prob=config.prob_mutation, eta=config.eta_mutation),
eliminate_duplicates=True)

sampling = select_operator("init", config),
crossover = select_operator("cx", config),
mutation = select_operator("mut", config),
eliminate_duplicates = select_operator("dup", config)
)

''' Prioritize max search time over set maximal number of generations'''
if config.maximal_execution_time is not None:
self.termination = get_termination("time", config.maximal_execution_time)
Expand Down
17 changes: 10 additions & 7 deletions opensbt/algorithm/nsga2dt_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
from opensbt.model_ga.result import SimulationResult
from opensbt.evaluation.critical import Critical
from opensbt.utils.operators import select_operator
from pymoo.termination import get_termination
from pymoo.algorithms.moo.nsga2 import NSGA2
from pymoo.core.population import Population
Expand Down Expand Up @@ -125,10 +126,11 @@ def run(self) -> SimulationResult:
inner_algorithm = NSGA2(
pop_size=None,
n_offsprings=None,
sampling=None,
crossover=SBX(prob=prob_crossover, eta=eta_crossover),
mutation=PM(prob=prob_mutation, eta=eta_mutation),
eliminate_duplicates=True)
sampling = select_operator("init", config),
crossover = select_operator( "cx", config),
mutation = select_operator( "mut", config),
eliminate_duplicates = select_operator( "dup", config)
)

tree_iteration = 0
n_func_evals = 0
Expand All @@ -153,9 +155,10 @@ def run(self) -> SimulationResult:
pop_size=pop_size,
n_offsprings=num_offsprings,
sampling=initial_population,
crossover=SBX(prob=prob_crossover, eta=eta_crossover),
mutation=PM(prob=prob_mutation, eta=eta_mutation),
eliminate_duplicates=True)
crossover = select_operator( "cx", config),
mutation = select_operator( "mut", config),
eliminate_duplicates = select_operator( "dup", config)
)

termination = get_termination("n_gen", inner_num_gen)

Expand Down
4 changes: 3 additions & 1 deletion opensbt/algorithm/pso_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from opensbt.utils.operators import select_operator
from pymoo.core.problem import Problem
from pymoo.termination import get_termination
from pymoo.algorithms.soo.nonconvex.pso import PSO
Expand Down Expand Up @@ -38,7 +39,8 @@ def __init__(self,

# initialize algorithm
self.algorithm = PSO(
pop_size=config.population_size,
pop_size = config.population_size,
sampling = select_operator("init",config)
)

''' Prioritize max search time over set maximal number of generations'''
Expand Down
26 changes: 25 additions & 1 deletion opensbt/experiment/search_configuration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from dataclasses import dataclass


@dataclass
class SearchConfiguration(object):
""" This class holds all configuration parameter related to opimization algorithms
"""
Expand All @@ -24,6 +27,16 @@ class SearchConfiguration(object):

seed = None

operators= {
"cx" : None,
"mut" : None,
"dup" : None,
"init" : None
}

custom_params = { # to be forwarded to operators

}

class DefaultSearchConfiguration(SearchConfiguration):
""" This class holds all configuration parameter initialized with default values
Expand All @@ -48,4 +61,15 @@ class DefaultSearchConfiguration(SearchConfiguration):
ref_point_hv = None
nadir = ref_point_hv

seed = None
seed = None

operators= {
"cx" : None,
"mut" : None,
"dup" : None,
"init" : None
}

custom_params = { # to be forwarded to operators

}
33 changes: 33 additions & 0 deletions opensbt/utils/operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pymoo.operators.crossover.sbx import SBX # type: ignore
from pymoo.operators.mutation.pm import PM # type: ignore
from pymoo.operators.sampling.lhs import LHS # type: ignore

def select_operator(operation,
config,
**kwargs):
"""
Selects either the default operator or a custom operator based on the condition.
"""
if kwargs is not None:
kwargs = {}

if config.operators[operation] is None:
if operation == "mut":
operator = PM
if "prob" not in kwargs:
kwargs["prob"] = config.prob_mutation
if "eta" not in kwargs:
kwargs["eta"] = config.eta_mutation
elif operation == "cx":
operator = SBX
if "prob" not in kwargs:
kwargs["prob"] = config.prob_crossover
if "eta" not in kwargs:
kwargs["eta"] = config.eta_crossover
elif operation == "init":
operator = LHS
elif operation == "dup":
return True
else:
operator = config.operators[operation]
return operator(**kwargs) # Passes the keyword arguments to the operator

0 comments on commit bb8e619

Please sign in to comment.