-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtext_utils_torch.py
47 lines (39 loc) · 1.8 KB
/
text_utils_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from transformers import BertForSequenceClassification
class BertWrapperTorch:
def __init__(self, model, device, merge_logits=False):
"""
TODO: make the model be anything
"""
assert isinstance(model, BertForSequenceClassification)
self.model = model.to(device)
self.model.eval()
self.device = device
self.merge_logits = merge_logits
@torch.no_grad()
def get_embedding(self, **inputs):
if not isinstance(inputs['input_ids'], torch.Tensor):
inputs = {k: torch.LongTensor(v).to(self.device) for k, v in inputs.items()}
inputs.pop('attention_mask', None)
embedding = self.model.bert.embeddings(**inputs)
return embedding
@torch.no_grad()
def get_predictions(self, batch_embedding):
# NOTE: this works only when the model is BertForSequenceClassification
encoder_outputs = self.model.bert.encoder(batch_embedding,
output_hidden_states=True,
return_dict=False)
sequence_output = encoder_outputs[0]
pooled_output = self.model.bert.pooler(sequence_output)
logits = self.model.classifier(pooled_output)
return logits.cpu()
def __call__(self, return_embedding=False, **inputs):
batch_embeddings = self.get_embedding(**inputs)
batch_predictions = self.get_predictions(batch_embeddings)
if self.merge_logits:
batch_predictions2 = (batch_predictions[:, 1] - batch_predictions[:, 0])
batch_predictions = batch_predictions2.unsqueeze(1)
outs = batch_predictions.numpy()
if return_embedding:
outs = (batch_predictions.numpy(), batch_embeddings.cpu().numpy())
return outs