diff --git a/sefaria/helper/normalization.py b/sefaria/helper/normalization.py index 1aea57e560..9969022838 100644 --- a/sefaria/helper/normalization.py +++ b/sefaria/helper/normalization.py @@ -86,7 +86,6 @@ def remove_subsets_reducer(curr_text_to_remove: list, next: tuple) -> list: def get_mapping_after_normalization(self, text, removal_list=None, reverse=False, **kwargs): """ text - unnormalized text - find_text_to_remove - function which takes text as param and return list of tuples. each tuple is of form ((start, end), replacement) where start and end are indices in unnormalized string and replacement is the string that will replace text at these indices removal_list - instead of passing `find_text_to_remove`, you can pass an already calculated list of tuples. should be in same format as return value of find_text_to_remove reverse - bool. If True, then will return mapping from unnormalized string to normalized string @@ -109,8 +108,10 @@ def get_mapping_after_normalization(self, text, removal_list=None, reverse=False # must be match object start, end = removal.start(), removal.end() normalized_text_index = start if reverse else (start + min(len(subst), end-start) - total_removed) - total_removed += (end - start - len(subst)) - removal_map[normalized_text_index] = total_removed + curr_removed = end - start - len(subst) + if curr_removed > 0: + total_removed += curr_removed + removal_map[normalized_text_index] = total_removed return removal_map @staticmethod @@ -125,7 +126,9 @@ def convert_normalized_indices_to_unnormalized_indices(normalized_indices, remov sign = -1 if reverse else 1 for start, end in normalized_indices: unnorm_start_index = bisect_right(removal_keys, start) - 1 - unnorm_end_index = bisect_right(removal_keys, (end - 1 if reverse else end)) - 1 # not sure if end-1 is specific to reverse case, but seems to be working + # special case if range is zero-length. treat end as literal and not off-by-one. + bisect_end_index = end if end == start else (end - 1) + unnorm_end_index = bisect_right(removal_keys, bisect_end_index) - 1 unnorm_start = start if unnorm_start_index < 0 else start + (sign * removal_map[removal_keys[unnorm_start_index]]) unnorm_end = end if unnorm_end_index < 0 else end + (sign * removal_map[removal_keys[unnorm_end_index]]) @@ -255,60 +258,47 @@ def find_text_to_remove(self, s, **kwargs): apply normalization steps one-by-one and keep track of mapping from one step to the next iteratively apply mappings (in reverse) on each step's removal inds to get inds in original string """ - all_text_to_remove = [] + final_text_to_remove = [] mappings = [] snorm = s for step in self.steps: - temp_text_to_remove = step.find_text_to_remove(snorm, **kwargs) - if len(temp_text_to_remove) == 0: + curr_text_to_remove = step.find_text_to_remove(snorm, **kwargs) + if len(curr_text_to_remove) == 0: text_to_remove_inds, text_to_remove_repls = [], [] else: - text_to_remove_inds, text_to_remove_repls = zip(*temp_text_to_remove) + text_to_remove_inds, text_to_remove_repls = zip(*curr_text_to_remove) for mapping in reversed(mappings): text_to_remove_inds = step.convert_normalized_indices_to_unnormalized_indices(text_to_remove_inds, mapping) - temp_text_to_remove = list(zip(text_to_remove_inds, text_to_remove_repls)) - all_text_to_remove += [temp_text_to_remove] + curr_text_to_remove = list(zip(text_to_remove_inds, text_to_remove_repls)) + + # merge any overlapping ranges + # later edits should override earlier ones + final_text_to_remove = self.merge_removal_inds(final_text_to_remove, curr_text_to_remove) mappings += [step.get_mapping_after_normalization(snorm, **kwargs)] snorm = step.normalize(snorm, **kwargs) - # merge any overlapping ranges - # later edits should override earlier ones - final_text_to_remove = reduce(lambda a, b: self.merge_removal_inds(a, b), all_text_to_remove) final_text_to_remove.sort(key=lambda x: x[0]) return final_text_to_remove @staticmethod - def merge_removal_inds(curr_removal_inds, new_removal_inds): - if isinstance(new_removal_inds, tuple): - new_removal_inds = [new_removal_inds] - curr_removal_inds.sort(key=lambda x: x[0]) - new_removal_inds.sort(key=lambda x: x[0]) - merged_inds = curr_removal_inds[:] - last_curr = 0 - for new_inds, new_repl in new_removal_inds: - inds_are_final = True - for i, (curr_inds, curr_repl) in enumerate(curr_removal_inds[last_curr:]): - if new_inds[1] <= curr_inds[0]: - # curr_inds are past new_inds indicating rest of curr_inds will also be past. break early. - break - elif curr_inds[0] >= new_inds[0] and curr_inds[1] <= new_inds[1]: # are curr_inds subset of new_inds? - # if earlier inds are a subset of later inds, later inds override - merged_inds.remove((curr_inds, curr_repl)) - elif new_inds[0] < curr_inds[1] or new_inds[1] > curr_inds[0]: - # if later inds overlap and earlier inds are not a subset, merge - if new_inds[0] >= curr_inds[0] and new_inds[1] <= curr_inds[1]: - merged_repl = curr_repl[:new_inds[0] - curr_inds[0]] + new_repl + curr_repl[new_inds[1] - - curr_inds[1]:] - merged_inds[i+last_curr] = (curr_inds, merged_repl) - inds_are_final = False - last_curr += 1 - break - else: - # overlap that's not a subset. more complicated merge that I don't want to deal with now - pass - last_curr += 1 - if inds_are_final: - merged_inds += [(new_inds, new_repl)] - return merged_inds + def merge_removal_inds(*all_removal_inds): + combined_removal_inds = reduce(lambda a, b: a + b, all_removal_inds, []) + combined_removal_inds.sort(key=lambda x: x[0][0]) + merged_removal_inds = [] + for curr_inds, curr_repl in combined_removal_inds: + if len(merged_removal_inds) == 0: + merged_removal_inds += [(curr_inds, curr_repl)] + continue + last_inds, last_repl = merged_removal_inds[-1] + if curr_inds[0] >= last_inds[1]: + # If current interval doesn't overlap with the last interval in result, append it + merged_removal_inds += [(curr_inds, curr_repl)] + else: + # some sort of overlap + curr_merged_inds = (last_inds[0], max(last_inds[1], curr_inds[1])) + curr_merged_repl = last_repl[:curr_inds[0]-last_inds[0]] + curr_repl + last_repl[(curr_inds[1]+1)-last_inds[0]:] + merged_removal_inds[-1] = (curr_merged_inds, curr_merged_repl) + + return merged_removal_inds class TableReplaceNormalizer(AbstractNormalizer): diff --git a/sefaria/helper/tests/normalization_tests.py b/sefaria/helper/tests/normalization_tests.py index 1b7d8903f0..59e4ed7767 100644 --- a/sefaria/helper/tests/normalization_tests.py +++ b/sefaria/helper/tests/normalization_tests.py @@ -49,15 +49,27 @@ def test_br_tag_html_composer(): assert text[start4:end4] == ' ' -def test_normalizer_composer(): +def test_simpler_normalizer_composer(): + text = ' [sup' + normalized = " sup" + nsc = NormalizerComposer(['brackets', 'double-space']) + assert nsc.normalize(text) == normalized + text_to_remove = nsc.find_text_to_remove(text) + assert len(text_to_remove) == 2 + (start0, end0), repl0 = text_to_remove[0] + assert text[start0:end0] == " " + assert repl0 == ' ' + + +def test_complicated_normalizer_composer(): text = """(hello other stuff) [sup] (this is) a test""" normalized = """ sup a test """ nsc = NormalizerComposer(['html', "parens-plus-contents", 'brackets', 'double-space']) assert nsc.normalize(text) == normalized text_to_remove = nsc.find_text_to_remove(text) - assert len(text_to_remove) == 5 + assert len(text_to_remove) == 6 (start0, end0), repl0 = text_to_remove[0] - assert text[start0:end0] == "(hello other stuff) [" + assert text[start0:end0] == "(hello other stuff) " assert repl0 == ' ' @@ -96,7 +108,7 @@ def test_word_to_char(): word_indices = (2, 4) result = char_indices_from_word_indices(test_string, [word_indices])[0] start, end = result - assert test_string[start:end] == 'go here\n\nhello ' # TODO used to not have trailing space. not sure how critical this is. + assert test_string[start:end] == 'go here\n\nhello' assert test_string[start:end].split() == words