Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Rematerialization to Keras #20743

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from

Conversation

divyashreepathihalli
Copy link
Collaborator

No description provided.

@codecov-commenter
Copy link

codecov-commenter commented Jan 9, 2025

Codecov Report

Attention: Patch coverage is 49.45055% with 46 lines in your changes missing coverage. Please review.

Project coverage is 72.13%. Comparing base (90568da) to head (8c30ed7).

Files with missing lines Patch % Lines
keras/src/layers/layer.py 21.73% 35 Missing and 1 partial ⚠️
keras/src/ops/core.py 46.15% 7 Missing ⚠️
keras/api/_tf_keras/keras/ops/__init__.py 0.00% 1 Missing ⚠️
keras/src/backend/common/remat_scope.py 95.83% 0 Missing and 1 partial ⚠️
keras/src/backend/torch/core.py 50.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (90568da) and HEAD (8c30ed7). Click for more details.

HEAD has 4 uploads less than BASE
Flag BASE (90568da) HEAD (8c30ed7)
keras 5 3
keras-numpy 1 0
keras-torch 1 0
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20743      +/-   ##
==========================================
- Coverage   82.01%   72.13%   -9.88%     
==========================================
  Files         557      558       +1     
  Lines       52016    52106      +90     
  Branches     8037     8056      +19     
==========================================
- Hits        42659    37587    -5072     
- Misses       7403    12671    +5268     
+ Partials     1954     1848     -106     
Flag Coverage Δ
keras 72.05% <49.45%> (-9.78%) ⬇️
keras-jax 64.14% <48.35%> (-0.09%) ⬇️
keras-numpy ?
keras-openvino 29.91% <39.56%> (+0.01%) ⬆️
keras-tensorflow 64.76% <48.35%> (-0.03%) ⬇️
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth implementing the higher level APIs in this PR too, as a way to validate that the low-level and high level can work together (and that this low-level API gives us what we want).

And maybe write a colab showing this functionality in action.

keras/src/ops/core.py Outdated Show resolved Hide resolved
keras/src/ops/core.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

keras/src/backend/jax/core.py Outdated Show resolved Hide resolved
keras/src/backend/tensorflow/core.py Outdated Show resolved Hide resolved
"""
Implementation of rematerialization.

Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add return section

@@ -658,6 +658,16 @@ def random_seed_dtype():
return "int32"


def remat(func, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add it in the numpy and openvino backends, returning plain function results with no remat (for compatibility)

keras/src/ops/core.py Outdated Show resolved Hide resolved
@innat
Copy link

innat commented Jan 10, 2025

@divyashreepathihalli
Thanks for the initial pull request for this. I noticed you're using the tf.recompute_grad API and attempt to provide high-level API for ease of use, which is great. A while back, I came across cases where people used various hacks with tf.recompute_grad for advanced modeling, like in EfficientDet; please check this comment.

To ensure broader usability, it might be worth considering such scenarios when designing the high-level API, so others don't have to rely on hacks, or at least minimize the need to do so.

@divyashreepathihalli divyashreepathihalli marked this pull request as draft January 10, 2025 20:53
@divyashreepathihalli divyashreepathihalli changed the title Add Rematerialization op to Keras Add Rematerialization to Keras Jan 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants