diff --git a/keras/backend/tensorflow/rnn.py b/keras/backend/tensorflow/rnn.py index 874c9f8e041e..ffea6bb625a1 100644 --- a/keras/backend/tensorflow/rnn.py +++ b/keras/backend/tensorflow/rnn.py @@ -90,7 +90,9 @@ def swap_batch_timestep(input_t): flattened_inputs = tree.flatten(inputs) time_steps = flattened_inputs[0].shape[0] - time_steps_t = tf.shape(flattened_inputs[0])[0] + time_steps_t = ( + tf.shape(flattened_inputs[0])[0] if time_steps is None else time_steps + ) for input_ in flattened_inputs: input_.shape.with_rank_at_least(3)