-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmt5.py
136 lines (98 loc) · 4.38 KB
/
mt5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# coding:utf-8
from typing import List
from pytorch_lightning import LightningModule
from transformers import MT5ForConditionalGeneration, AutoTokenizer
class MT5(LightningModule):
"""
Google MT5 transformer class.
"""
def __init__(self, model_name_or_path: str = None):
"""
Initialize module.
:param model_name_or_path: model name
"""
super().__init__()
# Load model and tokenizer
self.save_hyperparameters()
self.model = MT5ForConditionalGeneration.from_pretrained(
model_name_or_path) if model_name_or_path is not None else None
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
use_fast=True) if model_name_or_path is not None else None
def forward(self, **inputs):
"""
Forward inputs.
:param inputs: dictionary of inputs (input_ids, attention_mask, labels)
"""
return self.model(**inputs)
def qa(self, batch: List[dict], max_length: int = 512, **kwargs):
"""
Question answering prediction.
:param batch: batch of dict {question: q, context: c}
:param max_length: max length of output
"""
# Transform inputs
inputs = [f"question: {context['question']} context: {context['context']}" for context in batch]
# Predict
outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
return outputs
def qg(self, batch: List[str] = None, max_length: int = 512, **kwargs):
"""
Question generation prediction.
:param batch: batch of context with highlighted elements
:param max_length: max length of output
"""
# Transform inputs
inputs = [f"generate: {context}" for context in batch]
# Predict
outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
return outputs
def ae(self, batch: List[str], max_length: int = 512, **kwargs):
"""
Answer extraction prediction.
:param batch: list of context
:param max_length: max length of output
"""
# Transform inputs
inputs = [f"extract: {context}" for context in batch]
# Predict
outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
return outputs
def multitask(self, batch: List[str], max_length: int = 512, **kwargs):
"""
Answer extraction + question generation + question answering.
:param batch: list of context
:param max_length: max length of outputs
"""
# Build output dict
dict_batch = {'context': [context for context in batch], 'answers': [], 'questions': [], 'answers_bis': []}
# Iterate over context
for context in batch:
answers = self.ae(batch=[context], max_length=max_length, **kwargs)[0]
answers = answers.split('<sep>')
answers = [ans.strip() for ans in answers if ans != ' ']
dict_batch['answers'].append(answers)
for_qg = [f"{context.replace(ans, f'<hl> {ans} <hl> ')}" for ans in answers]
questions = self.qg(batch=for_qg, max_length=max_length, **kwargs)
dict_batch['questions'].append(questions)
new_answers = self.qa([{'context': context, 'question': question} for question in questions],
max_length=max_length, **kwargs)
dict_batch['answers_bis'].append(new_answers)
return dict_batch
def predict(self, inputs, max_length, **kwargs):
"""
Inference processing.
:param inputs: list of inputs
:param max_length: max_length of outputs
"""
# Tokenize inputs
inputs = self.tokenizer(inputs, max_length=max_length, padding='max_length', truncation=True,
return_tensors="pt")
# Retrieve input_ids and attention_mask
input_ids = inputs.input_ids.to(self.model.device)
attention_mask = inputs.attention_mask.to(self.model.device)
# Predict
outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=max_length,
**kwargs)
# Decode outputs
predictions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return predictions