Skip to content

Commit

Permalink
Merge pull request #5 from brunoarine/fix/reference-doc
Browse files Browse the repository at this point in the history
fix: reference document issue
  • Loading branch information
brunoarine authored Jul 11, 2023
2 parents f5b46cf + f56cbd7 commit e1e05e0
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 24 deletions.
16 changes: 10 additions & 6 deletions findlike/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Callable

from .markup import Markup
from .utils import compress, try_read_file
from .utils import try_read_file

WORD_RE = re.compile(r"(?u)\b\w{2,}\b")
URL_RE = re.compile(r"\S*https?:\S*")
Expand Down Expand Up @@ -60,6 +60,7 @@ def _stemmize(self, tokens: list[str]) -> list[str]:
"""Get only the stems from a list of words."""
return [self.stemmer(w) for w in tokens]


class Corpus:
"""This wrapper provides easy access to a filtered corpus.
Expand All @@ -85,7 +86,7 @@ def __init__(

self.documents_: list[str] = []
self.paths_: list[Path] = []
self.reference_: str| None = None
self.reference_: str | None = None

self.add_from_paths()

Expand All @@ -94,13 +95,13 @@ def add_from_file(self, path: Path, is_reference: bool = False):
Args:
path (Path): The path to the file.
is_reference (bool, optional): Indicates if the file is a reference file.
is_reference (bool, optional): Indicates if the file is a reference file.
Defaults to False.
Notes:
- The file content is added to the corpus if it meets the minimum character
- The file content is added to the corpus if it meets the minimum character
length requirement.
- If front matter stripping is enabled, the file content is stripped of its
- If front matter stripping is enabled, the file content is stripped of its
front matter before being added to the corpus.
"""
loaded_doc = try_read_file(path)
Expand All @@ -109,14 +110,17 @@ def add_from_file(self, path: Path, is_reference: bool = False):
loaded_doc = self.strip_front_matter(
loaded_doc, extension=path.suffix
)
self.documents_.append(loaded_doc)
if is_reference:
self.reference_ = loaded_doc
if self.reference_ not in self.documents_:
self.documents_.append(self.reference_)
else:
self.documents_.append(loaded_doc)
self.paths_.append(path)

def add_from_query(self, query: str):
self.documents_.append(query)
self.reference_ = query

def add_from_paths(self) -> list[str | None]:
"""Load document contents from the specified paths."""
Expand Down
8 changes: 2 additions & 6 deletions findlike/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ def fit(self, documents: list[str]):
self.target_embeddings_ = self._vectorizer.fit_transform(documents)

def get_scores(self, source: str):
# Since the reference has been appended to the corpus, the last
# item in the embeddings list will be the reference's.
self.reference_embeddings_ = self.target_embeddings_[-1]
self.reference_embeddings_ = self._vectorizer.transform([source])
scores = cosine_similarity(
self.reference_embeddings_, self.target_embeddings_
).flatten()
Expand Down Expand Up @@ -61,8 +59,6 @@ def fit(self, documents: list[str]):
self._model = BM25Okapi(self.tokenized_documents_)

def get_scores(self, source: str):
# Since the reference has been appended to the corpus, the last
# item in the embeddings list will be the reference's.
tokenized_source = self.tokenized_documents_[-1]
tokenized_source = self.processor.tokenizer(source)
scores = self._model.get_scores(tokenized_source)
return scores
69 changes: 58 additions & 11 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from click.testing import CliRunner
import pytest
from findlike import cli
import numpy as np
import json
from pathlib import Path

import numpy as np
import pytest
from click.testing import CliRunner
from scipy.stats import spearmanr

from findlike import cli

reference = "Hurricane Irene was a long-lived Cape Verde hurricane during the 2005 Atlantic hurricane season. The storm formed near Cape Verde on August 4 and crossed the Atlantic, turning northward around Bermuda before being absorbed by an extratropical cyclone while situated southeast of Newfoundland. Irene proved to be a difficult storm to forecast due to oscillations in strength. After almost dissipating on August 10, Irene peaked as a Category 2 hurricane on August 16. Irene persisted for 14 days as a tropical system, the longest duration of any storm of the 2005 season. It was the ninth named storm and fourth hurricane of the record-breaking season."
candidates = [
"In 2023, a sizable earthquake dubbed the 'Goliath Quake' occurred in the Pacific Ring of Fire. The seismic activity started near Japan on July 16 and sent shockwaves across the Pacific, leading to tsunamis in several coastal regions before dissipating on July 24. Forecasting the earthquake's impact was a complex task due to the frequent aftershocks. After a series of smaller tremors, the earthquake peaked with a magnitude of 8.9 on July 20. The quake lasted for 8 days, the longest seismic activity recorded in the 21st century. It was one of the most powerful earthquakes in recent history.",
Expand All @@ -21,7 +24,7 @@
]
scores = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0]

