-
Notifications
You must be signed in to change notification settings - Fork 35
/
relational_rnn_general.py
366 lines (302 loc) · 15.3 KB
/
relational_rnn_general.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
# this class largely follows the official sonnet implementation
# https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/relational_memory.py
class RelationalMemory(nn.Module):
"""
Constructs a `RelationalMemory` object.
This class is same as the RMC from relational_rnn_models.py, but without language modeling-specific variables.
Args:
mem_slots: The total number of memory slots to use.
head_size: The size of an attention head.
input_size: The size of input per step. i.e. the dimension of each input vector
num_heads: The number of attention heads to use. Defaults to 1.
num_blocks: Number of times to compute attention per time step. Defaults
to 1.
forget_bias: Bias to use for the forget gate, assuming we are using
some form of gating. Defaults to 1.
input_bias: Bias to use for the input gate, assuming we are using
some form of gating. Defaults to 0.
gate_style: Whether to use per-element gating ('unit'),
per-memory slot gating ('memory'), or no gating at all (None).
Defaults to `unit`.
attention_mlp_layers: Number of layers to use in the post-attention
MLP. Defaults to 2.
key_size: Size of vector to use for key & query vectors in the attention
computation. Defaults to None, in which case we use `head_size`.
name: Name of the module.
# NEW flag for this class
return_all_outputs: Whether the model returns outputs for each step (like seq2seq) or only the final output.
Raises:
ValueError: gate_style not one of [None, 'memory', 'unit'].
ValueError: num_blocks is < 1.
ValueError: attention_mlp_layers is < 1.
"""
def __init__(self, mem_slots, head_size, input_size, num_heads=1, num_blocks=1, forget_bias=1., input_bias=0.,
gate_style='unit', attention_mlp_layers=2, key_size=None, return_all_outputs=False):
super(RelationalMemory, self).__init__()
########## generic parameters for RMC ##########
self.mem_slots = mem_slots
self.head_size = head_size
self.num_heads = num_heads
self.mem_size = self.head_size * self.num_heads
# a new fixed params needed for pytorch port of RMC
# +1 is the concatenated input per time step : we do self-attention with the concatenated memory & input
# so if the mem_slots = 1, this value is 2
self.mem_slots_plus_input = self.mem_slots + 1
if num_blocks < 1:
raise ValueError('num_blocks must be >=1. Got: {}.'.format(num_blocks))
self.num_blocks = num_blocks
if gate_style not in ['unit', 'memory', None]:
raise ValueError(
'gate_style must be one of [\'unit\', \'memory\', None]. got: '
'{}.'.format(gate_style))
self.gate_style = gate_style
if attention_mlp_layers < 1:
raise ValueError('attention_mlp_layers must be >= 1. Got: {}.'.format(
attention_mlp_layers))
self.attention_mlp_layers = attention_mlp_layers
self.key_size = key_size if key_size else self.head_size
########## parameters for multihead attention ##########
# value_size is same as head_size
self.value_size = self.head_size
# total size for query-key-value
self.qkv_size = 2 * self.key_size + self.value_size
self.total_qkv_size = self.qkv_size * self.num_heads # denoted as F
# each head has qkv_sized linear projector
# just using one big param is more efficient, rather than this line
# self.qkv_projector = [nn.Parameter(torch.randn((self.qkv_size, self.qkv_size))) for _ in range(self.num_heads)]
self.qkv_projector = nn.Linear(self.mem_size, self.total_qkv_size)
self.qkv_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.total_qkv_size])
# used for attend_over_memory function
self.attention_mlp = nn.ModuleList([nn.Linear(self.mem_size, self.mem_size)] * self.attention_mlp_layers)
self.attended_memory_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size])
self.attended_memory_layernorm2 = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size])
########## parameters for initial embedded input projection ##########
self.input_size = input_size
self.input_projector = nn.Linear(self.input_size, self.mem_size)
########## parameters for gating ##########
self.num_gates = 2 * self.calculate_gate_size()
self.input_gate_projector = nn.Linear(self.mem_size, self.num_gates)
self.memory_gate_projector = nn.Linear(self.mem_size, self.num_gates)
# trainable scalar gate bias tensors
self.forget_bias = nn.Parameter(torch.tensor(forget_bias, dtype=torch.float32))
self.input_bias = nn.Parameter(torch.tensor(input_bias, dtype=torch.float32))
########## number of outputs returned #####
self.return_all_outputs = return_all_outputs
def repackage_hidden(self, h):
"""Wraps hidden states in new Tensors, to detach them from their history."""
# needed for truncated BPTT, called at every batch forward pass
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(self.repackage_hidden(v) for v in h)
def initial_state(self, batch_size, trainable=False):
"""
Creates the initial memory.
We should ensure each row of the memory is initialized to be unique,
so initialize the matrix to be the identity. We then pad or truncate
as necessary so that init_state is of size
(batch_size, self.mem_slots, self.mem_size).
Args:
batch_size: The size of the batch.
trainable: Whether the initial state is trainable. This is always True.
Returns:
init_state: A truncated or padded matrix of size
(batch_size, self.mem_slots, self.mem_size).
"""
init_state = torch.stack([torch.eye(self.mem_slots) for _ in range(batch_size)])
# pad the matrix with zeros
if self.mem_size > self.mem_slots:
difference = self.mem_size - self.mem_slots
pad = torch.zeros((batch_size, self.mem_slots, difference))
init_state = torch.cat([init_state, pad], -1)
# truncation. take the first 'self.mem_size' components
elif self.mem_size < self.mem_slots:
init_state = init_state[:, :, :self.mem_size]
return init_state
def multihead_attention(self, memory):
"""
Perform multi-head attention from 'Attention is All You Need'.
Implementation of the attention mechanism from
https://arxiv.org/abs/1706.03762.
Args:
memory: Memory tensor to perform attention on.
Returns:
new_memory: New memory tensor.
"""
# First, a simple linear projection is used to construct queries
qkv = self.qkv_projector(memory)
# apply layernorm for every dim except the batch dim
qkv = self.qkv_layernorm(qkv)
# mem_slots needs to be dynamically computed since mem_slots got concatenated with inputs
# example: self.mem_slots=10 and seq_length is 3, and then mem_slots is 10 + 1 = 11 for each 3 step forward pass
# this is the same as self.mem_slots_plus_input, but defined to keep the sonnet implementation code style
mem_slots = memory.shape[1] # denoted as N
# split the qkv to multiple heads H
# [B, N, F] => [B, N, H, F/H]
qkv_reshape = qkv.view(qkv.shape[0], mem_slots, self.num_heads, self.qkv_size)
# [B, N, H, F/H] => [B, H, N, F/H]
qkv_transpose = qkv_reshape.permute(0, 2, 1, 3)
# [B, H, N, key_size], [B, H, N, key_size], [B, H, N, value_size]
q, k, v = torch.split(qkv_transpose, [self.key_size, self.key_size, self.value_size], -1)
# scale q with d_k, the dimensionality of the key vectors
q *= (self.key_size ** -0.5)
# make it [B, H, N, N]
dot_product = torch.matmul(q, k.permute(0, 1, 3, 2))
weights = F.softmax(dot_product, dim=-1)
# output is [B, H, N, V]
output = torch.matmul(weights, v)
# [B, H, N, V] => [B, N, H, V] => [B, N, H*V]
output_transpose = output.permute(0, 2, 1, 3).contiguous()
new_memory = output_transpose.view((output_transpose.shape[0], output_transpose.shape[1], -1))
return new_memory
@property
def state_size(self):
return [self.mem_slots, self.mem_size]
@property
def output_size(self):
return self.mem_slots * self.mem_size
def calculate_gate_size(self):
"""
Calculate the gate size from the gate_style.
Returns:
The per sample, per head parameter size of each gate.
"""
if self.gate_style == 'unit':
return self.mem_size
elif self.gate_style == 'memory':
return 1
else: # self.gate_style == None
return 0
def create_gates(self, inputs, memory):
"""
Create input and forget gates for this step using `inputs` and `memory`.
Args:
inputs: Tensor input.
memory: The current state of memory.
Returns:
input_gate: A LSTM-like insert gate.
forget_gate: A LSTM-like forget gate.
"""
# We'll create the input and forget gates at once. Hence, calculate double
# the gate size.
# equation 8: since there is no output gate, h is just a tanh'ed m
memory = torch.tanh(memory)
# TODO: check this input flattening is correct
# sonnet uses this, but i think it assumes time step of 1 for all cases
# if inputs is (B, T, features) where T > 1, this gets incorrect
# inputs = inputs.view(inputs.shape[0], -1)
# fixed implementation
if len(inputs.shape) == 3:
if inputs.shape[1] > 1:
raise ValueError(
"input seq length is larger than 1. create_gate function is meant to be called for each step, with input seq length of 1")
inputs = inputs.view(inputs.shape[0], -1)
# matmul for equation 4 and 5
# there is no output gate, so equation 6 is not implemented
gate_inputs = self.input_gate_projector(inputs)
gate_inputs = gate_inputs.unsqueeze(dim=1)
gate_memory = self.memory_gate_projector(memory)
else:
raise ValueError("input shape of create_gate function is 2, expects 3")
# this completes the equation 4 and 5
gates = gate_memory + gate_inputs
gates = torch.split(gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2)
input_gate, forget_gate = gates
assert input_gate.shape[2] == forget_gate.shape[2]
# to be used for equation 7
input_gate = torch.sigmoid(input_gate + self.input_bias)
forget_gate = torch.sigmoid(forget_gate + self.forget_bias)
return input_gate, forget_gate
def attend_over_memory(self, memory):
"""
Perform multiheaded attention over `memory`.
Args:
memory: Current relational memory.
Returns:
The attended-over memory.
"""
for _ in range(self.num_blocks):
attended_memory = self.multihead_attention(memory)
# Add a skip connection to the multiheaded attention's input.
memory = self.attended_memory_layernorm(memory + attended_memory)
# add a skip connection to the attention_mlp's input.
attention_mlp = memory
for i, l in enumerate(self.attention_mlp):
attention_mlp = self.attention_mlp[i](attention_mlp)
attention_mlp = F.relu(attention_mlp)
memory = self.attended_memory_layernorm2(memory + attention_mlp)
return memory
def forward_step(self, inputs, memory, treat_input_as_matrix=False):
"""
Forward step of the relational memory core.
Args:
inputs: Tensor input.
memory: Memory output from the previous time step.
treat_input_as_matrix: Optional, whether to treat `input` as a sequence
of matrices. Default to False, in which case the input is flattened
into a vector.
Returns:
output: This time step's output.
next_memory: The next version of memory to use.
"""
if treat_input_as_matrix:
# keep (Batch, Seq, ...) dim (0, 1), flatten starting from dim 2
inputs = inputs.view(inputs.shape[0], inputs.shape[1], -1)
# apply linear layer for dim 2
inputs_reshape = self.input_projector(inputs)
else:
# keep (Batch, ...) dim (0), flatten starting from dim 1
inputs = inputs.view(inputs.shape[0], -1)
# apply linear layer for dim 1
inputs = self.input_projector(inputs)
# unsqueeze the time step to dim 1
inputs_reshape = inputs.unsqueeze(dim=1)
memory_plus_input = torch.cat([memory, inputs_reshape], dim=1)
next_memory = self.attend_over_memory(memory_plus_input)
# cut out the concatenated input vectors from the original memory slots
n = inputs_reshape.shape[1]
next_memory = next_memory[:, :-n, :]
if self.gate_style == 'unit' or self.gate_style == 'memory':
# these gates are sigmoid-applied ones for equation 7
input_gate, forget_gate = self.create_gates(inputs_reshape, memory)
# equation 7 calculation
next_memory = input_gate * torch.tanh(next_memory)
next_memory += forget_gate * memory
output = next_memory.view(next_memory.shape[0], -1)
return output, next_memory
def forward(self, inputs, memory):
# Starting each batch, we detach the hidden state from how it was previously produced.
# If we didn't, the model would try backpropagating all the way to start of the dataset.
memory = self.repackage_hidden(memory)
# for loop implementation of (entire) recurrent forward pass of the model
# inputs is batch first [batch, seq], and output logit per step is [batch, vocab]
# so the concatenated logits are [seq * batch, vocab]
# targets are flattened [seq, batch] => [seq * batch], so the dimension is correct
logits = []
# shape[1] is seq_lenth T
for idx_step in range(inputs.shape[1]):
logit, memory = self.forward_step(inputs[:, idx_step], memory)
logits.append(logit)
logits = torch.cat(logits)
if self.return_all_outputs:
return logits, memory
else:
return logit, memory
# ########## DEBUG: unit test code ##########
# input_size = 44
# seq_length = 1
# batch_size = 32
# model = RelationalMemory(mem_slots=10, head_size=20, input_size=input_size, num_tokens=66, num_heads=8, num_blocks=1, forget_bias=1., input_bias=0.)
# model_memory = model.initial_state(batch_size=batch_size)
#
# # random input
# random_input = torch.randn((32, seq_length, input_size))
# # random targets
# random_targets = torch.randn((32, seq_length, input_size))
#
# # take a one step forward
# logit, next_memory = model(random_input, model_memory, random_targets, treat_input_as_matrix=True)