Skip to content

Commit

Permalink
Fix experiments running (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
zuevmaxim authored May 31, 2024
1 parent 7b6f13f commit 4330649
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
2 changes: 2 additions & 0 deletions autotm/algorithms_for_tuning/genetic_algorithm/ga.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ def _sort_population(population):
population.sort(key=operator.attrgetter("fitness_value"), reverse=True)

def _calculate_uncertain_res(self, generation, iteration_num: int, proc=0.3):
if len(generation) == 0:
return []
X = np.array([individ.dto.params.to_vector() for individ in generation])
certanty = get_prediction_uncertanty(
self.surrogate.surrogate, X, self.surrogate.name
Expand Down
13 changes: 9 additions & 4 deletions autotm/algorithms_for_tuning/genetic_algorithm/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def __init__(self, surrogate_name, **kwargs):
self.gpr_kernel = None

def create(self):
kernel = self.kwargs["gpr_kernel"]
del self.kwargs["gpr_kernel"]
gpr_alpha = self.kwargs["gpr_alpha"]
del self.kwargs["gpr_alpha"]
normalize_y = self.kwargs["normalize_y"]
del self.kwargs["normalize_y"]

if self.name == "random-forest-regressor":
self.surrogate = RandomForestRegressor(**self.kwargs)
elif self.name == "mlp-regressor":
Expand All @@ -70,8 +77,6 @@ def create(self):
)
elif self.name == "GPR": # tune ??
if not self.gpr_kernel:
kernel = self.kwargs["gpr_kernel"]
del self.kwargs["gpr_kernel"]
if kernel == "RBF":
self.gpr_kernel = 1.0 * RBF(1.0)
elif kernel == "RBFwithConstant":
Expand All @@ -85,8 +90,8 @@ def create(self):
elif kernel == "RationalQuadratic":
self.gpr_kernel = RationalQuadratic(1.0)
self.kwargs["kernel"] = self.gpr_kernel
self.kwargs["alpha"] = self.kwargs["gpr_alpha"]
del self.kwargs["gpr_alpha"]
self.kwargs["alpha"] = gpr_alpha
self.kwargs["normalize_y"] = normalize_y
self.surrogate = GaussianProcessRegressor(**self.kwargs)
elif self.name == "decision-tree-regressor":
try:
Expand Down
6 changes: 5 additions & 1 deletion autotm/visualization/dynamic_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,11 @@ def get_metric_df(self):
cur_df = pd.DataFrame(cur_df_dict)
cur_df[GENERATION_COL] = gen
dfs.append(cur_df)
self.crossover_df = pd.concat(dfs)
if len(dfs) > 0:
self.crossover_df = pd.concat(dfs)
else:
warnings.warn("No crossover changes have been found to save", RuntimeWarning)
self.crossover_df = pd.DataFrame([])

def write_metrics_to_file(self):
os.makedirs(self.save_path, exist_ok=True)
Expand Down

0 comments on commit 4330649

Please sign in to comment.