std_args = ["--query", reference, "-F", "json", "-s", "-m", len(scores)]
std_args = ["-F", "json", "-s", "-h", "-m", len(scores)]


@pytest.fixture
Expand All @@ -34,6 +37,8 @@ def create_directory(tmp_path):
for i, candidate in enumerate(candidates):
file = tmp_path / f"file_{i:02d}.txt"
file.write_text(candidate)
reference_file = tmp_path / "reference.txt"
reference_file.write_text(reference)
return str(tmp_path)


Expand All @@ -42,13 +47,17 @@ def create_directory_with_non_text(tmp_path):
for i, candidate in enumerate(candidates):
file = tmp_path / f"file_{i:02d}.000"
file.write_text(candidate)
reference_file = tmp_path / "reference.000"
reference_file.write_text(reference)
return str(tmp_path)


@pytest.mark.parametrize("format", cli.FORMATTER_CLASSES.keys())
def test_formats(runner, create_directory, format):
reference_path = Path(create_directory) / "reference.txt"
result = runner.invoke(
cli.cli, ["-d", create_directory, "-F", format, *std_args]
cli.cli,
[str(reference_path), "-d", create_directory, "-F", format, *std_args],
)
assert result.exit_code == 0

Expand All @@ -58,19 +67,57 @@ def test_formats(runner, create_directory, format):

@pytest.mark.parametrize("algorithm", cli.ALGORITHM_CLASSES.keys())
def test_algorithms(runner, create_directory, algorithm):
reference_path = Path(create_directory) / "reference.txt"
result = runner.invoke(
cli.cli, ["-d", create_directory, "-a", algorithm, *std_args]
cli.cli,
[
str(reference_path),
"-d",
create_directory,
"-a",
algorithm,
*std_args,
],
)
json_data = json.loads(result.output.strip())

output_scores = [float(x["score"]) for x in json_data]
assert spearmanr(output_scores, scores)[0] > 0.99
pairs = [(item["score"], item["target"]) for item in json_data]
sorted_pairs = sorted(pairs, key=lambda x: x[1])[::-1]
output_scores = [float(x[0]) for x in sorted_pairs]
corr = round(spearmanr(output_scores, scores)[0], 2)
assert corr >= 0.95


def test_other_extensions(runner, create_directory_with_non_text):
reference_path = Path(create_directory_with_non_text) / "reference.000"
result = runner.invoke(
cli.cli,
["-d", create_directory_with_non_text, "-f", "*.000", *std_args],
[
str(reference_path),
"-d",
create_directory_with_non_text,
"-f",
"*.000",
*std_args,
],
)
assert "000" in result.output.strip()


@pytest.mark.parametrize("algorithm", cli.ALGORITHM_CLASSES.keys())
def test_query(runner, create_directory, algorithm):
result = runner.invoke(
cli.cli,
[
"-q",
reference,
"-d",
create_directory,
"-a",
algorithm,
*std_args,
],
)
json_data = json.loads(result.output.strip())

output_scores = [float(x["score"]) for x in json_data]
assert spearmanr(output_scores, scores)[0] > 0.99
9 changes: 8 additions & 1 deletion tests/test_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,11 @@ def test_strip_front_matter(self, corpus):
"""
extension = ".org"
expected = "This is some text.\n** A heading\nSome more text."
assert corpus.strip_front_matter(dedent(document), extension) == expected
assert corpus.strip_front_matter(dedent(document), extension) == expected

def test_reference_duplicity(self, corpus, temp_files):
corpus.add_from_file(temp_files[0])
corpus.add_from_file(temp_files[1])
corpus.add_from_file(temp_files[0], is_reference=True)
duplicates = set([x for x in corpus.documents_ if corpus.documents_.count(x) > 1])
assert len(duplicates) == 0

0 comments on commit e1e05e0

Please sign in to comment.