-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Rematerialization to Keras #20743
Draft
divyashreepathihalli
wants to merge
14
commits into
keras-team:master
Choose a base branch
from
divyashreepathihalli:rematerialization
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+512
−1
Draft
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
b2e53f6
add remat op
divyashreepathihalli 5be2e6f
update test
divyashreepathihalli 7581b4a
remove print statements
divyashreepathihalli 7277150
remove memory testing
divyashreepathihalli 073f46b
run api_gen.sh
divyashreepathihalli 3be014c
update docstring
divyashreepathihalli fa1460f
add remat scope
divyashreepathihalli fef369a
code reformat
divyashreepathihalli 373a0dc
update scope to return all configs
divyashreepathihalli f1eb799
add remat wrapper to layer
divyashreepathihalli d8fa54f
add output size mode
divyashreepathihalli 38691f9
Merge branch 'keras-team:master' into rematerialization
divyashreepathihalli 62a24cf
Merge branch 'keras-team:master' into rematerialization
divyashreepathihalli 8c30ed7
add activation mode to remat
divyashreepathihalli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from keras.src.backend.common import global_state | ||
|
||
|
||
class RematScope: | ||
"""A context manager for enabling rematerialization in Keras. | ||
|
||
Rematerialization (gradient checkpointing) trades memory for computation by | ||
recomputing intermediate activations during the backward pass. This is | ||
particularly useful for training large models or large batch sizes within | ||
limited memory constraints. | ||
|
||
Args: | ||
mode: Rematerialization mode to apply. | ||
Options: | ||
- "full": Apply rematerialization globally to all supported | ||
operations. | ||
- "activations": Apply rematerialization only to activation layers. | ||
- "larger_than": Apply rematerialization to layers with output sizes | ||
larger than `output_size_threshold`. | ||
- "list_of_layers": Apply rematerialization to a specific list of | ||
layer names. | ||
- None: Disable rematerialization. | ||
output_size_threshold: Output size threshold for the | ||
`"larger_than"` mode. Layers producing outputs larger than this | ||
threshold will be rematerialized. Default is `1024`. | ||
layer_names: List of layer names for the | ||
`"list_of_layers"` mode. Default is an empty list. | ||
|
||
Examples: | ||
Using "list_of_layers" mode: | ||
|
||
```python | ||
from keras.src.backend.common.remat_scope import RematScope | ||
|
||
with RematScope(mode="list_of_layers", layer_names=["dense_1", | ||
"conv2d_1"]): | ||
layer1 = keras.layers.Dense(128, name="dense_1") | ||
layer2 = keras.layers.Conv2D(64, (3, 3), name="conv2d_1") | ||
layer3 = keras.layers.Dense(64, name="dense_2") | ||
|
||
# Only layer1 and layer2 will apply rematerialization | ||
output1 = layer1(input_tensor) | ||
output2 = layer2(output1) | ||
output3 = layer3(output2) | ||
``` | ||
|
||
Using "larger_than" mode with a specific output size threshold: | ||
|
||
```python | ||
from keras.src.backend.common.remat_scope import RematScope | ||
|
||
with RematScope(mode="larger_than", output_size_threshold=2048): | ||
layer = keras.layers.Conv2D(64, (3, 3)) | ||
output = layer(input_tensor) # Conv2D outputs larger than 2048 | ||
``` | ||
|
||
Nested scopes for fine-grained control: | ||
|
||
```python | ||
from keras.src.backend.common.remat_scope import RematScope | ||
|
||
with RematScope(mode="full"): | ||
layer1 = keras.layers.Dense(128, activation='relu') | ||
with RematScope(mode="larger_than", output_size_threshold=512): | ||
layer2 = keras.layers.Conv2D(32, (3, 3)) | ||
output = layer2(layer1(input_tensor)) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, mode="full", output_size_threshold=1024, layer_names=None | ||
): | ||
if mode not in { | ||
"full", | ||
"activations", | ||
"larger_than", | ||
"list_of_layers", | ||
None, | ||
}: | ||
raise ValueError( | ||
f"Invalid mode '{mode}'. Supported modes are: " | ||
"'full', 'activations', 'larger_than', 'list_of_layers', or " | ||
" None." | ||
) | ||
self.mode = mode | ||
self.output_size_threshold = output_size_threshold | ||
self.layer_names = layer_names or [] | ||
self._pop_on_exit = False | ||
|
||
def __enter__(self): | ||
remat_scope_stack = global_state.get_global_attribute( | ||
"remat_scope_stack", default=[], set_to_default=True | ||
) | ||
remat_scope_stack.append(self) | ||
self._pop_on_exit = True | ||
return self | ||
|
||
def __exit__(self, *args, **kwargs): | ||
if self._pop_on_exit: | ||
remat_scope_stack = global_state.get_global_attribute( | ||
"remat_scope_stack" | ||
) | ||
remat_scope_stack.pop() | ||
|
||
|
||
def get_current_remat_mode(): | ||
"""Get the current rematerialization mode and associated settings. | ||
|
||
Returns: | ||
dict: A dictionary containing the rematerialization mode and other | ||
settings. | ||
Example: | ||
{ | ||
"mode": "list_of_layers", | ||
"output_size_threshold": 1024, | ||
"layer_names": ["dense_1", "conv2d_1"] | ||
} | ||
""" | ||
remat_scope_stack = global_state.get_global_attribute("remat_scope_stack") | ||
if remat_scope_stack is None or not remat_scope_stack: | ||
return None | ||
active_scope = remat_scope_stack[-1] | ||
return { | ||
"mode": active_scope.mode, | ||
"output_size_threshold": active_scope.output_size_threshold, | ||
"layer_names": active_scope.layer_names, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from keras.src import testing | ||
from keras.src.backend.common import global_state | ||
from keras.src.backend.common.remat_scope import RematScope | ||
from keras.src.backend.common.remat_scope import get_current_remat_mode | ||
|
||
|
||
class TestRematScope(testing.TestCase): | ||
def setUp(self): | ||
"""Reset global state before each test.""" | ||
global_state.clear_session() | ||
|
||
def test_remat_scope_activation(self): | ||
self.assertIsNone( | ||
get_current_remat_mode() | ||
) # Initially, no mode is active | ||
|
||
with RematScope(mode="full"): | ||
self.assertEqual( | ||
get_current_remat_mode()["mode"], "full" | ||
) # Mode is set to "full" | ||
|
||
self.assertIsNone( | ||
get_current_remat_mode() | ||
) # Mode is restored to None after scope ends | ||
|
||
def test_remat_scope_nested(self): | ||
"""Test nested scopes with different rematerialization modes.""" | ||
with RematScope(mode="full"): | ||
self.assertEqual( | ||
get_current_remat_mode()["mode"], "full" | ||
) # Outer scope is "full" | ||
|
||
with RematScope(mode="activations"): | ||
self.assertEqual( | ||
get_current_remat_mode()["mode"], "activations" | ||
) # Inner scope is "activations" | ||
|
||
self.assertEqual( | ||
get_current_remat_mode()["mode"], "full" | ||
) # Back to outer scope | ||
|
||
self.assertIsNone( | ||
get_current_remat_mode() | ||
) # Mode is restored to None after all scopes | ||
|
||
def test_remat_scope_stack_management(self): | ||
"""Test that the remat_scope_stack is managed correctly.""" | ||
self.assertIsNone( | ||
global_state.get_global_attribute("remat_scope_stack") | ||
) # No stack initially | ||
|
||
with RematScope(mode="full"): | ||
remat_stack = global_state.get_global_attribute("remat_scope_stack") | ||
self.assertIsNotNone(remat_stack) # Stack is initialized | ||
self.assertEqual(len(remat_stack), 1) # Stack contains one entry | ||
|
||
with RematScope(mode="activations"): | ||
remat_stack = global_state.get_global_attribute( | ||
"remat_scope_stack" | ||
) | ||
self.assertEqual( | ||
len(remat_stack), 2 | ||
) # Stack contains two entries | ||
|
||
remat_stack = global_state.get_global_attribute("remat_scope_stack") | ||
self.assertEqual(len(remat_stack), 1) # Back to one entry | ||
|
||
self.assertEqual( | ||
global_state.get_global_attribute("remat_scope_stack"), [] | ||
) # Stack is cleared | ||
|
||
def test_invalid_mode(self): | ||
"""Test that invalid rematerialization modes raise an error.""" | ||
with self.assertRaises(ValueError): | ||
RematScope(mode="invalid") # Invalid mode should raise ValueError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -661,6 +661,18 @@ def random_seed_dtype(): | |
return "int32" | ||
|
||
|
||
def remat(func, *args, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also add it in the numpy and openvino backends, returning plain function results with no remat (for compatibility) |
||
"""Implementation of rematerialization. | ||
|
||
Args: | ||
func: The function or operation to rematerialize. | ||
Returns: | ||
A function wrapping func that defines a custom gradient, which | ||
recomputes f on the backwards pass of a gradient call. | ||
""" | ||
return torch.utils.checkpoint.checkpoint(func)(*args, **kwargs) | ||
|
||
|
||
class custom_gradient: | ||
"""Decorator for custom gradients. | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add return section