Skip to content

Commit

Permalink
complete example of AutoTM
Browse files Browse the repository at this point in the history
  • Loading branch information
fonhorst committed Jul 3, 2023
1 parent 2b7b3a2 commit 6e38f0e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 22 deletions.
55 changes: 46 additions & 9 deletions autotm/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
import pickle
import shutil
import tempfile
import uuid
from typing import Union, Optional, Any, Dict, List
Expand All @@ -11,7 +13,7 @@

from autotm.algorithms_for_tuning.bayesian_optimization import bayes_opt
from autotm.algorithms_for_tuning.genetic_algorithm import genetic_algorithm
from autotm.fitness.tm import TopicModel
from autotm.fitness.tm import TopicModel, extract_topics, print_topics

from autotm.infer import TopicsExtractor
from autotm.preprocessing.dictionaries_preparation import prepare_all_artifacts
Expand All @@ -21,6 +23,8 @@


class AutoTM(BaseEstimator):
_ARTM_MODEL_FILENAME = "artm_model"
_AUTOTM_DATA_FILENAME = "autotm_data"
_SUPPORTED_ALGS = ["ga", "baeys"]

@classmethod
Expand All @@ -29,7 +33,25 @@ def load(cls, path: str) -> 'AutoTM':
Loads AutoTM instance from a path on local filesystem.
:param path: a local filesystem path to load an AutoTM instance from.
"""
raise NotImplementedError()
assert os.path.exists(path), f"Path doesn't exist: {path}"

artm_model_path = os.path.join(path, cls._ARTM_MODEL_FILENAME)
autotm_data_path = os.path.join(path, cls._AUTOTM_DATA_FILENAME)

if not (os.path.exists(artm_model_path) and os.path.exists(autotm_data_path)):
raise FileNotFoundError(f"One or two of the follwing paths don't exist: "
f"{artm_model_path}, {autotm_data_path}")

model = artm.load_artm_model(artm_model_path)

with open(autotm_data_path, "rb") as f:
params = pickle.load(f)

autotm = AutoTM(**params)
autotm._model = model

return autotm


def __init__(self,
topic_count: int = 10,
Expand Down Expand Up @@ -83,7 +105,7 @@ def __init__(self,
self.exp_id = exp_id
self.exp_tag = exp_tag
self.exp_dataset_name = exp_dataset_name
self._model: Optional[TopicModel] = None
self._model: Optional[artm.ARTM] = None

def fit(self, dataset: Union[pd.DataFrame, pd.Series]) -> 'AutoTM':
"""
Expand Down Expand Up @@ -122,6 +144,7 @@ def fit(self, dataset: Union[pd.DataFrame, pd.Series]) -> 'AutoTM':
f"Only the following algorithms are supported: {self._SUPPORTED_ALGS}")

if self.alg_name == "ga":
# TODO: add checking of surrogate alg names
# TODO: make mlflow arguments optional
# exp_id and dataset_name will be needed further to store results in mlflow
best_topic_model = genetic_algorithm.run_algorithm(
Expand All @@ -141,7 +164,7 @@ def fit(self, dataset: Union[pd.DataFrame, pd.Series]) -> 'AutoTM':
**self.alg_params
)

self._model = best_topic_model
self._model = best_topic_model.model

return self

Expand All @@ -165,7 +188,7 @@ def predict(self, dataset: Union[pd.DataFrame, pd.Series]) -> ArrayLike:
self._check_if_already_fitted()

with tempfile.TemporaryDirectory(dir=self.working_dir_path) as extractor_working_dir:
topics_extractor = TopicsExtractor(self._model.model)
topics_extractor = TopicsExtractor(self._model)
mixtures = topics_extractor.get_prob_mixture(dataset=dataset, working_dir=extractor_working_dir)

return mixtures
Expand Down Expand Up @@ -193,27 +216,41 @@ def fit_predict(self, dataset: Union[pd.DataFrame, pd.Series]) -> ArrayLike:
self.fit(dataset)
return self.predict(dataset)

def save(self, path: str):
def save(self, path: str, overwrite: bool = False):
"""
Saves AutoTM to a filesystem.
:param path: local filesystem path to save AutoTM on
"""
raise NotImplementedError()
path_exists = os.path.exists(path)
if path_exists and not overwrite:
raise RuntimeError("The path is already exists and is not allowed to overwrite")
elif path_exists:
logger.debug(f"Removing existing path: {path}")
shutil.rmtree(path)

os.makedirs(path)

artm_model_path = os.path.join(path, self._ARTM_MODEL_FILENAME)
autotm_data_path = os.path.join(path, self._AUTOTM_DATA_FILENAME)

self._model.dump_artm_model(artm_model_path)
with open(autotm_data_path, "wb") as f:
pickle.dump(self.get_params(), f)

@property
def topics(self) -> Dict[str, List[str]]:
"""
Inferred set of topics with their respective top words.
"""
self._check_if_already_fitted()
return self._model.get_topics()
return extract_topics(self._model)

def print_topics(self):
"""
Print topics in a human readable form in stdout
"""
self._check_if_already_fitted()
self._model.print_topics()
print_topics(self._model)

def _check_if_already_fitted(self, fit_is_ok=True):
if fit_is_ok and self._model is None:
Expand Down
34 changes: 21 additions & 13 deletions autotm/fitness/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@
logging.basicConfig(level="INFO")


def extract_topics(model: artm.ARTM):
if "TopTokensScore" not in model.score_tracker:
logger.warning(
f"Key 'TopTokensScore' is not presented in the model's score_tracker. "
f"Returning empty dict of topics."
)
return dict()
res = model.score_tracker["TopTokensScore"].last_tokens
topics = {topic: tokens[:50] for topic, tokens in res.items()}
return topics


def print_topics(model: artm.ARTM):
for i, (topic, top_tokens) in enumerate(extract_topics(model).items()):
print(topic)
print(top_tokens)
print()


class Dataset:
_batches_path: str = "batches"
_wv_path: str = "test_set_data_voc.txt"
Expand Down Expand Up @@ -546,21 +565,10 @@ def save_model(self, path):
self.model.dump_artm_model(path)

def print_topics(self):
for i, (topic, top_tokens) in enumerate(self.get_topics().items()):
print(topic)
print(top_tokens)
print()
print_topics(self.model)

def get_topics(self):
if "TopTokensScore" not in self.model.score_tracker:
logger.warning(
f"Key 'TopTokensScore' is not presented in the model's (uid={self.uid}) score_tracker. "
f"Returning empty dict of topics."
)
return dict()
res = self.model.score_tracker["TopTokensScore"].last_tokens
topics = {topic: tokens[:50] for topic, tokens in res.items()}
return topics
return extract_topics(self.model)

def _get_avg_coherence_score(self, for_individ_fitness=False):
coherences_main, coherences_back = self.__return_all_tokens_coherence(
Expand Down

0 comments on commit 6e38f0e

Please sign in to comment.