Skip to content

Commit

Permalink
only capture exception when they are expected
Browse files Browse the repository at this point in the history
  • Loading branch information
maxbachmann committed Dec 28, 2023
1 parent a7c3f6e commit a4c4523
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 51 deletions.
140 changes: 91 additions & 49 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def is_none(s):
return False


def call_and_maybe_catch(call, *args, **kwargs):
def call_and_maybe_catch(call, *args, catch_exceptions=False, **kwargs):
if not catch_exceptions:
return call(*args, **kwargs)

try:
return call(*args, **kwargs)
except AssertionError as e:
Expand All @@ -51,7 +54,7 @@ def compare_exceptions(e1, e2):
return False


def scorer_tester(scorer, s1, s2, **kwargs):
def scorer_tester(scorer, s1, s2, catch_exceptions=False, **kwargs):
score1 = call_and_maybe_catch(scorer, s1, s2, **kwargs)
exception = isinstance(score1, Exception)

Expand All @@ -69,12 +72,24 @@ def scorer_tester(scorer, s1, s2, **kwargs):
if temp_kwargs:
process_kwargs["scorer_kwargs"] = temp_kwargs

extractOne_res1 = call_and_maybe_catch(process_cpp.extractOne, s1, [s2], scorer=scorer, **process_kwargs)
extractOne_res2 = call_and_maybe_catch(process_py.extractOne, s1, [s2], scorer=scorer, **process_kwargs)
extract_res1 = call_and_maybe_catch(process_cpp.extract, s1, [s2], scorer=scorer, **process_kwargs)
extract_res2 = call_and_maybe_catch(process_py.extract, s1, [s2], scorer=scorer, **process_kwargs)
extract_iter_res1 = call_and_maybe_catch(list, process_cpp.extract_iter(s1, [s2], scorer=scorer, **process_kwargs))
extract_iter_res2 = call_and_maybe_catch(list, process_py.extract_iter(s1, [s2], scorer=scorer, **process_kwargs))
extractOne_res1 = call_and_maybe_catch(
process_cpp.extractOne, s1, [s2], catch_exceptions=catch_exceptions, scorer=scorer, **process_kwargs
)
extractOne_res2 = call_and_maybe_catch(
process_py.extractOne, s1, [s2], catch_exceptions=catch_exceptions, scorer=scorer, **process_kwargs
)
extract_res1 = call_and_maybe_catch(
process_cpp.extract, s1, [s2], catch_exceptions=catch_exceptions, scorer=scorer, **process_kwargs
)
extract_res2 = call_and_maybe_catch(
process_py.extract, s1, [s2], catch_exceptions=catch_exceptions, scorer=scorer, **process_kwargs
)
extract_iter_res1 = call_and_maybe_catch(
list, process_cpp.extract_iter(s1, [s2], scorer=scorer, **process_kwargs), catch_exceptions=catch_exceptions
)
extract_iter_res2 = call_and_maybe_catch(
list, process_py.extract_iter(s1, [s2], scorer=scorer, **process_kwargs), catch_exceptions=catch_exceptions
)

if exception:
assert compare_exceptions(extractOne_res1, score1)
Expand Down Expand Up @@ -123,11 +138,19 @@ def scorer_tester(scorer, s1, s2, **kwargs):
np = None

