Skip to content

Commit

Permalink
PR #17022: Add slicing -> reduce-scatter for dynamic-slice-fusion
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17022

This patch adds support for slicing operations (`dynamic-slice` and `slice`) on operands of reduce-scatter in dynamic-slice-fusion.
Copybara import of the project:

--
70d09c2 by Shraiysh Vaishay <[email protected]>:

Add slicing -> reduce-scatter for dynamic-slice-fusion

This patch adds support for slicing operations (dynamic-slice and slice) on
operands of reduce-scatter in dynamic-slice-fusion.

The command buffer also supports dynamic-slice-fusion where we only have
static slices. So, these operations can now be a part of the cuda graph.

Merging this change closes #17022

COPYBARA_INTEGRATE_REVIEW=#17022 from shraiysh:reduce_scatter_slicing_dus_fusion 70d09c2
PiperOrigin-RevId: 675871194
  • Loading branch information
shraiysh authored and Google-ML-Automation committed Sep 18, 2024
1 parent 5b0dfe3 commit 409ecf9
Show file tree
Hide file tree
Showing 7 changed files with 574 additions and 160 deletions.
189 changes: 112 additions & 77 deletions xla/service/gpu/fusions/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -780,10 +780,14 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(
IrEmitterContext& ir_emitter_context, const HloFusionAdaptor& adaptor,
const HloFusionInstruction& fusion_instr, const HloInstType* instr,
bool use_global_device_ids) {
if (instr->opcode() != HloOpcode::kReduceScatter) {
return absl::UnimplementedError(
"Dynamic slice fusion with collectives only works for reduce-scatter "
"instruction");
Thunk::Kind collective_done_thunk_kind;
switch (instr->opcode()) {
case HloOpcode::kReduceScatter:
collective_done_thunk_kind = Thunk::kNcclReduceScatterDone;
break;
default:
return absl::InternalError(
"Unexpected operation in dynamic slice fusion");
}

const BufferAssignment& buffer_assignment =
Expand All @@ -800,96 +804,120 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(

// Collect slice information for inputs.
unsigned arg_idx = 0;
TF_ASSIGN_OR_RETURN(arguments.emplace_back(),
GetOperandSlice(buffer_assignment, adaptor, fusion_instr,
*instr->operand(arg_idx), slice_instrs,
/*shape_idx=*/{}, arg_idx));
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice src,
GetOperandSlice(buffer_assignment, adaptor, fusion_instr,
/*start_instr=*/*instr->operand(arg_idx), slice_instrs,
/*shape_idx=*/{}, arg_idx));
arguments.push_back(src);
TF_RETURN_IF_ERROR(CollectSliceInfo(
buffer_assignment, fusion_instr,
absl::Span<HloInstruction*>(slice_instrs), offset_buffer_indices,
orig_shapes, sliced_shapes, offset_byte_sizes, arg_idx++));
/*slice_instrs=*/absl::Span<HloInstruction*>(slice_instrs),
/*offsets=*/offset_buffer_indices, orig_shapes, sliced_shapes,
offset_byte_sizes, arg_idx));
arg_idx++;

// Collect slice information for outputs.
TF_ASSIGN_OR_RETURN(
arguments.emplace_back(),
GetResultSlice(buffer_assignment, adaptor, fusion_instr, *instr,
slice_instrs, /*shape_idx=*/{}, arg_idx));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dst,
GetResultSlice(buffer_assignment, adaptor, fusion_instr,
/*start_instr=*/*instr, slice_instrs,
/*shape_idx=*/{}, arg_idx));
arguments.push_back(dst);
TF_RETURN_IF_ERROR(CollectSliceInfo(
buffer_assignment, fusion_instr,
absl::Span<HloInstruction*>(slice_instrs), offset_buffer_indices,
orig_shapes, sliced_shapes, offset_byte_sizes, arg_idx));

