Skip to content

Commit

Permalink
General offset computation for dynamic slice fusion
Browse files Browse the repository at this point in the history
This patch adds logic for computing general offset while creating dynamic-slice-fusion.
  • Loading branch information
shraiysh committed Sep 6, 2024
1 parent 8e9efb1 commit 470a734
Show file tree
Hide file tree
Showing 4 changed files with 625 additions and 160 deletions.
96 changes: 95 additions & 1 deletion xla/service/gpu/fusions/dynamic_slice_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/gpu/backend_configs.pb.h"
Expand Down Expand Up @@ -2791,7 +2792,6 @@ TEST_F(DynamicSliceFusionTest, CustomCallDUSTuple) {
hlo_config.set_debug_options(debug_options);
TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto(
computation.proto(), hlo_config));

DynamicSliceFusionRewriter pass(PLATFORM);
TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get()));
EXPECT_TRUE(changed);
Expand Down Expand Up @@ -3407,6 +3407,100 @@ TEST_F(DynamicSliceFusionTest, OffsetArrayTestU64) {
".*");
}

TEST_F(DynamicSliceFusionTest, TestWithRewriter) {
const char* hlo = R"(
HloModule test_module, replica_count=2
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}
Body {
param = (s32[], s32[16, 32], s32[8, 32]) parameter(0)
i = s32[] get-tuple-element(param), index=0
dest = s32[16,32] get-tuple-element(param), index=1
src = s32[8,32] get-tuple-element(param), index=2
eight = s32[] constant(8)
zero = s32[] constant(0)
thirty_two = s32[] constant(32)
add = s32[] add(eight, i)
add.2 = s32[] subtract(add, thirty_two)
compare = pred[] compare(add, thirty_two), direction=LT
offset = s32[] select(compare, add, add.2)
rs = s32[4,32] reduce-scatter(src), channel_id=0, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add
fusion = s32[16,32] dynamic-update-slice(dest, rs, offset, zero)
one = s32[] constant(1)
i_plus_one = s32[] add(i, one)
ROOT tuple = tuple(i_plus_one, fusion, src)
}
Cond {
param = (s32[], s32[16,32], s32[8,32]) parameter(0)
loop_iter = s32[] get-tuple-element(param), index=0
four = s32[] constant(32)
ROOT compare = pred[] compare(loop_iter, four), direction=LT
}
ENTRY main {
zero = s32[] constant(0)
dest = s32[16,32] parameter(0)
src = s32[8,32] parameter(1)
tuple = tuple(zero, dest, src)
ROOT while = while(tuple), body=Body, condition=Cond
}
)";

HloModuleConfig config;
DebugOptions dboptions;
dboptions.set_xla_gpu_enable_dynamic_slice_fusion(false);
config.set_debug_options(dboptions);
TF_ASSERT_OK_AND_ASSIGN(auto module0,
ParseAndReturnVerifiedModule(hlo, config));

TF_ASSERT_OK_AND_ASSIGN(auto module_without_fusion,
GetOptimizedModule(std::move(module0)));
dboptions.set_xla_gpu_enable_dynamic_slice_fusion(true);
config.set_debug_options(dboptions);
TF_ASSERT_OK_AND_ASSIGN(auto module1,
ParseAndReturnVerifiedModule(hlo, config));
TF_ASSERT_OK_AND_ASSIGN(auto module_with_fusion,
GetOptimizedModule(std::move(module1)));

ASSERT_EQ(GetDynamicSliceFusions(*module_without_fusion).size(), 0);
auto fusions = GetDynamicSliceFusions(*module_with_fusion);
ASSERT_EQ(fusions.size(), 1);
HloPrintOptions options;
options.set_print_large_constants(true)
.set_print_result_shape(false)
.set_print_operand_shape(false);
TF_ASSERT_OK_AND_ASSIGN(auto filecheck_fusion,
RunFileCheck(fusions[0]->ToString(options),
R"(
// CHECK-DAG: %[[rs:.+]] = reduce-scatter({{.+}})
// CHECK-DAG: %[[offset_vals:.+]] = constant({8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7})
// CHECK-DAG: %[[offset_as_arr:.+]] = dynamic-slice(%[[offset_vals]], {{.+}}), dynamic_slice_sizes={1}
// CHECK-DAG: %[[offset:.+]] = reshape(%[[offset_as_arr]])
// CHECK-DAG: ROOT %{{.+}} = dynamic-update-slice({{.+}}, %[[rs]], %[[offset]], {{.+}})
)"));
ASSERT_TRUE(filecheck_fusion);
TF_ASSERT_OK_AND_ASSIGN(
auto filecheck_while_loop,
RunFileCheck(fusions[0]->FusionInstruction()->parent()->ToString(options),
R"(
// CHECK-DAG: %[[p:.+]] = parameter(0)
// CHECK-DAG: %[[loop_counter:.+]] = get-tuple-element(%[[p]]), index=3
// CHECK-DAG: %[[address_computation:.+]] = fusion({{.+}}, %[[loop_counter]]), kind=kCustom
// CHECK-DAG: %[[updated_loop_counter:.+]] = add(%[[loop_counter]], {{.+}})
// CHECK-DAG: ROOT {{.+}} = tuple({{.+}}, %[[address_computation]], {{.+}}, %[[updated_loop_counter]])
)"));
ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};
EXPECT_TRUE(RunAndCompareTwoModulesReplicated(
std::move(module_without_fusion), std::move(module_with_fusion), false,
true, error_spec));
}

} // namespace
} // namespace gpu
} // namespace xla
3 changes: 3 additions & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1358,13 +1358,16 @@ cc_library(
srcs = ["dynamic_slice_fusion_rewriter.cc"],
hdrs = ["dynamic_slice_fusion_rewriter.h"],
deps = [
"//xla:literal_util",
"//xla:shape_util",
"//xla:util",
"//xla/ffi:ffi_api",
"//xla/hlo/ir:hlo",
"//xla/hlo/evaluator:hlo_evaluator",
"//xla/service:custom_call_target_registry",
"//xla/service:hlo_pass",
"//xla/service:pattern_matcher",
"//xla/service:while_loop_analysis",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:cublas_cudnn",
"//xla/service/gpu:gpu_constants",
Expand Down
Loading

0 comments on commit 470a734

Please sign in to comment.