diff --git a/evaluation/setup.py b/evaluation/setup.py index 2a5a695..c3568a2 100644 --- a/evaluation/setup.py +++ b/evaluation/setup.py @@ -42,6 +42,8 @@ 'scipy', 'networkx', 'blanc', + 'scikit-learn', + 'wmd', ], entry_points={ 'console_scripts': [ diff --git a/evaluation/summ_eval/data_stats_metric.py b/evaluation/summ_eval/data_stats_metric.py index b6a4850..dcc86d3 100644 --- a/evaluation/summ_eval/data_stats_metric.py +++ b/evaluation/summ_eval/data_stats_metric.py @@ -2,15 +2,17 @@ from collections import Counter from multiprocessing import Pool import gin +import logging import spacy from summ_eval.data_stats_utils import Fragments from summ_eval.metric import Metric +logger = logging.getLogger(__name__) + try: _en = spacy.load('en_core_web_sm') except OSError: - print('Downloading the spacy en_core_web_sm model\n' - "(don't worry, this will only happen once)", file=stderr) + logger.info("Downloading the spacy en_core_web_sm model\n (don't worry, this will only happen once)") from spacy.cli import download download('en_core_web_sm') _en = spacy.load('en_core_web_sm') diff --git a/evaluation/summ_eval/sentence_transformers/SentenceTransformer.py b/evaluation/summ_eval/sentence_transformers/SentenceTransformer.py index e112696..6726555 100644 --- a/evaluation/summ_eval/sentence_transformers/SentenceTransformer.py +++ b/evaluation/summ_eval/sentence_transformers/SentenceTransformer.py @@ -50,7 +50,7 @@ def __init__(self, model_name_or_path: str = None, modules: Iterable[nn.Module] if not os.listdir(model_path): - if model_url[-1] is "/": + if model_url[-1] == "/": model_url = model_url[:-1] logging.info("Downloading sentence transformer model from {} and saving it at {}".format(model_url, model_path)) try: diff --git a/evaluation/summ_eval/sentence_transformers/losses/test_batch_hard_triplet_loss.py b/evaluation/summ_eval/sentence_transformers/losses/test_batch_hard_triplet_loss.py index bcbe2e7..16af948 100644 --- a/evaluation/summ_eval/sentence_transformers/losses/test_batch_hard_triplet_loss.py +++ b/evaluation/summ_eval/sentence_transformers/losses/test_batch_hard_triplet_loss.py @@ -1,6 +1,6 @@ import numpy as np import torch -from sentence_transformers.losses import BatchHardTripletLoss +from summ_eval.sentence_transformers.losses import BatchHardTripletLoss # Test-suite from https://github.com/omoindrot/tensorflow-triplet-loss/blob/master/model/tests/test_triplet_loss.py # Skipped the `test_gradients_pairwise_distances()` test since it's trivial to see if your model loss turns NaN