if np is not None:
cdist_scores1 = call_and_maybe_catch(process_cpp.cdist, [s1], [s2], scorer=scorer, **process_kwargs)
cdist_scores2 = call_and_maybe_catch(process_py.cdist, [s1], [s2], scorer=scorer, **process_kwargs)
cdist_scores1 = call_and_maybe_catch(
process_cpp.cdist, [s1], [s2], catch_exceptions=catch_exceptions, scorer=scorer, **process_kwargs
)
cdist_scores2 = call_and_maybe_catch(
process_py.cdist, [s1], [s2], catch_exceptions=catch_exceptions, scorer=scorer, **process_kwargs
)
# probably trigger multi match / simd implementations
cdist_scores3 = call_and_maybe_catch(process_cpp.cdist, [s1] * 2, [s2] * 4, scorer=scorer, **process_kwargs)
cdist_scores4 = call_and_maybe_catch(process_py.cdist, [s1] * 2, [s2] * 4, scorer=scorer, **process_kwargs)
cdist_scores3 = call_and_maybe_catch(
process_cpp.cdist, [s1] * 2, [s2] * 4, catch_exceptions=catch_exceptions, scorer=scorer, **process_kwargs
)
cdist_scores4 = call_and_maybe_catch(
process_py.cdist, [s1] * 2, [s2] * 4, catch_exceptions=catch_exceptions, scorer=scorer, **process_kwargs
)

if exception:
assert compare_exceptions(cdist_scores1, score1)
Expand All @@ -146,9 +169,9 @@ def scorer_tester(scorer, s1, s2, **kwargs):
return score1


def symmetric_scorer_tester(scorer, s1, s2, **kwargs):
score1 = call_and_maybe_catch(scorer_tester, scorer, s1, s2, **kwargs)
score2 = call_and_maybe_catch(scorer_tester, scorer, s2, s1, **kwargs)
def symmetric_scorer_tester(scorer, s1, s2, catch_exceptions=False, **kwargs):
score1 = call_and_maybe_catch(scorer_tester, scorer, s1, s2, catch_exceptions=catch_exceptions, **kwargs)
score2 = call_and_maybe_catch(scorer_tester, scorer, s2, s1, catch_exceptions=catch_exceptions, **kwargs)

if isinstance(score1, Exception):
assert compare_exceptions(score1, score2)
Expand Down Expand Up @@ -195,8 +218,11 @@ def validate_attrs(func1, func2):

self.get_scorer_flags = get_scorer_flags

def _editops(self, s1, s2, **kwargs):
results = [call_and_maybe_catch(scorer.editops, s1, s2, **kwargs) for scorer in self.scorers]
def _editops(self, s1, s2, catch_exceptions=False, **kwargs):
results = [
call_and_maybe_catch(scorer.editops, s1, s2, catch_exceptions=catch_exceptions, **kwargs)
for scorer in self.scorers
]

for result in results:
assert compare_exceptions(result, results[0])
Expand All @@ -206,8 +232,11 @@ def _editops(self, s1, s2, **kwargs):

return results[0]

def _opcodes(self, s1, s2, **kwargs):
results = [call_and_maybe_catch(scorer.opcodes, s1, s2, **kwargs) for scorer in self.scorers]
def _opcodes(self, s1, s2, catch_exceptions=False, **kwargs):
results = [
call_and_maybe_catch(scorer.opcodes, s1, s2, catch_exceptions=catch_exceptions, **kwargs)
for scorer in self.scorers
]

for result in results:
assert compare_exceptions(result, results[0])
Expand All @@ -217,11 +246,14 @@ def _opcodes(self, s1, s2, **kwargs):

return results[0]

def _distance(self, s1, s2, **kwargs):
def _distance(self, s1, s2, catch_exceptions=False, **kwargs):
symmetric = self.get_scorer_flags(s1, s2, **kwargs)["symmetric"]
tester = symmetric_scorer_tester if symmetric else scorer_tester

scores = [call_and_maybe_catch(tester, scorer.distance, s1, s2, **kwargs) for scorer in self.scorers]
scores = [
call_and_maybe_catch(tester, scorer.distance, s1, s2, catch_exceptions=catch_exceptions, **kwargs)
for scorer in self.scorers
]

if any(isinstance(score, Exception) for score in scores):
for score in scores:
Expand All @@ -232,7 +264,7 @@ def _distance(self, s1, s2, **kwargs):
assert pytest.approx(scores[0]) == scores[-1]
return scores[0]