/*slice_instrs=*/absl::Span<HloInstruction*>(slice_instrs),
/*offsets=*/offset_buffer_indices, orig_shapes, sliced_shapes,
offset_byte_sizes, arg_idx));

// Sanity checks.
// 1. Expect atleast one slicing operation.
// 2. Expect atleast one dynamic index operation iff the fusion is a
// dynamic-address-fusion.
if (absl::c_all_of(slice_instrs, [&](HloInstruction* slice_instr) {
return slice_instr &&
slice_instr->opcode() != HloOpcode::kDynamicUpdateSlice;
return slice_instr == nullptr;
})) {
return absl::InternalError(
"DynamicSliceFusion with reduce-scatter expects a dynamic-update-slice "
"operation.");
return absl::InternalError("Expected atleast one slicing operation");
}
bool isDynamic =
absl::c_any_of(slice_instrs, [&](const HloInstruction* slice_instr) {
return DynCastOrNull<HloDynamicIndexInstruction>(slice_instr) !=
nullptr;
});
TF_ASSIGN_OR_RETURN(
auto backend_config,
fusion_instr.backend_config<xla::gpu::GpuBackendConfig>());
const std::string fusion_name =
backend_config.fusion_backend_config().custom_fusion_config().name();
TF_RET_CHECK(isDynamic == (fusion_name == "dynamic_address_computation"))
<< "Dynamic index operation found in a fusion instruction that is not "
"labelled dynamic_address_computation";

// Provide fake allocations for inputs and outputs. The dynamic-slice thunk
// will own these allocations.
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations(2);
unsigned fake_arg_idx = 0;
int64_t operand_byte_size =
ShapeUtil::ByteSizeOf(instr->operand(fake_arg_idx)->shape());
fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
/*index=*/fake_arg_idx, operand_byte_size, /*color=*/0);
BufferAllocation::Slice src(
/*allocation=*/fake_allocations[fake_arg_idx].get(), /*offset=*/0,
/*size=*/operand_byte_size);
fake_arg_idx++;
TF_RET_CHECK(instr->shape().IsArray() &&
"The output is not expected to be a tuple.");
int64_t out_fake_byte_size =
ShapeUtil::ByteSizeOf(instr->shape()); // TODO: we don't need this
fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
/*index=*/fake_arg_idx, /*size*/ out_fake_byte_size, /*color=*/0);
BufferAllocation::Slice dst(
/*allocation=*/fake_allocations[fake_arg_idx].get(),
/*offset=*/0, /*size=*/out_fake_byte_size);

std::vector<NcclCollectiveThunk::Buffer> buffers;
const Shape& src_shape = instr->operand(0)->shape();
const Shape& dst_shape = instr->shape();
buffers.push_back(NcclCollectiveThunk::Buffer{
/*element_count=*/ShapeUtil::ElementsIn(src_shape), /*source_buffer=*/src,
/*destination_buffer=*/dst,
/*source_memory_space=*/src_shape.layout().memory_space(),
/*destination_memory_space=*/dst_shape.layout().memory_space(),
/*source_value=*/nullptr,
/*destination_value=*/nullptr});

ThunkSequence seq;
auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(instr);
int64_t replica_count = instr->GetModule()->config().replica_count();
int64_t partition_count = instr->GetModule()->config().num_partitions();
absl::Status implementable_status =
NcclThunkType::CheckImplementable(instr, replica_count, partition_count);
bool is_degenerate = GetNcclCollectiveConfig(instr, use_global_device_ids)
.IsDegenerate(replica_count, partition_count);
Thunk::ThunkInfo thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(instr);

