-
Notifications
You must be signed in to change notification settings - Fork 434
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Propagate shardings to the root instruction of while condition.
### Sharding Propagation Although the root instruction of while condition is in the shape `pred[]`. It can have the following meaningful shardings. 1. {replicated} 2. {manual} 3. subgroup sharding, e.g., {devices=[2,2]<=[4] last_tile_dims={manual, replicated}} Thus, we need to propagate the sharding to the root such that the partitioner can correctly handle the while condition. ### SPMD Partitioner The condition root must be replicated so that all partitions follow the same control flow. It can also have some tile dims to be manual. Thus, we need to replicate all data dims and keep the manual dims. PiperOrigin-RevId: 658455111
- Loading branch information
1 parent
0609842
commit e500f12
Showing
4 changed files
with
101 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters