-
Notifications
You must be signed in to change notification settings - Fork 0
/
jax_transformer.py
363 lines (290 loc) · 12.8 KB
/
jax_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
import numpy as np
import math
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from jax import random
from flax.training import train_state
from data import SpecialTokens
import os
import orbax
from flax.training import orbax_utils
import torch
def dotprod_attention(q, k, v, mask=None):
dot_prod = (q @ jnp.swapaxes(k, -2, -1)) / \
(k.shape[-1]**0.5) # (b, h, sl, sl)
if mask is not None:
dot_prod = jnp.where(mask == 0, -9e15, dot_prod)
attention = jax.nn.softmax(dot_prod, axis=-1)
values = attention @ v
return values, attention
class MultiheadSelfAttention(nn.Module):
dim: int # d
num_heads: int # h
def setup(self):
self.qkv_layer = nn.Dense(features=3*self.dim,
kernel_init=nn.initializers.xavier_uniform(),
bias_init=nn.initializers.zeros)
self.output_layer = nn.Dense(features=self.dim,
kernel_init=nn.initializers.xavier_uniform(),
bias_init=nn.initializers.zeros)
def __call__(self, x, mask=None):
batch_size = x.shape[0]
seq_len = x.shape[1]
# x: (b, sl, dx)
qkv = self.qkv_layer(x) # (b, sl, 3d)
qkv = qkv.reshape(batch_size, seq_len,
self.num_heads, -1) # (b, sl, h, 3d/h)
qkv = qkv.transpose(0, 2, 1, 3) # (b, h, sl, 3d/h)
q, k, v = jnp.array_split(qkv, 3, axis=-1) # (b, h, sl, d/h)
values, attention = dotprod_attention(
q, k, v, mask=mask) # values: (b, h, sl, d/h)
values = values.transpose(0, 2, 1, 3) # (b, sl, h, d/h)
values = values.reshape(batch_size, seq_len, self.dim)
out = self.output_layer(values)
return out, attention
class MultiheadAttention(nn.Module):
dim: int # d
num_heads: int # h
def setup(self):
self.q_layer = nn.Dense(features=self.dim,
kernel_init=nn.initializers.xavier_uniform(),
bias_init=nn.initializers.zeros)
self.k_layer = nn.Dense(features=self.dim,
kernel_init=nn.initializers.xavier_uniform(),
bias_init=nn.initializers.zeros)
self.v_layer = nn.Dense(features=self.dim,
kernel_init=nn.initializers.xavier_uniform(),
bias_init=nn.initializers.zeros)
self.output_layer = nn.Dense(features=self.dim,
kernel_init=nn.initializers.xavier_uniform(),
bias_init=nn.initializers.zeros)
def __call__(self, q, k, v, mask=None):
batch_size = q.shape[0]
seq_len = q.shape[1]
# x: (b, sl, dx)
q = self.q_layer(q) # (b, sl, d)
k = self.k_layer(k) # (b, sl, d)
v = self.v_layer(v) # (b, sl, d)
qkv = jnp.concatenate([q, k, v], axis=2) # (b, sl, 3d)
qkv = qkv.reshape((batch_size, seq_len,
self.num_heads, -1), order='F') # (b, sl, h, 3d/h)
qkv = qkv.transpose(0, 2, 1, 3) # (b, h, sl, 3d/h)
q, k, v = jnp.array_split(qkv, 3, axis=-1) # (b, h, sl, d/h)
values, attention = dotprod_attention(
q, k, v, mask=mask) # values: (b, h, sl, d/h)
values = values.transpose(0, 2, 1, 3) # (b, sl, h, d/h)
values = values.reshape(batch_size, seq_len, self.dim)
out = self.output_layer(values)
return out, attention
class EncoderBlock(nn.Module):
dim: int
num_heads: int
feed_forward_dim: int
dropout_prob: float
def setup(self):
self.mh_attn = MultiheadSelfAttention(
dim=self.dim, num_heads=self.num_heads)
self.norm1 = nn.LayerNorm()
self.feed_forward = [
nn.Dense(self.feed_forward_dim),
nn.Dropout(self.dropout_prob),
nn.relu,
nn.Dense(self.dim)
]
self.norm2 = nn.LayerNorm()
self.dropout = nn.Dropout(self.dropout_prob)
def __call__(self, x, mask=None, train=True):
attn_out, _ = self.mh_attn(x, mask=mask)
x = x + self.dropout(attn_out, deterministic=not train)
x = self.norm1(x)
feed_forward_out = x
for layer in self.feed_forward:
feed_forward_out = layer(feed_forward_out) if not isinstance(layer, nn.Dropout) \
else layer(feed_forward_out, deterministic=not train)
x = x + self.dropout(feed_forward_out, deterministic=not train)
x = self.norm2(x)
return x
class Encoder(nn.Module):
num_blocks: int
dim: int
num_heads: int
feed_forward_dim: int
dropout_prob: float
def setup(self):
self.blocks = [EncoderBlock(dim=self.dim, num_heads=self.num_heads,
feed_forward_dim=self.feed_forward_dim, dropout_prob=self.dropout_prob)
for _ in range(self.num_blocks)]
def __call__(self, x, mask=None, train=True):
for block in self.blocks:
x = block(x, mask=mask, train=train)
return x
class DecoderBlock(nn.Module):
dim: int
num_heads: int
feed_forward_dim: int
dropout_prob: float
def setup(self):
self.m_mh_attn = MultiheadSelfAttention(
dim=self.dim, num_heads=self.num_heads)
self.norm1 = nn.LayerNorm()
self.e_mh_attn = MultiheadAttention(
dim=self.dim, num_heads=self.num_heads)
self.norm2 = nn.LayerNorm()
self.feed_forward = [
nn.Dense(self.feed_forward_dim),
nn.Dropout(self.dropout_prob),
nn.relu,
nn.Dense(self.dim)
]
self.norm3 = nn.LayerNorm()
self.dropout = nn.Dropout(self.dropout_prob)
def __call__(self, x, encoder_output, src_mask=None, tgt_mask=None, train=True):
attn_out, _ = self.m_mh_attn(x, mask=tgt_mask)
x = x + self.dropout(attn_out, deterministic=not train)
x = self.norm1(x)
attn_out, _ = self.e_mh_attn(x, encoder_output, encoder_output, mask=src_mask)
x = x + self.dropout(attn_out, deterministic=not train)
x = self.norm2(x)
feed_forward_out = x
for layer in self.feed_forward:
feed_forward_out = layer(feed_forward_out) if not isinstance(layer, nn.Dropout) \
else layer(feed_forward_out, deterministic=not train)
x = x + self.dropout(feed_forward_out, deterministic=not train)
x = self.norm3(x)
return x
class Decoder(nn.Module):
num_blocks: int
dim: int
num_heads: int
feed_forward_dim: int
dropout_prob: float
def setup(self):
self.blocks = [DecoderBlock(dim=self.dim, num_heads=self.num_heads,
feed_forward_dim=self.feed_forward_dim, dropout_prob=self.dropout_prob)
for _ in range(self.num_blocks)]
def __call__(self, x, encoder_output, src_mask=None, tgt_mask=None, train=True):
for block in self.blocks:
x = block(x, encoder_output, src_mask=src_mask, tgt_mask=tgt_mask, train=train)
return x
class PositionalEncoding(nn.Module):
dim: int
max_len: int = 500
def setup(self):
pe = np.zeros((self.max_len, self.dim))
pos = np.arange(self.max_len)
den = 1/np.exp(np.arange(0, self.dim, 2) * (-math.log(10000.0) / self.dim))
pe[:, ::2] = np.sin(np.expand_dims(pos, axis=1) / np.expand_dims(den, axis=0))
pe[:, 1::2] = np.cos(np.expand_dims(pos, axis=1) / np.expand_dims(den, axis=0))
pe = np.expand_dims(pe, axis=0)
self.pe = jax.device_put(pe)
def __call__(self, x):
x = x + self.pe[:, :x.shape[1]]
return x
class Transformer(nn.Module):
src_vocab_size: int
tgt_vocab_size: int
num_blocks: int
dim: int
num_heads: int
feed_forward_dim: int
dropout_prob: float
def setup(self):
self.src_emb = nn.Embed(num_embeddings=self.src_vocab_size, features=self.dim)
self.tgt_emb = nn.Embed(num_embeddings=self.tgt_vocab_size, features=self.dim)
self.src_pos = PositionalEncoding(dim=self.dim)
self.tgt_pos = PositionalEncoding(dim=self.dim)
self.encoder = Encoder(
num_blocks=self.num_blocks,
dim=self.dim,
num_heads=self.num_heads,
feed_forward_dim=self.feed_forward_dim,
dropout_prob=self.dropout_prob)
self.decoder = Decoder(
num_blocks=self.num_blocks,
dim=self.dim,
num_heads=self.num_heads,
feed_forward_dim=self.feed_forward_dim,
dropout_prob=self.dropout_prob)
self.output_net = [
nn.Dense(self.dim),
nn.LayerNorm(),
nn.relu,
nn.Dropout(self.dropout_prob),
nn.Dense(self.tgt_vocab_size)
]
def encode(self, src, src_mask=None, train=True):
src = self.src_emb(src)
src = self.src_pos(src)
src = self.encoder(src, mask=src_mask, train=train)
return src
def decode(self, encoder_output, tgt, src_mask=None, tgt_mask=None, train=True):
tgt = self.tgt_emb(tgt)
tgt = self.tgt_pos(tgt)
tgt = self.decoder(tgt, encoder_output, src_mask=src_mask, tgt_mask=tgt_mask, train=train)
return tgt
def project(self, x, train=True):
for layer in self.output_net:
x = layer(x) if not isinstance(layer, nn.Dropout) else layer(x, deterministic=not train)
return x
def __call__(self, src=None, tgt=None, src_mask=None, tgt_mask=None, train=True):
encoder_output = self.encode(src, src_mask, train=train)
decoder_output = self.decode(encoder_output, tgt, src_mask, tgt_mask, train=train)
out = self.project(decoder_output, train=train)
return out
def make_transformer(config):
transformer = Transformer(src_vocab_size=config['src_vocab_size'],
tgt_vocab_size=config['tgt_vocab_size'],
num_blocks=config['num_blocks'],
dim=config['emb_size'],
num_heads=config['num_heads'],
feed_forward_dim=config['ffn_hid_dim'],
dropout_prob=config['dropout_prob'])
main_rng = random.PRNGKey(config['seed'])
main_rng, rng = random.split(main_rng)
src = random.randint(rng, (1, config['seq_len']), 0, config['src_vocab_size'])
main_rng, rng = random.split(main_rng)
tgt = random.randint(rng, (1, config['seq_len']), 0, config['tgt_vocab_size'])
main_rng, init_rng, dropout_init_rng = random.split(main_rng, 3)
params = transformer.init(
{'params': init_rng, 'dropout': dropout_init_rng}, src, tgt, train=True)['params']
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(1e-4)
)
state = train_state.TrainState.create(apply_fn=transformer.apply, params=params, tx=optimizer)
# 0, 1. (0 means masked)
# (sl) -> (sl, sl)
def generate_mask(src, tgt):
src_mask = (src != SpecialTokens.PAD)[:, None] # (sl, 1)
src_mask = src_mask @ src_mask.transpose(1, 0) # (sl, sl)
tgt_padding_mask = (tgt != SpecialTokens.PAD)[:, None] # (sl, 1)
tgt_padding_mask = tgt_padding_mask @ tgt_padding_mask.transpose(1, 0) # (sl, sl)
tgt_causal_mask = (jnp.triu(jnp.ones((config['seq_len'], config['seq_len'])))).transpose(1, 0) > 0.5 # (sl, sl)
tgt_mask = tgt_padding_mask & tgt_causal_mask
src_mask = src_mask[None, :, :].repeat(config['num_heads'], axis=0)
tgt_mask = tgt_mask[None, :, :].repeat(config['num_heads'], axis=0)
return src_mask, tgt_mask
generate_mask = jax.vmap(jax.jit(generate_mask))
def pad_to_seq_len(item):
pad_len = config['seq_len'] - item.shape[0]
assert pad_len > 0, f"sentence too long {item.shape[0]} > {config['seq_len']}"
return torch.cat((item, torch.Tensor([SpecialTokens.PAD] * pad_len))).int()
return transformer, state, generate_mask, pad_to_seq_len
def load_state(config, state):
step = 1
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(state)
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=1, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
os.path.join(config['ckpt_dir'], 'managed'), orbax_checkpointer, options)
try:
step = checkpoint_manager.latest_step()
state = checkpoint_manager.restore(step, items=state)
print(f'Loaded checkpoint from epoch {step}')
except Exception as e:
step = 1
print('Could not load checkpoint. Training from scratch.')
print(e)
return state, checkpoint_manager, save_args, step