def _similarity(self, s1, s2, **kwargs):
def _similarity(self, s1, s2, catch_exceptions=False, **kwargs):
symmetric = self.get_scorer_flags(s1, s2, **kwargs)["symmetric"]
tester = symmetric_scorer_tester if symmetric else scorer_tester

Expand All @@ -247,11 +279,16 @@ def _similarity(self, s1, s2, **kwargs):
assert pytest.approx(scores[0]) == scores[-1]
return scores[0]

def _normalized_distance(self, s1, s2, **kwargs):
def _normalized_distance(self, s1, s2, catch_exceptions=False, **kwargs):
symmetric = self.get_scorer_flags(s1, s2, **kwargs)["symmetric"]
tester = symmetric_scorer_tester if symmetric else scorer_tester

scores = [call_and_maybe_catch(tester, scorer.normalized_distance, s1, s2, **kwargs) for scorer in self.scorers]
scores = [
call_and_maybe_catch(
tester, scorer.normalized_distance, s1, s2, catch_exceptions=catch_exceptions, **kwargs
)
for scorer in self.scorers
]

if any(isinstance(score, Exception) for score in scores):
for score in scores:
Expand All @@ -262,12 +299,15 @@ def _normalized_distance(self, s1, s2, **kwargs):
assert pytest.approx(scores[0]) == scores[-1]
return scores[0]

def _normalized_similarity(self, s1, s2, **kwargs):
def _normalized_similarity(self, s1, s2, catch_exceptions=False, **kwargs):
symmetric = self.get_scorer_flags(s1, s2, **kwargs)["symmetric"]
tester = symmetric_scorer_tester if symmetric else scorer_tester

scores = [
call_and_maybe_catch(tester, scorer.normalized_similarity, s1, s2, **kwargs) for scorer in self.scorers
call_and_maybe_catch(
tester, scorer.normalized_similarity, s1, s2, catch_exceptions=catch_exceptions, **kwargs
)
for scorer in self.scorers
]

if any(isinstance(score, Exception) for score in scores):
Expand All @@ -279,17 +319,19 @@ def _normalized_similarity(self, s1, s2, **kwargs):
assert pytest.approx(scores[0]) == scores[-1]
return scores[0]

def _validate(self, s1, s2, **kwargs):
def _validate(self, s1, s2, catch_exceptions=False, **kwargs):
# todo requires more complex test handling
# score_cutoff = kwargs.get("score_cutoff")
kwargs = {k: v for k, v in kwargs.items() if k != "score_cutoff"}

maximum = self.get_scorer_flags(s1, s2, **kwargs)["maximum"]

dist = call_and_maybe_catch(self._distance, s1, s2, **kwargs)
sim = call_and_maybe_catch(self._similarity, s1, s2, **kwargs)
norm_dist = call_and_maybe_catch(self._normalized_distance, s1, s2, **kwargs)
norm_sim = call_and_maybe_catch(self._normalized_similarity, s1, s2, **kwargs)
dist = call_and_maybe_catch(self._distance, s1, s2, catch_exceptions=catch_exceptions, **kwargs)
sim = call_and_maybe_catch(self._similarity, s1, s2, catch_exceptions=catch_exceptions, **kwargs)
norm_dist = call_and_maybe_catch(self._normalized_distance, s1, s2, catch_exceptions=catch_exceptions, **kwargs)
norm_sim = call_and_maybe_catch(
self._normalized_similarity, s1, s2, catch_exceptions=catch_exceptions, **kwargs
)

if isinstance(dist, Exception):
assert compare_exceptions(dist, sim)
Expand All @@ -307,45 +349,45 @@ def _validate(self, s1, s2, **kwargs):

return dist, sim, norm_dist, norm_sim

def distance(self, s1, s2, **kwargs):
dist, _, _, _ = self._validate(s1, s2, **kwargs)
def distance(self, s1, s2, catch_exceptions=False, **kwargs):
dist, _, _, _ = self._validate(s1, s2, catch_exceptions=catch_exceptions, **kwargs)
if "score_cutoff" not in kwargs:
return dist

