Skip to content

Commit

Permalink
add fix for ga
Browse files Browse the repository at this point in the history
  • Loading branch information
fonhorst committed Jul 4, 2023
1 parent f77cb5e commit 3030288
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
7 changes: 5 additions & 2 deletions autotm/algorithms_for_tuning/genetic_algorithm/ga.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def apply_nelder_mead(self, starting_points_set, num_gen, num_iterations=2):
new_population.append(make_individual(dto=solution_dto))
return new_population

def run(self, verbose=False) -> Individual:
def run(self, verbose=False, visualize_results=False) -> Individual:
self.evaluations_counter = 0
ftime = str(int(time.time()))

Expand Down Expand Up @@ -1018,7 +1018,10 @@ def run(self, verbose=False) -> Individual:
best_solution = population[0]
log_best_solution(best_solution, alg_args=" ".join(sys.argv), is_tmp=True)

self.metric_collector.save_and_visualise_trace()
if visualize_results:
self.metric_collector.save_and_visualise_trace()
else:
self.metric_collector.save_trace()

logger.info(f"Y: {y}")
best_individual = population[0]
Expand Down
13 changes: 10 additions & 3 deletions autotm/visualization/dynamic_tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import time
import warnings

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -189,7 +190,11 @@ def get_metric_df(self):
cur_df = pd.DataFrame(cur_df_dict)
cur_df[GENERATION_COL] = gen
dfs.append(cur_df)
self.mutation_df = pd.concat(dfs)
if len(dfs) > 0:
self.mutation_df = pd.concat(dfs)
else:
warnings.warn("No mutations changes have been found to save", RuntimeWarning)
self.mutation_df = pd.DataFrame([])
if self.crossover_df is not None:
print("Crossover df already exists")
else:
Expand Down Expand Up @@ -227,11 +232,13 @@ def write_metrics_to_file(self):
)
)

def save_and_visualise_trace(self, plot_mutation_effectiveness=False):
def save_trace(self):
self.get_metric_df()
# save params
self.write_metrics_to_file()

def save_and_visualise_trace(self, plot_mutation_effectiveness=False):
self.save_trace()

# traces vis
graph_template = "plotly_white"

Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_fit_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def check_predictions(autotm: AutoTM, df: pd.DataFrame, mixtures: ArrayLike):

assert n_samples_mixture == n_samples
assert n_topics_mixture == n_topics

# TODO: check for nullability of the content in mixtures
assert (~mixtures.isna()).all().all()
assert (~mixtures.isnull()).all().all()


def test_fit_predict():
Expand Down

0 comments on commit 3030288

Please sign in to comment.