Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Complete the optimization of deterministic scatter operations #18326

Closed
wants to merge 1 commit into from

Conversation

serach24
Copy link
Contributor

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844

copybara-service bot pushed a commit that referenced this pull request Oct 17, 2024
Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang <[email protected]>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 17, 2024
Imported from GitHub PR openxla/xla#17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at openxla/xla#18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615ed047b28405a0634c42f741a678be605a by Chenhao Jiang <[email protected]>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615ed047b28405a0634c42f741a678be605a
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit that referenced this pull request Oct 17, 2024
Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang <[email protected]>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit that referenced this pull request Oct 17, 2024
Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang <[email protected]>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 17, 2024
Imported from GitHub PR openxla/xla#17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at openxla/xla#18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615ed047b28405a0634c42f741a678be605a by Chenhao Jiang <[email protected]>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615ed047b28405a0634c42f741a678be605a
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit that referenced this pull request Oct 17, 2024
Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang <[email protected]>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 17, 2024
Imported from GitHub PR openxla/xla#17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at openxla/xla#18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615ed047b28405a0634c42f741a678be605a by Chenhao Jiang <[email protected]>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615ed047b28405a0634c42f741a678be605a
PiperOrigin-RevId: 686779279
@akuegel
Copy link
Member

akuegel commented Oct 17, 2024

Sorry, I didn't have time yet to take a look, but I did a few adjustments to get part #1 landed, so you probably need to rebase.

copybara-service bot pushed a commit that referenced this pull request Oct 17, 2024
Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang <[email protected]>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615
PiperOrigin-RevId: 686871951
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 17, 2024
Imported from GitHub PR openxla/xla#17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at openxla/xla#18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615ed047b28405a0634c42f741a678be605a by Chenhao Jiang <[email protected]>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

PiperOrigin-RevId: 686871951
@serach24 serach24 force-pushed the chenhao/opt_det_scatter_full branch 2 times, most recently from 53d02b7 to ded14b2 Compare October 17, 2024 20:27
Copy link
Member

@akuegel akuegel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for taking so long, I had a bit of vacation and then was out sick.
In general this looks good, I just stumbled over the downcasts to int32, I think it would be worth adding comments whenever we downcast from int64_t to int32_t.

if (scatter_indices->shape().rank() == 1) {
CHECK_EQ(scatter_shape.dimensions_size(), 1);
CHECK_EQ(operand_dims.size(), 1);
int32_t value = is_out_of_bound ? operand_dims[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment why it is ok to downcast to int32_t. We are assigning a int64_t value here.


// Return the offset tensor as an HloInstruction
return parent->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2FromArray2D<int>(offset_tensor)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this also need int64_t?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, changed

copybara-service bot pushed a commit that referenced this pull request Oct 28, 2024
…r operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 690490783
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 28, 2024
…r operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 690490783
copybara-service bot pushed a commit that referenced this pull request Oct 28, 2024
…r operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 690490783
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 28, 2024
…r operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 690490783
copybara-service bot pushed a commit that referenced this pull request Oct 29, 2024
…r operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 690490783
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 29, 2024
…r operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 690490783
@serach24
Copy link
Contributor Author

@akuegel Hi Adrian, I have fixed the issue as well as several potential problems that I found with thorough testing. I also add a debug_option to disable this pass if anything happens. The PR is at #19275.

copybara-service bot pushed a commit that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696790875
serach24 added a commit to serach24/xla that referenced this pull request Nov 15, 2024
… scatter operations

Imported from GitHub PR openxla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes openxla#18326

COPYBARA_INTEGRATE_REVIEW=openxla#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 691023328
copybara-service bot pushed a commit that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
b016044 by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
fbdb066 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
d36c8ac by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
678886f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
b01604490908fbe43685aed7178d0a66602b7a8c by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
fbdb066fd38a2fadb4322caaabe8c8d1a9fa77e3 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
d36c8ac7260c241c4ca6ed7dc16018f8030c0b80 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
678886f97bd133c4ffa2fbf0365e15c808383a6f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696956113
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

PiperOrigin-RevId: 696956113
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening.

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 694219933
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
Explicit batch dims reserves in all the involved tensors for gather/scatter operations.

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696765965
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

vmap with scatter_add extremely slow when using xla_gpu_deterministic_ops
3 participants