Skip to content

Commit

Permalink
WIP: TensorFlow 2.x API updates.
Browse files Browse the repository at this point in the history
Signed-off-by: format 2020.06.15 <github.com/ChrisCummins/format>
  • Loading branch information
ChrisCummins committed Aug 17, 2020
1 parent d129ba9 commit b69d036
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 18 deletions.
29 changes: 15 additions & 14 deletions programl/models/lstm/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from labm8.py import app
from labm8.py.progress import NullContext
from labm8.py.progress import ProgressContext
from tensorflow import keras

from programl.models.batch_data import BatchData
from programl.models.batch_results import BatchResults
Expand Down Expand Up @@ -99,7 +100,7 @@ def __init__(

# Reset any previous Tensorflow session. This is required when running
# consecutive LSTM models in the same process.
tf.compat.v1.keras.backend.clear_session()
keras.backend.clear_session()

@staticmethod
def MakeLstmLayer(*args, **kwargs):
Expand All @@ -110,32 +111,32 @@ def MakeLstmLayer(*args, **kwargs):
much slower but works on CPU.
"""
if FLAGS.cudnn_lstm and tf.compat.v1.test.is_gpu_available():
return tf.compat.v1.keras.layers.CuDNNLSTM(*args, **kwargs)
return keras.layers.CuDNNLSTM(*args, **kwargs)
else:
return tf.compat.v1.keras.layers.LSTM(*args, **kwargs, implementation=1)
return keras.layers.LSTM(*args, **kwargs, implementation=1)

def CreateKerasModel(self) -> tf.compat.v1.keras.Model:
def CreateKerasModel(self): # -> keras.Model:
"""Construct the tensorflow computation graph."""
vocab_ids = tf.compat.v1.keras.layers.Input(
vocab_ids = keras.layers.Input(
batch_shape=(self.batch_size, self.padded_sequence_length,),
dtype="int32",
name="sequence_in",
)
embeddings = tf.compat.v1.keras.layers.Embedding(
embeddings = keras.layers.Embedding(
input_dim=len(self.vocabulary) + 2,
input_length=self.padded_sequence_length,
output_dim=FLAGS.hidden_size,
name="embedding",
trainable=FLAGS.trainable_embeddings,
)(vocab_ids)

selector_vectors = tf.compat.v1.keras.layers.Input(
selector_vectors = keras.layers.Input(
batch_shape=(self.batch_size, self.padded_sequence_length, 2),
dtype="float32",
name="selector_vectors",
)

lang_model_input = tf.compat.v1.keras.layers.Concatenate(
lang_model_input = keras.layers.Concatenate(
axis=2, name="embeddings_and_selector_vectorss"
)([embeddings, selector_vectors],)

Expand All @@ -152,18 +153,18 @@ def CreateKerasModel(self) -> tf.compat.v1.keras.Model:

# Dense layers.
for i in range(1, FLAGS.hidden_dense_layer_count + 1):
lang_model = tf.compat.v1.keras.layers.Dense(
lang_model = keras.layers.Dense(
FLAGS.hidden_size, activation="relu", name=f"dense_{i}",
)(lang_model)
node_out = tf.compat.v1.keras.layers.Dense(
node_out = keras.layers.Dense(
self.node_y_dimensionality, activation="sigmoid", name="node_out",
)(lang_model)

model = tf.compat.v1.keras.Model(
model = keras.Model(
inputs=[vocab_ids, selector_vectors], outputs=[node_out],
)
model.compile(
optimizer=tf.compat.v1.keras.optimizers.Adam(
optimizer=keras.optimizers.Adam(
learning_rate=FLAGS.learning_rate
),
metrics=["accuracy"],
Expand Down Expand Up @@ -297,13 +298,13 @@ def LoadModelData(self, data_to_load: Any) -> None:
tf.compat.v1.reset_default_graph()
SetAllowedGrowthOnKerasSession()

self.model = tf.compat.v1.keras.models.load_model(path)
self.model = keras.models.load_model(path)


def SetAllowedGrowthOnKerasSession():
"""Allow growth on GPU for Keras."""
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(session)
# set_session(session)
return session
8 changes: 4 additions & 4 deletions programl/task/dataflow/lstm_batch_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Optional

import numpy as np
import tensorflow as tf
from tensorflow import keras
from labm8.py import app

from programl.graph.format.py import graph_serializer
Expand Down Expand Up @@ -84,23 +84,23 @@ def _Build(self) -> BatchData:
graph_count=len(self.graph_node_sizes),
model_data=LstmBatchData(
graph_node_sizes=np.array(self.graph_node_sizes, dtype=np.int32),
encoded_sequences=tf.compat.v1.keras.preprocessing.sequence.pad_sequences(
encoded_sequences=keras.preprocessing.sequence.pad_sequences(
self.vocab_ids,
maxlen=self.padded_sequence_length,
dtype="int32",
padding="pre",
truncating="post",
value=self._vocab_id_pad,
),
selector_vectors=tf.compat.v1.keras.preprocessing.sequence.pad_sequences(
selector_vectors=keras.preprocessing.sequence.pad_sequences(
self.selector_vectors,
maxlen=self.padded_sequence_length,
dtype="float32",
padding="pre",
truncating="post",
value=np.zeros(2, dtype=np.float32),
),
node_labels=tf.compat.v1.keras.preprocessing.sequence.pad_sequences(
node_labels=keras.preprocessing.sequence.pad_sequences(
self.targets,
maxlen=self.padded_sequence_length,
dtype="float32",
Expand Down
6 changes: 6 additions & 0 deletions programl/task/dataflow/train_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
classification targets for data flow problems.
"""
import pathlib
import sys
import time
from typing import Dict

import numpy as np
from labm8.py import app
from labm8.py import bazelutil
from labm8.py import gpu_scheduler
from labm8.py import humanize
from labm8.py import pbutil
Expand All @@ -35,6 +37,10 @@
from programl.task.dataflow import dataflow
from programl.task.dataflow.graph_loader import DataflowGraphLoader
from programl.task.dataflow.lstm_batch_builder import DataflowLstmBatchBuilder

# NOTE(cec): Workaround to prevent third_party package name shadowing from
# labm8.
sys.path.insert(0, str(bazelutil.DataPath("programl")))
from third_party.py.ncc import vocabulary


Expand Down

0 comments on commit b69d036

Please sign in to comment.