Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Fix] Balance transformer blocks across shards #41

Merged
merged 4 commits into from
Nov 26, 2023

Conversation

abourramouss
Copy link
Contributor

As we were discussing, the current implementation works like this:

  1. It first gives an equal number of parameters to each shard.
  2. If a transformer block is going to be split across diferent shards, prevent it, and make the current transformer block part of the current shard.

This way, we can guarantee that each shard/partition will get an equal amount of transformer blocks.

But there is an edge case, where if we specify that we want 5 shards and we have 6 transformer blocks in the model, In that case:

Shard 1 to 3 get 2 transformer blocks each, shard 4 gets the final layers and shard 5 doesn't get nothing.

To prevent this, if balance is not really important, we could shard based on transformer blocks, so if 5 shards were specified, shard 1 would get 2 transformer blocks and shard 2 to 5 would get 1 transformer block.

It must use the transformers_h_X part to indentify the block, since from transformer_h_X to end changes from model to model
@xrsrke xrsrke merged commit 05e1c45 into xrsrke:main Nov 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants