Skip to content

Commit 9067422

Browse files
beckerhecopybara-github
authored andcommitted
Make Sonnet use CudnnRNNV3
CudnnRNN and CudnnRNNV2 are not compatible with cuDNN 9+, so this change makes Sonnet use CudnnRNNV3 instead. Note that this raises the minimum supported cuDNN version to 8.1 (which is below 8.9 - the minimum supported cuDNN version in Tensorflow anyway). PiperOrigin-RevId: 621192321
1 parent 26b0518 commit 9067422

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

examples/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# buildifier: disable=out-of-order-load - Breaks copybara otherwise
12
load("//third_party/bazel_rules/rules_python/python:py_binary.bzl", "py_binary")
23
load("//sonnet/src:build_defs.bzl", "snt_py_library", "snt_py_test")
34

sonnet/src/recurrent.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -1069,12 +1069,20 @@ def _block_unrolled_lstm(input_sequence, initial_state, w_i, w_h, b):
10691069

10701070
def _cudnn_unrolled_lstm(input_sequence, initial_state, w_i, w_h, b):
10711071
"""GPU/CuDNN-RNN specialization of :class:`UnrolledLSTM`."""
1072+
max_sequence_length = tf.shape(input_sequence)[0]
1073+
batch_dim = tf.expand_dims(tf.shape(input_sequence)[1], axis=0)
1074+
1075+
# cuDNN 9+ always requires the sequence_length array argument to be present,
1076+
# so we generate it here with the max_sequence_length in all positions.
1077+
sequence_lengths = tf.broadcast_to(max_sequence_length, batch_dim)
1078+
10721079
# Intuitively, concat/transpose is not free but we did not see
10731080
# it significantly affecting performance in benchmarks.
1074-
output_sequence, all_hidden, all_cell, _ = tf.raw_ops.CudnnRNN(
1081+
output_sequence, all_hidden, all_cell, _, _ = tf.raw_ops.CudnnRNNV3(
10751082
input=input_sequence,
10761083
input_h=tf.expand_dims(initial_state.hidden, axis=0),
10771084
input_c=tf.expand_dims(initial_state.cell, axis=0),
1085+
sequence_lengths=sequence_lengths,
10781086
params=tf.concat(
10791087
[
10801088
tf.reshape(tf.transpose(w_i), [-1]),
@@ -1659,7 +1667,15 @@ def __call__(self, inputs, prev_state):
16591667
w_hz, w_hr, w_ha = tf.split(self._w_h, num_or_size_splits=3, axis=1)
16601668
b_z, b_r, b_a = tf.split(self.b, num_or_size_splits=3)
16611669
b_h_zero = tf.zeros([self._hidden_size])
1662-
outputs, next_hidden, _, _ = tf.raw_ops.CudnnRNN(
1670+
1671+
max_sequence_length = tf.shape(inputs)[0]
1672+
batch_dim = tf.expand_dims(tf.shape(inputs)[1], axis=0)
1673+
1674+
# cuDNN 9+ always requires the sequence_length array argument to be present,
1675+
# so we generate it here with the max_sequence_length in all positions.
1676+
sequence_lengths = tf.broadcast_to(max_sequence_length, batch_dim)
1677+
1678+
outputs, next_hidden, _, _, _ = tf.raw_ops.CudnnRNNV3(
16631679
input=inputs,
16641680
input_h=tf.expand_dims(prev_state, axis=0),
16651681
input_c=0,
@@ -1681,7 +1697,8 @@ def __call__(self, inputs, prev_state):
16811697
b_h_zero,
16821698
],
16831699
axis=0),
1684-
rnn_mode="gru")
1700+
rnn_mode="gru",
1701+
sequence_lengths=sequence_lengths)
16851702

16861703
return outputs, next_hidden
16871704

0 commit comments

Comments
 (0)