Skip to content

Commit

Permalink
PR #16849: [XLA:GPU] Fix dropout state in cuDNN FMHA
Browse files Browse the repository at this point in the history
Imported from GitHub PR #16849

* Fix issue that dropout state is not set after switching to cuDNN thunk in commit: 0aa816e.
* Maintain a separate offset copy for each device as thunk can be executed by multiple threads concurrently.
* Add an unit test for the fmha dropout lowering. As it is hard to implement similar philox rng in HLOs, it only guarantees the the fmha custom call with dropout can be lowered correctly.
Copybara import of the project:

--
6e4753d by cjkkkk <[email protected]>:

init

--
6fea371 by cjkkkk <[email protected]>:

fix rocm

--
28dbac1 by cjkkkk <[email protected]>:

address comments

Merging this change closes #16849

COPYBARA_INTEGRATE_REVIEW=#16849 from Cjkkkk:fix_fmha_dropout 28dbac1
PiperOrigin-RevId: 676319582
  • Loading branch information
Cjkkkk authored and Google-ML-Automation committed Sep 19, 2024
1 parent a89f900 commit 42b04a6
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 26 deletions.
9 changes: 8 additions & 1 deletion xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -961,9 +961,16 @@ absl::Status IrEmitterUnnested::EmitCuDnnThunk(
instr->operands()));
TF_ASSIGN_OR_RETURN(const std::string fingerprint,
FingerprintWithBackendConfig<GpuBackendConfig>(*instr));
// check if sdpa dropout is enabled
std::optional<int64_t> dropout_seed = std::nullopt;
if (MHACallHasDropout(instr->custom_call_target())) {
TF_ASSIGN_OR_RETURN(const auto gpu_config,
instr->backend_config<xla::gpu::GpuBackendConfig>());
dropout_seed = gpu_config.cudnn_fmha_backend_config().seed();
}
AddThunkToThunkSequence(std::make_unique<CuDnnThunk>(
fingerprint, Thunk::ThunkInfo::WithProfileAnnotation(instr),
kernel_arguments.args()));
kernel_arguments.args(), dropout_seed));
return absl::OkStatus();
}

Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/runtime/command_buffer_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,8 @@ absl::Status CuDnnCmd::Record(const Thunk::ExecuteParams& execute_params,
return AddTracedCommandBuffer(
execute_params, record_params, command_buffer, [&](se::Stream* stream) {
return graph_->get()->Execute(
*stream, absl::Span<se::DeviceMemoryBase>(operands));
*stream, absl::Span<se::DeviceMemoryBase>(operands),
execute_params.collective_params->local_device_ordinal);
});
}

Expand Down
13 changes: 10 additions & 3 deletions xla/service/gpu/runtime/cudnn_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ namespace xla {
namespace gpu {

CuDnnThunk::CuDnnThunk(std::string fingerprint, ThunkInfo thunk_info,
absl::Span<const KernelArgument> kernel_arguments)
absl::Span<const KernelArgument> kernel_arguments,
std::optional<int64_t> sdpa_dropout_seed)
: Thunk(Kind::kCuDnn, std::move(thunk_info)),
fingerprint_(std::move(fingerprint)),
graph_(std::make_shared<se::dnn::LazyDnnGraph>(nullptr)) {
graph_(std::make_shared<se::dnn::LazyDnnGraph>(nullptr)),
sdpa_dropout_seed_(sdpa_dropout_seed) {
args_.reserve(kernel_arguments.size());
for (const KernelArgument& kernel_argument : kernel_arguments) {
args_.push_back(kernel_argument.slice());
Expand All @@ -52,6 +54,10 @@ absl::Status CuDnnThunk::Initialize(const InitializeParams& params) {
std::string().swap(fingerprint_);
if (result.ok()) {
graph_->swap(*result);
if (sdpa_dropout_seed_.has_value()) {
graph_->get()->InitDropoutState(params.local_device_count,
*sdpa_dropout_seed_, 16);
}
}
ret = result.status();
});
Expand All @@ -68,7 +74,8 @@ absl::Status CuDnnThunk::ExecuteOnStream(const ExecuteParams& params) {
buffer_args.push_back(params.buffer_allocations->GetDeviceAddress(arg));
}
return graph_->get()->Execute(*params.stream,
absl::Span<se::DeviceMemoryBase>(buffer_args));
absl::Span<se::DeviceMemoryBase>(buffer_args),
params.collective_params->local_device_ordinal);
}

} // namespace gpu
Expand Down
5 changes: 4 additions & 1 deletion xla/service/gpu/runtime/cudnn_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ namespace gpu {
class CuDnnThunk : public Thunk {
public:
CuDnnThunk(std::string fingerprint, ThunkInfo,
absl::Span<const KernelArgument>);
absl::Span<const KernelArgument>,
std::optional<int64_t> sdpa_dropout_seed = std::nullopt);
CuDnnThunk(const CuDnnThunk&) = delete;
CuDnnThunk& operator=(const CuDnnThunk&) = delete;
~CuDnnThunk() override = default;
Expand All @@ -53,6 +54,8 @@ class CuDnnThunk : public Thunk {
std::string fingerprint_;
std::shared_ptr<se::dnn::LazyDnnGraph> graph_;
std::vector<BufferAllocation::Slice> args_;
// Sdpa dropout seed
std::optional<int64_t> sdpa_dropout_seed_;
};

} // namespace gpu
Expand Down
63 changes: 63 additions & 0 deletions xla/service/gpu/tests/gpu_fused_mha_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,63 @@ class FlashAttentionBMMScalePaddingMaskSoftmaxBMMF8
}
};

