Skip to content

Commit

Permalink
support hashables in fuzz module
Browse files Browse the repository at this point in the history
  • Loading branch information
maxbachmann committed Aug 15, 2023
1 parent 86db592 commit 99ab6a4
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 26 deletions.
77 changes: 58 additions & 19 deletions src/rapidfuzz/fuzz_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from math import ceil
import itertools
from typing import Any, Callable, Hashable, Sequence

from rapidfuzz._common_py import conv_sequences
Expand Down Expand Up @@ -33,6 +34,36 @@ def _norm_distance(dist: int, lensum: int, score_cutoff: float) -> float:
return score if score >= score_cutoff else 0


def _split_sequence(seq: Sequence[Hashable]) -> list[Sequence[Hashable]]:
if isinstance(seq, str) or isinstance(seq, bytes):
return seq.split()

splitted_seq = [[]]
for x in seq:
ch = x if isinstance(x, str) else chr(x)
if ch.isspace():
splitted_seq.append([])
else:
splitted_seq[-1].append(x)

return [tuple(x) for x in splitted_seq if x]


def _join_splitted_sequence(seq_list: list[Sequence[Hashable]]):
if not seq_list:
return ""
if isinstance(next(iter(seq_list)), str):
return " ".join(seq_list)
if isinstance(next(iter(seq_list)), bytes):
return b" ".join(seq_list)

joined = []
for seq in seq_list:
joined += seq
joined += [ord(' ')]
return joined[:-1]


