diff --git a/keras_core/backend/jax/rnn.py b/keras_core/backend/jax/rnn.py index 2d67d5385..3550268e8 100644 --- a/keras_core/backend/jax/rnn.py +++ b/keras_core/backend/jax/rnn.py @@ -1,8 +1,10 @@ +import contextlib + import tree from jax import lax from jax import numpy as jnp -from keras_core.backend.common.stateless_scope import StatelessScope +from keras_core.backend.common import stateless_scope from keras_core.utils.nest import pack_sequence_as @@ -181,10 +183,16 @@ def _step(states, current_input): scan_xs = inputs - with StatelessScope(): - # We must use a stateless scope because `scan` will involve - # JAX tracing -- any variable update at this stage would - # be a leak. + # We must use a stateless scope because `scan` will involve + # JAX tracing -- any variable update at this stage would + # be a leak. + if stateless_scope.in_stateless_scope(): + # Leverage the parent scope. + scope = contextlib.nullcontext() + else: + scope = stateless_scope.StatelessScope() + + with scope: new_states, outputs = lax.scan( f=_step, init=initial_states, diff --git a/keras_core/backend/jax/trainer.py b/keras_core/backend/jax/trainer.py index 430bdcaab..2868e3dd4 100644 --- a/keras_core/backend/jax/trainer.py +++ b/keras_core/backend/jax/trainer.py @@ -44,7 +44,14 @@ def compute_loss_and_updates( return_losses=True, **kwargs, ) - loss = self.compute_loss(x, y, y_pred, sample_weight, allow_empty=True) + + trainable_mapping = zip(self.trainable_variables, trainable_variables) + with backend.StatelessScope(state_mapping=trainable_mapping): + # Note that this is needed for the regularization loss, which need + # the latest value of train/non-trainable variables. + loss = self.compute_loss( + x, y, y_pred, sample_weight, allow_empty=True + ) if losses: loss += ops.sum(losses) unscaled_loss = loss @@ -577,8 +584,9 @@ def evaluate( ] metrics_variables = [v.value for v in self.metrics_variables] - self._purge_model_variables(trainable_variables=False, - optimizer_variables=False) + self._purge_model_variables( + trainable_variables=False, optimizer_variables=False + ) for step, data in epoch_iterator.enumerate_epoch(return_type="np"): callbacks.on_test_batch_begin(step) @@ -911,19 +919,21 @@ def _enforce_jax_state_sharding( metrics_variables, ) - def _purge_model_variables(self, - trainable_variables=True, - non_trainable_variables=True, - optimizer_variables=True, - metric_variables=True): + def _purge_model_variables( + self, + trainable_variables=True, + non_trainable_variables=True, + optimizer_variables=True, + metric_variables=True, + ): """Remove all the model variable for memory saving. - + During JAX training, since the training function are stateless, we have to pass in and get the model weights over and over, during which the copy of the weights that attached to the KerasVariable are still and - occupying extra memory. We remove those variable to save memory (for + occupying extra memory. We remove those variable to save memory (for better memory utilization) at the beginning of the epoch, and reattach - the value back to variables at the end of the epoch, via + the value back to variables at the end of the epoch, via `jax_state_sync()`. """ if trainable_variables: diff --git a/keras_core/layers/layer.py b/keras_core/layers/layer.py index 6d3c487b1..165a8edcd 100644 --- a/keras_core/layers/layer.py +++ b/keras_core/layers/layer.py @@ -1040,6 +1040,8 @@ def losses(self): losses.extend(layer._get_own_losses()) weight_regularization_losses = [] for v in self.trainable_weights: + if backend.in_stateless_scope(): + v = backend.get_stateless_scope().get_current_value(v) regularizer = getattr(v, "regularizer", None) if regularizer: weight_regularization_losses.append(regularizer(v))