Skip to content

Commit

Permalink
almond_translate: accommodate enforcement of unique ids differently
Browse files Browse the repository at this point in the history
instead change ids on the go to be unique, lifting the burden from user or dataset
  • Loading branch information
Mehrad0711 committed Oct 6, 2021
1 parent 845d620 commit a4212d9
Showing 1 changed file with 89 additions and 96 deletions.
185 changes: 89 additions & 96 deletions genienlp/tasks/almond_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,51 +305,107 @@ def __init__(self, name, args):
self._metrics = ['casedbleu']
self.need_attention_scores = True

def preprocess_field(self, sentence, field_name=None, answer=None, example_id=None, preprocess_entities=True):
assert example_id
if field_name != 'answer':
if field_name + '-' + example_id in self.all_ids:
logger.warning(
f'example id: {example_id} is repeated in the dataset. If using alignment, ids between all data splits have to be unique'
)
example_id += '+'
def construct_id2span_mapping(self, example_id, sentence, field_name):
assert field_name in ['context', 'question']
# translation task constructs a dictionary mapping ids to entity spans in the sentence
# this ensures the ids are unique
while field_name + '-' + example_id in self.all_ids:
example_id += '.'

self.all_ids.add(field_name + '-' + example_id)
self.all_ids.add(field_name + '-' + example_id)

src_quotation_symbol = '"'
src_tokens = sentence.split(" ")
src_spans_ind = [index for index, token in enumerate(src_tokens) if token == src_quotation_symbol]
src_quotation_symbol = '"'
src_tokens = sentence.split(" ")
src_spans_ind = [index for index, token in enumerate(src_tokens) if token == src_quotation_symbol]

if len(src_spans_ind) % 2 != 0:
raise ValueError(f'Corrupted span in sentence: [{sentence}]')
if len(src_spans_ind) % 2 != 0:
raise ValueError(f'Corrupted span in sentence: [{sentence}]')

if self.args.align_preserve_input_quotation:
src_spans = [(src_spans_ind[i] + 1, src_spans_ind[i + 1] - 1) for i in range(0, len(src_spans_ind), 2)]
else:
src_tokens = [token for token in src_tokens if token != src_quotation_symbol]
src_spans = [
(src_spans_ind[i] + 1 - (i + 1), src_spans_ind[i + 1] - 1 - (i + 1))
for i in range(0, len(src_spans_ind), 2)
]
if self.args.align_preserve_input_quotation:
src_spans = [(src_spans_ind[i] + 1, src_spans_ind[i + 1] - 1) for i in range(0, len(src_spans_ind), 2)]
else:
src_tokens = [token for token in src_tokens if token != src_quotation_symbol]
src_spans = [
(src_spans_ind[i] + 1 - (i + 1), src_spans_ind[i + 1] - 1 - (i + 1)) for i in range(0, len(src_spans_ind), 2)
]

# remove illegal src_spans (caused by inputs such as " ")
src_spans = [span for span in src_spans if span[0] <= span[1]]

# remove illegal src_spans (caused by inputs such as " ")
src_spans = [span for span in src_spans if span[0] <= span[1]]
sentence = " ".join(src_tokens)
src_spans_flatten = [val for tup in src_spans for val in tup]

sentence = " ".join(src_tokens)
src_spans_flatten = [val for tup in src_spans for val in tup]
# append question spans to context spans
if example_id in self.input_spans:
self.input_spans[example_id] += src_spans_flatten
else:
self.input_spans[example_id] = src_spans_flatten

return example_id, sentence

def preprocess_field(self, sentence, field_name=None, answer=None, example_id=None, preprocess_entities=True):
return super().preprocess_field(sentence, field_name, answer, preprocess_entities)

# append question spans to context spans
if example_id in self.input_spans:
self.input_spans[example_id] += src_spans_flatten
def _make_example(self, parts, dir_name=None, **kwargs):
# answer has to be provided by default unless doing prediction
no_answer = getattr(self.args, 'translate_no_answer', False)
split_sentence = getattr(self.args, 'translate_example_split', False)
src_lang = kwargs.get('src_lang', 'en')

