Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradient accumulation support for all backends, and enable optimizer EMA for JAX and torch #18951

Merged
merged 10 commits into from
Dec 18, 2023

Conversation

fchollet
Copy link
Collaborator

No description provided.

@fchollet
Copy link
Collaborator Author

@qlzh727 please review this PR -- in particular, check whether the EMA variables and the gradient accumulators are updated in a way that is correct in a tf.distribute context

@codecov-commenter
Copy link

codecov-commenter commented Dec 17, 2023

Codecov Report

Attention: 3 lines in your changes are missing coverage. Please review.

Comparison is base (d550552) 79.55% compared to head (ac794c9) 79.63%.
Report is 4 commits behind head on master.

Files Patch % Lines
keras/optimizers/base_optimizer.py 92.68% 1 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18951      +/-   ##
==========================================
+ Coverage   79.55%   79.63%   +0.07%     
==========================================
  Files         337      338       +1     
  Lines       35056    35182     +126     
  Branches     6872     6908      +36     
==========================================
+ Hits        27890    28017     +127     
+ Misses       5587     5585       -2     
- Partials     1579     1580       +1     
Flag Coverage Δ
keras 79.48% <96.00%> (+0.07%) ⬆️
keras-jax 61.28% <52.00%> (+0.09%) ⬆️
keras-numpy 55.97% <39.00%> (+<0.01%) ⬆️
keras-tensorflow 63.12% <47.00%> (-0.08%) ⬇️
keras-torch 63.80% <44.00%> (-0.05%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@fchollet
Copy link
Collaborator Author

fchollet commented Dec 17, 2023

@haifeng-jin can you please review the changes in Nadam and provide your analysis on why the torch tests are failing? I think it has to do with the fact that torch has its own custom Nadam update step.

EDIT: actually, I fixed it...

@haifeng-jin
Copy link
Contributor

Nice!

def _distributed_apply_gradients_fn(
self, distribution, grads_and_vars, **kwargs
def _distributed_tf_update_step(
self, distribution, grads_and_vars, learning_rate
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that the learning_rate is not used here, the value in the apply_grad_to_updatew_var was retrieved from self._get_current_learning_rate(). Did I miss anything?

grads, trainable_variables, self.learning_rate
)

if self.use_ema:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, does the existing jax optimizer support ema? or its from the base_optimizer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now JAX supports the feature -- the new unit tests check it. The previous JAX optimizer (before this PR) didn't -- only TF did.

@@ -7,6 +7,36 @@


class OptimizerTest(testing.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also update the test in tensorflow/optimizer_distribute_test.py for distribution related test case? (for ema and gradient accumulation)

for i in range(len(grads))
]
for n_g_acc, g_acc in zip(new_g_accs, self._accumulated_gradients):
g_acc.assign(n_g_acc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the tf.distribute context, I think this probably should use https://www.tensorflow.org/api_docs/python/tf/distribute/StrategyExtended#update

Copy link
Collaborator Author

@fchollet fchollet Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, but what about self.iterations? We update it with a simple assign. When is it ok to use assign and when is it not?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the assign works when each of the replica suppose to get same value, eg iterations are always same across all the replica. I don't think that's case for accumulated grad (eg each replica should get different value, and eventually the overall value should get a mean reduce?)

@@ -163,30 +163,24 @@ def test_ema(self):
def test_gradient_accumulation(self):
with self.strategy.scope():
v = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]])
grads = backend.convert_to_tensor([[1.0, 1.0], [2.0, 2.0]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the test, can we test with the grad that has different value on each replica?

You can create a distribute value via https://www.tensorflow.org/api_docs/python/tf/distribute/DistributedValues

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Best I can tell, it is not possible to call an optimizer (or any other Keras function) with DistributeValues. The reason is that DistributeValues is a stand-in for a tensor, but it does not implement the tensor API (no .shape, no .dtype, etc).

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Dec 18, 2023
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Dec 18, 2023
@fchollet fchollet merged commit c3d269b into master Dec 18, 2023
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

6 participants