return self._distance(s1, s2, **kwargs)
return self._distance(s1, s2, catch_exceptions=catch_exceptions, **kwargs)

def similarity(self, s1, s2, **kwargs):
_, sim, _, _ = self._validate(s1, s2, **kwargs)
def similarity(self, s1, s2, catch_exceptions=False, **kwargs):
_, sim, _, _ = self._validate(s1, s2, catch_exceptions=catch_exceptions, **kwargs)
if "score_cutoff" not in kwargs:
return sim

return self._similarity(s1, s2, **kwargs)
return self._similarity(s1, s2, catch_exceptions=catch_exceptions, **kwargs)

def normalized_distance(self, s1, s2, **kwargs):
def normalized_distance(self, s1, s2, catch_exceptions=False, **kwargs):
if not is_none(s1) and not is_none(s2):
_, _, norm_dist, _ = self._validate(s1, s2, **kwargs)
_, _, norm_dist, _ = self._validate(s1, s2, catch_exceptions=catch_exceptions, **kwargs)
# todo we should be able to handle this in a nicer way
if "score_cutoff" not in kwargs:
return norm_dist
return self._normalized_distance(s1, s2, **kwargs)
return self._normalized_distance(s1, s2, catch_exceptions=catch_exceptions, **kwargs)

def normalized_similarity(self, s1, s2, **kwargs):
def normalized_similarity(self, s1, s2, catch_exceptions=False, **kwargs):
if not is_none(s1) and not is_none(s2):
_, _, _, norm_sim = self._validate(s1, s2, **kwargs)
_, _, _, norm_sim = self._validate(s1, s2, catch_exceptions=catch_exceptions, **kwargs)
if "score_cutoff" not in kwargs:
return norm_sim
return self._normalized_similarity(s1, s2, **kwargs)
return self._normalized_similarity(s1, s2, catch_exceptions=catch_exceptions, **kwargs)

def editops(self, s1, s2, **kwargs):
editops_ = self._editops(s1, s2, **kwargs)
opcodes_ = self._opcodes(s1, s2, **kwargs)
def editops(self, s1, s2, catch_exceptions=False, **kwargs):
editops_ = self._editops(s1, s2, catch_exceptions=catch_exceptions, **kwargs)
opcodes_ = self._opcodes(s1, s2, catch_exceptions=catch_exceptions, **kwargs)
assert opcodes_.as_editops() == editops_
assert opcodes_ == editops_.as_opcodes()
return editops_

def opcodes(self, s1, s2, **kwargs):
editops_ = self._editops(s1, s2, **kwargs)
opcodes_ = self._opcodes(s1, s2, **kwargs)
def opcodes(self, s1, s2, catch_exceptions=False, **kwargs):
editops_ = self._editops(s1, s2, catch_exceptions=catch_exceptions, **kwargs)
opcodes_ = self._opcodes(s1, s2, catch_exceptions=catch_exceptions, **kwargs)
assert opcodes_.as_editops() == editops_
assert opcodes_ == editops_.as_opcodes()
return opcodes_
3 changes: 2 additions & 1 deletion tests/distance/test_Hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def test_disable_padding():
assert Hamming.distance("aaaa", "bbbb", pad=False) == 4

with pytest.raises(ValueError, match="Sequences are not the same length."):
Hamming.distance("aaaa", "aaaaa", pad=False)
Hamming.distance("aaaa", "aaaaa", catch_exceptions=True, pad=False)

# todo
with pytest.raises(ValueError, match="Sequences are not the same length."):
metrics_cpp.hamming_editops("aaaa", "aaaaa", pad=False)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_fuzz.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def test_invalid_input(scorer):
when invalid types are passed to a scorer an exception should be thrown
"""
with pytest.raises(TypeError):
scorer(1, 1)
scorer(1, 1, catch_exceptions=True)


@pytest.mark.parametrize("scorer", scorers)
Expand Down

0 comments on commit a4c4523

Please sign in to comment.