-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
56 lines (44 loc) · 1.7 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import unicode_literals, print_function, division
import torch
import torch.nn as nn
from transformers import BartForConditionalGeneration
class Model(nn.Module):
def __init__(self, args, logger, **kwargs):
super().__init__()
# Instantiate baseline module
logger.info("Creating seq2seq model from pretrained weights.")
self.seq2seq = BartForConditionalGeneration.from_pretrained(
args.base_model_pretrained_name,
cache_dir=args.pretrained_model_cache_dir)
vocab_size = kwargs.pop("vocab_size")
self.seq2seq.resize_token_embeddings(vocab_size)
h_dim = kwargs.pop("h_dim")
s_dim = kwargs.pop("s_dim")
self.bl = nn.Parameter(torch.FloatTensor(1, h_dim, s_dim))
def forward(self, batch):
outputs = self.seq2seq(
**batch,
output_attentions=True,
output_hidden_states=True,
)
m_output = {}
m_output["cost"] = outputs.loss.cpu()
h_x = outputs.encoder_hidden_states[-1]
h_y = outputs.decoder_hidden_states[-1]
# Some extended example code
x = h_x @ self.bl.to(device=h_x.device)
x = x @ h_y.transpose(-1,-2)
return m_output
@torch.no_grad()
def generate(
self,
batch,
options,
**model_kwargs,
):
inputs = batch[0]
return self.seq2seq.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**model_kwargs
)