Skip to content

Commit

Permalink
small fixes after rebasing on main
Browse files Browse the repository at this point in the history
  • Loading branch information
fonhorst committed Dec 20, 2023
1 parent 1112f18 commit fad95ba
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
6 changes: 4 additions & 2 deletions autotm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import warnings

from autotm.preprocessing import PREPOCESSED_DATASET_FILENAME

# TODO: Suppressing of DeprecationWarnings that are raise if we are running with __main__, need to research further
with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand Down Expand Up @@ -204,7 +206,7 @@ def predict(self, dataset: Union[pd.DataFrame, pd.Series]) -> pd.DataFrame:
extractor_working_dir,
**self.preprocessing_params
)
preprocessed_dataset = pd.read_csv(os.path.join(extractor_working_dir, "prep_df.csv"))
preprocessed_dataset = pd.read_csv(os.path.join(extractor_working_dir, PREPOCESSED_DATASET_FILENAME))
else:
preprocessed_dataset = dataset
topics_extractor = TopicsExtractor(self._model)
Expand Down Expand Up @@ -237,7 +239,7 @@ def fit_predict(self, dataset: Union[pd.DataFrame, pd.Series]) -> pd.DataFrame:
processed_dataset_path = os.path.join(self.working_dir_path, f"{uuid.uuid4()}")
self.fit(dataset, processed_dataset_path=processed_dataset_path)

preprocessed_dataset = pd.read_csv(os.path.join(processed_dataset_path, "prep_df.csv"))
preprocessed_dataset = pd.read_csv(os.path.join(processed_dataset_path, PREPOCESSED_DATASET_FILENAME))
return self.predict(preprocessed_dataset)

def save(self, path: str, overwrite: bool = False):
Expand Down
2 changes: 1 addition & 1 deletion autotm/fitness/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Dataset:
_ppmi_dict_df_path: str = "ppmi_df.txt"
_ppmi_dict_tf_path: str = "ppmi_tf.txt"
_mutual_info_dict_path: str = "mutual_info_dict.pkl"
_texts_path: str = "ppp.csv"
_texts_path: str = PREPOCESSED_DATASET_FILENAME
_labels_path = "labels.pkl"

def __init__(self, base_path: str, topic_count: int):
Expand Down
2 changes: 1 addition & 1 deletion examples/autotm_fit_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main():
alg_name=alg_name,
alg_params={
"num_iterations": 2,
"num_individuals": 10,
"num_individuals": 2,
"use_nelder_mead_in_mutation": False,
"use_nelder_mead_in_crossover": False,
"use_nelder_mead_in_selector": False,
Expand Down

0 comments on commit fad95ba

Please sign in to comment.