@@ -1069,12 +1069,20 @@ def _block_unrolled_lstm(input_sequence, initial_state, w_i, w_h, b):
1069
1069
1070
1070
def _cudnn_unrolled_lstm (input_sequence , initial_state , w_i , w_h , b ):
1071
1071
"""GPU/CuDNN-RNN specialization of :class:`UnrolledLSTM`."""
1072
+ max_sequence_length = tf .shape (input_sequence )[0 ]
1073
+ batch_dim = tf .expand_dims (tf .shape (input_sequence )[1 ], axis = 0 )
1074
+
1075
+ # cuDNN 9+ always requires the sequence_length array argument to be present,
1076
+ # so we generate it here with the max_sequence_length in all positions.
1077
+ sequence_lengths = tf .broadcast_to (max_sequence_length , batch_dim )
1078
+
1072
1079
# Intuitively, concat/transpose is not free but we did not see
1073
1080
# it significantly affecting performance in benchmarks.
1074
- output_sequence , all_hidden , all_cell , _ = tf .raw_ops .CudnnRNN (
1081
+ output_sequence , all_hidden , all_cell , _ , _ = tf .raw_ops .CudnnRNNV3 (
1075
1082
input = input_sequence ,
1076
1083
input_h = tf .expand_dims (initial_state .hidden , axis = 0 ),
1077
1084
input_c = tf .expand_dims (initial_state .cell , axis = 0 ),
1085
+ sequence_lengths = sequence_lengths ,
1078
1086
params = tf .concat (
1079
1087
[
1080
1088
tf .reshape (tf .transpose (w_i ), [- 1 ]),
@@ -1659,7 +1667,15 @@ def __call__(self, inputs, prev_state):
1659
1667
w_hz , w_hr , w_ha = tf .split (self ._w_h , num_or_size_splits = 3 , axis = 1 )
1660
1668
b_z , b_r , b_a = tf .split (self .b , num_or_size_splits = 3 )
1661
1669
b_h_zero = tf .zeros ([self ._hidden_size ])
1662
- outputs , next_hidden , _ , _ = tf .raw_ops .CudnnRNN (
1670
+
1671
+ max_sequence_length = tf .shape (inputs )[0 ]
1672
+ batch_dim = tf .expand_dims (tf .shape (inputs )[1 ], axis = 0 )
1673
+
1674
+ # cuDNN 9+ always requires the sequence_length array argument to be present,
1675
+ # so we generate it here with the max_sequence_length in all positions.
1676
+ sequence_lengths = tf .broadcast_to (max_sequence_length , batch_dim )
1677
+
1678
+ outputs , next_hidden , _ , _ , _ = tf .raw_ops .CudnnRNNV3 (
1663
1679
input = inputs ,
1664
1680
input_h = tf .expand_dims (prev_state , axis = 0 ),
1665
1681
input_c = 0 ,
@@ -1681,7 +1697,8 @@ def __call__(self, inputs, prev_state):
1681
1697
b_h_zero ,
1682
1698
],
1683
1699
axis = 0 ),
1684
- rnn_mode = "gru" )
1700
+ rnn_mode = "gru" ,
1701
+ sequence_lengths = sequence_lengths )
1685
1702
1686
1703
return outputs , next_hidden
1687
1704
0 commit comments