-
Notifications
You must be signed in to change notification settings - Fork 0
/
nn.py
292 lines (264 loc) · 9.37 KB
/
nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
from __future__ import annotations
from functools import partial, wraps
from typing import Optional, Union, Callable
import haiku as hk
import jax
import jax.numpy as jnp
from chex import ArrayTree, PRNGKey
from einops import rearrange, repeat
from jax import Array
from .common import Config, get_logger
_SMALL_INIT = hk.initializers.VarianceScaling(0.01)
logger = get_logger()
def full_precision(fn: Callable[[Array], Array]) -> Callable[[Array], Array]:
@wraps(fn)
def inner(x: Array) -> Array:
return fn(x.astype(jnp.float32)).astype(x.dtype)
return inner
def rotary_pos_emb(
x: Array, # B H S D
) -> Array:
dim, seq = x.shape[-1], x.shape[-2]
# Near eq. 15 in https://arxiv.org/abs/2104.09864, equivalent to those
# in https://arxiv.org/abs/1706.03762
ts = jnp.arange(0, dim, 2, dtype=jnp.float32) # D/2
inv_freqs = 10_000 ** (-ts / dim) # D/2
grid = jnp.einsum("s, d -> s d", jnp.arange(seq), inv_freqs) # S D/2
# Eq. 34 in https://arxiv.org/abs/2104.09864
sin_embs = repeat(jnp.sin(grid), "s d -> 1 s (d 2)") # B S D
sin_embs = sin_embs.astype(x.dtype)
cos_embs = repeat(jnp.cos(grid), "s d -> 1 s (d 2)") # B S D
cos_embs = cos_embs.astype(x.dtype)
# Pairwise swap with alternating signs
x1, x2 = x[..., ::2], x[..., 1::2] # [x1, x3, x5, ...], [x2, x4, x6, ...]
x1x2 = jnp.stack([-x2, x1], axis=-1) # [[-x2, x1], [-x4, x3], ...]
xs = rearrange(x1x2, "... d two -> ... (d two)", two=2) # [-x2, x1, -x4, x3, ...]
out = x * cos_embs + xs * sin_embs
return out
class MultiHeadAttention(hk.Module):
def __init__(
self,
*,
num_heads: int,
pos_emb_portion: float,
name: str,
) -> None:
"""Multi-head attention.
Args:
num_heads: Number of attention heads.
pos_emb_portion: Portion of the dimension to use for rotary positional
embeddings.
name: Name of the module.
"""
super().__init__(name=name)
self.num_heads = num_heads
self.pos_emb_portion = pos_emb_portion
def __call__(
self,
x: Array,
mask: Optional[Array] = None,
) -> Array:
"""Applies multi-head attention.
Args:
x: Input array of shape [batch, sequence, features].
mask: Mask array of shape [sequence, sequence]. If
provided, the attention will be masked out for the masked tokens. The
mask should be broadcastable to the shape of the attention logits.
Returns:
Output array of shape [batch, sequence, features].
"""
# Constants
D, H = x.shape[-1], self.num_heads
if D % H != 0:
raise ValueError(f"Dimension {D} must be divisible by number of heads {H}")
K = D // H
# Projections
projection = partial(hk.Linear, with_bias=False)
q_proj = projection(K * H, name="q_proj")
k_proj = projection(K * H, name="k_proj")
v_proj = projection(K * H, name="v_proj")
o_proj = projection(D, name="o_proj", w_init=_SMALL_INIT)
# Q, K, V
p = int(K * self.pos_emb_portion)
q: Array = q_proj(x) / K**0.5 # B L H K
q = rearrange(q, "b l (h k) -> b h l k", h=H)
q = jnp.concatenate([rotary_pos_emb(q[..., :p]), q[..., p:]], axis=-1)
k: Array = k_proj(x) # B L H K
k = rearrange(k, "b l (h k) -> b h l k", h=H)
k = jnp.concatenate([rotary_pos_emb(k[..., :p]), k[..., p:]], axis=-1)
v: Array = v_proj(x) # B L H V
v = rearrange(v, "b l (h v) -> b h l v", h=H)
# Attention weights
l: Array = jnp.einsum("b h i k, b h j k -> b h i j", q, k) # B H L L
_apply_mask = lambda l_, m_: (l_ if m_ is None else jnp.where(m_, l_, -jnp.inf))
with jax.debug_infs(False):
l = hk.remat(_apply_mask)(l, mask)
a = full_precision(jax.nn.softmax)(l) # B H L L
# Attention output
y: Array = jnp.einsum("b h i j, b h j v -> b h i v", a, v) # B H L V
y = rearrange(y, "b h l v -> b l (h v)") # B L (H V)
o = o_proj(y) # B L M
return o
class FeedForward(hk.Module):
def __init__(
self,
hidden_dim: int,
name: Optional[str] = None,
) -> None:
super().__init__(name=name)
self.hidden_dim = hidden_dim
def __call__(
self,
x: Array,
) -> Array:
model_dim = x.shape[-1]
# Projections
projection = partial(hk.Linear, with_bias=False)
in_proj = projection(self.hidden_dim, name="in_proj")
gate_proj = projection(self.hidden_dim, name="gate_proj")
out_proj = projection(model_dim, name="out_proj", w_init=_SMALL_INIT)
# Feed-forward
gated_act = hk.remat(lambda g, a: jax.nn.silu(g) * a)
y = out_proj(gated_act(gate_proj(x), in_proj(x))) # B L M
return y
class Block(hk.Module):
def __init__(
self,
*,
num_heads: int,
hidden_dim: int,
pos_emb_portion: float,
dropout: float,
name: Optional[str] = None,
) -> None:
super().__init__(name=name)
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.pos_emb_portion = pos_emb_portion
self.dropout = dropout
def __call__(
self,
x: Array,
is_training: bool,
mask: Optional[Array] = None,
) -> Array:
mha = MultiHeadAttention(
num_heads=self.num_heads, pos_emb_portion=self.pos_emb_portion, name="mha"
)
mha_ln = hk.remat(hk.LayerNorm(-1, True, False, name="mha_ln"))
ff = FeedForward(self.hidden_dim, name="ff")
ff_ln = hk.remat(hk.LayerNorm(-1, True, False, name="ff_ln"))
# Multi-head attention
y = mha(mha_ln(x), mask)
y = hk.dropout(hk.next_rng_key(), self.dropout, y) if is_training else y
x = x + y
# Feed-forward
z = ff(ff_ln(x))
z = hk.dropout(hk.next_rng_key(), self.dropout, z) if is_training else z
out = x + z
return out
class Model(hk.Module):
def __init__(
self,
*,
num_layers: int,
vocabulary_size: int,
embedding_dim: int,
model_dim: int,
num_heads: int,
pos_emb_portion: float,
hidden_dim: int,
dropout: float,
name: Optional[str] = None,
) -> None:
super().__init__(name=name)
self.num_layers = num_layers
self.vocabulary_size = vocabulary_size
self.embedding_dim = embedding_dim
self.model_dim = model_dim
self.num_heads = num_heads
self.pos_emb_portion = pos_emb_portion
self.hidden_dim = hidden_dim
self.dropout = dropout
@classmethod
def from_config(cls, config: Config) -> Model:
cfg = config.model
return cls(
num_layers=int(cfg.num_layers),
vocabulary_size=int(cfg.vocabulary_size),
embedding_dim=int(cfg.embedding_dim),
model_dim=int(cfg.model_dim),
num_heads=int(cfg.num_heads),
pos_emb_portion=float(cfg.pos_emb_portion),
hidden_dim=int(cfg.hidden_dim),
dropout=float(cfg.dropout),
)
def __call__(
self,
indices: Array,
is_training: bool,
mask: Optional[Array] = None,
) -> Array:
embedding = hk.Embed(
self.vocabulary_size,
self.embedding_dim,
w_init=_SMALL_INIT,
name="embedding",
)
embedding_proj = (
hk.Linear(
self.model_dim,
with_bias=False,
name="embedding_proj",
)
if self.embedding_dim != self.model_dim
else lambda x: x
)
blocks = [
Block(
num_heads=self.num_heads,
hidden_dim=self.hidden_dim,
pos_emb_portion=self.pos_emb_portion,
dropout=self.dropout,
name=f"block_{i}",
)
for i in range(self.num_layers)
]
out_ln = hk.LayerNorm(-1, True, False, name="out_ln")
out_proj = hk.Linear(
self.embedding_dim,
with_bias=False,
w_init=_SMALL_INIT,
name="out_proj",
)
# Execution
embeddings = embedding(indices)
h = embedding_proj(embeddings) # type: ignore
for block in blocks:
h = block(h, is_training, mask)
final_hidden = out_proj(out_ln(h))
logits: Array = jnp.einsum(
"b s m, v m -> b s v", final_hidden, embedding.embeddings
)
return logits
@classmethod
def get_params(
cls,
config: Config,
rng_or_seed: Union[int, PRNGKey],
log_size: bool = True,
) -> ArrayTree:
def fn() -> None:
model = cls.from_config(config)
model(jnp.zeros((1, 1), dtype=jnp.int32), False)
rng = (
jax.random.PRNGKey(rng_or_seed)
if isinstance(rng_or_seed, int)
else rng_or_seed
)
params = hk.transform(fn).init(rng)
if log_size:
params_n = hk.data_structures.tree_size(params)
params_mb = round(hk.data_structures.tree_bytes(params) / 1e6, 2)
logger.info(f"Model parameters: {params_n:,} ({params_mb:.2f} MB)")
return params