Skip to content

Commit

Permalink
PR #19275: [NVIDIA] Add fixes for supporting determinism expander for…
Browse files Browse the repository at this point in the history
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
3b7b56a by Chenhao Jiang <[email protected]>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <[email protected]>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
126c952 by Chenhao Jiang <[email protected]>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
1ecb608 by Chenhao Jiang <[email protected]>:

Fix the scatter determinism expander for various dimension numbers

--
985079f by Chenhao Jiang <[email protected]>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696790875
  • Loading branch information
serach24 authored and Google-ML-Automation committed Nov 15, 2024
1 parent 521cd7b commit ab72e9d
Show file tree
Hide file tree
Showing 9 changed files with 1,252 additions and 146 deletions.
19 changes: 19 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_dot_merger_threshold_mb(32);
opts.set_xla_enable_fast_math(false);
opts.set_xla_gpu_experimental_parallel_collective_overlap_limit(1);
opts.set_xla_pjrt_allow_auto_layout_in_hlo(false);
opts.set_xla_gpu_enable_scatter_determinism_expander(true);
return opts;
}

Expand Down Expand Up @@ -2056,6 +2058,23 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_experimental_parallel_collective_overlap_limit(),
"This controls how many in-flight collectives "
"latency hiding scheduler can schedule."));
flag_list->push_back(tsl::Flag(
"xla_pjrt_allow_auto_layout_in_hlo",
bool_setter_for(&DebugOptions::set_xla_pjrt_allow_auto_layout_in_hlo),
debug_options->xla_pjrt_allow_auto_layout_in_hlo(),
"Experimental: Make unset entry computation layout mean auto layout "
"instead of default layout in HLO when run through PjRT. In other cases "
"(StableHLO or non-PjRT) the auto layout is already used."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_scatter_determinism_expander",
bool_setter_for(
&DebugOptions::set_xla_gpu_enable_scatter_determinism_expander),
debug_options->xla_gpu_enable_scatter_determinism_expander(),
"Enable the scatter determinism expander, an optimized pass that "
"rewrites scatter operations to ensure deterministic behavior with high "
"performance."
"Note that even when this flag is disabled, scatter operations may still "
"be deterministic, although with additional overhead."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
1 change: 1 addition & 0 deletions xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ xla_cc_test(
"//xla/pjrt:pjrt_future",
"//xla/pjrt:pjrt_stream_executor_client",
"//xla/pjrt/distributed:in_memory_key_value_store",
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
"//xla/service:gpu_plugin",
"//xla/service:platform_util",
"//xla/stream_executor:device_memory",
Expand Down
33 changes: 33 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ limitations under the License.
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/pjrt/pjrt_stream_executor_client.h"
#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h"
#include "xla/service/platform_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -1753,5 +1754,37 @@ TEST(StreamExecutorGpuClientTest, GetDefaultLayout) {
EXPECT_EQ(layout.element_size_in_bits(), 4);
}

TEST(StreamExecutorGpuClientTest, AutoLayoutIsSupported) {
const char* hlo_text = R"(
HloModule DotLayout,
entry_computation_layout={(f32[2,3,5],f32[3,4,5])->f32[5,2,4]{2,1,0}}
ENTRY dot {
p0 = f32[2,3,5]{2,1,0} parameter(0)
p1 = f32[3,4,5]{2,1,0} parameter(1)
ROOT dot.1330.10585 = f32[5,2,4]{2,1,0} dot(p0, p1),
lhs_batch_dims={2}, lhs_contracting_dims={1},
rhs_batch_dims={2}, rhs_contracting_dims={0}
})";

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> m,
ParseAndReturnUnverifiedModule(
hlo_text, {}, HloParserOptions().set_fill_missing_layouts(false)));

TF_ASSERT_OK_AND_ASSIGN(auto client,
GetStreamExecutorGpuClient(GpuClientOptions()));
CompileOptions compile_options;
compile_options.executable_build_options.mutable_debug_options()
->set_xla_pjrt_allow_auto_layout_in_hlo(true);
XlaComputation computation = m->ToProto();
TF_ASSERT_OK_AND_ASSIGN(auto executable,
client->Compile(computation, compile_options));
TF_ASSERT_OK_AND_ASSIGN(auto layouts, executable->GetParameterLayouts());
// Check that the assigned layouts are not default.
EXPECT_NE(layouts[0]->ToString(), "{2,1,0}");
EXPECT_NE(layouts[1]->ToString(), "{2,1,0}");
}

} // namespace
} // namespace xla
9 changes: 6 additions & 3 deletions xla/pjrt/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -683,11 +683,14 @@ absl::Status DetermineArgumentLayoutsFromCompileOptions(

// Assign a default layout based on `sharded_shape` to any array subshapes in
// `dst_shape` that are missing layouts.
auto assign_layouts = [&choose_compact_layout_for_shape_function](
const Shape& sharded_shape, Shape* dst_shape) {
const bool allow_auto_layout =
build_options && build_options->has_debug_options() &&
build_options->debug_options().xla_pjrt_allow_auto_layout_in_hlo();
auto assign_layouts = [&](const Shape& sharded_shape, Shape* dst_shape) {
return ShapeUtil::ForEachMutableSubshapeWithStatus(
dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
if (subshape->IsArray() && !subshape->has_layout()) {
if (subshape->IsArray() && !subshape->has_layout() &&
!allow_auto_layout) {
CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
const Shape& sharded_subshape =
ShapeUtil::GetSubshape(sharded_shape, idx);
Expand Down
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2121,6 +2121,7 @@ cc_library(
"//xla/hlo/transforms:op_expander_pass",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
],
Expand Down
4 changes: 3 additions & 1 deletion xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,9 @@ absl::Status RunOptimizationPasses(
if (RequireDeterminism(hlo_module->config())) {
// Scatter can be indeterministic if indices are not unique or a non
// associative combiner function is used. Eliminate these Scatter ops.
pipeline.AddPass<ScatterDeterminismExpander>();
if (debug_options.xla_gpu_enable_scatter_determinism_expander()) {
pipeline.AddPass<ScatterDeterminismExpander>();
}
pipeline.AddPass<ScatterExpander>(
ScatterExpander::kEliminateIndeterministicScatters);
}
Expand Down
Loading

0 comments on commit ab72e9d

Please sign in to comment.