Skip to content

Commit

Permalink
Optimize deterministic scalar scatter
Browse files Browse the repository at this point in the history
Performance Takeaways:
- Our optimized implementation shows significant speedup, especially with larger index sizes, achieving up to over 9,300x speedup in certain input and index sizes.
- Our implementation has a slight slowdown compared to the non-deterministic scatter. For most cases, we have a slowdown around 1x - 4x. In the worst case with a rare index size setup, we have a slowdown factor of 9.15.

Full Microbenchmark:
| Input Size | Index Size | Non-Det | Original Det | New Det  | Slowdown (vs Non-det) | Speedup (vs Original Det) |
|------------|------------|---------|--------------|----------|-----------------------|---------------------------|
| 10         | 10         | 3.96E-05| 7.82E-05     | 4.66E-05 | 1.18                  | 1.68                      |
| 10         | 100        | 3.72E-05| 4.83E-04     | 9.73E-05 | 2.62                  | 4.96                      |
| 10         | 1000       | 3.92E-05| 4.20E-03     | 6.62E-05 | 1.69                  | 63.50                     |
| 10         | 10000      | 4.36E-05| 4.31E-02     | 1.21E-04 | 2.77                  | 357.37                    |
| 10         | 100000     | 1.06E-04| 4.33E-01     | 1.71E-04 | 1.61                  | 2536.56                   |
| 10         | 1000000    | 4.31E-04| 4.17E+00     | 4.45E-04 | 1.03                  | 9372.37                   |
| 100        | 10         | 4.27E-05| 7.76E-05     | 4.71E-05 | 1.10                  | 1.65                      |
| 100        | 100        | 4.01E-05| 4.91E-04     | 5.61E-05 | 1.40                  | 8.75                      |
| 100        | 1000       | 5.17E-05| 4.21E-03     | 1.10E-04 | 2.13                  | 38.24                     |
| 100        | 10000      | 4.08E-05| 4.27E-02     | 1.05E-04 | 2.57                  | 407.45                    |
| 100        | 100000     | 7.60E-05| 4.14E-01     | 1.69E-04 | 2.22                  | 2455.08                   |
| 100        | 1000000    | 2.86E-04| 4.17E+00     | 4.62E-04 | 1.62                  | 9009.13                   |
| 1000       | 10         | 3.95E-05| 7.85E-05     | 4.97E-05 | 1.26                  | 1.58                      |
| 1000       | 100        | 4.16E-05| 4.85E-04     | 5.27E-05 | 1.27                  | 9.21                      |
| 1000       | 1000       | 3.90E-05| 4.25E-03     | 6.35E-05 | 1.63                  | 66.86                     |
| 1000       | 10000      | 4.08E-05| 4.25E-02     | 1.22E-04 | 3.00                  | 346.99                    |
| 1000       | 100000     | 4.26E-05| 4.15E-01     | 1.92E-04 | 4.51                  | 2161.72                   |
| 1000       | 1000000    | 1.73E-04| 4.26E+00     | 4.75E-04 | 2.74                  | 8964.91                   |
| 10000      | 10         | 4.17E-05| 8.00E-05     | 4.76E-05 | 1.14                  | 1.68                      |
| 10000      | 100        | 3.68E-05| 7.16E-04     | 1.10E-04 | 3.00                  | 6.49                      |
| 10000      | 1000       | 4.13E-05| 4.23E-03     | 1.01E-04 | 2.44                  | 42.12                     |
| 10000      | 10000      | 3.71E-05| 4.23E-02     | 1.44E-04 | 3.89                  | 293.14                    |
| 10000      | 100000     | 9.70E-05| 4.28E-01     | 1.72E-04 | 1.77                  | 2494.21                   |
| 10000      | 1000000    | 1.18E-04| 4.17E+00     | 4.91E-04 | 4.15                  | 8488.57                   |
| 100000     | 10         | 3.73E-05| 7.25E-05     | 4.92E-05 | 1.32                  | 1.47                      |
| 100000     | 100        | 4.09E-05| 4.91E-04     | 6.33E-05 | 1.55                  | 7.76                      |
| 100000     | 1000       | 4.10E-05| 4.25E-03     | 6.40E-05 | 1.56                  | 66.39                     |
| 100000     | 10000      | 3.78E-05| 4.22E-02     | 1.26E-04 | 3.34                  | 334.38                    |
| 100000     | 100000     | 4.42E-05| 4.16E-01     | 1.67E-04 | 3.79                  | 2486.22                   |
| 100000     | 1000000    | 5.37E-05| 4.17E+00     | 4.92E-04 | 9.15                  | 8474.51                   |
| 1000000    | 10         | 3.97E-05| 8.10E-05     | 5.12E-05 | 1.29                  | 1.58                      |
| 1000000    | 100        | 4.56E-05| 4.94E-04     | 6.08E-05 | 1.33                  | 8.13                      |
| 1000000    | 1000       | 4.47E-05| 4.29E-03     | 6.17E-05 | 1.38                  | 69.44                     |
| 1000000    | 10000      | 4.48E-05| 4.27E-02     | 1.18E-04 | 2.63                  | 362.68                    |
| 1000000    | 100000     | 4.25E-05| 4.19E-01     | 1.78E-04 | 4.19                  | 2352.46                   |
| 1000000    | 1000000    | 6.59E-05| 4.18E+00     | 5.01E-04 | 7.60                  | 8334.87                   |
  • Loading branch information