FusionEmissionResult result;
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations(2);
if (isDynamic) {
// Provide fake allocations for inputs and outputs. The dynamic-slice thunk
// will own these allocations.
unsigned fake_arg_idx = 0;
int64_t operand_byte_size =
ShapeUtil::ByteSizeOf(instr->operand(fake_arg_idx)->shape());
fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
/*index=*/fake_arg_idx, operand_byte_size, /*color=*/0);
src = BufferAllocation::Slice(
/*allocation=*/fake_allocations[fake_arg_idx].get(), /*offset=*/0,
/*size=*/operand_byte_size);
fake_arg_idx++;
TF_RET_CHECK(instr->shape().IsArray() &&
"The output is not expected to be a tuple.");
int64_t out_fake_byte_size =
ShapeUtil::ByteSizeOf(instr->shape()); // TODO: we don't need this
fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
/*index=*/fake_arg_idx, /*size*/ out_fake_byte_size, /*color=*/0);
dst = BufferAllocation::Slice(
/*allocation=*/fake_allocations[fake_arg_idx].get(),
/*offset=*/0, /*size=*/out_fake_byte_size);
}

// First we get the thunk sequence. This decides whether to generate a d2d
// copy thunk or collective thunk.
ThunkSequence seq;
if (is_degenerate) {
// Degenerate collectives are simply identity function. Buffer
// assignment expects a copy, so that's what we do.
for (int64_t i = 0; i < buffers.size(); i++) {
const Shape shape = instr->operand(i)->shape();
TF_RET_CHECK(shape == instr->shape())
<< "Expected operand shape to be equal to result shape, because the "
"collective is degenerate: "
<< shape.ToString() << " vs " << instr->shape().ToString();
seq.emplace_back(std::make_unique<DeviceToDeviceCopyThunk>(
thunk_info,
/*source_buffer=*/buffers[i].source_buffer,
/*destination_buffer=*/buffers[i].destination_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
}
const Shape shape = instr->operand(0)->shape();
TF_RET_CHECK(shape == instr->shape())
<< "Expected operand shape to be equal to result shape, because "
"the "
"collective is degenerate: "
<< shape.ToString() << " vs " << instr->shape().ToString();
seq.emplace_back(std::make_unique<DeviceToDeviceCopyThunk>(
thunk_info,
/*source_buffer=*/src,
/*destination_buffer=*/dst,
/*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
} else if (implementable_status.ok()) {
std::vector<NcclCollectiveThunk::Buffer> buffers;
const Shape& src_shape = instr->operand(0)->shape();
const Shape& dst_shape = instr->shape();
buffers.push_back(NcclCollectiveThunk::Buffer{
/*element_count=*/ShapeUtil::ElementsIn(src_shape),
/*source_buffer=*/src,
/*destination_buffer=*/dst,
/*source_memory_space=*/src_shape.layout().memory_space(),
/*destination_memory_space=*/dst_shape.layout().memory_space(),
/*source_value=*/nullptr,
/*destination_value=*/nullptr});
auto collective_start_thunk = std::make_unique<NcclThunkType>(
thunk_info, NcclApi::Default(), instr, buffers);
auto collective_done_thunk = std::make_unique<NcclCollectiveDoneThunk>(
/*kind=*/Thunk::kNcclReduceScatterDone,
/*kind=*/collective_done_thunk_kind,
/*thunk_info=*/Thunk::ThunkInfo::WithProfileAnnotation(instr),
/*async_events=*/collective_start_thunk->async_events(),
/*async_stream_kind=*/AsyncStreamKind::kCollective);
Expand All @@ -899,13 +927,20 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(
return implementable_status;
}

std::unique_ptr<Thunk> thunk = std::make_unique<DynamicSliceThunk>(
thunk_info, std::make_unique<ThunkSequence>(std::move(seq)),
std::move(arguments), std::move(fake_allocations),
std::move(offset_buffer_indices), std::move(orig_shapes),
std::move(sliced_shapes), std::move(offset_byte_sizes));
FusionEmissionResult result;
result.thunks.push_back(std::move(thunk));
// Depending on whether this is a dynamic fusion or not, we wrap the thunk(s)
// within a dynamic-slice thunk.
if (isDynamic) {
std::unique_ptr<Thunk> thunk = std::make_unique<DynamicSliceThunk>(
thunk_info, std::make_unique<ThunkSequence>(std::move(seq)),
std::move(arguments), std::move(fake_allocations),
std::move(offset_buffer_indices), std::move(orig_shapes),
std::move(sliced_shapes), std::move(offset_byte_sizes));
result.thunks.push_back(std::move(thunk));
} else {
for (auto& thunk : seq) {
result.thunks.push_back(std::move(thunk));
}
}
return result;
}

Expand Down
154 changes: 154 additions & 0 deletions xla/service/gpu/fusions/dynamic_slice_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3496,6 +3496,160 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterDegenerateCollective) {
/*run_hlo_passes=*/false, /*use_threads=*/true, error));
}

