From f2c37668377f09b09a2830e68a8e97dc882cdb75 Mon Sep 17 00:00:00 2001 From: Qianli Scott Zhu Date: Tue, 19 Sep 2023 15:37:54 -0700 Subject: [PATCH] Fix JAX RNN backend issue. (#924) --- keras_core/backend/jax/rnn.py | 11 +++- keras_core/layers/__init__.py | 3 ++ keras_core/layers/rnn/rnn.py | 18 +++++++ keras_core/layers/rnn/stacked_rnn_cells.py | 3 +- .../layers/rnn/stacked_rnn_cells_test.py | 51 +++++++++++++++++++ 5 files changed, 83 insertions(+), 3 deletions(-) diff --git a/keras_core/backend/jax/rnn.py b/keras_core/backend/jax/rnn.py index 2d67d5385..9b1e0b49f 100644 --- a/keras_core/backend/jax/rnn.py +++ b/keras_core/backend/jax/rnn.py @@ -1,8 +1,10 @@ +import contextlib + import tree from jax import lax from jax import numpy as jnp -from keras_core.backend.common.stateless_scope import StatelessScope +from keras_core.backend.common import stateless_scope from keras_core.utils.nest import pack_sequence_as @@ -181,7 +183,12 @@ def _step(states, current_input): scan_xs = inputs - with StatelessScope(): + if stateless_scope.in_stateless_scope(): + # Reuse the existing parent stateless scope. + scope = contextlib.nullcontext() + else: + scope = stateless_scope.StatelessScope() + with scope: # We must use a stateless scope because `scan` will involve # JAX tracing -- any variable update at this stage would # be a leak. diff --git a/keras_core/layers/__init__.py b/keras_core/layers/__init__.py index 444845c22..2f549f4d2 100644 --- a/keras_core/layers/__init__.py +++ b/keras_core/layers/__init__.py @@ -121,9 +121,12 @@ from keras_core.layers.rnn.conv_lstm2d import ConvLSTM2D from keras_core.layers.rnn.conv_lstm3d import ConvLSTM3D from keras_core.layers.rnn.gru import GRU +from keras_core.layers.rnn.gru import GRUCell from keras_core.layers.rnn.lstm import LSTM +from keras_core.layers.rnn.lstm import LSTMCell from keras_core.layers.rnn.rnn import RNN from keras_core.layers.rnn.simple_rnn import SimpleRNN +from keras_core.layers.rnn.simple_rnn import SimpleRNNCell from keras_core.layers.rnn.stacked_rnn_cells import StackedRNNCells from keras_core.layers.rnn.time_distributed import TimeDistributed from keras_core.saving import serialization_lib diff --git a/keras_core/layers/rnn/rnn.py b/keras_core/layers/rnn/rnn.py index 9f6aa24cf..e41e299f1 100644 --- a/keras_core/layers/rnn/rnn.py +++ b/keras_core/layers/rnn/rnn.py @@ -390,6 +390,10 @@ def call( initial_state, ) + # Prepopulate the dropout state so that the inner_loop is stateless + # this is particularly important for JAX backend. + self._maybe_config_dropout_masks(self.cell, sequences, initial_state) + last_output, outputs, states = self.inner_loop( sequences=sequences, initial_state=initial_state, @@ -421,6 +425,20 @@ def call( return output, *states return output + def _maybe_config_dropout_masks(self, cell, input_sequence, input_state): + step_input = input_sequence[:, 0, :] + state = ( + input_state[0] + if isinstance(input_state, (list, tuple)) + else input_state + ) + if isinstance(cell, DropoutRNNCell): + cell.get_dropout_mask(step_input) + cell.get_recurrent_dropout_mask(state) + if isinstance(cell, StackedRNNCells): + for c, s in zip(cell.cells, input_state): + self._maybe_config_dropout_masks(c, input_sequence, s) + def _maybe_reset_dropout_masks(self, cell): if isinstance(cell, DropoutRNNCell): cell.reset_dropout_mask() diff --git a/keras_core/layers/rnn/stacked_rnn_cells.py b/keras_core/layers/rnn/stacked_rnn_cells.py index deb808462..f4e257628 100644 --- a/keras_core/layers/rnn/stacked_rnn_cells.py +++ b/keras_core/layers/rnn/stacked_rnn_cells.py @@ -89,6 +89,7 @@ def call(self, inputs, states, training=False, **kwargs): # Call the cells in order and store the returned states. new_states = [] for cell, states in zip(self.cells, states): + state_is_list = tree.is_nested(states) states = list(states) if tree.is_nested(states) else [states] if isinstance(cell, Layer) and cell._call_has_training_arg: kwargs["training"] = training @@ -96,7 +97,7 @@ def call(self, inputs, states, training=False, **kwargs): kwargs.pop("training", None) cell_call_fn = cell.__call__ if callable(cell) else cell.call inputs, states = cell_call_fn(inputs, states, **kwargs) - if len(states) == 1: + if len(states) == 1 and not state_is_list: states = states[0] new_states.append(states) diff --git a/keras_core/layers/rnn/stacked_rnn_cells_test.py b/keras_core/layers/rnn/stacked_rnn_cells_test.py index e9a0e0f71..edb51fbef 100644 --- a/keras_core/layers/rnn/stacked_rnn_cells_test.py +++ b/keras_core/layers/rnn/stacked_rnn_cells_test.py @@ -81,6 +81,57 @@ def test_basics(self): supports_masking=True, custom_objects={"TwoStatesRNNCell": TwoStatesRNNCell}, ) + self.run_layer_test( + layers.RNN, + init_kwargs={ + "cell": [ + layers.SimpleRNNCell(3, dropout=0.1, recurrent_dropout=0.1), + layers.SimpleRNNCell(4, dropout=0.1, recurrent_dropout=0.1), + layers.SimpleRNNCell(5, dropout=0.1, recurrent_dropout=0.1), + ], + "return_sequences": True, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 5), + expected_num_trainable_weights=9, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + ) + self.run_layer_test( + layers.RNN, + init_kwargs={ + "cell": [ + layers.GRUCell(3, dropout=0.1, recurrent_dropout=0.1), + layers.GRUCell(4, dropout=0.1, recurrent_dropout=0.1), + layers.GRUCell(5, dropout=0.1, recurrent_dropout=0.1), + ], + "return_sequences": True, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 5), + expected_num_trainable_weights=9, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + ) + self.run_layer_test( + layers.RNN, + init_kwargs={ + "cell": [ + layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1), + layers.LSTMCell(4, dropout=0.1, recurrent_dropout=0.1), + layers.LSTMCell(5, dropout=0.1, recurrent_dropout=0.1), + ], + "return_sequences": True, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 5), + expected_num_trainable_weights=9, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + ) def test_correctness_single_state_stack(self): sequence = np.arange(24).reshape((2, 3, 4)).astype("float32")