Skip to content

Commit

Permalink
Remove experimental_run_tf_function in AttentionWrapper test (#791)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored and seanpmorgan committed Dec 20, 2019
1 parent a74309b commit ed1d909
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions tensorflow_addons/seq2seq/attention_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def test_passing_memory_from_call(self, attention_cls):
("bahdanau_monotonic", wrapper.BahdanauMonotonicAttention),
)
def test_save_load_layer(self, attention_cls):
self.skipTest("Attention not working with single code path.")
vocab = 20
embedding_dim = 6
inputs = tf.keras.Input(shape=[self.timestep])
Expand All @@ -146,7 +145,7 @@ def test_save_load_layer(self, attention_cls):
model = tf.keras.Model([inputs, query, state], score)
# Fall back to v1 style Keras training loop until issue with
# using outputs of a layer in another layer's constructor.
model.compile("rmsprop", "mse", experimental_run_tf_function=False)
model.compile("rmsprop", "mse")
model.fit([x, self.query, self.state], (y, y))
y_ref = model.predict_on_batch([x_test, self.query, self.state])

Expand All @@ -158,8 +157,7 @@ def test_save_load_layer(self, attention_cls):

# Fall back to v1 style Keras training loop until issue with
# using outputs of a layer in another layer's constructor.
loaded_model.compile(
"rmsprop", "mse", experimental_run_tf_function=False)
loaded_model.compile("rmsprop", "mse")

y = loaded_model.predict_on_batch([x_test, self.query, self.state])

Expand Down

0 comments on commit ed1d909

Please sign in to comment.