diff --git a/keras_core/backend/common/variables.py b/keras_core/backend/common/variables.py index 9050a769f..3551e4d87 100644 --- a/keras_core/backend/common/variables.py +++ b/keras_core/backend/common/variables.py @@ -114,7 +114,7 @@ def value(self): def assign(self, value): value = self._convert_to_tensor(value, dtype=self.dtype) - if not shape_equal(value, self.value): + if not shape_equal(value.shape, self.shape): raise ValueError( "The shape of the target variable and " "the shape of the target value in " @@ -444,11 +444,11 @@ def standardize_shape(shape): return shape -def shape_equal(a, b): - """Return whether a.shape == b.shape (allows None entries).""" - if len(a.shape) != len(b.shape): +def shape_equal(a_shape, b_shape): + """Return whether a_shape == b_shape (allows None entries).""" + if len(a_shape) != len(b_shape): return False - for e1, e2 in zip(a.shape, b.shape): + for e1, e2 in zip(a_shape, b_shape): if e1 is not None and e2 is not None and e1 != e2: return False return True diff --git a/keras_core/backend/jax/trainer.py b/keras_core/backend/jax/trainer.py index 77c59eb36..5a329d0ef 100644 --- a/keras_core/backend/jax/trainer.py +++ b/keras_core/backend/jax/trainer.py @@ -410,6 +410,7 @@ def fit( optimizer_variables = [v.value for v in self.optimizer.variables] metrics_variables = [v.value for v in self.metrics_variables] + self._purge_model_variables() for step, data in epoch_iterator.enumerate_epoch(return_type="np"): # Callbacks callbacks.on_train_batch_begin(step) @@ -568,6 +569,8 @@ def evaluate( ] metrics_variables = [v.value for v in self.metrics_variables] + self._purge_model_variables(trainable_variables=False, + optimizer_variables=False) for step, data in epoch_iterator.enumerate_epoch(return_type="np"): callbacks.on_test_batch_begin(step) @@ -895,3 +898,31 @@ def _enforce_jax_state_sharding( optimizer_variables, metrics_variables, ) + + def _purge_model_variables(self, + trainable_variables=True, + non_trainable_variables=True, + optimizer_variables=True, + metric_variables=True): + """Remove all the model variable for memory saving. + + During JAX training, since the training function are stateless, we have + to pass in and get the model weights over and over, during which the + copy of the weights that attached to the KerasVariable are still and + occupying extra memory. We remove those variable to save memory (for + better memory utilization) at the beginning of the epoch, and reattach + the value back to variables at the end of the epoch, via + `jax_state_sync()`. + """ + if trainable_variables: + for v in self.trainable_variables: + v._value = None + if non_trainable_variables: + for v in self.non_trainable_variables: + v._value = None + if optimizer_variables: + for v in self.optimizer.variables: + v._value = None + if metric_variables: + for v in self.metrics_variables: + v._value = None diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 5028b7b28..0e2788bdf 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -261,6 +261,63 @@ def test_predict_flow(self, run_eagerly, jit_compile): self.assertAllClose(outputs["y_one"], 4 * np.ones((100, 3))) self.assertAllClose(outputs["y_two"], 4 * np.ones((100, 3))) + @pytest.mark.skipif( + backend.backend() != "jax", + reason="Memory optimization is only implemented in JAX", + ) + def test_fit_eval_flow_for_jax_model_weights(self): + model = ExampleModel(units=3) + epochs = 3 + batch_size = 20 + steps_per_epoch = 7 + dataset_size = batch_size * (steps_per_epoch - 2) + x = np.ones((dataset_size, 4)) + y = np.zeros((dataset_size, 3)) + + class ModelWeightCheck(Callback): + def __init__(self): + super().__init__() + + # Note that we access model via self._model since self.model + # will trigger a sync of the jax training state back to the model. + def on_train_batch_begin(self, batch, logs=None): + for v in self._model.trainable_variables: + assert v._value is None + for v in self._model.non_trainable_variables: + assert v._value is None + for v in self._model.optimizer.variables: + assert v._value is None + for v in self._model.metrics_variables: + assert v._value is None + + def on_test_batch_begin(self, batch, logs=None): + for v in self._model.non_trainable_variables: + assert v._value is None + for v in self._model.metrics_variables: + assert v._value is None + + model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + + model.fit( + x, + y, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch, + epochs=epochs, + callbacks=[ModelWeightCheck()], + ) + + model.evaluate( + x, + y, + batch_size=batch_size, + callbacks=[ModelWeightCheck()], + ) + @pytest.mark.requires_trainable_backend @pytest.mark.skipif( backend.backend() == "torch",