-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenc_dec.py
566 lines (489 loc) · 23.4 KB
/
enc_dec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
import tensorflow as tf
from tensorflow.contrib.rnn import DropoutWrapper
from tensorflow.contrib.layers import xavier_initializer as glorot
from utils import dense
from pydoc import locate
from bnlstm import BNLSTMCell
class EncDec():
""" Encoder Decoder """
def __init__(self,params, embedding,emb_dim, num_classes=None, output_layer=None):
"""
Args:
hparams: hyper param instance
embedding : embedding matrix as numpy array
emb_dim : size of an embedding
"""
global hparams
hparams = params
self.num_classes = num_classes
self.floatX = tf.float32
self.intX = tf.int32
# self.final_emb_dim = emb_dim + num_classes
self.bi_encoder_hidden = hparams.cell_units * 2
if hparams.bidirectional == True:
decoder_num_units = self.bi_encoder_hidden # double if bidirectional
else:
decoder_num_units = hparams.cell_units
# helper variable to keep track of steps
self.global_step = tf.Variable(0, name='global_step', trainable=False)
############################
# Inputs
############################
self.keep_prob = tf.placeholder(self.floatX)
self.mode = tf.placeholder(tf.bool, name="mode") # 1 stands for training
# self.max_infer_len = tf.placeholder(tf.intX) # max steps in inferences
self.vocab_size = embedding.shape[0]
# Embedding tensor is of shape [vocab_size x embedding_size]
self.embedding_tensor = self.embedding_setup(embedding, hparams.emb_trainable)
# Encoder inputs
with tf.name_scope("encoder_input"):
self.enc_input = tf.placeholder(self.intX, shape=[None, hparams.max_seq_len])
self.enc_embedded = self.embedded(self.enc_input, self.embedding_tensor)
# self.enc_embedded = tf.layers.batch_normalization(enc_embedded, training=self.mode)
self.enc_input_len = tf.placeholder(self.intX, shape=[None,])
# Condition on Y ==> Embeddings + labels
# self.enc_embedded = self.emb_add_class(self.enc_embedded, self.classes)
# Decoder inputs and targets
with tf.name_scope("decoder_input"):
self.dec_targets = tf.placeholder(self.intX, shape=[None, hparams.max_seq_len])
self.dec_input = tf.placeholder(self.intX, shape=[None, hparams.max_seq_len])
self.dec_embedded = self.embedded(self.dec_input, self.embedding_tensor)
# self.dec_embedded = tf.layers.batch_normalization(dec_embedded, training=self.mode)
# self.dec_embedded = self.emb_add_class(self.dec_embedded, self.classes)
self.dec_input_len = tf.placeholder(self.intX, shape=[None,])
self.batch_size = tf.shape(self.enc_input)[0]
############################
# Build Model
############################
# Setup cells
cell_enc_fw, cell_enc_bw, cell_enc, cell_dec = \
self.build_cell(hparams.cell_units, decoder_num_units, cell_type=hparams.cell_type)
# Get encoder data
with tf.name_scope("encoder"):
if hparams.bidirectional == True:
self.encoded_outputs, self.encoded_state = self.encoder_bi(cell_enc_fw, \
cell_enc_bw, self.enc_embedded, self.enc_input_len)
else:
self.encoded_outputs, self.encoded_state = self.encoder_one_way(\
cell_enc, self.enc_embedded, self.enc_input_len)
# Get decoder data
with tf.name_scope("decoder"):
# Get attention
self.attn_cell, self.initial_state = self.decoder_attn(
self.batch_size,
cell=cell_dec,
mem_units=self.bi_encoder_hidden,
attention_states=self.encoded_outputs,
seq_len_enc=self.enc_input_len,
attn_units=hparams.dec_out_units,
encoder_state=self.encoded_state)
# Get decoder output hidden states
self.decoded_outputs, self.decoded_final_state, self.decoded_final_seq_len=\
self.decoder_train(
self.batch_size,
attn_cell=self.attn_cell,
initial_state=self.initial_state,
decoder_inputs=self.dec_embedded,
seq_len_dec=self.dec_input_len,
output_layer=output_layer)
self.alignment_history = self.decoded_final_state.alignment_history.stack()
# Merged summary ops
self.merged_summary_ops = tf.summary.merge_all()
def optimize_step(self, loss, glbl_step):
""" Locate optimizer from hparams, take a step """
Opt = locate("tensorflow.train." + hparams.optimizer)
if Opt is None:
raise ValueError("Invalid optimizer: " + hparams.optimizer)
optimizer = Opt(hparams.l_rate)
grads_vars = optimizer.compute_gradients(loss)
capped_grads = [(None if grad is None else tf.clip_by_value(grad, -1., 1.), var)\
for grad, var in grads_vars]
take_step = optimizer.apply_gradients(capped_grads, global_step=glbl_step)
return take_step
def embedding_setup(self, embedding, emb_trainable):
""" If trainable, returns variable, otherwise the original embedding """
if emb_trainable == True:
emb_variable = tf.get_variable(
name="embedding_matrix", shape=embedding.shape,
initializer = tf.constant_initializer(embedding))
return emb_variable
else:
return embedding
def embedded(self, word_ids, embedding_tensor, scope="embedding"):
"""Swap ints for dense embeddings, on cpu.
word_ids correspond the proper row index of the embedding_tensor
Args:
words_ids: array of [batch_size x sequence of word ids]
embedding_tensor: tensor from which to retrieve the embedding, word id
takes corresponding tensor row
Returns:
tensor of shape [batch_size, sequence length, embedding size]
"""
with tf.variable_scope(scope):
with tf.device("/cpu:0"):
inputs = tf.nn.embedding_lookup(embedding_tensor, word_ids)
return inputs
def build_cell(self, num_units, decoder_num_units, cell_type="LSTMCell"):
if cell_type == "BNLSTMCell":
Cell = BNLSTMCell
cell_enc_fw = Cell(num_units, is_training=self.mode)
cell_enc_bw = Cell(num_units, is_training=self.mode)
cell_enc = Cell(num_units, is_training=self.mode)
cell_dec = Cell(decoder_num_units, is_training=self.mode)
else:
Cell = locate("tensorflow.contrib.rnn." + cell_type)
if Cell is None:
raise ValueError("Invalid cell type " + cell_type)
cell_enc_fw = Cell(num_units)
cell_enc_bw = Cell(num_units)
cell_enc = Cell(num_units)
cell_dec = Cell(decoder_num_units)
# Dropout wrapper
cell_enc_fw = DropoutWrapper(cell_enc_fw, output_keep_prob=self.keep_prob)
cell_enc_bw = DropoutWrapper(cell_enc_bw, output_keep_prob=self.keep_prob)
cell_enc = DropoutWrapper(cell_enc, output_keep_prob=self.keep_prob)
cell_dec = DropoutWrapper(cell_dec, output_keep_prob=self.keep_prob)
return cell_enc_fw, cell_enc_bw, cell_enc, cell_dec
def encoder_one_way(self, cell, x, seq_len, init_state=None):
""" Dynamic encoder for one direction
Returns:
outputs: all sequence hidden states as Tensor of shape [batch,time,units]
state: last hidden state
"""
# Output is the outputs at all time steps, state is the last state
with tf.variable_scope("dynamic_rnn"):
outputs, state = tf.nn.dynamic_rnn(\
cell, x, sequence_length=seq_len, initial_state=init_state,
dtype=self.floatX)
# state is a StateTuple class with properties StateTuple.c and StateTuple.h
return outputs, state
def encoder_bi(self, cell_fw, cell_bw, x, seq_len, init_state_fw=None,
init_state_bw=None):
""" Dynamic encoder for two directions
Returns:
outputs: a tuple(output_fw, output_bw), all sequence hidden states, each
as tensor of shape [batch,time,units]
state: tuple(output_state_fw, output_state_bw) containing the forward
and the backward final states of bidirectional rnlast hidden state
"""
# Output is the outputs at all time steps, state is the last state
with tf.variable_scope("bidirectional_dynamic_rnn"):
outputs, state = tf.nn.bidirectional_dynamic_rnn(\
cell_fw=cell_fw,
cell_bw=cell_bw,
inputs=x,
sequence_length=seq_len,
initial_state_fw=init_state_fw,
initial_state_bw=init_state_bw,
dtype=self.floatX)
# outputs: a tuple(output_fw, output_bw), all sequence hidden states,
# each as tensor of shape [batch,time,units]
# Since we don't need the outputs separate, we concat here
outputs = tf.concat(outputs,2)
outputs.set_shape([None, None, self.bi_encoder_hidden])
# If LSTM cell, then "state" is not a tuple of Tensors but an
# LSTMStateTuple of "c" and "h". Need to concat separately then new
if "LSTMStateTuple" in str(type(state[0])):
c = tf.concat([state[0][0],state[1][0]],axis=1)
h = tf.concat([state[0][1],state[1][1]],axis=1)
state = tf.contrib.rnn.LSTMStateTuple(c,h)
else:
state = tf.concat(state,1)
# Manually set shape to Tensor or all hell breaks loose
state.set_shape([None, self.bi_encoder_hidden])
return outputs, state
def emb_add_class(self, enc_embedded, classes):
""" Concatenate input and classes. Do not use for classification """
num_classes = tf.shape(classes)[1]
# final_emb_dim = tf.to_int32(tf.shape(enc_embedded)[2] + num_classes)
time_steps = tf.shape(enc_embedded)[1]
classes = tf.tile(classes, [1, time_steps]) # copy along axis=1 only
classes = tf.reshape(classes, [-1, time_steps, num_classes]) # match input
classes = tf.cast(classes, self.floatX)
concat = tf.concat([enc_embedded, classes], 2) # concat 3rd dimension
# Hardset the shape. This is hacky, but because of tf.reshape, it seems the
# tensor loses it's shape property which causes problems with contrib.rnn
# wich uses the shape property
concat.set_shape([None, None, self.final_emb_dim])
return concat
def add_classes_to_state(self, state_tuple, classes):
""" Concatenate hidden state with class labels
Args:
encoded_state: An LSTMStateTuple with properties c and h
classes: one-hot encoded labels to be concatenated to StateTuple.h
"""
# h is shape [batch_size, num_units]
classes = tf.cast(classes, self.floatX)
h_new = tf.concat([state_tuple.h, classes], 1) # concat along 1st axis
new_state_tuple = tf.contrib.rnn.LSTMStateTuple(state_tuple.c, h_new)
return new_state_tuple
def decoder_attn(self, batch_size, cell, mem_units, attention_states,
seq_len_enc, attn_units, encoder_state):
"""
Args:
cell: an instance of RNNCell.
mem_units: num of units in attention_states
attention_states: hidden states (from encoder) to attend over.
seq_len_dec: seq. len. of decoder input
attn_units: depth of attention (output) tensor
encoder_state: initial state for decoder
"""
# Attention Mechanisms. Bahdanau is additive style attention
attn_mech = tf.contrib.seq2seq.BahdanauAttention(
num_units = mem_units, # depth of query mechanism
memory = attention_states, # hidden states to attend (output of RNN)
memory_sequence_length=seq_len_enc, # masks false memories
normalize=True, # normalize energy term
name='BahdanauAttention')
# Attention Wrapper: adds the attention mechanism to the cell
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
cell = cell,# Instance of RNNCell
attention_mechanism = attn_mech, # Instance of AttentionMechanism
attention_layer_size = attn_units, # Int, depth of attention (output) tensor
alignment_history = True, # whether to store history in final output
name="attention_wrapper")
# Initial state for decoder
# Clone attention state from current attention, but use encoder_state
initial_state = attn_cell.zero_state(\
batch_size=batch_size, dtype=self.floatX)
initial_state = initial_state.clone(cell_state = encoder_state)
return attn_cell, initial_state
def decoder_train(self, batch_size, attn_cell, initial_state, decoder_inputs,
seq_len_dec, output_layer=None):
"""
Args:
attn_cell: cell wrapped with attention
initial_state: initial_state for decoder
decoder_inputs: decoder inputs for training
seq_len_enc: seq. len. of encoder input, will ignore memories beyond this
output_layer: Dense layer to project output units to vocab
Returns:
outputs: a BasicDecoderOutput with properties:
rnn_output: outputs across time,
if output_layer, then [batch_size,dec_seq_len, out_size]
otherwise output is [batch_size,dec_seq_len, cell_num_units]
sample_id: an argmax over time of rnn_output, Tensor of shape
[batch_size, dec_seq_len]
final_state: an AttentionWrapperState, a namedtuple which contains:
cell_state: such as LSTMStateTuple
attention: attention emitted at previous time step
time: current time step (the last one)
alignments: Tensor of alignments emitted at previous time step for
each attention mechanism
alignment_history: TensorArray of laignment matrices from all time
steps for each attention mechanism. Call stack() on each to convert
to Tensor
"""
# TrainingHelper does no sampling, only uses sequence inputs
helper = tf.contrib.seq2seq.TrainingHelper(
inputs = decoder_inputs, # decoder inputs
sequence_length = seq_len_dec, # decoder input length
name = "decoder_training_helper")
# Decoder setup. This decoder takes inputs and states and feeds it to the
# RNN cell at every timestep
decoder = tf.contrib.seq2seq.BasicDecoder(
cell = attn_cell,
helper = helper, # A Helper instance
initial_state = initial_state, # initial state of decoder
output_layer = output_layer) # instance of tf.layers.Layer, like Dense
# Perform dynamic decoding with decoder object
# If impute_fnished=True ensures finished states are copied through,
# corresponding outputs are zeroed out. For proper backprop
# Maximum iterations: should be fixed for training, different value for generation
outputs, final_state, final_sequence_lengths= \
tf.contrib.seq2seq.dynamic_decode(\
decoder=decoder,
impute_finished=True,
maximum_iterations=hparams.max_seq_len) # if None, decode till decoder is done
return outputs, final_state, final_sequence_lengths
def output_logits(self, decoded_outputs, num_units, vocab_size, scope):
""" Output projection function
To be used for single timestep in RNN decoder
"""
with tf.variable_scope(scope):
w = tf.get_variable("weights", [num_units, vocab_size],
dtype=self.floatX, initializer=glorot())
b = tf.get_variable("biases", [vocab_size],
dtype=self.floatX, initializer=tf.constant_initializer(0.0))
logits = tf.matmul(decoded_outputs, w) + b
return logits
def log_prob(self, logits, targets):
""" Calculate the perplexity of a sequence:
\left(\prod_{i=1}^{N} \frac{1}{P(w_i|past)} \right)^{1/n}
that is, the total product of 1 over the probability of each word, and n
root of that total
For language model, lower perplexity means better model
"""
# Probability of entire vocabulary over time
probs = tf.nn.softmax(logits)
# Get the model probability of only the targets
# Targets are the vocabulary index
# probs = tf.gather(probs, targets)
return probs
class EncDecClass(EncDec):
"""
EncDec for classification. Classification based on last decoded hidden state.
To use, must provide encoder/decoder inputs + class label
"""
def __init__(self, hparams, embedding, emb_dim):
super().__init__(hparams, embedding, emb_dim, output_layer=None)
self.model_type = "classification"
# Class label
with tf.name_scope("class_labels"):
# Labels for classification, single label per sample
self.classes = tf.placeholder(self.intX, shape=[None, hparams.num_classes])
with tf.name_scope("classification"):
if hparams.class_over_sequence == True:
# Classification over entire sequence output
self.class_logits = self.sequence_class_logits(\
decoded_outputs=self.decoded_outputs,
pool_size=hparams.dec_out_units,
max_seq_len=hparams.max_seq_len,
num_classes=hparams.num_classes)
else:
# Classification input uses only sequence final state
self.class_logits = self.output_logits(self.decoded_final_state.attention,
hparams.dec_out_units, hparams.num_classes, "class_softmax")
# Classification loss
self.loss = self.classification_loss(self.classes, self.class_logits)
self.cost = tf.reduce_mean(self.loss) # average across batch
tf.summary.scalar("class_cost", self.cost)
self.y_pred, self.y_true = self.predict(self.class_logits, self.classes)
# Loss ###################
self.optimize = self.optimize_step(self.cost,self.global_step)
def sequence_class_logits(self, decoded_outputs, pool_size, max_seq_len, num_classes):
""" Logits for the sequence """
with tf.variable_scope("pooling"):
features = tf.expand_dims(self.decoded_outputs.rnn_output, axis=-1)
pooled = tf.nn.max_pool(
value=features, # [batch, height, width, channels]
ksize=[1, 1, pool_size, 1],
strides=[1, 1, 1, 1],
padding='VALID',
name="pool")
# Get rid of last 2 empty dimensions
pooled = tf.squeeze(pooled, axis=[2,3], name="pool_squeeze")
# Pad
pad_len = max_seq_len - tf.shape(pooled)[1]
paddings = [[0,0],[0, pad_len]]
x = tf.pad(pooled, paddings=paddings, mode='CONSTANT', name="padding")
with tf.variable_scope("dense_layers"):
# FC layers
out_dim = hparams.hidden_size
in_dim=max_seq_len
for i in range(0,self.hparams.fc_num_layers):
layer_name = "fc_{}".format(i+1)
x = dense(x, in_dim, out_dim, act=tf.nn.relu, scope=layer_name)
x = tf.nn.dropout(x, self.keep_prob)
in_dim=out_dim
# Logits
logits = dense(x, out_dim, num_classes, act=None, scope="class_log")
return logits
def classification_loss(self, classes_true, classes_logits):
""" Class loss. If binary, two outputs"""
entropy_fn = tf.nn.sparse_softmax_cross_entropy_with_logits
classes_max = tf.argmax(classes_true, axis=1)
class_loss = entropy_fn(
labels=classes_max,
logits=classes_logits)
return class_loss
def predict(self, pred_logits, classes):
""" Returns class label (int) for prediction and gold
Args:
pred_logits : predicted logits, not yet softmax
classes : labels as one-hot vectors
"""
y_pred = tf.nn.softmax(pred_logits)
y_pred = tf.argmax(y_pred, axis=1)
y_true = tf.argmax(classes, axis=1)
return y_pred, y_true
class EncDecGen(EncDec):
"""
EncDec for text generation
"""
def __init__(self, hparams, embedding, emb_dim):
# Must train output_layer and recycle later for inference
vocab_size = embedding.shape[0]
output_layer = tf.contrib.keras.layers.Dense(vocab_size, use_bias=False)
super().__init__(hparams, embedding, emb_dim, output_layer=output_layer)
self.model_type="generative"
# Sequence outputs over vocab, training
self.seq_logits = self.decoded_outputs.rnn_output
# Generator loss
self.loss = self.sequence_loss(\
self.seq_logits, self.dec_targets, self.dec_input_len)
self.cost = tf.reduce_mean(self.loss) # average across batch
# Optimize ###################
self.optimize = self.optimize_step(hparams.l_rate).minimize(\
self.cost, global_step=self.global_step)
# Generated text ###################
# Sequence outputs over vocab, inferred
self.infer_outputs, self.infer_final_state, self.infer_final_seq_len=\
self.decoder_infer(self.batch_size, self.attn_cell, self.initial_state, output_layer)
self.sample_id = self.infer_outputs.sample_id
def sequence_loss(self, logits, targets, seq_len):
""" Loss on sequence, given logits and one-hot targets
Default loss below is softmax cross ent on logits
Arguments:
logits : logits over predictions, [batch, seq_len, num_decoder_symb]
targets : the class id, shape is [batch_size, seq_len], dtype int
"""
# creates mask [batch_size, seq_len]
mask = tf.sequence_mask(seq_len, dtype=tf.float32)
# We need to delete zeroed elements in targets, beyond max sequence
max_seq = tf.reduce_max(seq_len)
max_seq = tf.to_int32(max_seq)
# Slice time dimension to max_seq
logits = tf.slice(logits, [0, 0, 0], [-1, max_seq, -1])
targets = tf.slice(targets, [0, 0], [-1, max_seq])
# weight_mask = tf.slice(weight_mask, [0,0], [-1, max_seq])
loss = tf.contrib.seq2seq.sequence_loss(logits, targets, mask,
average_across_batch=False)
return loss
def decoder_infer(self, batch_size, attn_cell, initial_state, output_layer):
"""
Args:
attn_cell: cell wrapped with attention
initial_state: initial_state for decoder
output_layer: Trained dense layer to project output units to vocab
Returns:
see decoder_train() above
"""
# Greedy decoder
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
embedding=self.embedding_tensor,
start_tokens=tf.tile([hparams.start_token], [batch_size]),
end_token=hparams.end_token)
# Decoder setup. This decoder takes inputs and states and feeds it to the
# RNN cell at every timestep
decoder = tf.contrib.seq2seq.BasicDecoder(
cell = attn_cell,
helper = helper, # A Helper instance
initial_state = initial_state, # initial state of decoder
output_layer = output_layer) # instance of tf.layers.Layer, like Dense
# Perform dynamic decoding with decoder object
# If impute_fnished=True ensures finished states are copied through,
# corresponding outputs are zeroed out. For proper backprop
# Maximum iterations: should be fixed for training, can be none for infer
outputs, final_state, final_sequence_lengths= \
tf.contrib.seq2seq.dynamic_decode(\
decoder=decoder,
impute_finished=True,
maximum_iterations=hparams.max_seq_len) # if None, decode till stop token
return outputs, final_state, final_sequence_lengths
def sequence_output_logits(self, decoded_outputs, num_units, vocab_size):
""" Output projection over all timesteps
Returns:
logit tensor of shape [batch_size, timesteps, vocab_size]
"""
# We need to get the sequence length for *this* batch, this will not be
# equal for each batch since the decoder is dynamic. Meaning length is
# equal to the longest sequence in the batch, not the max over data
max_seq_len = tf.shape(decoded_outputs)[1]
# Reshape to rank 2 tensor so timestep is no longer a dimension
output = tf.reshape(decoded_outputs, [-1, num_units])
# Get the logits
logits = self.output_logits(output, num_units, vocab_size, "seq_softmax")
# Reshape back to the original tensor shape
logits = tf.reshape(logits, [-1, max_seq_len, vocab_size])
return logits