TEST_F(DynamicSliceFusionTest, ReduceScatterSlice) {
const char* hlo_ref = R"(
HloModule jit_slice, replica_count=2
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = add(a,b)
}
ENTRY %main.9 {
%p0 = s32[2,8,32]{2,1,0} parameter(0)
%slice = s32[1,8,32]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:32]}
%bc1 = s32[8,32]{1,0} reshape(%slice)
ROOT rs = s32[4,32] reduce-scatter(bc1), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add
}
)";

HloModuleConfig config;
DebugOptions options;
options.set_xla_gpu_enable_dynamic_slice_fusion(false);
options.clear_xla_gpu_enable_command_buffer();
config.set_debug_options(options);
TF_ASSERT_OK_AND_ASSIGN(auto module_ref,
ParseAndReturnVerifiedModule(hlo_ref, config));

options.set_xla_gpu_enable_dynamic_slice_fusion(true);
options.clear_xla_gpu_enable_command_buffer();
config.set_debug_options(options);
TF_ASSERT_OK_AND_ASSIGN(auto module_new,
ParseAndReturnVerifiedModule(hlo_ref, config));

TF_ASSERT_OK_AND_ASSIGN(auto module_ref_opt,
GetOptimizedModule(std::move(module_ref)));
TF_ASSERT_OK_AND_ASSIGN(auto module_new_opt,
GetOptimizedModule(std::move(module_new)));

ASSERT_TRUE(GetDynamicSliceFusions(*module_ref_opt).empty());
ASSERT_FALSE(GetDynamicSliceFusions(*module_new_opt).empty());

auto module_new_opt_clone = module_new_opt->Clone();
TF_ASSERT_OK_AND_ASSIGN(
auto exec, CreateExecutable(std::move(module_new_opt_clone), false));
GpuExecutable* gpu_exec = dynamic_cast<GpuExecutable*>(exec.get());
ASSERT_EQ(gpu_exec->GetThunk().thunks().size(), 2ul);
auto& rs_start_thunk = gpu_exec->GetThunk().thunks()[0];
auto& rs_done_thunk = gpu_exec->GetThunk().thunks()[1];
ASSERT_EQ(rs_start_thunk->kind(), Thunk::kNcclReduceScatterStart);
ASSERT_EQ(rs_done_thunk->kind(), Thunk::kNcclReduceScatterDone);

ErrorSpec error{/*aabs=*/1e-3, /*arel=*/1e-3};
EXPECT_TRUE(RunAndCompareTwoModulesReplicated(std::move(module_ref_opt),
std::move(module_new_opt),
false, true, error));
}

TEST_F(DynamicSliceFusionTest, ReduceScatterDynamicSlice) {
const char* hlo_ref = R"(
HloModule jit_slice, replica_count=2
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = add(a,b)
}
ENTRY %main.9 {
p0 = s32[2,8,32]{2,1,0} parameter(0)
c0 = s32[] constant(0)
c1 = s32[] constant(1)
slice = s32[1,8,32]{2,1,0} dynamic-slice(p0, c1, c0, c0), dynamic_slice_sizes={1,8,32}
bc1 = s32[8,32]{1,0} reshape(slice)
ROOT rs = s32[4,32] reduce-scatter(bc1), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add
})";

