-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add gradient accumulation support for all backends, and enable optimizer EMA for JAX and torch #18951
Conversation
@qlzh727 please review this PR -- in particular, check whether the EMA variables and the gradient accumulators are updated in a way that is correct in a tf.distribute context |
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #18951 +/- ##
==========================================
+ Coverage 79.55% 79.63% +0.07%
==========================================
Files 337 338 +1
Lines 35056 35182 +126
Branches 6872 6908 +36
==========================================
+ Hits 27890 28017 +127
+ Misses 5587 5585 -2
- Partials 1579 1580 +1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
@haifeng-jin can you please review the changes in EDIT: actually, I fixed it... |
Nice! |
def _distributed_apply_gradients_fn( | ||
self, distribution, grads_and_vars, **kwargs | ||
def _distributed_tf_update_step( | ||
self, distribution, grads_and_vars, learning_rate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems that the learning_rate is not used here, the value in the apply_grad_to_updatew_var was retrieved from self._get_current_learning_rate(). Did I miss anything?
grads, trainable_variables, self.learning_rate | ||
) | ||
|
||
if self.use_ema: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, does the existing jax optimizer support ema? or its from the base_optimizer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now JAX supports the feature -- the new unit tests check it. The previous JAX optimizer (before this PR) didn't -- only TF did.
@@ -7,6 +7,36 @@ | |||
|
|||
|
|||
class OptimizerTest(testing.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also update the test in tensorflow/optimizer_distribute_test.py for distribution related test case? (for ema and gradient accumulation)
for i in range(len(grads)) | ||
] | ||
for n_g_acc, g_acc in zip(new_g_accs, self._accumulated_gradients): | ||
g_acc.assign(n_g_acc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the tf.distribute context, I think this probably should use https://www.tensorflow.org/api_docs/python/tf/distribute/StrategyExtended#update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, but what about self.iterations
? We update it with a simple assign
. When is it ok to use assign and when is it not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the assign works when each of the replica suppose to get same value, eg iterations are always same across all the replica. I don't think that's case for accumulated grad (eg each replica should get different value, and eventually the overall value should get a mean reduce?)
@@ -163,30 +163,24 @@ def test_ema(self): | |||
def test_gradient_accumulation(self): | |||
with self.strategy.scope(): | |||
v = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) | |||
grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]]) | |||
grads = backend.convert_to_tensor([[1.0, 1.0], [2.0, 2.0]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the test, can we test with the grad that has different value on each replica?
You can create a distribute value via https://www.tensorflow.org/api_docs/python/tf/distribute/DistributedValues
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Best I can tell, it is not possible to call an optimizer (or any other Keras function) with DistributeValues
. The reason is that DistributeValues
is a stand-in for a tensor, but it does not implement the tensor API (no .shape
, no .dtype
, etc).
No description provided.