Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Rematerialization to Keras #20743

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras.src.ops.core import fori_loop
from keras.src.ops.core import is_tensor
from keras.src.ops.core import map
from keras.src.ops.core import remat
from keras.src.ops.core import saturate_cast
from keras.src.ops.core import scan
from keras.src.ops.core import scatter
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras.src.ops.core import fori_loop
from keras.src.ops.core import is_tensor
from keras.src.ops.core import map
from keras.src.ops.core import remat
from keras.src.ops.core import saturate_cast
from keras.src.ops.core import scan
from keras.src.ops.core import scatter
Expand Down
127 changes: 127 additions & 0 deletions keras/src/backend/common/remat_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from keras.src.backend.common import global_state


class RematScope:
"""A context manager for enabling rematerialization in Keras.

Rematerialization (gradient checkpointing) trades memory for computation by
recomputing intermediate activations during the backward pass. This is
particularly useful for training large models or large batch sizes within
limited memory constraints.

Args:
mode: Rematerialization mode to apply.
Options:
- "full": Apply rematerialization globally to all supported
operations.
- "activations": Apply rematerialization only to activation layers.
- "larger_than": Apply rematerialization to layers with output sizes
larger than `output_size_threshold`.
- "list_of_layers": Apply rematerialization to a specific list of
layer names.
- None: Disable rematerialization.
output_size_threshold: Output size threshold for the
`"larger_than"` mode. Layers producing outputs larger than this
threshold will be rematerialized. Default is `1024`.
layer_names: List of layer names for the
`"list_of_layers"` mode. Default is an empty list.

Examples:
Using "list_of_layers" mode:

```python
from keras.src.backend.common.remat_scope import RematScope

with RematScope(mode="list_of_layers", layer_names=["dense_1",
"conv2d_1"]):
layer1 = keras.layers.Dense(128, name="dense_1")
layer2 = keras.layers.Conv2D(64, (3, 3), name="conv2d_1")
layer3 = keras.layers.Dense(64, name="dense_2")

# Only layer1 and layer2 will apply rematerialization
output1 = layer1(input_tensor)
output2 = layer2(output1)
output3 = layer3(output2)
```

Using "larger_than" mode with a specific output size threshold:

```python
from keras.src.backend.common.remat_scope import RematScope

with RematScope(mode="larger_than", output_size_threshold=2048):
layer = keras.layers.Conv2D(64, (3, 3))
output = layer(input_tensor) # Conv2D outputs larger than 2048
```

Nested scopes for fine-grained control:

```python
from keras.src.backend.common.remat_scope import RematScope

with RematScope(mode="full"):
layer1 = keras.layers.Dense(128, activation='relu')
with RematScope(mode="larger_than", output_size_threshold=512):
layer2 = keras.layers.Conv2D(32, (3, 3))
output = layer2(layer1(input_tensor))
```
"""

def __init__(
self, mode="full", output_size_threshold=1024, layer_names=None
):
if mode not in {
"full",
"activations",
"larger_than",
"list_of_layers",
None,
}:
raise ValueError(
f"Invalid mode '{mode}'. Supported modes are: "
"'full', 'activations', 'larger_than', 'list_of_layers', or "
" None."
)
self.mode = mode
self.output_size_threshold = output_size_threshold
self.layer_names = layer_names or []
self._pop_on_exit = False

def __enter__(self):
remat_scope_stack = global_state.get_global_attribute(
"remat_scope_stack", default=[], set_to_default=True
)
remat_scope_stack.append(self)
self._pop_on_exit = True
return self

def __exit__(self, *args, **kwargs):
if self._pop_on_exit:
remat_scope_stack = global_state.get_global_attribute(
"remat_scope_stack"
)
remat_scope_stack.pop()


def get_current_remat_mode():
"""Get the current rematerialization mode and associated settings.

Returns:
dict: A dictionary containing the rematerialization mode and other
settings.
Example:
{
"mode": "list_of_layers",
"output_size_threshold": 1024,
"layer_names": ["dense_1", "conv2d_1"]
}
"""
remat_scope_stack = global_state.get_global_attribute("remat_scope_stack")
if remat_scope_stack is None or not remat_scope_stack:
return None
active_scope = remat_scope_stack[-1]
return {
"mode": active_scope.mode,
"output_size_threshold": active_scope.output_size_threshold,
"layer_names": active_scope.layer_names,
}
75 changes: 75 additions & 0 deletions keras/src/backend/common/remat_scope_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from keras.src import testing
from keras.src.backend.common import global_state
from keras.src.backend.common.remat_scope import RematScope
from keras.src.backend.common.remat_scope import get_current_remat_mode


class TestRematScope(testing.TestCase):
def setUp(self):
"""Reset global state before each test."""
global_state.clear_session()

def test_remat_scope_activation(self):
self.assertIsNone(
get_current_remat_mode()
) # Initially, no mode is active

with RematScope(mode="full"):
self.assertEqual(
get_current_remat_mode()["mode"], "full"
) # Mode is set to "full"

self.assertIsNone(
get_current_remat_mode()
) # Mode is restored to None after scope ends

def test_remat_scope_nested(self):
"""Test nested scopes with different rematerialization modes."""
with RematScope(mode="full"):
self.assertEqual(
get_current_remat_mode()["mode"], "full"
) # Outer scope is "full"

with RematScope(mode="activations"):
self.assertEqual(
get_current_remat_mode()["mode"], "activations"
) # Inner scope is "activations"

self.assertEqual(
get_current_remat_mode()["mode"], "full"
) # Back to outer scope

self.assertIsNone(
get_current_remat_mode()
) # Mode is restored to None after all scopes

def test_remat_scope_stack_management(self):
"""Test that the remat_scope_stack is managed correctly."""
self.assertIsNone(
global_state.get_global_attribute("remat_scope_stack")
) # No stack initially

with RematScope(mode="full"):
remat_stack = global_state.get_global_attribute("remat_scope_stack")
self.assertIsNotNone(remat_stack) # Stack is initialized
self.assertEqual(len(remat_stack), 1) # Stack contains one entry

with RematScope(mode="activations"):
remat_stack = global_state.get_global_attribute(
"remat_scope_stack"
)
self.assertEqual(
len(remat_stack), 2
) # Stack contains two entries

remat_stack = global_state.get_global_attribute("remat_scope_stack")
self.assertEqual(len(remat_stack), 1) # Back to one entry

self.assertEqual(
global_state.get_global_attribute("remat_scope_stack"), []
) # Stack is cleared

def test_invalid_mode(self):
"""Test that invalid rematerialization modes raise an error."""
with self.assertRaises(ValueError):
RematScope(mode="invalid") # Invalid mode should raise ValueError
12 changes: 12 additions & 0 deletions keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,18 @@ def custom_gradient(fun):
return jax.custom_gradient(fun=fun)


def remat(func, *args, **kwargs):
"""Implementation of rematerialization.

Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add return section

func: The function or operation to rematerialize.
Returns:
A function wrapping func that defines a custom gradient, which
recomputes f on the backwards pass of a gradient call.
"""
return jax.checkpoint(func)(*args, **kwargs)


class name_scope(base_name_scope):
def __init__(self, name, **kwargs):
super().__init__(name, **kwargs)
Expand Down
12 changes: 12 additions & 0 deletions keras/src/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,18 @@ def custom_gradient(fun):
return tf.custom_gradient(f=fun)


def remat(func, *args, **kwargs):
"""Implementation of rematerialization.

Args:
func (callable): The function or operation to rematerialize.
Returns:
A function wrapping func that defines a custom gradient, which
recomputes f on the backwards pass of a gradient call.
"""
return tf.recompute_grad(func)(*args, **kwargs)


class name_scope(base_name_scope):
def __init__(self, name, **kwargs):
super().__init__(name, **kwargs)
Expand Down
12 changes: 12 additions & 0 deletions keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,18 @@ def random_seed_dtype():
return "int32"


def remat(func, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add it in the numpy and openvino backends, returning plain function results with no remat (for compatibility)

"""Implementation of rematerialization.

Args:
func: The function or operation to rematerialize.
Returns:
A function wrapping func that defines a custom gradient, which
recomputes f on the backwards pass of a gradient call.
"""
return torch.utils.checkpoint.checkpoint(func)(*args, **kwargs)


class custom_gradient:
"""Decorator for custom gradients.

Expand Down
93 changes: 92 additions & 1 deletion keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@
from keras.src.backend import KerasTensor
from keras.src.backend.common import global_state
from keras.src.backend.common.name_scope import current_path
from keras.src.backend.common.remat_scope import get_current_remat_mode
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
from keras.src.distribution import distribution_lib
from keras.src.dtype_policies import DTypePolicyMap
from keras.src.layers import input_spec
from keras.src.metrics.metric import Metric
from keras.src.ops.core import remat
from keras.src.ops.numpy import prod
from keras.src.ops.operation import Operation
from keras.src.saving.keras_saveable import KerasSaveable
from keras.src.utils import python_utils
Expand Down Expand Up @@ -1040,7 +1043,7 @@ def stateless_call(
if self.dtype_policy.quantization_mode is not None:
outputs = self.quantized_call(*args, **kwargs)
else:
outputs = self.call(*args, **kwargs)
outputs = self.remat_wrapper(self.call)(*args, **kwargs)
if return_losses:
losses = self.losses

Expand Down Expand Up @@ -1560,6 +1563,94 @@ def _open_name_scope(self):
self._parent_path = current_path()
return backend.name_scope(self.name, caller=self)

def remat_wrapper(self, layer_call):
"""Wrap the layer's call method to enable rematerialization dynamically.

Args:
layer_call: The original `call` method of a layer.

Returns:
callable: The wrapped method with rematerialization logic applied.
"""

def calculate_output_size(spec):
"""Calculate the total output size from the output spec.

Args:
spec: The output spec returned by compute_output_spec.

Returns:
int: The total size of the output or None if dimensions are
unknown.
"""
if isinstance(spec, KerasTensor):
shape = spec.shape
if None in shape:
return None
return int(prod(shape)) # Using Keras ops to compute size

elif isinstance(spec, (list, tuple)):
total_size = 0
for s in spec:
size = calculate_output_size(s)
if size is None:
return (
None # If any size is indeterminate, return None.
)
total_size += size
return total_size

elif isinstance(spec, dict):
total_size = 0
for k, v in spec.items():
size = calculate_output_size(v)
if size is None:
return (
None # If any size is indeterminate, return None.
)
total_size += size
return total_size

return None

def wrapped_call(*args, **kwargs):
remat_mode = get_current_remat_mode()

if remat_mode:
# Full rematerialization
if remat_mode["mode"] == "full":
return remat(layer_call)(*args, **kwargs)

# Apply rematerialization to specific layers
if remat_mode["mode"] == "list_of_layers" and (
self.name in remat_mode["layer_names"]
):
return remat(layer_call)(*args, **kwargs)

# Apply rematerialization based on output size threshold
if remat_mode["mode"] == "larger_than":
output_spec = self.compute_output_spec(*args, **kwargs)
output_size = calculate_output_size(output_spec)
if (
output_size
and output_size > remat_mode["output_size_threshold"]
):
return remat(layer_call)(*args, **kwargs)

# Apply rematerialization to activation functions only
elif remat_mode == "activations":
has_activation = (
hasattr(self, "activation")
and self.activation is not None
and not utils.is_default(self.activation)
)
if has_activation:
self.activation = remat(self.activation)

return layer_call(*args, **kwargs)

return wrapped_call


def is_backend_tensor_or_symbolic(x, allow_none=False):
if allow_none and x is None:
Expand Down
Loading
Loading