-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
74 lines (53 loc) · 1.95 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
import os
import yaml
from numpy import load
# Load flags.yaml
stream = open("flags.yaml", 'r')
FLAGS = yaml.load(stream)
## Build model
BATCH_SIZE = FLAGS['batch_size']
EPOCHS = FLAGS['epochs']
# Length of the vocabulary in chars
# TODO: ensure vocab is in namespace or find way to get access
vocab = load(os.path.join('data', FLAGS['vocab_file']))
vocab_size = len(vocab)
# The embedding dimension
embedding_dim = FLAGS['embedding_dim']
# Number of RNN units
rnn_units = FLAGS['rnn_units']
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size, embedding_dim,
batch_input_shape=[batch_size, None]),
tf.keras.layers.GRU(rnn_units,
return_sequences=True,
stateful=True,
recurrent_initializer='glorot_uniform'),
tf.keras.layers.GRU(rnn_units,
return_sequences=True,
stateful=True,
recurrent_initializer='glorot_uniform'),
tf.keras.layers.GRU(rnn_units,
return_sequences=True,
stateful=True,
recurrent_initializer='glorot_uniform'),
# tf.keras.layers.Dense(rnn_units // 2, activation='relu'),
tf.keras.layers.Dense(vocab_size*2, activation='relu'),
tf.keras.layers.Dense(vocab_size)
])
return model
model = build_model(
vocab_size = vocab_size,
embedding_dim=embedding_dim,
rnn_units=rnn_units,
batch_size=BATCH_SIZE)
print(model.summary())
# Train the model
def loss_fn(labels, logits):
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
min_lr = FLAGS['min_lr']
max_lr = FLAGS['max_lr']
optimizer = tf.keras.optimizers.Adam(learning_rate=max_lr)
model.compile(optimizer=optimizer, loss=loss_fn)