Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/transformer sequence sharding #90

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from

Conversation

japols
Copy link
Member

@japols japols commented Nov 28, 2024

This PR adds a new sharding strategy shard_sequence for the transformer processor.

The current implementation (shard_heads) alternates between sharding across the sequence and sharding across heads for the sliding window attention mechanism. This requires two all-to-all communication steps per layer.

The shard_sequence strategy simplifies this process by keeping a sequence shard on each GPU and computing the sliding window attention locally. This requires a halo exchange to exchange overlapping window segments (halos) between neighboring sequence shards.

Instead of 2 all-to-all communication steps per layer, the halo exchange only requires a single point-to-point communication between neighbouring GPUs, hopefully reducing communication time and improving scalability of model sharding across multiple GPUs.

@FussyDuck
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
1 out of 2 committers have signed the CLA.

✅ japols
❌ Jan Patrick Polster


Jan Patrick Polster seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@japols japols self-assigned this Nov 28, 2024
@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.85%. Comparing base (fd2bcf1) to head (4ab4205).

Additional details and impacted files
@@           Coverage Diff            @@
##           develop      #90   +/-   ##
========================================
  Coverage    99.85%   99.85%           
========================================
  Files           23       23           
  Lines         1374     1374           
========================================
  Hits          1372     1372           
  Misses           2        2           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants