-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixing batch_dim_name attribute #20674
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #20674 +/- ##
=======================================
Coverage 81.95% 81.96%
=======================================
Files 543 543
Lines 50664 50668 +4
Branches 7828 7830 +2
=======================================
+ Hits 41524 41528 +4
Misses 7246 7246
Partials 1894 1894
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
jax_distribution_lib.distribute_data_input, data, layouts | ||
jax_dist_data_input = partial( | ||
jax_distribution_lib.distribute_data_input, | ||
batch_dim_name=distribution._batch_dim_name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do this without accessing the private variable _batch_dim_name
? Could we consider passing the batch_dim_name
as an argument to the relevant functions? Or, maybe the distribution
object provides a public method or property to access the batch dimension name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll think of a cleaner way.
The goal at this point is to get a second pair of eyes on this fix and validate it is correct. See use cases at the end of the intro paragraph. Also, since you implemented the multi-host code, could you check if this fix does not break it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'll run the internal multi-host test to make sure it still works.
@martin-gorner Could you clarify what you meant by 'no coverage' in this context? |
I meant "convergence", i.e. the loss is not decreasing. |
ModelParallel(batch_dim_name='batch')
is currently dysfunctional and will work only if batch_dim_name corresponds to the first dimension of the mesh, which is the default anyway. There is also a problem for meshes with 3 and more dimensions.Minimal repro 1 (showing error):
https://colab.research.google.com/drive/1jzmCZ2WNlKtD4j2heSaq-mxBoG-9WeeS?usp=sharing
Minimal repro 2 with a 3D mesh (showing error):
https://colab.research.google.com/drive/1AGku4hjwhTN_2h5yiU7Q-a6vvSrc8nRH
Real-world repro 1 (showing successful run with fix):
https://colab.research.google.com/drive/1cyn_XUFwdLUJE4pRNWPgZ2H5wzKzto-T?usp=sharing
Real-world repro 2 (showing a run without errors - but unfortunately no convergence):
https://colab.research.google.com/drive/1kY9qq27YxpowqYDT3gL98U5RuN6CYQ7b?usp=sharing
The use case is not just hypothetical.
With
DeviceMesh((4,2), ("model", "batch"))
, fine-tuning proceeds at 147ms/step.With
DeviceMesh((2,4), ("batch", "model"))
, fine-tuning proceeds at 205ms/step.The fix makes the first, faster use case work, as tested with the real-world repro 1 notebook on TPU v5e.
Remaining issues:
**
num_model_replicas_total = layout.mesh.shape[batch_dim_name]
i.e. the number of model replicas is the nb of devices along the "batch" axis of the mesh**
mesh_model_dim_size
computation: data is replicated as many times as there are unique model shards.