Skip to content

Commit

Permalink
---
Browse files Browse the repository at this point in the history
  • Loading branch information
fonhorst committed Jul 3, 2023
1 parent 8b5afb4 commit 2b7b3a2
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions autotm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import tempfile
import uuid
from typing import Union, Optional, Any, Dict
from typing import Union, Optional, Any, Dict, List

import artm
import pandas as pd
Expand All @@ -11,6 +11,7 @@

from autotm.algorithms_for_tuning.bayesian_optimization import bayes_opt
from autotm.algorithms_for_tuning.genetic_algorithm import genetic_algorithm
from autotm.fitness.tm import TopicModel

from autotm.infer import TopicsExtractor
from autotm.preprocessing.dictionaries_preparation import prepare_all_artifacts
Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(self,
self.exp_id = exp_id
self.exp_tag = exp_tag
self.exp_dataset_name = exp_dataset_name
self._model: Optional[artm.ARTM] = None
self._model: Optional[TopicModel] = None

def fit(self, dataset: Union[pd.DataFrame, pd.Series]) -> 'AutoTM':
"""
Expand All @@ -101,6 +102,8 @@ def fit(self, dataset: Union[pd.DataFrame, pd.Series]) -> 'AutoTM':
Fitted Estimator.
"""
self._check_if_already_fitted(fit_is_ok=False)

processed_dataset_path = os.path.join(self.working_dir_path, f"{uuid.uuid4()}")

logger.info("Stage 1: Dataset preparation")
Expand Down Expand Up @@ -138,7 +141,7 @@ def fit(self, dataset: Union[pd.DataFrame, pd.Series]) -> 'AutoTM':
**self.alg_params
)

self._model = best_topic_model.model
self._model = best_topic_model

return self

Expand All @@ -159,8 +162,10 @@ def predict(self, dataset: Union[pd.DataFrame, pd.Series]) -> ArrayLike:
Returns the probabilities of each topic to be in the every given text.
Topic's probabilities are ordered according to topics ordering in 'self.topics' property.
"""
self._check_if_already_fitted()

with tempfile.TemporaryDirectory(dir=self.working_dir_path) as extractor_working_dir:
topics_extractor = TopicsExtractor(self._model)
topics_extractor = TopicsExtractor(self._model.model)
mixtures = topics_extractor.get_prob_mixture(dataset=dataset, working_dir=extractor_working_dir)

return mixtures
Expand All @@ -183,6 +188,8 @@ def fit_predict(self, dataset: Union[pd.DataFrame, pd.Series]) -> ArrayLike:
Returns the probabilities of each topic to be in the every given text.
Topic's probabilities are ordered according to topics ordering in 'self.topics' property.
"""
self._check_if_already_fitted(fit_is_ok=False)

self.fit(dataset)
return self.predict(dataset)

Expand All @@ -194,14 +201,23 @@ def save(self, path: str):
raise NotImplementedError()

@property
def topics(self) -> pd.DataFrame:
def topics(self) -> Dict[str, List[str]]:
"""
Inferred set of topics with their respective top words.
"""
raise NotImplementedError()
self._check_if_already_fitted()
return self._model.get_topics()

def print_topics(self):
"""
Print topics in a human readable form in stdout
"""
raise NotImplementedError()
self._check_if_already_fitted()
self._model.print_topics()

def _check_if_already_fitted(self, fit_is_ok=True):
if fit_is_ok and self._model is None:
raise RuntimeError("The model is not fitted")

if not fit_is_ok and self._model is not None:
raise RuntimeError("The model is already fitted")

0 comments on commit 2b7b3a2

Please sign in to comment.