From a4212d96ed3ae08314205c620bf2c0bba0fe63fc Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 6 Oct 2021 16:21:37 -0700 Subject: [PATCH] almond_translate: accommodate enforcement of unique ids differently instead change ids on the go to be unique, lifting the burden from user or dataset --- genienlp/tasks/almond_task.py | 185 ++++++++++++++++------------------ 1 file changed, 89 insertions(+), 96 deletions(-) diff --git a/genienlp/tasks/almond_task.py b/genienlp/tasks/almond_task.py index c8960e5d1..a66c3324f 100644 --- a/genienlp/tasks/almond_task.py +++ b/genienlp/tasks/almond_task.py @@ -305,51 +305,107 @@ def __init__(self, name, args): self._metrics = ['casedbleu'] self.need_attention_scores = True - def preprocess_field(self, sentence, field_name=None, answer=None, example_id=None, preprocess_entities=True): - assert example_id - if field_name != 'answer': - if field_name + '-' + example_id in self.all_ids: - logger.warning( - f'example id: {example_id} is repeated in the dataset. If using alignment, ids between all data splits have to be unique' - ) - example_id += '+' + def construct_id2span_mapping(self, example_id, sentence, field_name): + assert field_name in ['context', 'question'] + # translation task constructs a dictionary mapping ids to entity spans in the sentence + # this ensures the ids are unique + while field_name + '-' + example_id in self.all_ids: + example_id += '.' - self.all_ids.add(field_name + '-' + example_id) + self.all_ids.add(field_name + '-' + example_id) - src_quotation_symbol = '"' - src_tokens = sentence.split(" ") - src_spans_ind = [index for index, token in enumerate(src_tokens) if token == src_quotation_symbol] + src_quotation_symbol = '"' + src_tokens = sentence.split(" ") + src_spans_ind = [index for index, token in enumerate(src_tokens) if token == src_quotation_symbol] - if len(src_spans_ind) % 2 != 0: - raise ValueError(f'Corrupted span in sentence: [{sentence}]') + if len(src_spans_ind) % 2 != 0: + raise ValueError(f'Corrupted span in sentence: [{sentence}]') - if self.args.align_preserve_input_quotation: - src_spans = [(src_spans_ind[i] + 1, src_spans_ind[i + 1] - 1) for i in range(0, len(src_spans_ind), 2)] - else: - src_tokens = [token for token in src_tokens if token != src_quotation_symbol] - src_spans = [ - (src_spans_ind[i] + 1 - (i + 1), src_spans_ind[i + 1] - 1 - (i + 1)) - for i in range(0, len(src_spans_ind), 2) - ] + if self.args.align_preserve_input_quotation: + src_spans = [(src_spans_ind[i] + 1, src_spans_ind[i + 1] - 1) for i in range(0, len(src_spans_ind), 2)] + else: + src_tokens = [token for token in src_tokens if token != src_quotation_symbol] + src_spans = [ + (src_spans_ind[i] + 1 - (i + 1), src_spans_ind[i + 1] - 1 - (i + 1)) for i in range(0, len(src_spans_ind), 2) + ] + + # remove illegal src_spans (caused by inputs such as " ") + src_spans = [span for span in src_spans if span[0] <= span[1]] - # remove illegal src_spans (caused by inputs such as " ") - src_spans = [span for span in src_spans if span[0] <= span[1]] + sentence = " ".join(src_tokens) + src_spans_flatten = [val for tup in src_spans for val in tup] - sentence = " ".join(src_tokens) - src_spans_flatten = [val for tup in src_spans for val in tup] + # append question spans to context spans + if example_id in self.input_spans: + self.input_spans[example_id] += src_spans_flatten + else: + self.input_spans[example_id] = src_spans_flatten + + return example_id, sentence + + def preprocess_field(self, sentence, field_name=None, answer=None, example_id=None, preprocess_entities=True): + return super().preprocess_field(sentence, field_name, answer, preprocess_entities) - # append question spans to context spans - if example_id in self.input_spans: - self.input_spans[example_id] += src_spans_flatten + def _make_example(self, parts, dir_name=None, **kwargs): + # answer has to be provided by default unless doing prediction + no_answer = getattr(self.args, 'translate_no_answer', False) + split_sentence = getattr(self.args, 'translate_example_split', False) + src_lang = kwargs.get('src_lang', 'en') + + example_id = 'id-null' + question = 'translate from input to output' + + if no_answer: + if len(parts) == 1: + context = parts + elif len(parts) == 2: + example_id, context = parts + elif len(parts) == 4: + raise ValueError(f'Input file contains a line with {len(parts)} parts: {str(parts)}') + else: + if len(parts) == 2: + context, answer = parts + elif len(parts) == 3: + example_id, context, answer = parts else: - self.input_spans[example_id] = src_spans_flatten + raise ValueError(f'Input file contains a line with {len(parts)} parts: {str(parts)}') - sentence = super().preprocess_field(sentence, field_name, answer, preprocess_entities) + # no answer is provided + if no_answer: + answer = '.' - return sentence + contexts = [] + src_char_spans = None + if split_sentence: + if self.args.do_alignment: + src_quotation_symbol = '"' + src_char_spans_ind = [index for index, char in enumerate(context) if char == src_quotation_symbol] + src_char_spans = [ + (src_char_spans_ind[i], src_char_spans_ind[i + 1]) for i in range(0, len(src_char_spans_ind), 2) + ] + contexts = split_text_into_sentences(context, src_lang, src_char_spans) - def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs): + if len(contexts) > 1: + examples = [] + for i, text in enumerate(contexts): + ex_id, text = self.construct_id2span_mapping(self.name + '/' + example_id + f'@{i}', text, 'context') + examples.append( + Example.from_raw( + ex_id, + text, + question, + answer, + preprocess=self.preprocess_field, + lower=False, + ) + ) + else: + ex_id, context = self.construct_id2span_mapping(self.name + '/' + example_id, context, 'context') + examples = Example.from_raw(ex_id, context, question, answer, preprocess=self.preprocess_field, lower=False) + return examples + + def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs): numericalizer = kwargs.pop('numericalizer') cross_attentions = kwargs.pop('cross_attentions') tgt_lang = kwargs.pop('tgt_lang') @@ -442,69 +498,6 @@ def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, bat return partial_batch_prediction_ids, all_text_outputs - def _make_example(self, parts, dir_name=None, **kwargs): - # answer has to be provided by default unless doing prediction - no_answer = getattr(self.args, 'translate_no_answer', False) - split_sentence = getattr(self.args, 'translate_example_split', False) - src_lang = kwargs.get('src_lang', 'en') - - example_id = 'id-null' - question = 'translate from input to output' - - if no_answer: - if len(parts) == 1: - context = parts - elif len(parts) == 2: - example_id, context = parts - elif len(parts) == 3: - example_id, context, question = parts - elif len(parts) == 4: - raise ValueError(f'Input file contains a line with {len(parts)} parts: {str(parts)}') - else: - if len(parts) == 2: - context, answer = parts - elif len(parts) == 3: - example_id, context, answer = parts - elif len(parts) == 4: - example_id, context, question, answer = parts - else: - raise ValueError(f'Input file contains a line with {len(parts)} parts: {str(parts)}') - - # no answer is provided - if no_answer: - answer = '.' - - contexts = [] - src_char_spans = None - if split_sentence: - if self.args.do_alignment: - src_quotation_symbol = '"' - src_char_spans_ind = [index for index, char in enumerate(context) if char == src_quotation_symbol] - src_char_spans = [ - (src_char_spans_ind[i], src_char_spans_ind[i + 1]) for i in range(0, len(src_char_spans_ind), 2) - ] - contexts = split_text_into_sentences(context, src_lang, src_char_spans) - - if len(contexts) > 1: - examples = [] - for i, text in enumerate(contexts): - examples.append( - Example.from_raw( - self.name + '/' + example_id + f'@{i}', - text, - question, - answer, - preprocess=self.preprocess_field, - lower=False, - ) - ) - else: - examples = Example.from_raw( - self.name + '/' + example_id, context, question, answer, preprocess=self.preprocess_field, lower=False - ) - - return examples - @register_task('contextual_almond') class ContextualAlmond(BaseAlmondTask):