class FlashAttentionBMMScaleSoftmaxDropoutBMM
: public MultiHeadedAttentionTest {
protected:
static constexpr absl::string_view
kModuleFlashAttentionTrainingBMM1SoftmaxDropoutBMM2HloStringBF16 = R"(
HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,1024,4,64]{3,2,1,0}, bf16[4,1024,4,64]{3,2,1,0}, bf16[4,1024,4,64]{3,2,1,0}, bf16[4,1024,4,64]{3,2,1,0})->(bf16[4,1024,4,64]{3,2,1,0}, bf16[4,1024,4,64]{3,2,1,0}, bf16[4,1024,4,64]{3,2,1,0}, bf16[4,1024,4,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true,true,true}
ENTRY main.21 {
Arg_0.1 = bf16[4,1024,4,64]{3,2,1,0} parameter(0)
Arg_1.2 = bf16[4,1024,4,64]{3,2,1,0} parameter(1)
Arg_2.3 = bf16[4,1024,4,64]{3,2,1,0} parameter(2)
constant.5 = s32[] constant(512)
broadcast.6 = s32[4]{0} broadcast(constant.5), dimensions={}
custom-call.7 = (bf16[4,4,1024,64]{3,1,2,0}, f32[4,4,1024]{2,1,0}, u8[0]{0}) custom-call(Arg_0.1, Arg_1.2, Arg_2.3, broadcast.6, broadcast.6), custom_call_target="__cudnn$fmhaSoftmaxDropout", operand_layout_constraints={bf16[4,1024,4,64]{3,2,1,0}, bf16[4,1024,4,64]{3,2,1,0}, bf16[4,1024,4,64]{3,2,1,0}, s32[4]{0}, s32[4]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 1.0, "dropout_rate": 0.5, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["4", "4", "1024", "1024"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "seed": 42, "is_flash_attention": true, "mask_type": "PADDING", "bmm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}}}
get-tuple-element.9 = u8[0]{0} get-tuple-element(custom-call.7), index=2
get-tuple-element.10 = f32[4,4,1024]{2,1,0} get-tuple-element(custom-call.7), index=1
Arg_3.4 = bf16[4,1024,4,64]{3,2,1,0} parameter(3)
get-tuple-element.8 = bf16[4,4,1024,64]{3,1,2,0} get-tuple-element(custom-call.7), index=0
transpose.11 = bf16[4,1024,4,64]{3,2,1,0} transpose(get-tuple-element.8), dimensions={0,2,1,3}
custom-call.12 = (bf16[4,4,1024,64]{3,1,2,0}, bf16[4,4,1024,64]{3,1,2,0}, bf16[4,4,1024,64]{3,1,2,0}, u8[0]{0}) custom-call(Arg_0.1, Arg_1.2, Arg_2.3, get-tuple-element.10, Arg_3.4, /*index=5*/transpose.11, broadcast.6, broadcast.6), custom_call_target="__cudnn$fmhaSoftmaxDropoutBackward", operand_layout_constraints={bf16[4,1024,4,64]{3,2,1,0}, bf16[4,1024,4,64]{3,2,1,0}, bf16[4,1024,4,64]{3,2,1,0}, f32[4,4,1024]{2,1,0}, bf16[4,1024,4,64]{3,2,1,0}, bf16[4,1024,4,64]{3,2,1,0}, s32[4]{0}, s32[4]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 1.0, "dropout_rate": 0.5, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["4", "4", "1024", "1024"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "seed": 42, "is_flash_attention": true, "mask_type": "PADDING", "bmm1_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm1_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}}}
get-tuple-element.16 = u8[0]{0} get-tuple-element(custom-call.12), index=3
get-tuple-element.13 = bf16[4,4,1024,64]{3,1,2,0} get-tuple-element(custom-call.12), index=0
transpose.17 = bf16[4,1024,4,64]{3,2,1,0} transpose(get-tuple-element.13), dimensions={0,2,1,3}
get-tuple-element.14 = bf16[4,4,1024,64]{3,1,2,0} get-tuple-element(custom-call.12), index=1
transpose.18 = bf16[4,1024,4,64]{3,2,1,0} transpose(get-tuple-element.14), dimensions={0,2,1,3}
get-tuple-element.15 = bf16[4,4,1024,64]{3,1,2,0} get-tuple-element(custom-call.12), index=2
transpose.19 = bf16[4,1024,4,64]{3,2,1,0} transpose(get-tuple-element.15), dimensions={0,2,1,3}
ROOT tuple.20 = (bf16[4,1024,4,64]{3,2,1,0}, bf16[4,1024,4,64]{3,2,1,0}, bf16[4,1024,4,64]{3,2,1,0}, bf16[4,1024,4,64]{3,2,1,0}) tuple(transpose.11, transpose.17, transpose.18, transpose.19)
} // main.21
)";

void TestImpl_Flash_Attention_Training_BMM1_Softmax_Dropout_BMM2() {
if (skip_reason_) GTEST_SKIP() << *skip_reason_;
if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
se::dnn::VersionInfo(9, 0, 0)) {
GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0.";
}
XlaBuilder builder(TestName());

auto lhs_bmm1_literal =
GetInput4DLiteral<bfloat16>({4, 1024, 4, 64}, {3, 2, 1, 0});
auto rhs_bmm1_literal =
GetInput4DLiteral<bfloat16>({4, 1024, 4, 64}, {3, 2, 1, 0});
auto rhs_bmm2_literal =
GetInput4DLiteral<bfloat16>({4, 1024, 4, 64}, {3, 2, 1, 0});
auto do_literal =
GetInput4DLiteral<bfloat16>({4, 1024, 4, 64}, {3, 2, 1, 0});

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(
kModuleFlashAttentionTrainingBMM1SoftmaxDropoutBMM2HloStringBF16));
ExecuteAndTransfer(std::move(module), {&lhs_bmm1_literal, &rhs_bmm1_literal,
&rhs_bmm2_literal, &do_literal});
}
};

// BMM1 - Scale - CausalMask - Softmax - BMM2
XLA_TEST_F(FlashAttentionBMMScaleCausalMaskSoftmaxBMM,
Flash_Attention_BMM1_CausalMask_Softmax_BMM2_BF16) {
Expand Down Expand Up @@ -1413,6 +1470,12 @@ XLA_TEST_F(FlashAttentionBMMScalePaddingMaskSoftmaxBMMF8,
Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8) {
TestImpl_Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8();
}

// BMM1 - Scale - Softmax - BMM2 fp8
XLA_TEST_F(FlashAttentionBMMScaleSoftmaxDropoutBMM,
Flash_Attention_Training_BMM1_Softmax_Dropout_BMM2) {
TestImpl_Flash_Attention_Training_BMM1_Softmax_Dropout_BMM2();
}
} // namespace
} // namespace gpu
} // namespace xla
21 changes: 5 additions & 16 deletions xla/stream_executor/cuda/cuda_dnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7161,19 +7161,6 @@ CudnnSupport::NormRunnerFromDesc(
#endif // CUDNN_VERSION >= 8905
}

