Skip to content

Commit

Permalink
General offset computation for dynamic slice fusion
Browse files Browse the repository at this point in the history
This patch adds logic for computing general offset while creating dynamic-slice-fusion.
  • Loading branch information
shraiysh committed Sep 6, 2024
1 parent 8e9efb1 commit 1ca6525
Show file tree
Hide file tree
Showing 3 changed files with 622 additions and 160 deletions.
96 changes: 95 additions & 1 deletion xla/service/gpu/fusions/dynamic_slice_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/gpu/backend_configs.pb.h"
Expand Down Expand Up @@ -2791,7 +2792,6 @@ TEST_F(DynamicSliceFusionTest, CustomCallDUSTuple) {
hlo_config.set_debug_options(debug_options);
TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto(
computation.proto(), hlo_config));

DynamicSliceFusionRewriter pass(PLATFORM);
TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get()));
EXPECT_TRUE(changed);
Expand Down Expand Up @@ -3407,6 +3407,100 @@ TEST_F(DynamicSliceFusionTest, OffsetArrayTestU64) {
".*");
}

TEST_F(DynamicSliceFusionTest, TestWithRewriter) {
const char* hlo = R"(
HloModule test_module, replica_count=2
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}
Body {
param = (s32[], s32[16, 32], s32[8, 32]) parameter(0)
i = s32[] get-tuple-element(param), index=0
dest = s32[16,32] get-tuple-element(param), index=1
src = s32[8,32] get-tuple-element(param), index=2
eight = s32[] constant(8)
zero = s32[] constant(0)
thirty_two = s32[] constant(32)
add = s32[] add(eight, i)
add.2 = s32[] subtract(add, thirty_two)
compare = pred[] compare(add, thirty_two), direction=LT
offset = s32[] select(compare, add, add.2)
rs = s32[4,32] reduce-scatter(src), channel_id=0, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add
fusion = s32[16,32] dynamic-update-slice(dest, rs, offset, zero)
one = s32[] constant(1)
i_plus_one = s32[] add(i, one)
ROOT tuple = tuple(i_plus_one, fusion, src)
}
Cond {
param = (s32[], s32[16,32], s32[8,32]) parameter(0)
loop_iter = s32[] get-tuple-element(param), index=0
four = s32[] constant(32)
ROOT compare = pred[] compare(loop_iter, four), direction=LT
}
ENTRY main {
zero = s32[] constant(0)
dest = s32[16,32] parameter(0)
src = s32[8,32] parameter(1)
tuple = tuple(zero, dest, src)
ROOT while = while(tuple), body=Body, condition=Cond
}
)";

HloModuleConfig config;
DebugOptions dboptions;
dboptions.set_xla_gpu_enable_dynamic_slice_fusion(false);
config.set_debug_options(dboptions);
TF_ASSERT_OK_AND_ASSIGN(auto module0,
ParseAndReturnVerifiedModule(hlo, config));

TF_ASSERT_OK_AND_ASSIGN(auto module_without_fusion,
GetOptimizedModule(std::move(module0)));
dboptions.set_xla_gpu_enable_dynamic_slice_fusion(true);
config.set_debug_options(dboptions);
TF_ASSERT_OK_AND_ASSIGN(auto module1,
ParseAndReturnVerifiedModule(hlo, config));
TF_ASSERT_OK_AND_ASSIGN(auto module_with_fusion,
GetOptimizedModule(std::move(module1)));

ASSERT_EQ(GetDynamicSliceFusions(*module_without_fusion).size(), 0);
auto fusions = GetDynamicSliceFusions(*module_with_fusion);
ASSERT_EQ(fusions.size(), 1);
HloPrintOptions options;
options.set_print_large_constants(true)
.set_print_result_shape(false)
.set_print_operand_shape(false);
TF_ASSERT_OK_AND_ASSIGN(auto filecheck_fusion,
RunFileCheck(fusions[0]->ToString(options),
R"(
// CHECK-DAG: %[[rs:.+]] = reduce-scatter({{.+}})
// CHECK-DAG: %[[offset_vals:.+]] = constant({8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7})
// CHECK-DAG: %[[offset_as_arr:.+]] = dynamic-slice(%[[offset_vals]], {{.+}}), dynamic_slice_sizes={1}
// CHECK-DAG: %[[offset:.+]] = reshape(%[[offset_as_arr]])
// CHECK-DAG: ROOT %{{.+}} = dynamic-update-slice({{.+}}, %[[rs]], %[[offset]], {{.+}})
)"));
ASSERT_TRUE(filecheck_fusion);
TF_ASSERT_OK_AND_ASSIGN(
auto filecheck_while_loop,
RunFileCheck(fusions[0]->FusionInstruction()->parent()->ToString(options),
R"(
// CHECK-DAG: %[[p:.+]] = parameter(0)
// CHECK-DAG: %[[loop_counter:.+]] = get-tuple-element(%[[p]]), index=3
// CHECK-DAG: %[[address_computation:.+]] = fusion({{.+}}, %[[loop_counter]]), kind=kCustom
// CHECK-DAG: %[[updated_loop_counter:.+]] = add(%[[loop_counter]], {{.+}})
// CHECK-DAG: ROOT {{.+}} = tuple({{.+}}, %[[address_computation]], {{.+}}, %[[updated_loop_counter]])
)"));
ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};
EXPECT_TRUE(RunAndCompareTwoModulesReplicated(
std::move(module_without_fusion), std::move(module_with_fusion), false,
true, error_spec));
}

} // namespace
} // namespace gpu
} // namespace xla
Loading

0 comments on commit 1ca6525

Please sign in to comment.