Skip to content

Commit

Permalink
---
Browse files Browse the repository at this point in the history
  • Loading branch information
fonhorst committed Jun 3, 2024
1 parent 5ccf312 commit 227044c
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions autotm/fitness/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,13 @@ def _estimate(self, iter_num: int, population: List[Individual]) -> List[Individ
class SurrogateEnabledFitnessEstimatorMixin(FitnessEstimator):
SUPPORTED_CALC_SCHEMES = ["type1", "type2"]

def __init__(self,
ibuilder: IndividualBuilder,
surrogate: Surrogate,
calc_scheme: str,
speedup: bool = True,
num_fitness_evaluations: Optional[int] = None,
statistics_collector: Optional[StatisticsCollector] = None):
self.ibuilder = ibuilder
self.surrogate = surrogate
self.calc_scheme = calc_scheme
self.speedup = speedup

self.all_params: List[AbstractParams] = []
self.all_fitness: List[float] = []
ibuilder: IndividualBuilder
surrogate: Surrogate
calc_scheme: str
speedup: bool

if calc_scheme not in self.SUPPORTED_CALC_SCHEMES:
raise ValueError(f"Unexpected surrogate scheme! {self.calc_scheme}")
super().__init__(num_fitness_evaluations, statistics_collector)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@staticmethod
def surrogate_iteration(iter_num: int) -> bool:
Expand Down Expand Up @@ -237,4 +226,21 @@ def _estimate(self, iter_num: int, population: List[Individual]) -> List[Individ


class SurrogateEnabledComputableFitnessEstimator(ComputableFitnessEstimator, SurrogateEnabledFitnessEstimatorMixin):
pass
def __init__(self,
ibuilder: IndividualBuilder,
surrogate: Surrogate,
calc_scheme: str,
speedup: bool = True,
num_fitness_evaluations: Optional[int] = None,
statistics_collector: Optional[StatisticsCollector] = None):
self.ibuilder = ibuilder
self.surrogate = surrogate
self.calc_scheme = calc_scheme
self.speedup = speedup

self.all_params: List[AbstractParams] = []
self.all_fitness: List[float] = []

if calc_scheme not in self.SUPPORTED_CALC_SCHEMES:
raise ValueError(f"Unexpected surrogate scheme! {self.calc_scheme}")
super().__init__(ibuilder, num_fitness_evaluations, statistics_collector)

0 comments on commit 227044c

Please sign in to comment.