Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Optimize deterministic scalar scatter #17886

Closed

Conversation

serach24
Copy link
Contributor

@serach24 serach24 commented Oct 3, 2024

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844

Copy link

google-cla bot commented Oct 3, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@serach24 serach24 force-pushed the chenhao/opt_det_scatter_scalar branch 2 times, most recently from d2332e5 to 82f2237 Compare October 3, 2024 10:43
@jprabhas jprabhas requested a review from cheshire October 3, 2024 17:57
Copy link
Contributor

@cheshire cheshire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Is it possible to smash into one commit with a lot more detailed commit message?
  2. Could you provide microbenchmark results, esp. comparing deterministic and non-deterministic scatter performance? If the performance is comparable, maybe we could even try to make it deterministic by default?

ScatterDeterminismExpander scatter_determinism_expander;
TF_ASSERT_OK_AND_ASSIGN(
bool result, RunHloPass(&scatter_determinism_expander, module.get()));
EXPECT_TRUE(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we

a. FileCheck the result of the rewrite
b. Launch it and verify correctness
c. Verify that it's indeed deterministic by launching multiple times and comparing numerics

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test serves one single purpose: to check if we will pattern match when the scatter combiner is non-associative, and followed the same pattern as in scatter_expander_tests.

The a, b, c that you mentioned are all included in the rest of the tests, specifically:
FileCheck the result of the rewrite -> ScatterAddHloVerificationTest
Launch it and verify correctness -> ScatterAddCorrectnessTest and ScatterAddOutOfBoundCorrectnessTest
Verify that it's indeed deterministic by launching multiple times and comparing numerics -> ScatterAddReproducibilityTest

@serach24
Copy link
Contributor Author

serach24 commented Oct 4, 2024

  1. Is it possible to smash into one commit with a lot more detailed commit message?

I think it is doable, but won't PRs be squashed to merge?

  1. Could you provide microbenchmark results, esp. comparing deterministic and non-deterministic scatter performance? If the performance is comparable, maybe we could even try to make it deterministic by default?

This is provided in the evaluation section of the attached doc.

)";

RunAndFilecheckHloRewrite(kModuleStr, ScatterDeterminismExpander(),
kExpectedPattern, nullptr /*after_pass_checks*/,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we usually annotate parameters like this:

/*after_pass_checks=*/nullptr

But you are passing the default values, so better to just remove them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as requested

return scatter_indices;
}

if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the if above and the if here can be simplified into one if block to:

if (index_vector_dim >= scatter_indices_shape.dimensions_size() - 1) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as requested

@@ -695,6 +696,7 @@ absl::Status RunOptimizationPasses(
if (RequireDeterminism(hlo_module->config())) {
// Scatter can be indeterministic if indices are not unique or a non
// associative combiner function is used. Eliminate these Scatter ops.
pipeline.AddPass<ScatterDeterminismExpander>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you do several canonicalizations that may already be done with GpuScatterExpander. I think it would make sense to move GpuScatterExpander first, check which canonicalizations it already applies, and avoid duplicating those.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I mixed up GpuScatterExpander and ScatterSimplifier. ScatterSimplifier does a bunch of simplifications which seem related to what you do, the normalized scatter has this form (copied from the comment in scatter_simplifier.h):

// The output scatter's attributes will have the following characteristics:
// - scatter_indices is a two-dimensional tensor
// - index_vector_dim is 1
// - inserted_window_dims is []
// - update_window_dims is [0, 1, ...]
// - scatter_dims_to_operand_dims is [0, 1, ...]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there are overlaps between the canonicalization process of ScatterSimplifier and ScatterExpander, as pointed out in the comments in the scatter_simplifier.h, above what you pasted:

// It implements the first two steps of the algorithm decribed in
// ScatterExpander::ExpandInstruction (scatter_expander.cc). Additionally, it
// transposes updates and operands to transform scatter_dims_to_operand_dims
// into the identity mapping. This is different from the algorithm in
// ScatterExpander, which instead applies the mapping in scatter_indices.

I was following the exact same canonicalization of ScatterExpander, that is why I extracted those functions into the scatter_utils.cc file, to be reused by both ScatterExpander and ScatterDeterminismExpander. In the gpu_compiler.cc, I was also following the same convention, adding the ScatterDeterminismExpander pass before the ScatterExpander with kEliminateIndeterministicScatters matching

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I understand that it would be a bit harder to rewrite your pass based on the different canonicalization used in ScatterSimplifier. I guess there is still some potential to combine these canonicalizations, but it is somewhat orthogonal to your change.

@@ -230,8 +230,7 @@ TEST_F(ScatterDeterminismExpanderTest, ScatterAddHloVerificationTest) {
)";

RunAndFilecheckHloRewrite(kModuleStr, ScatterDeterminismExpander(),
kExpectedPattern, nullptr /*after_pass_checks*/,
nullptr /*config*/);
kExpectedPattern, nullptr, nullptr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for not being clear, I meant not passing the values for after_pass_checks and config at all. They have default values which are nullptr, no need to explicitly pass a default value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand now. Changed as requested

@@ -695,6 +696,7 @@ absl::Status RunOptimizationPasses(
if (RequireDeterminism(hlo_module->config())) {
// Scatter can be indeterministic if indices are not unique or a non
// associative combiner function is used. Eliminate these Scatter ops.
pipeline.AddPass<ScatterDeterminismExpander>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I understand that it would be a bit harder to rewrite your pass based on the different canonicalization used in ScatterSimplifier. I guess there is still some potential to combine these canonicalizations, but it is somewhat orthogonal to your change.

// traverse the tuple output of the computation
for (int i = 0; i < operand_size; ++i) {
const HloInstruction* output = root->operand(i);
std::unordered_set<const HloInstruction*> input_dependencies;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer to use absl::flat_hash_map instead of std::unordered_set because it is faster.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed

}

namespace {
void RecursivelyGetInputDependencies(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, this recursion can have exponential runtime if you don't also keep track of which instructions you have already visited. Currently, you do deduplication of parameters. If you have a visited set instead, you don't need that and can use a vector for dependencies, and also avoid the exponential runtime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. My initial thought was there isn't really complicated scatter computations so I did not bother to optimize this here.
Changed as suggested.

}
}

// Check if the every output of the computation only depends on the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "Check if the every" -> "Check if every"

also "scatter computation" instead of just "computation". Makes it clearer that this function does not process arbitrary computations, but just scatter computations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as requested

if (input_dependencies.size() > 2) {
return false;
}
if (input_dependencies.size() == 2) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if input_dependencies.size() == 1? For example there can be scatter computations that just throw away the initial value, and just use the value from updates.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as requested

return false;
}
if (input_dependencies.size() == 2) {
for (const HloInstruction* input : input_dependencies) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If all we care about are the parameter numbers, maybe also just store the parameter numbers instead of the HloInstruction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as requested

@serach24 serach24 force-pushed the chenhao/opt_det_scatter_scalar branch from 9d8eb96 to 92f6051 Compare October 13, 2024 23:37

#include "xla/service/scatter_determinism_expander.h"
#include <cstdint>
#include "absl/container/flat_hash_set.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will also need a corresponding BUILD dependency "@com_google_absl//absl/container:flat_hash_set"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

HloInstruction* scatter_indices,
const Shape& scatter_shape) {
if (scatter_indices->shape().rank() == 1) {
CHECK(scatter_shape.dimensions_size() == 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: CHECK_EQ instead of CHECK() with '==' op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

LiteralUtil::CreateFromArray(out_of_bound_array)));
}
// More than one dimension in scatter_indices
Array2D<int32_t> out_of_ound_array(scatter_indices->shape().dimensions(0),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: out_of_ound_array -> out_of_bound_array

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed


auto* sorting = parent->AddInstruction(HloInstruction::CreateSort(
ShapeUtil::MakeTupleShape(sort_shapes), 0, sort_operands, comparison,
false /*is_stable*/));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we prefer annotations like /is_stable=/false

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

int64_t num_updates = updates_shape.dimensions(0);

// Calculate the number of iterations needed (log_2(n))
int64_t log_n = static_cast<int64_t>(std::ceil(std::log2(num_updates)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use Log2Ceiling:

xla/xla/util.h

Line 562 in b4abe20

constexpr inline int Log2Ceiling(T x) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

std::vector<int64_t> strides = {1};

for (int64_t iteration = 0; iteration < log_n; ++iteration) {
offset = 1 << iteration;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offset = static_cast<int64_t>(1) << iteration
Unfortunately even if iteration is int64_t, it would still compute the shift with the int32_t value otherwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

int64_t log_n = static_cast<int64_t>(std::ceil(std::log2(num_updates)));

// Placeholder for offset calculation (2^d)
int64_t offset;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see an advantage to declare the variable here, instead of in the loop where it is assigned. The compiler should be able to do this optimization, and it seems easier to read to move the declaration down.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

@cheshire
Copy link
Contributor

This is provided in the evaluation section of the attached doc

Could we duplicate some of it in the commit message? Google Docs don't tend to live for very long: the access can be pulled at any time, whereas the commit message is there forever.

Copy link
Member

@akuegel akuegel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just a few small things left. Also please address the comment from George and copy the benchmark numbers into the PR description.

scatter_indices, sorted_scalar_indices, scatter, parent, num_indices);

// Finally, recreate the scatter instruction with unique indices
auto* new_scatter = parent->AddInstruction(HloInstruction::CreateScatter(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: directly return without assigning to new_scatter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

auto* new_scatter = parent->AddInstruction(HloInstruction::CreateScatter(
scatter->shape(), scatter_operands, last_occurrence_indices,
prefix_scan_updates, scatter->to_apply(), dim_numbers,
true /*indices_are_sorted*/, true /*unique_indices*/));
Copy link
Member

@akuegel akuegel Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: annotate /*indices_are_sorted=*/true and /*unique_indices=*/true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

int64_t scatter_indices_count = ScatterIndicesCount(scatter);
if (!IsInt32(scatter_indices_count)) {
// 2147483647 is the maximum value for a 32-bit signed integer (INT32_MAX).
return Unimplemented(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether it would better to just not match in this case (so moving the check to InstructionMatchesPattern), so that we can still support it (although very slow)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the show version (scatter_expander.cc), they have exactly the same check and will abort so I followed that here. Even if we do not match this here, it will still abort in the scatter_expander.cc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that makes sense, thanks for the explanation.

@serach24 serach24 force-pushed the chenhao/opt_det_scatter_scalar branch from f79da13 to 1a524ef Compare October 15, 2024 17:25
@serach24
Copy link
Contributor Author

I squashed the commits into one and included the microbenchmark results there.
Please take another look

@ezhulenev
Copy link
Member

Linter complains with:

CheckLint found errors.
These lines are out of order.
	xla/service/gpu/BUILD:1420-1656](http://google3/third_party/tensorflow/compiler/xla/service/gpu/BUILD:1420-1656)
Could not find a newline character at the end of the file.  [whitespace/ending_newline] [5]
	xla/service/scatter_utils.h:61](http://google3/third_party/tensorflow/compiler/xla/service/scatter_utils.h?l=61)
Could not find a newline character at the end of the file.  [whitespace/ending_newline] [5]
	xla/service/scatter_determinism_expander.h:44](http://google3/third_party/tensorflow/compiler/xla/service/scatter_determinism_expander.h?l=44)
Could not find a newline character at the end of the file.  [whitespace/ending_newline] [5]
	xla/service/scatter_utils.cc:212](http://google3/third_party/tensorflow/compiler/xla/service/scatter_utils.cc?l=212)

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |
@serach24 serach24 force-pushed the chenhao/opt_det_scatter_scalar branch from 1a524ef to 42cc615 Compare October 16, 2024 18:12
copybara-service bot pushed a commit that referenced this pull request Oct 17, 2024
Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang <[email protected]>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 17, 2024
Imported from GitHub PR openxla/xla#17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at openxla/xla#18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615ed047b28405a0634c42f741a678be605a by Chenhao Jiang <[email protected]>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615ed047b28405a0634c42f741a678be605a
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit that referenced this pull request Oct 17, 2024
Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang <[email protected]>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit that referenced this pull request Oct 17, 2024
Imported from GitHub PR #17886

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in `xla/service/ScatterExpander.cc`. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

The second PR is at #18326

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
42cc615 by Chenhao Jiang <[email protected]>:

Optimize deterministic scalar scatter

Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |

Merging this change closes #17886

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17886 from serach24:chenhao/opt_det_scatter_scalar 42cc615
PiperOrigin-RevId: 686779279
copybara-service bot pushed a commit that referenced this pull request Oct 29, 2024
…r operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 690490783
copybara-service bot pushed a commit that referenced this pull request Oct 29, 2024
…r operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 691023328
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 29, 2024
…r operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328
serach24 added a commit to serach24/xla that referenced this pull request Nov 9, 2024
… scatter operations

Imported from GitHub PR openxla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes openxla#18326

COPYBARA_INTEGRATE_REVIEW=openxla#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 691023328
serach24 added a commit to serach24/xla that referenced this pull request Nov 12, 2024
… scatter operations

Imported from GitHub PR openxla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes openxla#18326

COPYBARA_INTEGRATE_REVIEW=openxla#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 691023328
copybara-service bot pushed a commit that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 13, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696790875
serach24 added a commit to serach24/xla that referenced this pull request Nov 15, 2024
… scatter operations

Imported from GitHub PR openxla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes openxla#18326

COPYBARA_INTEGRATE_REVIEW=openxla#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 691023328
copybara-service bot pushed a commit that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
b016044 by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
fbdb066 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
d36c8ac by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
678886f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
b01604490908fbe43685aed7178d0a66602b7a8c by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
fbdb066fd38a2fadb4322caaabe8c8d1a9fa77e3 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
d36c8ac7260c241c4ca6ed7dc16018f8030c0b80 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
678886f97bd133c4ffa2fbf0365e15c808383a6f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696078761
copybara-service bot pushed a commit that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18326 from serach24:chenhao/opt_det_scatter_full de647d44eb28af71e1580b6e8ed9adc751e50f52
PiperOrigin-RevId: 696790875
copybara-service bot pushed a commit that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696956113
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 15, 2024
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR openxla/xla#19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886, and has fixed issues reported in the reverted PR openxla/xla#18326. The issue was that the changes in openxla/xla#18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a2b95e52654daf83a359d17a809dc3b784 by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR openxla/xla#18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of openxla/xla#17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d44eb28af71e1580b6e8ed9adc751e50f52 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952d6ccd3a4c00e1885923cb0f8ba6db9cf2 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608e3687cda358965d9fb60144362fdba477 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f4257e632e85162b5525cfd4655ddf555d by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

PiperOrigin-RevId: 696956113
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

vmap with scatter_add extremely slow when using xla_gpu_deterministic_ops
4 participants