Skip to content

Commit

Permalink
its a test
Browse files Browse the repository at this point in the history
  • Loading branch information
DonHaul committed Oct 9, 2024
1 parent b24bb59 commit 66b1c10
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 34 deletions.
34 changes: 16 additions & 18 deletions inspire_classifier/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,10 @@ def inspire_classifier():
"-b", "--base-path", type=click.Path(exists=True), required=False, nargs=1
)
def predict(title, abstract, base_path):
with click_spinner.spinner():
with current_app.app_context():
if base_path:
current_app.config["CLASSIFIER_BASE_PATH"] = base_path
click.echo(predict_coreness(title, abstract))
with click_spinner.spinner(),current_app.app_context():
if base_path:
current_app.config["CLASSIFIER_BASE_PATH"] = base_path
click.echo(predict_coreness(title, abstract))


@inspire_classifier.command("train")
Expand All @@ -58,19 +57,18 @@ def predict(title, abstract, base_path):
"-b", "--base-path", type=click.Path(exists=True), required=False, nargs=1
)
def train_classifier(language_model_epochs, classifier_epochs, base_path):
with click_spinner.spinner():
with current_app.app_context():
if language_model_epochs:
current_app.config["CLASSIFIER_LANGUAGE_MODEL_CYCLE_LENGTH"] = (
language_model_epochs
)
if classifier_epochs:
current_app.config["CLASSIFIER_CLASSIFIER_CYCLE_LENGTH"] = (
classifier_epochs
)
if base_path:
current_app.config["CLASSIFIER_BASE_PATH"] = base_path
train()
with click_spinner.spinner(),current_app.app_context():
if language_model_epochs:
current_app.config["CLASSIFIER_LANGUAGE_MODEL_CYCLE_LENGTH"] = (
language_model_epochs
)
if classifier_epochs:
current_app.config["CLASSIFIER_CLASSIFIER_CYCLE_LENGTH"] = (
classifier_epochs
)
if base_path:
current_app.config["CLASSIFIER_BASE_PATH"] = base_path
train()


@inspire_classifier.command("validate")
Expand Down
5 changes: 4 additions & 1 deletion inspire_classifier/domain/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,11 @@ def initialize_learner(
self,
dropout_multiplier=0.5,
weight_decay=1e-6,
learning_rates=np.array([1e-4, 1e-4, 1e-4, 1e-3, 1e-2]),
learning_rates=None,
):
if learning_rates is None:
learning_rates = np.array([1e-4, 1e-4, 1e-4, 1e-3, 1e-2])

self.learner = text_classifier_learner(
self.dataloader,
AWD_LSTM,
Expand Down
8 changes: 5 additions & 3 deletions inspire_classifier/domain/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@
def split_and_save_data_for_training(dataframe_path, dest_dir, val_fraction=0.1):
"""
Args:
dataframe_path: The path to the pandas dataframe containing the records. The dataframe should have one
column containing the title and abstract text appended (title + abstract). The second
column should contain the label as an integer (0: Rejected, 1: Non-Core, 2: Core).
dataframe_path: The path to the pandas dataframe containing the records.
The dataframe should have one column containing the title and
abstract text appended (title + abstract). The second column
should contain the label as an integer
(0: Rejected, 1: Non-Core, 2: Core).
dest_dir: Directory to save the training/validation csv.
val_fraction: the fraction of data to use as the validation set.
"""
Expand Down
4 changes: 3 additions & 1 deletion scripts/train_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def train_classifier(
print("-----------------")

os.system(
f"inspire-classifier train -b classifier --classifier-epochs {number_of_classifier_epochs} --language-model-epochs {number_of_lanuage_model_epochs}"
f"inspire-classifier train -b classifier "
f"--classifier-epochs {number_of_classifier_epochs} "
f"--language-model-epochs {number_of_lanuage_model_epochs}"
)
print("training finished successfully!")
os.system(
Expand Down
9 changes: 5 additions & 4 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ class Mock_Learner(Learner):
"""
Mocks the fit method of the Learner.
This is done to reduce the model training time during testing by making the fit run once (as opposed to 2 times and
3 times for the LanguageModel and Classifier respectively). It stores the result of the first run and then returns
the same result for the other times fit is run.
This is done to reduce the model training time during testing by making the fit
run once (as opposed to 2 times and 3 times for the LanguageModel and Classifier
respectively). It stores the result of the first run and then returns the same
result for the other times fit is run.
"""

def fit(self, *args, **kwargs):
Expand All @@ -70,7 +71,7 @@ def fit(self, *args, **kwargs):

@pytest.fixture(scope="session")
@patch("fastai.text.learner.text_classifier_learner", Mock_Learner)
def trained_pipeline(app, tmp_path_factory):
def _trained_pipeline(app, tmp_path_factory):
app.config["CLASSIFIER_BASE_PATH"] = tmp_path_factory.getbasetemp()
create_directories()
shutil.copy(
Expand Down
20 changes: 13 additions & 7 deletions tests/integration/test_classifier_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from math import isclose

import pandas as pd
import pytest

from inspire_classifier.api import predict_coreness
from inspire_classifier.utils import path_for
Expand All @@ -42,11 +43,13 @@
" numerical range.")


def test_create_directories(trained_pipeline):
@pytest.mark.usefixtures("_trained_pipeline")
def test_create_directories():
assert path_for("classifier_model").exists()


def test_preprocess_and_save_data(app, trained_pipeline):
@pytest.mark.usefixtures("_trained_pipeline")
def test_preprocess_and_save_data(app):
dataframe = pd.read_pickle(path_for("dataframe"))

training_valid__csv = pd.read_csv(path_for("train_valid_data"))
Expand All @@ -64,8 +67,8 @@ def test_preprocess_and_save_data(app, trained_pipeline):
abs_tol=1,
)


def test_vocab(app, trained_pipeline):
@pytest.mark.usefixtures("_trained_pipeline")
def test_vocab(app):
with open(path_for("data_itos"), "rb") as file:
data_itos = pickle.load(file)
# For performance when using mixed precision, the vocabulary is always made of
Expand All @@ -78,15 +81,18 @@ def test_vocab(app, trained_pipeline):
assert len(data_itos) == adjusted_max_vocab


def test_save_language_model(trained_pipeline):
@pytest.mark.usefixtures("_trained_pipeline")
def test_save_language_model():
assert path_for("finetuned_language_model_encoder").exists()


def test_train_and_save_classifier(trained_pipeline):
@pytest.mark.usefixtures("_trained_pipeline")
def test_train_and_save_classifier():
assert path_for("trained_classifier").exists()


def test_predict_coreness(trained_pipeline):
@pytest.mark.usefixtures("_trained_pipeline")
def test_predict_coreness():
assert path_for("data_itos").exists()
assert path_for("trained_classifier").exists()
output_dict = predict_coreness(title=TEST_TITLE, abstract=TEST_ABSTRACT)
Expand Down

0 comments on commit 66b1c10

Please sign in to comment.