Skip to content

Commit

Permalink
Fix for #18941 (#18968)
Browse files Browse the repository at this point in the history
This allows the static shape to propagate in RNNs in the case when the number of time steps is fixed and known at build time.
  • Loading branch information
hertschuh authored Dec 19, 2023
1 parent 2fbc05d commit 4a4a139
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion keras/backend/tensorflow/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def swap_batch_timestep(input_t):

flattened_inputs = tree.flatten(inputs)
time_steps = flattened_inputs[0].shape[0]
time_steps_t = tf.shape(flattened_inputs[0])[0]
time_steps_t = (
tf.shape(flattened_inputs[0])[0] if time_steps is None else time_steps
)

for input_ in flattened_inputs:
input_.shape.with_rank_at_least(3)
Expand Down

0 comments on commit 4a4a139

Please sign in to comment.