diff --git a/xla/service/gpu/fusions/custom.cc b/xla/service/gpu/fusions/custom.cc index b338c78307c9e..c89cdaba4f89d 100644 --- a/xla/service/gpu/fusions/custom.cc +++ b/xla/service/gpu/fusions/custom.cc @@ -780,10 +780,14 @@ absl::StatusOr EmitCollective( IrEmitterContext& ir_emitter_context, const HloFusionAdaptor& adaptor, const HloFusionInstruction& fusion_instr, const HloInstType* instr, bool use_global_device_ids) { - if (instr->opcode() != HloOpcode::kReduceScatter) { - return absl::UnimplementedError( - "Dynamic slice fusion with collectives only works for reduce-scatter " - "instruction"); + Thunk::Kind collective_done_thunk_kind; + switch (instr->opcode()) { + case HloOpcode::kReduceScatter: + collective_done_thunk_kind = Thunk::kNcclReduceScatterDone; + break; + default: + return absl::InternalError( + "Unexpected operation in dynamic slice fusion"); } const BufferAssignment& buffer_assignment = @@ -800,96 +804,120 @@ absl::StatusOr EmitCollective( // Collect slice information for inputs. unsigned arg_idx = 0; - TF_ASSIGN_OR_RETURN(arguments.emplace_back(), - GetOperandSlice(buffer_assignment, adaptor, fusion_instr, - *instr->operand(arg_idx), slice_instrs, - /*shape_idx=*/{}, arg_idx)); + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice src, + GetOperandSlice(buffer_assignment, adaptor, fusion_instr, + /*start_instr=*/*instr->operand(arg_idx), slice_instrs, + /*shape_idx=*/{}, arg_idx)); + arguments.push_back(src); TF_RETURN_IF_ERROR(CollectSliceInfo( buffer_assignment, fusion_instr, - absl::Span(slice_instrs), offset_buffer_indices, - orig_shapes, sliced_shapes, offset_byte_sizes, arg_idx++)); + /*slice_instrs=*/absl::Span(slice_instrs), + /*offsets=*/offset_buffer_indices, orig_shapes, sliced_shapes, + offset_byte_sizes, arg_idx)); + arg_idx++; // Collect slice information for outputs. - TF_ASSIGN_OR_RETURN( - arguments.emplace_back(), - GetResultSlice(buffer_assignment, adaptor, fusion_instr, *instr, - slice_instrs, /*shape_idx=*/{}, arg_idx)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dst, + GetResultSlice(buffer_assignment, adaptor, fusion_instr, + /*start_instr=*/*instr, slice_instrs, + /*shape_idx=*/{}, arg_idx)); + arguments.push_back(dst); TF_RETURN_IF_ERROR(CollectSliceInfo( buffer_assignment, fusion_instr, - absl::Span(slice_instrs), offset_buffer_indices, - orig_shapes, sliced_shapes, offset_byte_sizes, arg_idx)); - + /*slice_instrs=*/absl::Span(slice_instrs), + /*offsets=*/offset_buffer_indices, orig_shapes, sliced_shapes, + offset_byte_sizes, arg_idx)); + + // Sanity checks. + // 1. Expect atleast one slicing operation. + // 2. Expect atleast one dynamic index operation iff the fusion is a + // dynamic-address-fusion. if (absl::c_all_of(slice_instrs, [&](HloInstruction* slice_instr) { - return slice_instr && - slice_instr->opcode() != HloOpcode::kDynamicUpdateSlice; + return slice_instr == nullptr; })) { - return absl::InternalError( - "DynamicSliceFusion with reduce-scatter expects a dynamic-update-slice " - "operation."); + return absl::InternalError("Expected atleast one slicing operation"); } + bool isDynamic = + absl::c_any_of(slice_instrs, [&](const HloInstruction* slice_instr) { + return DynCastOrNull(slice_instr) != + nullptr; + }); + TF_ASSIGN_OR_RETURN( + auto backend_config, + fusion_instr.backend_config()); + const std::string fusion_name = + backend_config.fusion_backend_config().custom_fusion_config().name(); + TF_RET_CHECK(isDynamic == (fusion_name == "dynamic_address_computation")) + << "Dynamic index operation found in a fusion instruction that is not " + "labelled dynamic_address_computation"; - // Provide fake allocations for inputs and outputs. The dynamic-slice thunk - // will own these allocations. - std::vector> fake_allocations(2); - unsigned fake_arg_idx = 0; - int64_t operand_byte_size = - ShapeUtil::ByteSizeOf(instr->operand(fake_arg_idx)->shape()); - fake_allocations[fake_arg_idx] = std::make_unique( - /*index=*/fake_arg_idx, operand_byte_size, /*color=*/0); - BufferAllocation::Slice src( - /*allocation=*/fake_allocations[fake_arg_idx].get(), /*offset=*/0, - /*size=*/operand_byte_size); - fake_arg_idx++; - TF_RET_CHECK(instr->shape().IsArray() && - "The output is not expected to be a tuple."); - int64_t out_fake_byte_size = - ShapeUtil::ByteSizeOf(instr->shape()); // TODO: we don't need this - fake_allocations[fake_arg_idx] = std::make_unique( - /*index=*/fake_arg_idx, /*size*/ out_fake_byte_size, /*color=*/0); - BufferAllocation::Slice dst( - /*allocation=*/fake_allocations[fake_arg_idx].get(), - /*offset=*/0, /*size=*/out_fake_byte_size); - - std::vector buffers; - const Shape& src_shape = instr->operand(0)->shape(); - const Shape& dst_shape = instr->shape(); - buffers.push_back(NcclCollectiveThunk::Buffer{ - /*element_count=*/ShapeUtil::ElementsIn(src_shape), /*source_buffer=*/src, - /*destination_buffer=*/dst, - /*source_memory_space=*/src_shape.layout().memory_space(), - /*destination_memory_space=*/dst_shape.layout().memory_space(), - /*source_value=*/nullptr, - /*destination_value=*/nullptr}); - - ThunkSequence seq; - auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(instr); int64_t replica_count = instr->GetModule()->config().replica_count(); int64_t partition_count = instr->GetModule()->config().num_partitions(); absl::Status implementable_status = NcclThunkType::CheckImplementable(instr, replica_count, partition_count); bool is_degenerate = GetNcclCollectiveConfig(instr, use_global_device_ids) .IsDegenerate(replica_count, partition_count); + Thunk::ThunkInfo thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(instr); + + FusionEmissionResult result; + std::vector> fake_allocations(2); + if (isDynamic) { + // Provide fake allocations for inputs and outputs. The dynamic-slice thunk + // will own these allocations. + unsigned fake_arg_idx = 0; + int64_t operand_byte_size = + ShapeUtil::ByteSizeOf(instr->operand(fake_arg_idx)->shape()); + fake_allocations[fake_arg_idx] = std::make_unique( + /*index=*/fake_arg_idx, operand_byte_size, /*color=*/0); + src = BufferAllocation::Slice( + /*allocation=*/fake_allocations[fake_arg_idx].get(), /*offset=*/0, + /*size=*/operand_byte_size); + fake_arg_idx++; + TF_RET_CHECK(instr->shape().IsArray() && + "The output is not expected to be a tuple."); + int64_t out_fake_byte_size = + ShapeUtil::ByteSizeOf(instr->shape()); // TODO: we don't need this + fake_allocations[fake_arg_idx] = std::make_unique( + /*index=*/fake_arg_idx, /*size*/ out_fake_byte_size, /*color=*/0); + dst = BufferAllocation::Slice( + /*allocation=*/fake_allocations[fake_arg_idx].get(), + /*offset=*/0, /*size=*/out_fake_byte_size); + } + // First we get the thunk sequence. This decides whether to generate a d2d + // copy thunk or collective thunk. + ThunkSequence seq; if (is_degenerate) { // Degenerate collectives are simply identity function. Buffer // assignment expects a copy, so that's what we do. - for (int64_t i = 0; i < buffers.size(); i++) { - const Shape shape = instr->operand(i)->shape(); - TF_RET_CHECK(shape == instr->shape()) - << "Expected operand shape to be equal to result shape, because the " - "collective is degenerate: " - << shape.ToString() << " vs " << instr->shape().ToString(); - seq.emplace_back(std::make_unique( - thunk_info, - /*source_buffer=*/buffers[i].source_buffer, - /*destination_buffer=*/buffers[i].destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(shape))); - } + const Shape shape = instr->operand(0)->shape(); + TF_RET_CHECK(shape == instr->shape()) + << "Expected operand shape to be equal to result shape, because " + "the " + "collective is degenerate: " + << shape.ToString() << " vs " << instr->shape().ToString(); + seq.emplace_back(std::make_unique( + thunk_info, + /*source_buffer=*/src, + /*destination_buffer=*/dst, + /*mem_size=*/ShapeUtil::ByteSizeOf(shape))); } else if (implementable_status.ok()) { + std::vector buffers; + const Shape& src_shape = instr->operand(0)->shape(); + const Shape& dst_shape = instr->shape(); + buffers.push_back(NcclCollectiveThunk::Buffer{ + /*element_count=*/ShapeUtil::ElementsIn(src_shape), + /*source_buffer=*/src, + /*destination_buffer=*/dst, + /*source_memory_space=*/src_shape.layout().memory_space(), + /*destination_memory_space=*/dst_shape.layout().memory_space(), + /*source_value=*/nullptr, + /*destination_value=*/nullptr}); auto collective_start_thunk = std::make_unique( thunk_info, NcclApi::Default(), instr, buffers); auto collective_done_thunk = std::make_unique( - /*kind=*/Thunk::kNcclReduceScatterDone, + /*kind=*/collective_done_thunk_kind, /*thunk_info=*/Thunk::ThunkInfo::WithProfileAnnotation(instr), /*async_events=*/collective_start_thunk->async_events(), /*async_stream_kind=*/AsyncStreamKind::kCollective); @@ -899,13 +927,20 @@ absl::StatusOr EmitCollective( return implementable_status; } - std::unique_ptr thunk = std::make_unique( - thunk_info, std::make_unique(std::move(seq)), - std::move(arguments), std::move(fake_allocations), - std::move(offset_buffer_indices), std::move(orig_shapes), - std::move(sliced_shapes), std::move(offset_byte_sizes)); - FusionEmissionResult result; - result.thunks.push_back(std::move(thunk)); + // Depending on whether this is a dynamic fusion or not, we wrap the thunk(s) + // within a dynamic-slice thunk. + if (isDynamic) { + std::unique_ptr thunk = std::make_unique( + thunk_info, std::make_unique(std::move(seq)), + std::move(arguments), std::move(fake_allocations), + std::move(offset_buffer_indices), std::move(orig_shapes), + std::move(sliced_shapes), std::move(offset_byte_sizes)); + result.thunks.push_back(std::move(thunk)); + } else { + for (auto& thunk : seq) { + result.thunks.push_back(std::move(thunk)); + } + } return result; } diff --git a/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc b/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc index 8a68f785f98d7..212e6b51e5445 100644 --- a/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc +++ b/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc @@ -3496,6 +3496,160 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterDegenerateCollective) { /*run_hlo_passes=*/false, /*use_threads=*/true, error)); } +TEST_F(DynamicSliceFusionTest, ReduceScatterSlice) { + const char* hlo_ref = R"( + HloModule jit_slice, replica_count=2 + + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = add(a,b) + } + + ENTRY %main.9 { + %p0 = s32[2,8,32]{2,1,0} parameter(0) + %slice = s32[1,8,32]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:32]} + %bc1 = s32[8,32]{1,0} reshape(%slice) + ROOT rs = s32[4,32] reduce-scatter(bc1), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + } + )"; + + HloModuleConfig config; + DebugOptions options; + options.set_xla_gpu_enable_dynamic_slice_fusion(false); + options.clear_xla_gpu_enable_command_buffer(); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(auto module_ref, + ParseAndReturnVerifiedModule(hlo_ref, config)); + + options.set_xla_gpu_enable_dynamic_slice_fusion(true); + options.clear_xla_gpu_enable_command_buffer(); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(auto module_new, + ParseAndReturnVerifiedModule(hlo_ref, config)); + + TF_ASSERT_OK_AND_ASSIGN(auto module_ref_opt, + GetOptimizedModule(std::move(module_ref))); + TF_ASSERT_OK_AND_ASSIGN(auto module_new_opt, + GetOptimizedModule(std::move(module_new))); + + ASSERT_TRUE(GetDynamicSliceFusions(*module_ref_opt).empty()); + ASSERT_FALSE(GetDynamicSliceFusions(*module_new_opt).empty()); + + auto module_new_opt_clone = module_new_opt->Clone(); + TF_ASSERT_OK_AND_ASSIGN( + auto exec, CreateExecutable(std::move(module_new_opt_clone), false)); + GpuExecutable* gpu_exec = dynamic_cast(exec.get()); + ASSERT_EQ(gpu_exec->GetThunk().thunks().size(), 2ul); + auto& rs_start_thunk = gpu_exec->GetThunk().thunks()[0]; + auto& rs_done_thunk = gpu_exec->GetThunk().thunks()[1]; + ASSERT_EQ(rs_start_thunk->kind(), Thunk::kNcclReduceScatterStart); + ASSERT_EQ(rs_done_thunk->kind(), Thunk::kNcclReduceScatterDone); + + ErrorSpec error{/*aabs=*/1e-3, /*arel=*/1e-3}; + EXPECT_TRUE(RunAndCompareTwoModulesReplicated(std::move(module_ref_opt), + std::move(module_new_opt), + false, true, error)); +} + +TEST_F(DynamicSliceFusionTest, ReduceScatterDynamicSlice) { + const char* hlo_ref = R"( + HloModule jit_slice, replica_count=2 + + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = add(a,b) + } + + ENTRY %main.9 { + p0 = s32[2,8,32]{2,1,0} parameter(0) + c0 = s32[] constant(0) + c1 = s32[] constant(1) + slice = s32[1,8,32]{2,1,0} dynamic-slice(p0, c1, c0, c0), dynamic_slice_sizes={1,8,32} + bc1 = s32[8,32]{1,0} reshape(slice) + ROOT rs = s32[4,32] reduce-scatter(bc1), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + })"; + + HloModuleConfig config; + DebugOptions options; + options.set_xla_gpu_enable_dynamic_slice_fusion(false); + options.clear_xla_gpu_enable_command_buffer(); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(auto module_ref, + ParseAndReturnVerifiedModule(hlo_ref, config)); + + options.set_xla_gpu_enable_dynamic_slice_fusion(true); + options.clear_xla_gpu_enable_command_buffer(); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(auto module_new, + ParseAndReturnVerifiedModule(hlo_ref, config)); + + TF_ASSERT_OK_AND_ASSIGN(auto module_ref_opt, + GetOptimizedModule(std::move(module_ref))); + TF_ASSERT_OK_AND_ASSIGN(auto module_new_opt, + GetOptimizedModule(std::move(module_new))); + + ASSERT_TRUE(GetDynamicSliceFusions(*module_ref_opt).empty()); + ASSERT_FALSE(GetDynamicSliceFusions(*module_new_opt).empty()); + + ErrorSpec error{/*aabs=*/1e-3, /*arel=*/1e-3}; + EXPECT_TRUE(RunAndCompareTwoModulesReplicated(std::move(module_ref_opt), + std::move(module_new_opt), + false, true, error)); +} + +TEST_F(DynamicSliceFusionTest, ReduceScatterDegenerateSlice) { + const char* hlo_ref = R"( + HloModule test_module, replica_count=2 + + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = s32[] add(a, b) + } + + ENTRY main { + p0 = s32[2,4,8] parameter(0) + slice = s32[1,4,8] slice(p0), slice={[1:2], [0:4], [0:8]} + bc = s32[4,8] reshape(slice) + ROOT rs = s32[4,8] reduce-scatter(bc), channel_id=64, replica_groups={{0},{1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + } + )"; + HloModuleConfig config; + DebugOptions options; + options.set_xla_gpu_enable_dynamic_slice_fusion(false); + options.clear_xla_gpu_enable_command_buffer(); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(auto module_ref, + ParseAndReturnVerifiedModule(hlo_ref, config)); + + options.set_xla_gpu_enable_dynamic_slice_fusion(true); + options.clear_xla_gpu_enable_command_buffer(); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(auto module_new, + ParseAndReturnVerifiedModule(hlo_ref, config)); + + TF_ASSERT_OK_AND_ASSIGN(auto module_ref_opt, + GetOptimizedModule(std::move(module_ref))); + TF_ASSERT_OK_AND_ASSIGN(auto module_new_opt, + GetOptimizedModule(std::move(module_new))); + + ASSERT_TRUE(GetDynamicSliceFusions(*module_ref_opt).empty()); + ASSERT_FALSE(GetDynamicSliceFusions(*module_new_opt).empty()); + + auto module_new_opt_clone = module_new_opt->Clone(); + TF_ASSERT_OK_AND_ASSIGN( + auto exec, CreateExecutable(std::move(module_new_opt_clone), false)); + GpuExecutable* gpu_exec = dynamic_cast(exec.get()); + ASSERT_EQ(gpu_exec->GetThunk().thunks()[0]->kind(), Thunk::kCopy); + + ErrorSpec error{/*aabs=*/1e-3, /*arel=*/1e-3}; + EXPECT_TRUE(RunAndCompareTwoModulesReplicated(std::move(module_ref_opt), + std::move(module_new_opt), + false, true, error)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index 842dffa0028a6..1d94823639f28 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -438,14 +438,19 @@ cc_library( ], ) -xla_cc_test( +xla_test( name = "command_buffer_scheduling_test", srcs = ["command_buffer_scheduling_test.cc"], + backends = [ + "cpu", + "gpu", + ], deps = [ ":command_buffer_scheduling", "//xla/hlo/ir:hlo", "//xla/service:hlo_parser", "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:gpu_executable", "//xla/stream_executor:device_description", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", diff --git a/xla/service/gpu/transforms/command_buffer_scheduling.cc b/xla/service/gpu/transforms/command_buffer_scheduling.cc index a16b908c16330..c772ef052b5a1 100644 --- a/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -87,6 +87,78 @@ static bool IsNoOp(const HloInstruction* hlo) { HloOpcode::kGetTupleElement>(hlo); }; +//===----------------------------------------------------------------------===// +// Asynchronous HLO operations mapped to commands. +//===----------------------------------------------------------------------===// + +// Asynchronous HLO operations can be wrapped into command buffers only when +// both start and done operations can be put into the same command buffer. +// Command buffer semantics implies that when command buffer execution +// completes, all recorded commands are also completed, which means that if +// done operation is not part of the same command buffer, we would change the +// execution semantics and create additional synchronization point. + +static bool IsAsyncStartCommand(const HloInstruction* hlo, + const CommandBufferConfig& config) { + if (hlo->opcode() == HloOpcode::kAllReduceStart || + hlo->opcode() == HloOpcode::kAllGatherStart) { + return config.enabled_commands.contains(DebugOptions::COLLECTIVES); + } + + if (hlo->opcode() == HloOpcode::kAsyncStart) { + if (IsCublasGemm(*hlo->async_wrapped_instruction())) { + return config.enabled_commands.contains(DebugOptions::CUBLAS); + } + if (hlo->async_wrapped_opcode() == HloOpcode::kFusion) { + return config.enabled_commands.contains(DebugOptions::FUSION); + } + if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { + return config.enabled_commands.contains(DebugOptions::COLLECTIVES); + } + } + + if (hlo->opcode() == HloOpcode::kReduceScatter) { + return config.enabled_commands.contains(DebugOptions::COLLECTIVES); + } + + return false; +} + +static bool IsAsyncDoneCommand(const HloInstruction* hlo, + const CommandBufferConfig& config) { + if (hlo->opcode() == HloOpcode::kAllReduceDone || + hlo->opcode() == HloOpcode::kAllGatherDone) { + return config.enabled_commands.contains(DebugOptions::COLLECTIVES); + } + + if (hlo->opcode() == HloOpcode::kAsyncDone) { + if (IsCublasGemm(*hlo->async_wrapped_instruction())) { + return config.enabled_commands.contains(DebugOptions::CUBLAS); + } + if (hlo->async_wrapped_opcode() == HloOpcode::kFusion) { + return config.enabled_commands.contains(DebugOptions::FUSION); + } + if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { + return config.enabled_commands.contains(DebugOptions::COLLECTIVES); + } + } + + return false; +} + +// Finds an async-done HLO operation corresponding on an async-start one. +static HloInstruction* FindAsyncDoneCommand(const HloInstruction* start) { + if (start->opcode() == HloOpcode::kAllReduceStart || + start->opcode() == HloOpcode::kAllGatherStart) { + CHECK(start->users().size() == 1); // NOLINT, checked by HLO verifier + return start->users().front(); + } else if (start->opcode() == HloOpcode::kAsyncStart) { + return start->async_chain_done(); + } + + return nullptr; +} + //===----------------------------------------------------------------------===// // Synchronous HLO operations mapped to commands. //===----------------------------------------------------------------------===// @@ -173,12 +245,13 @@ static bool IsCommand(const HloInstruction* hlo, auto fusion_analysis = HloFusionAnalysis::Create(*hlo, config.device_description); const HloFusionAdaptor& adaptor = fusion_analysis.fusion(); - auto custom_call_adaptor = HloBfsFindIf( - adaptor.GetRoots(), adaptor, - [](auto node) { return node.opcode() == HloOpcode::kCustomCall; }); - const auto* custom_call = static_cast( - &custom_call_adaptor->instruction()); - return IsCommand(custom_call, config); + auto hero_adaptor = + HloBfsFindIf(adaptor.GetRoots(), adaptor, [](auto node) { + return node.opcode() == HloOpcode::kCustomCall || + node.opcode() == HloOpcode::kReduceScatter; + }); + const HloInstruction* hero = &hero_adaptor->instruction(); + return IsCommand(hero, config) || IsAsyncStartCommand(hero, config); } if (custom_config.name() == "dynamic_address_computation") { return false; @@ -206,74 +279,6 @@ static bool IsCommand(const HloInstruction* hlo, return false; } -//===----------------------------------------------------------------------===// -// Asynchronous HLO operations mapped to commands. -//===----------------------------------------------------------------------===// - -// Asynchronous HLO operations can be wrapped into command buffers only when -// both start and done operations can be put into the same command buffer. -// Command buffer semantics implies that when command buffer execution -// completes, all recorded commands are also completed, which means that if -// done operation is not part of the same command buffer, we would change the -// execution semantics and create additional synchronization point. - -static bool IsAsyncStartCommand(const HloInstruction* hlo, - const CommandBufferConfig& config) { - if (hlo->opcode() == HloOpcode::kAllReduceStart || - hlo->opcode() == HloOpcode::kAllGatherStart) { - return config.enabled_commands.contains(DebugOptions::COLLECTIVES); - } - - if (hlo->opcode() == HloOpcode::kAsyncStart) { - if (IsCublasGemm(*hlo->async_wrapped_instruction())) { - return config.enabled_commands.contains(DebugOptions::CUBLAS); - } - if (hlo->async_wrapped_opcode() == HloOpcode::kFusion) { - return config.enabled_commands.contains(DebugOptions::FUSION); - } - if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { - return config.enabled_commands.contains(DebugOptions::COLLECTIVES); - } - } - - return false; -} - -static bool IsAsyncDoneCommand(const HloInstruction* hlo, - const CommandBufferConfig& config) { - if (hlo->opcode() == HloOpcode::kAllReduceDone || - hlo->opcode() == HloOpcode::kAllGatherDone) { - return config.enabled_commands.contains(DebugOptions::COLLECTIVES); - } - - if (hlo->opcode() == HloOpcode::kAsyncDone) { - if (IsCublasGemm(*hlo->async_wrapped_instruction())) { - return config.enabled_commands.contains(DebugOptions::CUBLAS); - } - if (hlo->async_wrapped_opcode() == HloOpcode::kFusion) { - return config.enabled_commands.contains(DebugOptions::FUSION); - } - if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { - return config.enabled_commands.contains(DebugOptions::COLLECTIVES); - } - } - - return false; -} - -// Finds an async-done HLO operation corresponding on an async-start one. -static HloInstruction* FindAsyncDoneCommand(const HloInstruction* start) { - if (start->opcode() == HloOpcode::kAllReduceStart || - start->opcode() == HloOpcode::kAllGatherStart) { - CHECK(start->users().size() == 1); // NOLINT, checked by HLO verifier - return start->users().front(); - } else if (start->opcode() == HloOpcode::kAsyncStart) { - return start->async_chain_done(); - } - - return nullptr; -} - //===----------------------------------------------------------------------===// // HLO computations mapped to command buffers. //===----------------------------------------------------------------------===// diff --git a/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index 3bffa6eaa621e..6c79316a75a51 100644 --- a/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/gpu_executable.h" #include "xla/service/hlo_parser.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/filecheck.h" @@ -1071,5 +1072,161 @@ TEST_F(CommandBufferSchedulingTest, AsyncFusion) { }); } +TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionDynamicSlicing) { + if (backend().platform()->Name() == "Host") { + GTEST_SKIP() << "GPU support required for this test"; + } + const char* hlo = R"( + HloModule jit_slice, replica_count=2 + + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = add(a,b) + } + + ENTRY main.9 { + p0 = s32[2,8,32]{2,1,0} parameter(0) + p1 = s32[8,32]{1,0} parameter(1) + c0 = s32[] constant(0) + c1 = s32[] constant(1) + slice = s32[1,8,32]{2,1,0} dynamic-slice(p0, c1, c0, c0), dynamic_slice_sizes={1,8,32} + input = s32[8,32]{1,0} reshape(slice) + rs = s32[4,32] reduce-scatter(input), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + ROOT dus = s32[8,32] dynamic-update-slice(p1, rs, c0, c0) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, GetOptimizedModule(hlo)); + + HloModuleConfig config(m->config()); + DebugOptions options(config.debug_options()); + options.set_xla_gpu_graph_min_graph_size(0); + + auto check = [&m, this](DebugOptions options) -> absl::Status { + auto m_clone = m->Clone(); + HloModuleConfig config(m_clone->config()); + config.set_debug_options(options); + m_clone->set_config(config); + TF_ASSIGN_OR_RETURN(auto exec, CreateExecutable(std::move(m_clone), false)); + auto gpu_exec = std::unique_ptr( + static_cast(exec.release())); + TF_RET_CHECK(llvm::any_of(gpu_exec->GetThunk().thunks(), + [](const std::unique_ptr& thunk) { + return thunk->kind() == Thunk::kDynamicSlice; + })); + return absl::OkStatus(); + }; + + // With dynamic slicing, no matter what, there should be no command buffer. + // Case 1: FUSION on, COLLECTIVES on + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); + options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + TF_ASSERT_OK(check(options)); + + // Case 2: FUSION off, COLLECTIVES off + options.clear_xla_gpu_enable_command_buffer(); + TF_ASSERT_OK(check(options)); + + // Case 3: FUSION off, COLLECTIVES on + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + TF_ASSERT_OK(check(options)); + + // Case 4: FUSION on, COLLECTIVES off + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); + TF_ASSERT_OK(check(options)); +} + +TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionStaticSlicing) { + if (backend().platform()->Name() == "Host" || backend().device_count() < 2) { + GTEST_SKIP() << "Atleast two GPUs required for this test"; + } + const char* hlo = R"( + HloModule jit_slice, replica_count=2 + + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = add(a,b) + } + + ENTRY main.9 { + p0 = s32[2,8,32]{2,1,0} parameter(0) + p1 = s32[8,32]{1,0} parameter(1) + c0 = s32[] constant(0) + c1 = s32[] constant(1) + slice = s32[1,8,32]{2,1,0} slice(p0), slice={[1:2], [0:8], [0:32]} + input = s32[8,32]{1,0} reshape(slice) + ROOT rs = s32[4,32] reduce-scatter(input), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, GetOptimizedModule(hlo)); + + HloModuleConfig config(m->config()); + DebugOptions options(config.debug_options()); + + options.set_xla_gpu_graph_min_graph_size(0); + + auto get_exec = [&m, this](DebugOptions options) + -> absl::StatusOr> { + auto m_clone = m->Clone(); + HloModuleConfig config(m_clone->config()); + config.set_debug_options(options); + m_clone->set_config(config); + TF_ASSIGN_OR_RETURN(auto exec, CreateExecutable(std::move(m_clone), false)); + return std::unique_ptr( + static_cast(exec.release())); + }; + + // FUSION on, COLLECTIVES on -> command buffer + { + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); + options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_exec, get_exec(options)); + Thunk* child = gpu_exec->GetThunk().thunks()[0].get(); + ASSERT_EQ(child->kind(), Thunk::kCommandBuffer); + } + + // FUSION off, COLLECTIVES off -> no command buffer because collective hero. + { + options.clear_xla_gpu_enable_command_buffer(); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_exec, get_exec(options)); + Thunk* child = gpu_exec->GetThunk().thunks()[0].get(); + ASSERT_NE(child->kind(), Thunk::kCommandBuffer); + } + + // FUSION off, COLLECTIVES on -> command buffer because static slices. + { + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_exec, get_exec(options)); + Thunk* child = gpu_exec->GetThunk().thunks()[0].get(); + ASSERT_EQ(child->kind(), Thunk::kCommandBuffer); + } + + // FUSION on, COLLECTIVES off -> no command buffer because collective hero. + { + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); + TF_ASSERT_OK_AND_ASSIGN(auto gpu_exec, get_exec(options)); + Thunk* child = gpu_exec->GetThunk().thunks()[0].get(); + ASSERT_NE(child->kind(), Thunk::kCommandBuffer); + } + + // Finally compare with/without command buffer. + options.clear_xla_gpu_enable_command_buffer(); + auto m_ref = m->Clone(); + config.set_debug_options(options); + m_ref->set_config(config); + + config.set_debug_options(GetDebugOptionsForTest()); + m->set_config(config); + ASSERT_TRUE(RunAndCompareTwoModulesReplicated(std::move(m_ref), std::move(m), + false, true, std::nullopt)); +} + } // namespace } // namespace xla::gpu diff --git a/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc b/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc index 7ffae3fd341e7..8ea3bc3801062 100644 --- a/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc +++ b/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc @@ -251,8 +251,12 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { // the matched instructions that we have seen so far. InstructionSet processed_instrs; - const auto& aliasing_pairs = - Cast(instr)->output_to_operand_aliasing(); + std::vector>> + aliasing_pairs; + if (instr->opcode() == HloOpcode::kCustomCall) { + aliasing_pairs = + Cast(instr)->output_to_operand_aliasing(); + } absl::flat_hash_set aliased_operands; for (const auto& pair : aliasing_pairs) { aliased_operands.insert(pair.second.first); @@ -519,15 +523,11 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( for (HloComputation* computation : module->computations()) { if (computation->IsFusionComputation()) continue; for (HloInstruction* instr : computation->instructions()) { - UseDefDataflowPaths sliced_operand_paths = {instr}; - bool has_sliced_operand_paths = false; - if (IsLegacyCublasMatmul(*instr) || IsCustomCall(instr, platform_name_)) { - sliced_operand_paths = GetSlicedOperandPaths(instr); - has_sliced_operand_paths = sliced_operand_paths.size() > 1; - } if ((instr->opcode() == HloOpcode::kReduceScatter && instr->shape().IsArray()) || IsLegacyCublasMatmul(*instr) || IsCustomCall(instr, platform_name_)) { + UseDefDataflowPaths sliced_operand_paths = GetSlicedOperandPaths(instr); + bool has_sliced_operand_paths = sliced_operand_paths.size() > 1; DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr); bool has_sliced_user_paths = absl::c_any_of( sliced_user_paths, diff --git a/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc b/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc index aee3a5db0712a..9a71c9930adc7 100644 --- a/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc +++ b/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc @@ -2086,4 +2086,62 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSReduceScatterTupleNoTransform) { std::nullopt); } +TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterSlice) { + const char* hlo = R"( + HloModule jit_slice, replica_count=2 + + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = add(a,b) + } + + ENTRY %main.9 { + p0 = s32[2,8,32]{2,1,0} parameter(0) + slice = s32[1,8,32]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:32]} + bc = s32[8,32]{1,0} bitcast(%slice) + ROOT rs = s32[4,32] reduce-scatter(bc), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + })"; + const char* expected = R"( + // CHECK: dynamic-slice-fusion{{.*}} { + // CHECK: %[[p0:.+]] = {{.+}} parameter(0) + // CHECK: %[[slice:.+]] = {{.+}} slice(%[[p0]]), slice={[1:2], [0:8], [0:32]} + // CHECK: %[[bc:.+]] = {{.+}} bitcast(%[[slice]]) + // CHECK: ROOT {{.+}} = {{.+}} reduce-scatter(%[[bc]]) + // CHECK: } + )"; + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); +} + +TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDynamicSlice) { + const char* hlo = R"( + HloModule jit_slice, replica_count=2 + + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = add(a,b) + } + + ENTRY %main.9 { + p0 = s32[2,8,32]{2,1,0} parameter(0) + c0 = s32[] constant(0) + c1 = s32[] constant(1) + slice = s32[1,8,32]{2,1,0} dynamic-slice(p0, c1, c0, c0), dynamic_slice_sizes={1,8,32} + bc = s32[8,32]{1,0} bitcast(%slice) + ROOT rs = s32[4,32] reduce-scatter(bc), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + })"; + const char* expected = R"( + // CHECK: add + // CHECK: dynamic-slice-fusion{{.*}} { + // CHECK: %[[p0:.+]] = {{.+}} parameter(0) + // CHECK: %[[slice:.+]] = {{.+}} dynamic-slice(%[[p0]], {{.+}}), dynamic_slice_sizes={1,8,32} + // CHECK: %[[bc:.+]] = {{.+}} bitcast(%[[slice]]) + // CHECK: ROOT {{.+}} = {{.+}} reduce-scatter(%[[bc]]) + // CHECK: } + // CHECK: ENTRY + )"; + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); +} + } // namespace xla::gpu