From 4a4a139c7aada9f4495620e5a1c5f7ef20d84395 Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Tue, 19 Dec 2023 13:38:23 -0800 Subject: [PATCH] Fix for https://github.com/keras-team/keras/issues/18941 (#18968) This allows the static shape to propagate in RNNs in the case when the number of time steps is fixed and known at build time. --- keras/backend/tensorflow/rnn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras/backend/tensorflow/rnn.py b/keras/backend/tensorflow/rnn.py index 874c9f8e041..ffea6bb625a 100644 --- a/keras/backend/tensorflow/rnn.py +++ b/keras/backend/tensorflow/rnn.py @@ -90,7 +90,9 @@ def swap_batch_timestep(input_t): flattened_inputs = tree.flatten(inputs) time_steps = flattened_inputs[0].shape[0] - time_steps_t = tf.shape(flattened_inputs[0])[0] + time_steps_t = ( + tf.shape(flattened_inputs[0])[0] if time_steps is None else time_steps + ) for input_ in flattened_inputs: input_.shape.with_rank_at_least(3)