From 3e0346f51befc048ed090da087dbe9239c257b5a Mon Sep 17 00:00:00 2001 From: Bruno Arine Date: Tue, 11 Jul 2023 14:32:36 -0300 Subject: [PATCH 1/3] fix: query and reference setting on Corpus --- findlike/preprocessing.py | 16 ++++++++++------ findlike/wrappers.py | 8 ++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/findlike/preprocessing.py b/findlike/preprocessing.py index be16f82..d55dfd8 100644 --- a/findlike/preprocessing.py +++ b/findlike/preprocessing.py @@ -5,7 +5,7 @@ from typing import Callable from .markup import Markup -from .utils import compress, try_read_file +from .utils import try_read_file WORD_RE = re.compile(r"(?u)\b\w{2,}\b") URL_RE = re.compile(r"\S*https?:\S*") @@ -60,6 +60,7 @@ def _stemmize(self, tokens: list[str]) -> list[str]: """Get only the stems from a list of words.""" return [self.stemmer(w) for w in tokens] + class Corpus: """This wrapper provides easy access to a filtered corpus. @@ -85,7 +86,7 @@ def __init__( self.documents_: list[str] = [] self.paths_: list[Path] = [] - self.reference_: str| None = None + self.reference_: str | None = None self.add_from_paths() @@ -94,13 +95,13 @@ def add_from_file(self, path: Path, is_reference: bool = False): Args: path (Path): The path to the file. - is_reference (bool, optional): Indicates if the file is a reference file. + is_reference (bool, optional): Indicates if the file is a reference file. Defaults to False. Notes: - - The file content is added to the corpus if it meets the minimum character + - The file content is added to the corpus if it meets the minimum character length requirement. - - If front matter stripping is enabled, the file content is stripped of its + - If front matter stripping is enabled, the file content is stripped of its front matter before being added to the corpus. """ loaded_doc = try_read_file(path) @@ -109,14 +110,17 @@ def add_from_file(self, path: Path, is_reference: bool = False): loaded_doc = self.strip_front_matter( loaded_doc, extension=path.suffix ) - self.documents_.append(loaded_doc) if is_reference: self.reference_ = loaded_doc + if self.reference_ not in self.documents_: + self.documents_.append(self.reference_) else: + self.documents_.append(loaded_doc) self.paths_.append(path) def add_from_query(self, query: str): self.documents_.append(query) + self.reference_ = query def add_from_paths(self) -> list[str | None]: """Load document contents from the specified paths.""" diff --git a/findlike/wrappers.py b/findlike/wrappers.py index cd4875a..8af774b 100644 --- a/findlike/wrappers.py +++ b/findlike/wrappers.py @@ -30,9 +30,7 @@ def fit(self, documents: list[str]): self.target_embeddings_ = self._vectorizer.fit_transform(documents) def get_scores(self, source: str): - # Since the reference has been appended to the corpus, the last - # item in the embeddings list will be the reference's. - self.reference_embeddings_ = self.target_embeddings_[-1] + self.reference_embeddings_ = self._vectorizer.transform([source]) scores = cosine_similarity( self.reference_embeddings_, self.target_embeddings_ ).flatten() @@ -61,8 +59,6 @@ def fit(self, documents: list[str]): self._model = BM25Okapi(self.tokenized_documents_) def get_scores(self, source: str): - # Since the reference has been appended to the corpus, the last - # item in the embeddings list will be the reference's. - tokenized_source = self.tokenized_documents_[-1] + tokenized_source = self.processor.tokenizer(source) scores = self._model.get_scores(tokenized_source) return scores From cd6be24dc1deb73811e779c076831e8d49a6854f Mon Sep 17 00:00:00 2001 From: Bruno Arine Date: Tue, 11 Jul 2023 14:34:57 -0300 Subject: [PATCH 2/3] test: adapt unit tests to refactored Corpus --- tests/test_corpus.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test_corpus.py b/tests/test_corpus.py index d9ba43d..64bf8c8 100644 --- a/tests/test_corpus.py +++ b/tests/test_corpus.py @@ -112,4 +112,11 @@ def test_strip_front_matter(self, corpus): """ extension = ".org" expected = "This is some text.\n** A heading\nSome more text." - assert corpus.strip_front_matter(dedent(document), extension) == expected \ No newline at end of file + assert corpus.strip_front_matter(dedent(document), extension) == expected + + def test_reference_duplicity(self, corpus, temp_files): + corpus.add_from_file(temp_files[0]) + corpus.add_from_file(temp_files[1]) + corpus.add_from_file(temp_files[0], is_reference=True) + duplicates = set([x for x in corpus.documents_ if corpus.documents_.count(x) > 1]) + assert len(duplicates) == 0 From f56cbd7d4023a906da137e7880520450c77557c6 Mon Sep 17 00:00:00 2001 From: Bruno Arine Date: Tue, 11 Jul 2023 14:35:17 -0300 Subject: [PATCH 3/3] test: fix correlation test --- tests/test_cli.py | 69 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 58 insertions(+), 11 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index edd65ba..c03ab94 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,10 +1,13 @@ -from click.testing import CliRunner -import pytest -from findlike import cli -import numpy as np import json +from pathlib import Path + +import numpy as np +import pytest +from click.testing import CliRunner from scipy.stats import spearmanr +from findlike import cli + reference = "Hurricane Irene was a long-lived Cape Verde hurricane during the 2005 Atlantic hurricane season. The storm formed near Cape Verde on August 4 and crossed the Atlantic, turning northward around Bermuda before being absorbed by an extratropical cyclone while situated southeast of Newfoundland. Irene proved to be a difficult storm to forecast due to oscillations in strength. After almost dissipating on August 10, Irene peaked as a Category 2 hurricane on August 16. Irene persisted for 14 days as a tropical system, the longest duration of any storm of the 2005 season. It was the ninth named storm and fourth hurricane of the record-breaking season." candidates = [ "In 2023, a sizable earthquake dubbed the 'Goliath Quake' occurred in the Pacific Ring of Fire. The seismic activity started near Japan on July 16 and sent shockwaves across the Pacific, leading to tsunamis in several coastal regions before dissipating on July 24. Forecasting the earthquake's impact was a complex task due to the frequent aftershocks. After a series of smaller tremors, the earthquake peaked with a magnitude of 8.9 on July 20. The quake lasted for 8 days, the longest seismic activity recorded in the 21st century. It was one of the most powerful earthquakes in recent history.", @@ -21,7 +24,7 @@ ] scores = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0] -std_args = ["--query", reference, "-F", "json", "-s", "-m", len(scores)] +std_args = ["-F", "json", "-s", "-h", "-m", len(scores)] @pytest.fixture @@ -34,6 +37,8 @@ def create_directory(tmp_path): for i, candidate in enumerate(candidates): file = tmp_path / f"file_{i:02d}.txt" file.write_text(candidate) + reference_file = tmp_path / "reference.txt" + reference_file.write_text(reference) return str(tmp_path) @@ -42,13 +47,17 @@ def create_directory_with_non_text(tmp_path): for i, candidate in enumerate(candidates): file = tmp_path / f"file_{i:02d}.000" file.write_text(candidate) + reference_file = tmp_path / "reference.000" + reference_file.write_text(reference) return str(tmp_path) @pytest.mark.parametrize("format", cli.FORMATTER_CLASSES.keys()) def test_formats(runner, create_directory, format): + reference_path = Path(create_directory) / "reference.txt" result = runner.invoke( - cli.cli, ["-d", create_directory, "-F", format, *std_args] + cli.cli, + [str(reference_path), "-d", create_directory, "-F", format, *std_args], ) assert result.exit_code == 0 @@ -58,19 +67,57 @@ def test_formats(runner, create_directory, format): @pytest.mark.parametrize("algorithm", cli.ALGORITHM_CLASSES.keys()) def test_algorithms(runner, create_directory, algorithm): + reference_path = Path(create_directory) / "reference.txt" result = runner.invoke( - cli.cli, ["-d", create_directory, "-a", algorithm, *std_args] + cli.cli, + [ + str(reference_path), + "-d", + create_directory, + "-a", + algorithm, + *std_args, + ], ) json_data = json.loads(result.output.strip()) - - output_scores = [float(x["score"]) for x in json_data] - assert spearmanr(output_scores, scores)[0] > 0.99 + pairs = [(item["score"], item["target"]) for item in json_data] + sorted_pairs = sorted(pairs, key=lambda x: x[1])[::-1] + output_scores = [float(x[0]) for x in sorted_pairs] + corr = round(spearmanr(output_scores, scores)[0], 2) + assert corr >= 0.95 def test_other_extensions(runner, create_directory_with_non_text): + reference_path = Path(create_directory_with_non_text) / "reference.000" result = runner.invoke( cli.cli, - ["-d", create_directory_with_non_text, "-f", "*.000", *std_args], + [ + str(reference_path), + "-d", + create_directory_with_non_text, + "-f", + "*.000", + *std_args, + ], ) assert "000" in result.output.strip() + +@pytest.mark.parametrize("algorithm", cli.ALGORITHM_CLASSES.keys()) +def test_query(runner, create_directory, algorithm): + result = runner.invoke( + cli.cli, + [ + "-q", + reference, + "-d", + create_directory, + "-a", + algorithm, + *std_args, + ], + ) + json_data = json.loads(result.output.strip()) + + output_scores = [float(x["score"]) for x in json_data] + assert spearmanr(output_scores, scores)[0] > 0.99