Skip to content

Commit

Permalink
Fix unit test failure.
Browse files Browse the repository at this point in the history
  • Loading branch information
qlzh727 committed Sep 19, 2023
1 parent 95498c0 commit 9d8b15a
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 16 deletions.
18 changes: 13 additions & 5 deletions keras_core/backend/jax/rnn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import contextlib

import tree
from jax import lax
from jax import numpy as jnp

from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.backend.common import stateless_scope
from keras_core.utils.nest import pack_sequence_as


Expand Down Expand Up @@ -181,10 +183,16 @@ def _step(states, current_input):

scan_xs = inputs

with StatelessScope():
# We must use a stateless scope because `scan` will involve
# JAX tracing -- any variable update at this stage would
# be a leak.
# We must use a stateless scope because `scan` will involve
# JAX tracing -- any variable update at this stage would
# be a leak.
if stateless_scope.in_stateless_scope():
# Leverage the parent scope.
scope = contextlib.nullcontext()

Check warning on line 191 in keras_core/backend/jax/rnn.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/rnn.py#L191

Added line #L191 was not covered by tests
else:
scope = stateless_scope.StatelessScope()

Check warning on line 193 in keras_core/backend/jax/rnn.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/rnn.py#L193

Added line #L193 was not covered by tests

with scope:

Check warning on line 195 in keras_core/backend/jax/rnn.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/rnn.py#L195

Added line #L195 was not covered by tests
new_states, outputs = lax.scan(
f=_step,
init=initial_states,
Expand Down
32 changes: 21 additions & 11 deletions keras_core/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,14 @@ def compute_loss_and_updates(
return_losses=True,
**kwargs,
)
loss = self.compute_loss(x, y, y_pred, sample_weight, allow_empty=True)

trainable_mapping = zip(self.trainable_variables, trainable_variables)
with backend.StatelessScope(state_mapping=trainable_mapping):

Check warning on line 49 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L48-L49

Added lines #L48 - L49 were not covered by tests
# Note that this is needed for the regularization loss, which need
# the latest value of train/non-trainable variables.
loss = self.compute_loss(

Check warning on line 52 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L52

Added line #L52 was not covered by tests
x, y, y_pred, sample_weight, allow_empty=True
)
if losses:
loss += ops.sum(losses)
unscaled_loss = loss
Expand Down Expand Up @@ -577,8 +584,9 @@ def evaluate(
]
metrics_variables = [v.value for v in self.metrics_variables]

self._purge_model_variables(trainable_variables=False,
optimizer_variables=False)
self._purge_model_variables(

Check warning on line 587 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L587

Added line #L587 was not covered by tests
trainable_variables=False, optimizer_variables=False
)
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_test_batch_begin(step)

Expand Down Expand Up @@ -911,19 +919,21 @@ def _enforce_jax_state_sharding(
metrics_variables,
)

def _purge_model_variables(self,
trainable_variables=True,
non_trainable_variables=True,
optimizer_variables=True,
metric_variables=True):
def _purge_model_variables(

Check warning on line 922 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L922

Added line #L922 was not covered by tests
self,
trainable_variables=True,
non_trainable_variables=True,
optimizer_variables=True,
metric_variables=True,
):
"""Remove all the model variable for memory saving.
During JAX training, since the training function are stateless, we have
to pass in and get the model weights over and over, during which the
copy of the weights that attached to the KerasVariable are still and
occupying extra memory. We remove those variable to save memory (for
occupying extra memory. We remove those variable to save memory (for
better memory utilization) at the beginning of the epoch, and reattach
the value back to variables at the end of the epoch, via
the value back to variables at the end of the epoch, via
`jax_state_sync()`.
"""
if trainable_variables:
Expand Down
2 changes: 2 additions & 0 deletions keras_core/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,8 @@ def losses(self):
losses.extend(layer._get_own_losses())
weight_regularization_losses = []
for v in self.trainable_weights:
if backend.in_stateless_scope():
v = backend.get_stateless_scope().get_current_value(v)

Check warning on line 1044 in keras_core/layers/layer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/layers/layer.py#L1044

Added line #L1044 was not covered by tests
regularizer = getattr(v, "regularizer", None)
if regularizer:
weight_regularization_losses.append(regularizer(v))
Expand Down

0 comments on commit 9d8b15a

Please sign in to comment.