HloModuleConfig config;
DebugOptions options;
options.set_xla_gpu_enable_dynamic_slice_fusion(false);
options.clear_xla_gpu_enable_command_buffer();
config.set_debug_options(options);
TF_ASSERT_OK_AND_ASSIGN(auto module_ref,
ParseAndReturnVerifiedModule(hlo_ref, config));

options.set_xla_gpu_enable_dynamic_slice_fusion(true);
options.clear_xla_gpu_enable_command_buffer();
config.set_debug_options(options);
TF_ASSERT_OK_AND_ASSIGN(auto module_new,
ParseAndReturnVerifiedModule(hlo_ref, config));

TF_ASSERT_OK_AND_ASSIGN(auto module_ref_opt,
GetOptimizedModule(std::move(module_ref)));
TF_ASSERT_OK_AND_ASSIGN(auto module_new_opt,
GetOptimizedModule(std::move(module_new)));

ASSERT_TRUE(GetDynamicSliceFusions(*module_ref_opt).empty());
ASSERT_FALSE(GetDynamicSliceFusions(*module_new_opt).empty());

ErrorSpec error{/*aabs=*/1e-3, /*arel=*/1e-3};
EXPECT_TRUE(RunAndCompareTwoModulesReplicated(std::move(module_ref_opt),
std::move(module_new_opt),
false, true, error));
}

TEST_F(DynamicSliceFusionTest, ReduceScatterDegenerateSlice) {
const char* hlo_ref = R"(
HloModule test_module, replica_count=2
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}
ENTRY main {
p0 = s32[2,4,8] parameter(0)
slice = s32[1,4,8] slice(p0), slice={[1:2], [0:4], [0:8]}
bc = s32[4,8] reshape(slice)
ROOT rs = s32[4,8] reduce-scatter(bc), channel_id=64, replica_groups={{0},{1}}, use_global_device_ids=true, dimensions={0}, to_apply=add
}
)";
HloModuleConfig config;
DebugOptions options;
options.set_xla_gpu_enable_dynamic_slice_fusion(false);
options.clear_xla_gpu_enable_command_buffer();
config.set_debug_options(options);
TF_ASSERT_OK_AND_ASSIGN(auto module_ref,
ParseAndReturnVerifiedModule(hlo_ref, config));

options.set_xla_gpu_enable_dynamic_slice_fusion(true);
options.clear_xla_gpu_enable_command_buffer();
config.set_debug_options(options);
TF_ASSERT_OK_AND_ASSIGN(auto module_new,
ParseAndReturnVerifiedModule(hlo_ref, config));

TF_ASSERT_OK_AND_ASSIGN(auto module_ref_opt,
GetOptimizedModule(std::move(module_ref)));
TF_ASSERT_OK_AND_ASSIGN(auto module_new_opt,
GetOptimizedModule(std::move(module_new)));

ASSERT_TRUE(GetDynamicSliceFusions(*module_ref_opt).empty());
ASSERT_FALSE(GetDynamicSliceFusions(*module_new_opt).empty());

auto module_new_opt_clone = module_new_opt->Clone();
TF_ASSERT_OK_AND_ASSIGN(
auto exec, CreateExecutable(std::move(module_new_opt_clone), false));
GpuExecutable* gpu_exec = dynamic_cast<GpuExecutable*>(exec.get());
ASSERT_EQ(gpu_exec->GetThunk().thunks()[0]->kind(), Thunk::kCopy);

ErrorSpec error{/*aabs=*/1e-3, /*arel=*/1e-3};
EXPECT_TRUE(RunAndCompareTwoModulesReplicated(std::move(module_ref_opt),
std::move(module_new_opt),
false, true, error));
}

} // namespace
} // namespace gpu
} // namespace xla
Loading

0 comments on commit 409ecf9

Please sign in to comment.