Skip to content

Commit

Permalink
Momery optimization for jax trainer.
Browse files Browse the repository at this point in the history
  • Loading branch information
qlzh727 committed Sep 14, 2023
1 parent e8db3b6 commit 7f5926c
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 5 deletions.
10 changes: 5 additions & 5 deletions keras_core/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def value(self):

def assign(self, value):
value = self._convert_to_tensor(value, dtype=self.dtype)
if not shape_equal(value, self.value):
if not shape_equal(value.shape, self.shape):
raise ValueError(
"The shape of the target variable and "
"the shape of the target value in "
Expand Down Expand Up @@ -444,11 +444,11 @@ def standardize_shape(shape):
return shape


def shape_equal(a, b):
"""Return whether a.shape == b.shape (allows None entries)."""
if len(a.shape) != len(b.shape):
def shape_equal(a_shape, b_shape):
"""Return whether a_shape == b_shape (allows None entries)."""
if len(a_shape) != len(b_shape):
return False
for e1, e2 in zip(a.shape, b.shape):
for e1, e2 in zip(a_shape, b_shape):
if e1 is not None and e2 is not None and e1 != e2:
return False
return True
Expand Down
31 changes: 31 additions & 0 deletions keras_core/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ def fit(
optimizer_variables = [v.value for v in self.optimizer.variables]
metrics_variables = [v.value for v in self.metrics_variables]

self._purge_model_variables()
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
# Callbacks
callbacks.on_train_batch_begin(step)
Expand Down Expand Up @@ -568,6 +569,8 @@ def evaluate(
]
metrics_variables = [v.value for v in self.metrics_variables]

self._purge_model_variables(trainable_variables=False,
optimizer_variables=False)
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_test_batch_begin(step)

Expand Down Expand Up @@ -895,3 +898,31 @@ def _enforce_jax_state_sharding(
optimizer_variables,
metrics_variables,
)

def _purge_model_variables(self,
trainable_variables=True,
non_trainable_variables=True,
optimizer_variables=True,
metric_variables=True):
"""Remove all the model variable for memory saving.
During JAX training, since the training function are stateless, we have
to pass in and get the model weights over and over, during which the
copy of the weights that attached to the KerasVariable are still and
occupying extra memory. We remove those variable to save memory (for
better memory utilization) at the beginning of the epoch, and reattach
the value back to variables at the end of the epoch, via
`jax_state_sync()`.
"""
if trainable_variables:
for v in self.trainable_variables:
v._value = None
if non_trainable_variables:
for v in self.non_trainable_variables:
v._value = None
if optimizer_variables:
for v in self.optimizer.variables:
v._value = None
if metric_variables:
for v in self.metrics_variables:
v._value = None
57 changes: 57 additions & 0 deletions keras_core/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,63 @@ def test_predict_flow(self, run_eagerly, jit_compile):
self.assertAllClose(outputs["y_one"], 4 * np.ones((100, 3)))
self.assertAllClose(outputs["y_two"], 4 * np.ones((100, 3)))

@pytest.mark.skipif(
backend.backend() != "jax",
reason="Memory optimization is only implemented in JAX",
)
def test_fit_eval_flow_for_jax_model_weights(self):
model = ExampleModel(units=3)
epochs = 3
batch_size = 20
steps_per_epoch = 7
dataset_size = batch_size * (steps_per_epoch - 2)
x = np.ones((dataset_size, 4))
y = np.zeros((dataset_size, 3))

class ModelWeightCheck(Callback):
def __init__(self):
super().__init__()

# Note that we access model via self._model since self.model
# will trigger a sync of the jax training state back to the model.
def on_train_batch_begin(self, batch, logs=None):
for v in self._model.trainable_variables:
assert v._value is None
for v in self._model.non_trainable_variables:
assert v._value is None
for v in self._model.optimizer.variables:
assert v._value is None
for v in self._model.metrics_variables:
assert v._value is None

def on_test_batch_begin(self, batch, logs=None):
for v in self._model.non_trainable_variables:
assert v._value is None
for v in self._model.metrics_variables:
assert v._value is None

model.compile(
optimizer=optimizers.SGD(),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
)

model.fit(
x,
y,
batch_size=batch_size,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
callbacks=[ModelWeightCheck()],
)

model.evaluate(
x,
y,
batch_size=batch_size,
callbacks=[ModelWeightCheck()],
)

@pytest.mark.requires_trainable_backend
@pytest.mark.skipif(
backend.backend() == "torch",
Expand Down

0 comments on commit 7f5926c

Please sign in to comment.