Skip to content

Commit

Permalink
General offset computation for dynamic slice fusion
Browse files Browse the repository at this point in the history
This patch adds logic for computing general offset while creating dynamic-slice-fusion.
  • Loading branch information
shraiysh committed Sep 18, 2024
1 parent a1299f8 commit 2d5f03a
Show file tree
Hide file tree
Showing 4 changed files with 790 additions and 164 deletions.
96 changes: 96 additions & 0 deletions xla/service/gpu/fusions/dynamic_slice_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/gpu/backend_configs.pb.h"
Expand Down Expand Up @@ -3650,6 +3651,101 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterDegenerateSlice) {
false, true, error));
}

TEST_F(DynamicSliceFusionTest, TestWithRewriter) {
const char* hlo = R"(
HloModule test_module, replica_count=2
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}
Body {
param = (s32[], s32[16, 32], s32[8, 32]) parameter(0)
i = s32[] get-tuple-element(param), index=0
dest = s32[16,32] get-tuple-element(param), index=1
src = s32[8,32] get-tuple-element(param), index=2
eight = s32[] constant(8)
zero = s32[] constant(0)
thirty_two = s32[] constant(32)
add = s32[] add(eight, i)
add.2 = s32[] subtract(add, thirty_two)
compare = pred[] compare(add, thirty_two), direction=LT
offset = s32[] select(compare, add, add.2)
rs = s32[4,32] reduce-scatter(src), channel_id=0, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add
fusion = s32[16,32] dynamic-update-slice(dest, rs, offset, zero)
one = s32[] constant(1)
i_plus_one = s32[] add(i, one)
ROOT tuple = tuple(i_plus_one, fusion, src)
}
Cond {
param = (s32[], s32[16,32], s32[8,32]) parameter(0)
loop_iter = s32[] get-tuple-element(param), index=0
c32 = s32[] constant(32)
ROOT compare = pred[] compare(loop_iter, c32), direction=LT
}
ENTRY main {
zero = s32[] constant(0)
dest = s32[16,32] parameter(0)
src = s32[8,32] parameter(1)
tuple = tuple(zero, dest, src)
ROOT while = while(tuple), body=Body, condition=Cond
}
)";

HloModuleConfig config;
DebugOptions dboptions;
dboptions.set_xla_gpu_enable_dynamic_slice_fusion(false);
config.set_debug_options(dboptions);
TF_ASSERT_OK_AND_ASSIGN(auto module0,
ParseAndReturnVerifiedModule(hlo, config));

TF_ASSERT_OK_AND_ASSIGN(auto module_without_fusion,
GetOptimizedModule(std::move(module0)));
dboptions.set_xla_gpu_enable_dynamic_slice_fusion(true);
config.set_debug_options(dboptions);
TF_ASSERT_OK_AND_ASSIGN(auto module1,
ParseAndReturnVerifiedModule(hlo, config));
TF_ASSERT_OK_AND_ASSIGN(auto module_with_fusion,
GetOptimizedModule(std::move(module1)));

ASSERT_EQ(GetDynamicSliceFusions(*module_without_fusion).size(), 0);
auto fusions = GetDynamicSliceFusions(*module_with_fusion);
ASSERT_EQ(fusions.size(), 1);
HloPrintOptions options;
options.set_print_large_constants(true)
.set_print_result_shape(false)
.set_print_operand_shape(false);
TF_ASSERT_OK_AND_ASSIGN(auto filecheck_fusion,
RunFileCheck(fusions[0]->ToString(options),
R"(
// CHECK-DAG: %[[rs:.+]] = reduce-scatter({{.+}})
// CHECK-DAG: %[[offset_vals:.+]] = constant({8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7})
// CHECK-DAG: %[[offset_as_arr:.+]] = dynamic-slice(%[[offset_vals]], {{.+}}), dynamic_slice_sizes={1}
// CHECK-DAG: %[[offset:.+]] = reshape(%[[offset_as_arr]])
// CHECK-DAG: ROOT %{{.+}} = dynamic-update-slice({{.+}}, %[[rs]], %[[offset]], {{.+}})
)"));
EXPECT_TRUE(filecheck_fusion);
TF_ASSERT_OK_AND_ASSIGN(
auto filecheck_while_loop,
RunFileCheck(fusions[0]->FusionInstruction()->parent()->ToString(options),
R"(
// CHECK-DAG: %[[p:.+]] = parameter(0)
// CHECK-DAG: %[[loop_counter:.+]] = get-tuple-element(%[[p]]), index=3
// CHECK-DAG: %[[address_computation:.+]] = fusion({{.+}}, %[[loop_counter]]), kind=kCustom
// CHECK-DAG: %[[updated_loop_counter:.+]] = add(%[[loop_counter]], {{.+}})
// CHECK-DAG: ROOT {{.+}} = tuple({{.+}}, %[[address_computation]], {{.+}}, %[[updated_loop_counter]])
)"));
EXPECT_TRUE(filecheck_while_loop);
ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};
EXPECT_TRUE(RunAndCompareTwoModulesReplicated(
std::move(module_without_fusion), std::move(module_with_fusion), false,
true, error_spec));
}

} // namespace
} // namespace gpu
} // namespace xla
5 changes: 5 additions & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1388,19 +1388,23 @@ cc_library(
hdrs = ["dynamic_slice_fusion_rewriter.h"],
tags = ["gpu"],
deps = [
"//xla:literal_util",
"//xla:shape_util",
"//xla:util",
"//xla/ffi:ffi_api",
"//xla/hlo/ir:hlo",
"//xla/hlo/evaluator:hlo_evaluator",
"//xla/hlo/pass:hlo_pass",
"//xla/service:custom_call_target_registry",
"//xla/service:pattern_matcher",
"//xla/service:while_loop_analysis",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:cublas_cudnn",
"//xla/service/gpu:gpu_constants",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu/kernels:custom_fusion_library",
"//xla/tools:hlo_extractor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down Expand Up @@ -1439,6 +1443,7 @@ xla_cc_test(
"//xla/stream_executor",
"//xla/stream_executor/gpu:gpu_types_header",
"//xla/tests:hlo_test_base",
"//xla/tests:filecheck",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@tsl//tsl/platform:status",
Expand Down
Loading

0 comments on commit 2d5f03a

Please sign in to comment.