Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing batch_dim_name attribute #20674

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

martin-gorner
Copy link
Contributor

@martin-gorner martin-gorner commented Dec 20, 2024

ModelParallel(batch_dim_name='batch') is currently dysfunctional and will work only if batch_dim_name corresponds to the first dimension of the mesh, which is the default anyway. There is also a problem for meshes with 3 and more dimensions.

Minimal repro 1 (showing error):
https://colab.research.google.com/drive/1jzmCZ2WNlKtD4j2heSaq-mxBoG-9WeeS?usp=sharing

Minimal repro 2 with a 3D mesh (showing error):
https://colab.research.google.com/drive/1AGku4hjwhTN_2h5yiU7Q-a6vvSrc8nRH

Real-world repro 1 (showing successful run with fix):
https://colab.research.google.com/drive/1cyn_XUFwdLUJE4pRNWPgZ2H5wzKzto-T?usp=sharing

Real-world repro 2 (showing a run without errors - but unfortunately no convergence):
https://colab.research.google.com/drive/1kY9qq27YxpowqYDT3gL98U5RuN6CYQ7b?usp=sharing

The use case is not just hypothetical.
With DeviceMesh((4,2), ("model", "batch")), fine-tuning proceeds at 147ms/step.
With DeviceMesh((2,4), ("batch", "model")), fine-tuning proceeds at 205ms/step.
The fix makes the first, faster use case work, as tested with the real-world repro 1 notebook on TPU v5e.

Remaining issues:

  • Real-world repro 2 does not show convergence. Maybe there is another bug in loss aggregation??
  • The fixes should work for combined data and model parallelism where the data is sharded along one axis and the model along a different set of axes. That is the assumption in backend/jax/distribution_lb.py:
    ** num_model_replicas_total = layout.mesh.shape[batch_dim_name] i.e. the number of model replicas is the nb of devices along the "batch" axis of the mesh
    ** mesh_model_dim_size computation: data is replicated as many times as there are unique model shards.
  • However, the default layout map for Gemma shards the model also along the "batch" dimension. This will work as long as the "batch" dimension is 1 but is useless in that case. When the "batch" dimension is >=2, I don't know what it means, i.e. how many model model replicas there are and therefore how input data should be split. The Keras team should chime in on this.

@codecov-commenter
Copy link

codecov-commenter commented Dec 20, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.96%. Comparing base (7c491bd) to head (d635f47).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #20674   +/-   ##
=======================================
  Coverage   81.95%   81.96%           
=======================================
  Files         543      543           
  Lines       50664    50668    +4     
  Branches     7828     7830    +2     
=======================================
+ Hits        41524    41528    +4     
  Misses       7246     7246           
  Partials     1894     1894           
Flag Coverage Δ
keras 81.79% <100.00%> (+<0.01%) ⬆️
keras-jax 63.90% <100.00%> (+<0.01%) ⬆️
keras-numpy 58.82% <11.11%> (-0.01%) ⬇️
keras-openvino 29.94% <11.11%> (-0.01%) ⬇️
keras-tensorflow 64.65% <11.11%> (-0.01%) ⬇️
keras-torch 63.72% <11.11%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@github-actions github-actions bot added the Gemma Gemma model specific issues label Dec 20, 2024
jax_distribution_lib.distribute_data_input, data, layouts
jax_dist_data_input = partial(
jax_distribution_lib.distribute_data_input,
batch_dim_name=distribution._batch_dim_name,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this without accessing the private variable _batch_dim_name? Could we consider passing the batch_dim_name as an argument to the relevant functions? Or, maybe the distribution object provides a public method or property to access the batch dimension name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll think of a cleaner way.

The goal at this point is to get a second pair of eyes on this fix and validate it is correct. See use cases at the end of the intro paragraph. Also, since you implemented the multi-host code, could you check if this fix does not break it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll run the internal multi-host test to make sure it still works.

@SamanehSaadat
Copy link
Member

SamanehSaadat commented Dec 20, 2024

Real-world repro 2 does not show convergence. Maybe there is another bug in loss aggregation?

@martin-gorner Could you clarify what you meant by 'no coverage' in this context?

@martin-gorner
Copy link
Contributor Author

martin-gorner commented Dec 20, 2024

Real-world repro 2 does not show convergence. Maybe there is another bug in loss aggregation?

@martin-gorner Could you clarify what you meant by 'no coverage' in this context?

I meant "convergence", i.e. the loss is not decreasing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues size:S
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants