Skip to content

Commit

Permalink
---
Browse files Browse the repository at this point in the history
  • Loading branch information
fonhorst committed Jul 5, 2023
1 parent 85a23f7 commit 10c10eb
Showing 1 changed file with 48 additions and 29 deletions.
77 changes: 48 additions & 29 deletions autotm/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import os
import pprint
import tempfile
from contextlib import contextmanager
from typing import Optional
Expand All @@ -14,6 +16,38 @@
logger = logging.getLogger()


def obtain_autotm_params(
config_path: str,
topic_count: Optional[int],
lang: Optional[str],
alg: Optional[str],
surrogate_alg: Optional[str],
log_file: Optional[str]
):
# TODO: define clearly format of config
if config_path is not None:
logger.info(f"Reading config from path: {os.path.abspath(config_path)}")
with open(config_path, "r") as f:
config = yaml.safe_load(f)
else:
config = dict()

if topic_count is not None:
config['topic_count'] = topic_count
if lang is not None:
pp = config.get('preprocessing_params', dict())
pp['lang'] = lang
config['preprocessing_params'] = pp
if alg is not None:
config['alg_name'] = alg
if surrogate_alg is not None:
config['surrogate_alg_name'] = surrogate_alg
if log_file is not None:
config['log_file_path'] = log_file

return config


@contextmanager
def prepare_working_dir(working_dir: Optional[str] = None):
if working_dir is None:
Expand All @@ -29,9 +63,8 @@ def cli():
pass


# TODO: add lang and log_file params
@cli.command()
@click.option('--config', type=str, help="A path to config for fitting the model")
@click.option('--config', 'config_path', type=str, help="A path to config for fitting the model")
@click.option(
'--working-dir',
type=str,
Expand All @@ -48,56 +81,42 @@ def cli():
)
@click.option('--model', type=str, default='model.artm', help="A path that will contain fitted ARTM model")
@click.option('-t', '--topic-count', type=int, help="Number of topics to fit model with")
@click.option('--lang', type=str, help='Language of the dataset')
@click.option('--alg', type=str, help="Hyperparameters tuning algorithm. Available: ga, bayes")
@click.option('--surrogate-alg', type=str, help="Surrogate algorithm to use.")
@click.option('--log-file', type=str, help="Log file path")
def fit(
config: Optional[str],
config_path: Optional[str],
working_dir: Optional[str],
in_: str,
out: str,
model: str,
t: Optional[int],
topic_count: Optional[int],
lang: Optional[str],
alg: Optional[str],
surrogate_alg: Optional[str]
surrogate_alg: Optional[str],
log_file: Optional[str]
):
# TODO: define clearly format of config
if config is not None:
with open(config, "r") as f:
config = yaml.load(f)
else:
config = dict()
config = obtain_autotm_params(config_path, topic_count, lang, alg, surrogate_alg, log_file)

cli_params = dict()
if t is not None:
cli_params['topic_count'] = t
if alg is not None:
cli_params['alg_name'] = alg
if surrogate_alg is not None:
cli_params['surrogate_alg_name'] = surrogate_alg

config = {
**config,
**cli_params
}

# TODO: log all params
logger.debug(f"Running AutoTM with params: {pprint.pformat(config, indent=4)}")

df = pd.read_csv(in_)

with prepare_working_dir(working_dir) as work_dir:
logger.info(f"Using working directory {os.path.abspath(work_dir)} for AutoTM")
autotm = AutoTM(
**config,
working_dir_path=work_dir
)
mixtures = autotm.fit_predict(df)

logger.info(f"Calculated train mixtures: {mixtures.shape}\n\n{mixtures.head(10).to_string()}")

# saving artifacts
logger.info(f"Saving model to {os.path.abspath(model)}")
autotm.save(model)
logger.info(f"Saving mixtures to {os.path.abspath(out)}")
mixtures.to_csv(out)

click.echo('Initialized the database')
logger.info("Finished AutoTM")


@cli.command()
Expand Down

0 comments on commit 10c10eb

Please sign in to comment.