Skip to content

Commit

Permalink
cli: add flask cli for training and prediction
Browse files Browse the repository at this point in the history
Signed-off-by: Salman Maqbool [email protected]
  • Loading branch information
salmanmaq committed Jan 18, 2019
1 parent a35a66a commit 5b737de
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 3 deletions.
9 changes: 6 additions & 3 deletions inspire_classifier/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def create_directories():
path_for('classifier_data').mkdir(parents=True, exist_ok=True)
path_for('language_model_data').mkdir(parents=True, exist_ok=True)
path_for('classifier_model').mkdir(parents=True, exist_ok=True)
(path_for('language_model') / 'wikitext_103').mkdir(exist_ok=True, parents=True)
(path_for('language_model') / 'wikitext_103').mkdir(parents=True, exist_ok=True)


def preprocess_and_save_data():
Expand All @@ -61,7 +61,9 @@ def preprocess_and_save_data():
val_fraction=current_app.config['CLASSIFIER_VALIDATION_DATA_FRACTION']
)
except IOError as error:
raise IOError('Training dataframe not found.') from error
raise IOError('Training dataframe not found. Make sure the file is present in the right directory. '
'Please use the path specified in config.py for CLASSIFIER_DATAFRAME_PATH relative to the '
'CLASSIFIER_BASE_PATH.') from error

try:
generate_and_save_language_model_tokens(language_model_data_dir=path_for('language_model_data'))
Expand Down Expand Up @@ -129,7 +131,8 @@ def finetune_and_save_language_model():
language_model.train(finetuned_language_model_encoder_save_path=path_for('finetuned_language_model_encoder'),
cycle_length=current_app.config['CLASSIFIER_LANGUAGE_MODEL_CYCLE_LENGTH'])
except IOError as error:
raise IOError('Unable to save the finetuned language model.') from error
raise IOError('Unable to save the finetuned language model. Please check that the language model data directory '
'exists.') from error


def train_and_save_classifier():
Expand Down
66 changes: 66 additions & 0 deletions inspire_classifier/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
#
# This file is part of INSPIRE.
# Copyright (C) 2014-2019 CERN.
#
# INSPIRE is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# INSPIRE is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with INSPIRE. If not, see <http://www.gnu.org/licenses/>.
#
# In applying this license, CERN does not waive the privileges and immunities
# granted to it by virtue of its status as an Intergovernmental Organization
# or submit itself to any jurisdiction.

import click
import click_spinner
from flask import current_app
from flask.cli import FlaskGroup, with_appcontext
from inspire_classifier.api import (
train,
predict_coreness
)
from inspire_classifier.app import create_app


@click.group(cls=FlaskGroup, create_app=create_app)
def inspire_classifier():
"INSPIRE Classifier commands"


@inspire_classifier.command('predict-coreness')
@with_appcontext
@click.argument('title', type=str, required=True, nargs=1)
@click.argument('abstract', type=str, required=True, nargs=1)
@click.option('-b', '--base-path', type=click.Path(exists=True), required=False, nargs=1)
def predict(title, abstract, base_path):
with click_spinner.spinner():
with current_app.app_context():
if base_path:
current_app.config['CLASSIFIER_BASE_PATH'] = base_path
click.echo(predict_coreness(title, abstract))


@inspire_classifier.command('train')
@with_appcontext
@click.option('-l', '--language-model-epochs', type=int, required=False, nargs=1)
@click.option('-c', '--classifier-epochs', type=int, required=False, nargs=1)
@click.option('-b', '--base-path', type=click.Path(exists=True), required=False, nargs=1)
def train_classifier(language_model_epochs, classifier_epochs, base_path):
with click_spinner.spinner():
with current_app.app_context():
if language_model_epochs:
current_app.config['CLASSIFIER_LANGUAGE_MODEL_CYCLE_LENGTH'] = language_model_epochs
if classifier_epochs:
current_app.config['CLASSIFIER_CLASSIFIER_CYCLE_LENGTH'] = classifier_epochs
if base_path:
current_app.config['CLASSIFIER_BASE_PATH'] = base_path
train()
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
]

install_requires = [
'click~=7.0,>=7.0',
'click-spinner~=0.1,>=0.1.8',
'fastai==0.7.0',
'Flask~=1.0,>=1.0.2',
'flask-apispec~=0.0,>=0.7.0',
Expand Down Expand Up @@ -97,4 +99,9 @@
'Topic :: Internet :: WWW/HTTP :: Dynamic Content',
'Topic :: Software Development :: Libraries :: Python Modules',
],
entry_points={
'console_scripts': [
'inspire-classifier = inspire_classifier.cli:inspire_classifier'
]
}
)
58 changes: 58 additions & 0 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
#
# This file is part of INSPIRE.
# Copyright (C) 2014-2018 CERN.
#
# INSPIRE is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# INSPIRE is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with INSPIRE. If not, see <http://www.gnu.org/licenses/>.
#
# In applying this license, CERN does not waive the privileges and immunities
# granted to it by virtue of its status as an Intergovernmental Organization
# or submit itself to any jurisdiction.

from click.exceptions import MissingParameter
from inspire_classifier.cli import (
predict,
train_classifier
)
import pytest


def test_classifier_predict_cli_with_classifier_base_path():
input_arguments = predict.make_context('predict-coreness', args=['foo', 'bar', '-b', '.'])
assert input_arguments.params['title'] == 'foo'
assert input_arguments.params['abstract'] == 'bar'
assert input_arguments.params['base_path'] == '.'


def test_classifier_predict_cli_without_classifier_base_path():
input_arguments = predict.make_context('predict-coreness', args=['foo', 'bar'])
assert input_arguments.params['title'] == 'foo'
assert input_arguments.params['abstract'] == 'bar'
assert input_arguments.params['base_path'] is None


def test_classifier_predict_cli_fails_without_title_and_abstract():
with pytest.raises(MissingParameter):
predict.make_context('predict-coreness', args=[])


def test_classifier_train_cli_correctly_parses_arguments():
input_arguments = train_classifier.make_context('train', args=['-l', '15', '-c', '14', '-b', '.'])
assert input_arguments.params['language_model_epochs'] == 15
assert input_arguments.params['classifier_epochs'] == 14
assert input_arguments.params['base_path'] == '.'


def test_classifier_train_cli_works_with_no_arguments():
assert train_classifier.make_context('train', args=[])

0 comments on commit 5b737de

Please sign in to comment.