Skip to content

Commit

Permalink
Update jax trainer function to save memory buffer. (#897)
Browse files Browse the repository at this point in the history
* Update jax trainer function to save memory buffer.

* Address format issu.
  • Loading branch information
qlzh727 authored Sep 18, 2023
1 parent 9d39e9a commit 2dfe475
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions keras_core/backend/jax/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import jax
import numpy as np
import tree
Expand Down Expand Up @@ -237,8 +239,11 @@ def multi_train_steps(state, data):
train_step = one_train_step

if not self.run_eagerly and self.jit_compile:

@jax.jit
# Note that we mark the state and data to be donated to jax,
# so that jax will reuse the memory buffer for outputs.
# This will reduce the memory usage of the training function by
# half.
@partial(jax.jit, donate_argnames="state")
def compiled_train_step(state, data):
return train_step(state, data)

Expand Down Expand Up @@ -266,8 +271,11 @@ def multi_test_steps(state, data):
test_step = one_test_step

if not self.run_eagerly and self.jit_compile:

@jax.jit
# Note that we mark the state and data to be donated to jax,
# so that jax will reuse the memory buffer for outputs.
# This will reduce the memory usage of the training function by
# half.
@partial(jax.jit, donate_argnames="state")
def compiled_test_step(state, data):
return test_step(state, data)

Expand Down Expand Up @@ -578,15 +586,18 @@ def evaluate(
)
data = self._distribute_data(data)
logs, state = self.test_function(state, data)
# Note that trainable variables are not returned since they're
# immutable here.
_, non_trainable_variables, metrics_variables = state
(
trainable_variables,
non_trainable_variables,
metrics_variables,
) = state

# Setting _jax_state enables callbacks to force a state sync
# if they need to.
self._jax_state = {
# I wouldn't recommend modifying non-trainable model state
# during evaluate(), but it's allowed.
"trainable_variables": trainable_variables,
"non_trainable_variables": non_trainable_variables,
"metrics_variables": metrics_variables,
}
Expand Down Expand Up @@ -764,8 +775,9 @@ def test_on_batch(
logs, state = self.test_function(state, [data])

# State sync
_, non_trainable_variables, metrics_variables = state
trainable_variables, non_trainable_variables, metrics_variables = state
self._jax_state = {
"trainable_variables": trainable_variables,
"non_trainable_variables": non_trainable_variables,
"metrics_variables": metrics_variables,
}
Expand Down

0 comments on commit 2dfe475

Please sign in to comment.