-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
223 lines (176 loc) · 8.42 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import os
import torch as th
import numpy as np
from torch import tensor, nn
import torch.nn.functional as F
import transformers
from transformers import AutoModelForCausalLM, PretrainedConfig, AutoConfig
from typing import NamedTuple, Tuple, Union
from copy import deepcopy
from collections import defaultdict
from accelerate.utils import compute_module_sizes
from itertools import chain
import accelerate
import deepspeed
def topk_mask(xs: th.FloatTensor, k: int):
mintop = th.topk(xs, k)[0][:, -1].unsqueeze(-1)
return th.where(xs < mintop, -np.inf * th.ones_like(xs, dtype=xs.dtype), xs)
class QVOutput(Tuple):
logits: th.FloatTensor
qs: th.FloatTensor
target_qs: th.FloatTensor
vs: th.FloatTensor
past_key_values: Tuple[th.FloatTensor]
def make_head(n_embd: int, out: int):
return nn.Sequential(
nn.Linear(n_embd, n_embd * 2),
nn.ReLU(),
nn.Linear(n_embd * 2, out)
)
class QVModel(nn.Module):
def __init__(self, config: Union[PretrainedConfig, str], params):
super().__init__()
# enable zero3 init within from_pretrained
if os.environ.get('DEEPSPEED_ZERO_STAGE', '0') == '3':
config_path = os.environ.get('DEEPSPEED_CONFIG_FILE', '')
if config_path:
_hfconfig = transformers.deepspeed.HfDeepSpeedConfig(config_path)
if isinstance(config, PretrainedConfig):
self.gpt = AutoModelForCausalLM.from_config(config)
else:
self.gpt = AutoModelForCausalLM.from_pretrained(config)
if hasattr(self.gpt.config, 'hidden_size'):
self.n_embd = self.gpt.config.hidden_size
else:
self.n_embd = self.gpt.config.n_embd
self.vocab_size = self.gpt.config.vocab_size
self.v_head = make_head(self.n_embd, 1)
self.q1_head = make_head(self.n_embd, self.vocab_size)
self.target_q1_head = deepcopy(self.q1_head)
self.target_q1_head.requires_grad_(False)
self.tau = params['tau']
self.alpha = params['alpha']
self.gamma = params['gamma']
self.awac_scale = params['awac_scale']
self.cql_scale = params['cql_scale']
self.two_qs = params['two_qs']
if self.two_qs:
self.q2_head = make_head(self.n_embd, self.vocab_size)
self.target_q2_head = deepcopy(self.q2_head)
self.target_q2_head.requires_grad_(False)
def forward(self, **x):
if hasattr(self.gpt, 'gpt_neox'):
out = self.gpt.gpt_neox(**x)
else:
out = self.gpt.transformer(**x)
hs = out.last_hidden_state
if self.two_qs:
qs = (self.q1_head(hs), self.q2_head(hs))
target_qs = (self.target_q1_head(hs), self.target_q2_head(hs))
else:
qs = self.q1_head(hs)
target_qs = self.target_q1_head(hs)
if hasattr(self.gpt, 'gpt_neox'):
logits = self.gpt.embed_out(hs)
else:
logits = self.gpt.lm_head(hs)
return QVOutput((logits, qs, target_qs, self.v_head(hs), out.past_key_values))
def loss(self, batch):
tokens, attn, rewards = batch
actions = tokens[:, 1:, None]
isterminal = attn[:, :-1]
logits, qs, target_qs, vs, _ = self(input_ids=tokens, attention_mask=attn)
bsize, ntokens, dsize = logits.shape
if self.two_qs:
Q1 = qs[0][:, :-1].gather(-1, actions).squeeze(-1)
Q2 = qs[1][:, :-1].gather(-1, actions).squeeze(-1)
targetQ1 = target_qs[0][:, :-1].gather(-1, actions).squeeze(-1).detach()
targetQ2 = target_qs[1][:, :-1].gather(-1, actions).squeeze(-1).detach()
targetQ = th.minimum(targetQ1, targetQ2)
else:
Q = qs[:, :-1].gather(-1, actions).squeeze(-1)
targetQ = target_qs[:, :-1].gather(-1, actions).squeeze(-1).detach()
n_nonterminal = max(1, isterminal.sum())
V = vs[:, 1:].squeeze() * isterminal
Q_ = rewards + self.gamma * V
if self.two_qs:
loss_q1 = ((Q1 - Q_.detach()) * isterminal).pow(2).sum() / n_nonterminal
loss_q2 = ((Q2 - Q_.detach()) * isterminal).pow(2).sum() / n_nonterminal
loss_q = loss_q1 + loss_q2
else:
loss_q = ((Q - Q_.detach()) * isterminal).pow(2).sum() / n_nonterminal
loss_v = (((targetQ >= V).int() * self.tau * (targetQ - V).pow(2) + (targetQ < V).int() * (1 - self.tau) * (targetQ - V).pow(2)) * isterminal).sum() / n_nonterminal
if self.two_qs:
loss_cql_q1 = (F.cross_entropy(qs[0][:, :-1].reshape(-1, dsize), actions.reshape(-1), reduction='none').reshape(bsize, ntokens-1) * isterminal).sum() / n_nonterminal
loss_cql_q2 = (F.cross_entropy(qs[1][:, :-1].reshape(-1, dsize), actions.reshape(-1), reduction='none').reshape(bsize, ntokens-1) * isterminal).sum() / n_nonterminal
loss_cql = loss_cql_q1 + loss_cql_q2
else:
loss_cql = (F.cross_entropy(qs[:, :-1].reshape(-1, dsize), actions.reshape(-1), reduction='none').reshape(bsize, ntokens-1) * isterminal).sum() / n_nonterminal
loss_awac = (F.cross_entropy(logits[:, :-1].reshape(-1, dsize), actions.reshape(-1), reduction='none').reshape(bsize, ntokens-1) * isterminal).sum() / n_nonterminal
loss = loss_q + loss_v + self.cql_scale * loss_cql + self.awac_scale * loss_awac
stats = {
k: v for k, v in locals().items() if k in
['loss', 'loss_v', 'loss_q', 'loss_cql', 'loss_awac']
}
return loss, stats
def _sync_target_q_heads(self, alpha):
for target_param, copy_param in zip(self.target_q1_head.parameters(), self.q1_head.parameters()):
target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data)
if self.two_qs:
for target_param, copy_param in zip(self.target_q2_head.parameters(), self.q2_head.parameters()):
target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data)
def sync_target_q_heads(self):
if os.environ.get('DEEPSPEED_ZERO_STAGE', '0') == '3':
params = chain(self.q1_head.parameters(),
self.target_q1_head.parameters(),
self.q2_head.parameters() if self.two_qs else [],
self.target_q2_head.parameters() if self.two_qs else [])
with deepspeed.zero.GatheredParameters(list(params), modifier_rank=0):
if deepspeed.comm.get_rank() == 0:
self._sync_target_q_heads(self.alpha)
else:
self._sync_target_q_heads(self.alpha)
@th.inference_mode()
def sample(self, query, beta=1, max_length=32, temperature=1, top_k=20, logit_mask=None, logs=True, eos_token_id=50256):
input = query.clone()
past_key_values = None
tensors = defaultdict(list)
finished = th.zeros(input.shape[0], 1, dtype=th.long, device=query.device)
for _ in range(max_length-1):
logits, _, target_qs, vs, past_key_values = self.forward(input_ids=input, past_key_values=past_key_values)
if self.two_qs:
qs = th.minimum(target_qs[0][:, -1], target_qs[1][:, -1])
else:
qs = target_qs[:, -1]
logits = logits[:, -1]
if logit_mask is not None:
logits[th.where(logit_mask[input[:, -1]])] = -np.inf
adv = qs - vs[:, -1, :]
pi = F.log_softmax(logits, -1)
modpi = topk_mask(pi + beta * adv, top_k)
ps = F.softmax(modpi / temperature, -1)
tokens = th.multinomial(ps, 1)
tokens = (1 - finished) * tokens + finished * eos_token_id
query = th.hstack((query, tokens))
input = tokens
finished = (tokens == eos_token_id).long()
if logs:
tensors['qs'].append(qs)
tensors['vs'].append(vs)
tensors['adv'].append(adv)
stats = {}
for name, xs in tensors.items():
xs = th.vstack(xs)
stats.update({
f'{name}-min': xs.min(),
f'{name}-max': xs.max(),
f'{name}-std': xs.std(),
f'{name}-avg': xs.mean(),
})
return query, stats
@property
def dummy_inputs(self):
return {'input_ids': th.ones(1, 1, device=self.gpt.device, dtype=th.long)}
@property
def device(self):
return self.gpt.device