Skip to content

Commit

Permalink
PR #16879: General offset computation for dynamic slice fusion
Browse files Browse the repository at this point in the history
Imported from GitHub PR #16879

This patch adds logic for computing general offset while creating dynamic-slice-fusion.
Copybara import of the project:

--
6369888 by Shraiysh Vaishay <[email protected]>:

General offset computation for dynamic slice fusion

This patch adds logic for computing general offset while creating dynamic-slice-fusion.

Merging this change closes #16879

COPYBARA_INTEGRATE_REVIEW=#16879 from shraiysh:general_offset_dynamic_slice_fusion 6369888
PiperOrigin-RevId: 676813991
  • Loading branch information
shraiysh authored and Google-ML-Automation committed Sep 20, 2024
1 parent fe78e6e commit 7b44365
Show file tree
Hide file tree
Showing 4 changed files with 790 additions and 164 deletions.
96 changes: 96 additions & 0 deletions xla/service/gpu/fusions/dynamic_slice_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/gpu/backend_configs.pb.h"
Expand Down Expand Up @@ -3650,6 +3651,101 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterDegenerateSlice) {
false, true, error));
}

TEST_F(DynamicSliceFusionTest, TestWithRewriter) {
const char* hlo = R"(
HloModule test_module, replica_count=2
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}
Body {
param = (s32[], s32[16, 32], s32[8, 32]) parameter(0)
i = s32[] get-tuple-element(param), index=0
dest = s32[16,32] get-tuple-element(param), index=1
src = s32[8,32] get-tuple-element(param), index=2
eight = s32[] constant(8)
zero = s32[] constant(0)
thirty_two = s32[] constant(32)
add = s32[] add(eight, i)
add.2 = s32[] subtract(add, thirty_two)
compare = pred[] compare(add, thirty_two), direction=LT
offset = s32[] select(compare, add, add.2)
rs = s32[4,32] reduce-scatter(src), channel_id=0, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add
fusion = s32[16,32] dynamic-update-slice(dest, rs, offset, zero)
one = s32[] constant(1)
i_plus_one = s32[] add(i, one)
ROOT tuple = tuple(i_plus_one, fusion, src)
}
Cond {
param = (s32[], s32[16,32], s32[8,32]) parameter(0)
loop_iter = s32[] get-tuple-element(param), index=0
c32 = s32[] constant(32)
ROOT compare = pred[] compare(loop_iter, c32), direction=LT
}
ENTRY main {
zero = s32[] constant(0)
dest = s32[16,32] parameter(0)
src = s32[8,32] parameter(1)
tuple = tuple(zero, dest, src)
ROOT while = while(tuple), body=Body, condition=Cond
}
)";

HloModuleConfig config;
DebugOptions dboptions;
dboptions.set_xla_gpu_enable_dynamic_slice_fusion(false);
config.set_debug_options(dboptions);
TF_ASSERT_OK_AND_ASSIGN(auto module0,
ParseAndReturnVerifiedModule(hlo, config));

TF_ASSERT_OK_AND_ASSIGN(auto module_without_fusion,
GetOptimizedModule(std::move(module0)));
dboptions.set_xla_gpu_enable_dynamic_slice_fusion(true);
config.set_debug_options(dboptions);
TF_ASSERT_OK_AND_ASSIGN(auto module1,
ParseAndReturnVerifiedModule(hlo, config));
TF_ASSERT_OK_AND_ASSIGN(auto module_with_fusion,
GetOptimizedModule(std::move(module1)));

ASSERT_EQ(GetDynamicSliceFusions(*module_without_fusion).size(), 0);
auto fusions = GetDynamicSliceFusions(*module_with_fusion);
ASSERT_EQ(fusions.size(), 1);
HloPrintOptions options;
options.set_print_large_constants(true)
.set_print_result_shape(false)
.set_print_operand_shape(false);
TF_ASSERT_OK_AND_ASSIGN(auto filecheck_fusion,
RunFileCheck(fusions[0]->ToString(options),
R"(
// CHECK-DAG: %[[rs:.+]] = reduce-scatter({{.+}})
// CHECK-DAG: %[[offset_vals:.+]] = constant({8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7})
// CHECK-DAG: %[[offset_as_arr:.+]] = dynamic-slice(%[[offset_vals]], {{.+}}), dynamic_slice_sizes={1}
// CHECK-DAG: %[[offset:.+]] = reshape(%[[offset_as_arr]])
// CHECK-DAG: ROOT %{{.+}} = dynamic-update-slice({{.+}}, %[[rs]], %[[offset]], {{.+}})
)"));
EXPECT_TRUE(filecheck_fusion);
TF_ASSERT_OK_AND_ASSIGN(
auto filecheck_while_loop,
RunFileCheck(fusions[0]->FusionInstruction()->parent()->ToString(options),
R"(
// CHECK-DAG: %[[p:.+]] = parameter(0)
// CHECK-DAG: %[[loop_counter:.+]] = get-tuple-element(%[[p]]), index=3
// CHECK-DAG: %[[address_computation:.+]] = fusion({{.+}}, %[[loop_counter]]), kind=kCustom
// CHECK-DAG: %[[updated_loop_counter:.+]] = add(%[[loop_counter]], {{.+}})
// CHECK-DAG: ROOT {{.+}} = tuple({{.+}}, %[[address_computation]], {{.+}}, %[[updated_loop_counter]])
)"));
EXPECT_TRUE(filecheck_while_loop);
ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};
EXPECT_TRUE(RunAndCompareTwoModulesReplicated(
std::move(module_without_fusion), std::move(module_with_fusion), false,
true, error_spec));
}

} // namespace
} // namespace gpu
} // namespace xla
5 changes: 5 additions & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1423,19 +1423,23 @@ cc_library(
hdrs = ["dynamic_slice_fusion_rewriter.h"],
tags = ["gpu"],
deps = [
"//xla:literal_util",
"//xla:shape_util",
"//xla:util",
"//xla/ffi:ffi_api",
"//xla/hlo/evaluator:hlo_evaluator",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/service:custom_call_target_registry",
"//xla/service:pattern_matcher",
"//xla/service:while_loop_analysis",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:cublas_cudnn",
"//xla/service/gpu:gpu_constants",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu/kernels:custom_fusion_library",
"//xla/tools:hlo_extractor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down Expand Up @@ -1473,6 +1477,7 @@ xla_cc_test(
"//xla/service/gpu:gpu_device_info_for_tests",
"//xla/stream_executor",
"//xla/stream_executor/gpu:gpu_types_header",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
Expand Down
Loading

0 comments on commit 7b44365

Please sign in to comment.