// Returns the offset to increment for the dropout rng.
// The offset is used by runner to increment by the offset_increment for
// every call to cudnn fmha kernel to make sure dropout mask is evenly
// distributed. The recommended offset value by cudnn is max_sequence_length
// * max_sequence_length / number_of_threads_launched in kernel.
int64_t GetDropoutRngOffset(std::vector<int64_t>& intermediate_shape) {
int64_t kv_seq_len = intermediate_shape[intermediate_shape.size() - 1];
int64_t q_seq_len = intermediate_shape[intermediate_shape.size() - 2];
int64_t max_seq_len = std::max(q_seq_len, kv_seq_len);
int64_t cudnn_mha_num_threads = 256;
return max_seq_len * max_seq_len / cudnn_mha_num_threads;
}

bool CudnnSupport::GetRnnAlgorithms(
std::vector<dnn::AlgorithmDesc>* out_algorithms) {
PreloadCudnnSubLibs(PreloadCudnnType::Rnn);
Expand Down Expand Up @@ -8254,7 +8241,8 @@ absl::Status CudnnGraph::Build(dnn::DnnSupport& dnn_support,
}

absl::Status CudnnGraph::Execute(Stream& stream,
absl::Span<DeviceMemoryBase> operands) const {
absl::Span<DeviceMemoryBase> operands,
int64_t local_device_ordinal) const {
std::unordered_map<int64_t, void*> tensor_to_ptr_map;
absl::Span<DeviceMemoryBase> operands_without_workspace = operands;
DeviceMemoryBase workspace;
Expand All @@ -8272,9 +8260,10 @@ absl::Status CudnnGraph::Execute(Stream& stream,

if (dropout_rng_offset_increment_ > 0) {
#if CUDNN_VERSION >= 8800
UpdateDropoutState(local_device_ordinal);
tensor_to_ptr_map[next_uid()] = (void*)&dropout_rng_seed_;
current_dropout_rng_offset_ += dropout_rng_offset_increment_;
tensor_to_ptr_map[next_uid()] = (void*)&current_dropout_rng_offset_;
tensor_to_ptr_map[next_uid()] =
(void*)&current_dropout_rng_offset_[local_device_ordinal];
#else
return absl::UnimplementedError(
"Cudnn dropout offset and seed are only supported with Cudnn >= "
Expand Down
16 changes: 13 additions & 3 deletions xla/stream_executor/cuda/cuda_dnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,24 @@ class CudnnGraph : public dnn::DnnGraph {
// Builds single plan of the graph with given ID.
absl::Status Build(dnn::DnnSupport&, std::optional<int64_t> plan_id) override;
// Builds all the plans
absl::Status Execute(Stream& stream,
absl::Span<DeviceMemoryBase> operands) const override;
absl::Status Execute(Stream& stream, absl::Span<DeviceMemoryBase> operands,
int64_t local_device_ordinal) const override;
const cudnn_frontend::graph::Graph& Graph() const { return graph_; }
void InitDropoutState(int64_t local_device_count, int64_t seed,
int64_t increment) {
dropout_rng_seed_ = seed;
current_dropout_rng_offset_ = std::vector<int64_t>(local_device_count, 0);
dropout_rng_offset_increment_ = increment;
}
void UpdateDropoutState(int64_t local_device_ordinal) const {
current_dropout_rng_offset_[local_device_ordinal] +=
dropout_rng_offset_increment_;
}

private:
cudnn_frontend::graph::Graph graph_;
int64_t dropout_rng_seed_;
mutable int64_t current_dropout_rng_offset_;
mutable std::vector<int64_t> current_dropout_rng_offset_;
int64_t dropout_rng_offset_increment_ = 0;
};
#endif // CUDNN_VERSION >= 8100
Expand Down
5 changes: 4 additions & 1 deletion xla/stream_executor/dnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,10 @@ class DnnGraph {
virtual absl::Status Prepare(DnnSupport&, const NumericOptions&) = 0;
virtual absl::Status Build(DnnSupport&, std::optional<int64_t> plan_id) = 0;
virtual absl::Status Execute(Stream& stream,
absl::Span<DeviceMemoryBase> operands) const = 0;
absl::Span<DeviceMemoryBase> operands,
int64_t local_device_ordinal) const = 0;
virtual void InitDropoutState(int64_t local_device_count, int64_t seed,
int64_t increment) = 0;
};

using LazyDnnGraph = std::unique_ptr<DnnGraph>;
Expand Down

0 comments on commit 42b04a6

Please sign in to comment.