example_id = 'id-null'
question = 'translate from input to output'

if no_answer:
if len(parts) == 1:
context = parts
elif len(parts) == 2:
example_id, context = parts
elif len(parts) == 4:
raise ValueError(f'Input file contains a line with {len(parts)} parts: {str(parts)}')
else:
if len(parts) == 2:
context, answer = parts
elif len(parts) == 3:
example_id, context, answer = parts
else:
self.input_spans[example_id] = src_spans_flatten
raise ValueError(f'Input file contains a line with {len(parts)} parts: {str(parts)}')

sentence = super().preprocess_field(sentence, field_name, answer, preprocess_entities)
# no answer is provided
if no_answer:
answer = '.'

return sentence
contexts = []
src_char_spans = None
if split_sentence:
if self.args.do_alignment:
src_quotation_symbol = '"'
src_char_spans_ind = [index for index, char in enumerate(context) if char == src_quotation_symbol]
src_char_spans = [
(src_char_spans_ind[i], src_char_spans_ind[i + 1]) for i in range(0, len(src_char_spans_ind), 2)
]
contexts = split_text_into_sentences(context, src_lang, src_char_spans)

def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs):
if len(contexts) > 1:
examples = []
for i, text in enumerate(contexts):
ex_id, text = self.construct_id2span_mapping(self.name + '/' + example_id + f'@{i}', text, 'context')
examples.append(
Example.from_raw(
ex_id,
text,
question,
answer,
preprocess=self.preprocess_field,
lower=False,
)
)
else:
ex_id, context = self.construct_id2span_mapping(self.name + '/' + example_id, context, 'context')
examples = Example.from_raw(ex_id, context, question, answer, preprocess=self.preprocess_field, lower=False)

return examples

def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs):
numericalizer = kwargs.pop('numericalizer')
cross_attentions = kwargs.pop('cross_attentions')
tgt_lang = kwargs.pop('tgt_lang')
Expand Down Expand Up @@ -442,69 +498,6 @@ def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, bat

return partial_batch_prediction_ids, all_text_outputs

def _make_example(self, parts, dir_name=None, **kwargs):
# answer has to be provided by default unless doing prediction
no_answer = getattr(self.args, 'translate_no_answer', False)
split_sentence = getattr(self.args, 'translate_example_split', False)
src_lang = kwargs.get('src_lang', 'en')

example_id = 'id-null'
question = 'translate from input to output'

if no_answer:
if len(parts) == 1:
context = parts
elif len(parts) == 2:
example_id, context = parts
elif len(parts) == 3:
example_id, context, question = parts
elif len(parts) == 4:
raise ValueError(f'Input file contains a line with {len(parts)} parts: {str(parts)}')
else:
if len(parts) == 2:
context, answer = parts
elif len(parts) == 3:
example_id, context, answer = parts
elif len(parts) == 4:
example_id, context, question, answer = parts
else:
raise ValueError(f'Input file contains a line with {len(parts)} parts: {str(parts)}')

# no answer is provided
if no_answer:
answer = '.'

contexts = []
src_char_spans = None
if split_sentence:
if self.args.do_alignment:
src_quotation_symbol = '"'
src_char_spans_ind = [index for index, char in enumerate(context) if char == src_quotation_symbol]
src_char_spans = [
(src_char_spans_ind[i], src_char_spans_ind[i + 1]) for i in range(0, len(src_char_spans_ind), 2)
]
contexts = split_text_into_sentences(context, src_lang, src_char_spans)

if len(contexts) > 1:
examples = []
for i, text in enumerate(contexts):
examples.append(
Example.from_raw(
self.name + '/' + example_id + f'@{i}',
text,
question,
answer,
preprocess=self.preprocess_field,
lower=False,
)
)
else:
examples = Example.from_raw(
self.name + '/' + example_id, context, question, answer, preprocess=self.preprocess_field, lower=False
)

return examples


@register_task('contextual_almond')
class ContextualAlmond(BaseAlmondTask):
Expand Down

0 comments on commit a4212d9

Please sign in to comment.