def ratio(
s1: Sequence[Hashable] | None,
s2: Sequence[Hashable] | None,
Expand Down Expand Up @@ -356,8 +387,9 @@ def token_sort_ratio(
s1 = processor(s1)
s2 = processor(s2)

sorted_s1 = " ".join(sorted(s1.split()))
sorted_s2 = " ".join(sorted(s2.split()))
s1, s2 = conv_sequences(s1, s2)
sorted_s1 = _join_splitted_sequence(sorted(_split_sequence(s1)))
sorted_s2 = _join_splitted_sequence(sorted(_split_sequence(s2)))
return ratio(sorted_s1, sorted_s2, score_cutoff=score_cutoff)


Expand Down Expand Up @@ -412,8 +444,10 @@ def token_set_ratio(
if score_cutoff is None:
score_cutoff = 0

tokens_a = set(s1.split())
tokens_b = set(s2.split())
s1, s2 = conv_sequences(s1, s2)

tokens_a = set(_split_sequence(s1))
tokens_b = set(_split_sequence(s2))

# in FuzzyWuzzy this returns 0. For sake of compatibility return 0 here as well
# see https://github.com/maxbachmann/RapidFuzz/issues/110
Expand All @@ -427,13 +461,13 @@ def token_set_ratio(
if not intersect and (not diff_ab or not diff_ba):
return 100

diff_ab_joined = " ".join(sorted(diff_ab))
diff_ba_joined = " ".join(sorted(diff_ba))
diff_ab_joined = _join_splitted_sequence(sorted(diff_ab))
diff_ba_joined = _join_splitted_sequence(sorted(diff_ba))

ab_len = len(diff_ab_joined)
ba_len = len(diff_ba_joined)
# todo is length sum without joining faster?
sect_len = len(" ".join(intersect))
sect_len = len(_join_splitted_sequence(intersect))

# string length sect+ab <-> sect and sect+ba <-> sect
sect_ab_len = sect_len + (sect_len != 0) + ab_len
Expand Down Expand Up @@ -550,8 +584,9 @@ def partial_token_sort_ratio(
s1 = processor(s1)
s2 = processor(s2)

sorted_s1 = " ".join(sorted(s1.split()))
sorted_s2 = " ".join(sorted(s2.split()))
s1, s2 = conv_sequences(s1, s2)
sorted_s1 = _join_splitted_sequence(sorted(_split_sequence(s1)))
sorted_s2 = _join_splitted_sequence(sorted(_split_sequence(s2)))
return partial_ratio(sorted_s1, sorted_s2, score_cutoff=score_cutoff)


Expand Down Expand Up @@ -596,8 +631,10 @@ def partial_token_set_ratio(
s1 = processor(s1)
s2 = processor(s2)

tokens_a = set(s1.split())
tokens_b = set(s2.split())
s1, s2 = conv_sequences(s1, s2)

tokens_a = set(_split_sequence(s1))
tokens_b = set(_split_sequence(s2))
# in FuzzyWuzzy this returns 0. For sake of compatibility return 0 here as well
# see https://github.com/maxbachmann/RapidFuzz/issues/110
if not tokens_a or not tokens_b:
Expand All @@ -607,8 +644,8 @@ def partial_token_set_ratio(
if tokens_a.intersection(tokens_b):
return 100

diff_ab = " ".join(sorted(tokens_a.difference(tokens_b)))
diff_ba = " ".join(sorted(tokens_b.difference(tokens_a)))
diff_ab = _join_splitted_sequence(sorted(tokens_a.difference(tokens_b)))
diff_ba = _join_splitted_sequence(sorted(tokens_b.difference(tokens_a)))
return partial_ratio(diff_ab, diff_ba, score_cutoff=score_cutoff)


Expand Down Expand Up @@ -656,8 +693,10 @@ def partial_token_ratio(
if score_cutoff is None:
score_cutoff = 0

tokens_split_a = s1.split()
tokens_split_b = s2.split()
s1, s2 = conv_sequences(s1, s2)

tokens_split_a = _split_sequence(s1)
tokens_split_b = _split_sequence(s2)
tokens_a = set(tokens_split_a)
tokens_b = set(tokens_split_b)

Expand All @@ -669,8 +708,8 @@ def partial_token_ratio(
diff_ba = tokens_b.difference(tokens_a)

result = partial_ratio(
" ".join(sorted(tokens_split_a)),
" ".join(sorted(tokens_split_b)),
_join_splitted_sequence(sorted(tokens_split_a)),
_join_splitted_sequence(sorted(tokens_split_b)),
score_cutoff=score_cutoff,
)

Expand All @@ -682,8 +721,8 @@ def partial_token_ratio(
return max(
result,
partial_ratio(
" ".join(sorted(diff_ab)),
" ".join(sorted(diff_ba)),
_join_splitted_sequence(sorted(diff_ab)),
_join_splitted_sequence(sorted(diff_ba)),
score_cutoff=score_cutoff,
),
)
Expand Down
10 changes: 10 additions & 0 deletions tests/distance/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ def test_similar_array(scorer):
assert scorer.normalized_similarity(array("u", "the wonderful new york mets"), "the wonderful new york mets") == 1.0


@pytest.mark.parametrize("scorer", all_scorer_modules)
def test_similar_bytes(scorer):
"""
bytes should be supported and treated in a compatible way to strings
"""
assert scorer.normalized_similarity(b"the wonderful new york mets", b"the wonderful new york mets") == 1.0
assert scorer.normalized_similarity("the wonderful new york mets", b"the wonderful new york mets") == 1.0
assert scorer.normalized_similarity(b"the wonderful new york mets", "the wonderful new york mets") == 1.0


@pytest.mark.parametrize("scorer", all_scorer_modules)
def test_similar_ord_array(scorer):
"""
Expand Down
24 changes: 17 additions & 7 deletions tests/test_fuzz.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,13 @@ def QRatio(*args, **kwargs):
hashable_scorers = [
fuzz.ratio,
fuzz.partial_ratio,
fuzz_cpp.token_sort_ratio,
fuzz_cpp.token_set_ratio,
fuzz_cpp.token_ratio,
fuzz_cpp.partial_token_sort_ratio,
fuzz_cpp.partial_token_set_ratio,
fuzz_cpp.partial_token_ratio,
fuzz_cpp.WRatio,
fuzz.token_sort_ratio,
fuzz.token_set_ratio,
fuzz.token_ratio,
fuzz.partial_token_sort_ratio,
fuzz.partial_token_set_ratio,
fuzz.partial_token_ratio,
fuzz.WRatio,
fuzz.QRatio,
]

Expand Down Expand Up @@ -277,6 +277,16 @@ def test_array(scorer):
assert scorer(array("u", "the wonderful new york mets"), "the wonderful new york mets")


@pytest.mark.parametrize("scorer", hashable_scorers)
def test_bytes(scorer):
"""
bytes should be supported and treated in a compatible way to strings
"""
assert scorer(b"the wonderful new york mets", b"the wonderful new york mets") == 100
assert scorer("the wonderful new york mets", b"the wonderful new york mets") == 100
assert scorer(b"the wonderful new york mets", "the wonderful new york mets") == 100


@pytest.mark.parametrize("scorer", scorers)
def test_none_string(scorer):
"""
Expand Down

0 comments on commit 99ab6a4

Please sign in to comment.