… high-dimensional scatter operation and a flag to disable it
Imported from GitHub PR #19275
This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.
The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.
Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked.
Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit
Bugs resolved: jax-ml/jax#17844
Copybara import of the project:
--
b016044 by Chenhao Jiang <[email protected]>:
PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations
Imported from GitHub PR #18326
This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.
The change of this PR is on top of #17886
Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit
Bugs resolved: jax-ml/jax#17844
Copybara import of the project:
--
de647d4 by Chenhao Jiang <[email protected]>:
Support scatter with non-scalar indices and updates
Merging this change closes #18326
PiperOrigin-RevId: 691023328
--
fbdb066 by Chenhao Jiang <[email protected]>:
Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.
--
d36c8ac by Chenhao Jiang <[email protected]>:
Fix the scatter determinism expander for various dimension numbers
--
678886f by Chenhao Jiang <[email protected]>:
Add a flag for enabling the scatter_determinism_expander on GPU.
Merging this change closes #19275
FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761