serach24 committed Oct 16, 2024
1 parent f952893 commit 42cc615
Show file tree
Hide file tree
Showing 9 changed files with 1,160 additions and 183 deletions.
53 changes: 53 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2554,6 +2554,18 @@ cc_library(
],
)


cc_library(
name = "scatter_utils",
srcs = ["scatter_utils.cc"],
hdrs = ["scatter_utils.h"],
deps = [
":call_inliner",
":hlo_creation_utils",
],
)


cc_library(
name = "scatter_expander",
srcs = ["scatter_expander.cc"],
Expand All @@ -2562,6 +2574,7 @@ cc_library(
":call_inliner",
":hlo_creation_utils",
":op_expander_pass",
":scatter_utils",
":while_util",
"//xla:literal_util",
"//xla/hlo/ir:hlo",
Expand All @@ -2570,6 +2583,25 @@ cc_library(
],
)


cc_library(
name = "scatter_determinism_expander",
srcs = ["scatter_determinism_expander.cc"],
hdrs = ["scatter_determinism_expander.h"],
deps = [
":call_inliner",
":hlo_creation_utils",
":op_expander_pass",
":scatter_utils",
":while_util",
"//xla:literal_util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "scatter_expander_test",
srcs = ["scatter_expander_test.cc"],
Expand All @@ -2587,6 +2619,27 @@ xla_cc_test(
],
)

xla_test(
name = "scatter_determinism_expander_test",
srcs = ["scatter_determinism_expander_test.cc"],
backends = [
"cpu",
"gpu",
],
deps = [
":scatter_determinism_expander",
"//xla:literal",
"//xla:shape_util",
"//xla:test",
"//xla:types",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_matchers",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
],
)

cc_library(
name = "triangular_solve_expander",
srcs = ["triangular_solve_expander.cc"],
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1605,6 +1605,7 @@ cc_library(
"//xla/service:result_caster",
"//xla/service:rng_bit_generator_expander",
"//xla/service:rng_expander",
"//xla/service:scatter_determinism_expander",
"//xla/service:scatter_expander",
"//xla/service:scatter_simplifier",
"//xla/service:sharding_remover",
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ limitations under the License.
#include "xla/service/rng_bit_generator_expander.h"
#include "xla/service/rng_expander.h"
#include "xla/service/scatter_expander.h"
#include "xla/service/scatter_determinism_expander.h"
#include "xla/service/scatter_simplifier.h"
#include "xla/service/sharding_remover.h"
#include "xla/service/simplify_fp_conversions.h"
Expand Down Expand Up @@ -700,6 +701,7 @@ absl::Status RunOptimizationPasses(
if (RequireDeterminism(hlo_module->config())) {
// Scatter can be indeterministic if indices are not unique or a non
// associative combiner function is used. Eliminate these Scatter ops.
pipeline.AddPass<ScatterDeterminismExpander>();
pipeline.AddPass<ScatterExpander>(
ScatterExpander::kEliminateIndeterministicScatters);
}
Expand Down
Loading

0 comments on commit 42cc615

Please sign in to comment.