Skip to content

Commit

Permalink
Add cuda c++ for command buffer kernels
Browse files Browse the repository at this point in the history
Reverts 8bd6744

PiperOrigin-RevId: 675204038
  • Loading branch information
IllogicalMoose authored and Google-ML-Automation committed Sep 16, 2024
1 parent 0f8556c commit a03a49b
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 1 deletion.
6 changes: 5 additions & 1 deletion xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library")
load(
"@tsl//tsl/platform:build_config_root.bzl",
"if_static",
Expand Down Expand Up @@ -511,13 +512,16 @@ cuda_only_cc_library(
],
)

cc_library(
cuda_library(
name = "command_buffer_kernels",
srcs = [
"command_buffer_kernels.cc",
"command_buffer_kernels.cu.cc",
],
tags = ["no_rocm"],
deps = [
"//xla/stream_executor:kernel_spec",
"//xla/stream_executor/gpu:gpu_types_header",
"@com_google_absl//absl/status:statusor",
],
)
Expand Down
41 changes: 41 additions & 0 deletions xla/stream_executor/cuda/command_buffer_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include <string_view>

#include "absl/status/statusor.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/stream_executor/kernel_spec.h"

namespace stream_executor {
Expand Down Expand Up @@ -757,45 +758,85 @@ inline constexpr std::string_view kNoOpKernel = R"(
})";

} // namespace

#if CUDA_VERSION >= 12030
void* GetSetIfConditionKernel();
void* GetSetIfElseConditionKernel();
void* GetSetCaseConditionKernel();
void* GetSetForConditionKernel();
void* GetSetWhileConditionKernel();
void* GetNoOpKernel();
#endif

} // namespace cuda

namespace gpu {

// TODO(b/362786589): Remove PTX usage when we only support cuda >= 12.4.1
// See comment at top of this file for why PTX is used for cuda < 12.4.1.
absl::StatusOr<MultiKernelLoaderSpec> GetSetIfConditionKernelLoaderSpec() {
MultiKernelLoaderSpec spec(/*arity=*/2);
#if CUDA_VERSION >= 12030
spec.AddInProcessSymbol(cuda::GetSetIfConditionKernel(), "set_if_condition");
#else
spec.AddCudaPtxInMemory(cuda::kSetIfConditionKernel, "set_if_condition");
#endif
return spec;
}

absl::StatusOr<MultiKernelLoaderSpec> GetSetIfElseConditionKernelLoaderSpec() {
MultiKernelLoaderSpec spec(/*arity=*/3);
#if CUDA_VERSION >= 12030
spec.AddInProcessSymbol(cuda::GetSetIfElseConditionKernel(),
"set_if_else_condition");
#else
spec.AddCudaPtxInMemory(cuda::kSetIfElseConditionKernel,
"set_if_else_condition");
#endif
return spec;
}

absl::StatusOr<MultiKernelLoaderSpec> GetSetCaseConditionKernelLoaderSpec() {
MultiKernelLoaderSpec spec(/*arity=*/10);
#if CUDA_VERSION >= 12030
spec.AddInProcessSymbol(cuda::GetSetCaseConditionKernel(),
"set_case_condition");
#else
spec.AddCudaPtxInMemory(cuda::kSetCaseConditionKernel, "set_case_condition");
#endif
return spec;
}

absl::StatusOr<MultiKernelLoaderSpec> GetSetForConditionKernelLoaderSpec() {
MultiKernelLoaderSpec spec(/*arity=*/3);
#if CUDA_VERSION >= 12030
spec.AddInProcessSymbol(cuda::GetSetForConditionKernel(),
"set_for_condition");
#else
spec.AddCudaPtxInMemory(cuda::kSetForConditionKernel, "set_for_condition");
#endif
return spec;
}

absl::StatusOr<MultiKernelLoaderSpec> GetSetWhileConditionKernelLoaderSpec() {
MultiKernelLoaderSpec spec(/*arity=*/2);
#if CUDA_VERSION >= 12030
spec.AddInProcessSymbol(cuda::GetSetWhileConditionKernel(),
"set_while_condition");
#else
spec.AddCudaPtxInMemory(cuda::kSetWhileConditionKernel,
"set_while_condition");
#endif
return spec;
}

absl::StatusOr<MultiKernelLoaderSpec> GetNoOpKernelLoaderSpec() {
MultiKernelLoaderSpec spec(/*arity=*/0);
#if CUDA_VERSION >= 12030
spec.AddInProcessSymbol(cuda::GetNoOpKernel(), "noop");
#else
spec.AddCudaPtxInMemory(cuda::kNoOpKernel, "noop");
#endif
return spec;
}

Expand Down
117 changes: 117 additions & 0 deletions xla/stream_executor/cuda/command_buffer_kernels.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <array>

#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_device_runtime_api.h"
#include "third_party/gpus/cuda/include/cuda_runtime.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"

namespace stream_executor::cuda {

#if CUDA_VERSION >= 12030
// In all kernels defined below we set conditional handle value to `1` when we
// want to execute a CUDA graph tied to it, and to `0` otherwise. For loops, the
// graph will keep being executed until the conditional handle becomes `0`.

__global__ void SetIfCondition(cudaGraphConditionalHandle then_handle,
bool* predicate) {
if (*predicate) {
cudaGraphSetConditional(then_handle, 1);
} else {
cudaGraphSetConditional(then_handle, 0);
}
}

__global__ void SetIfElseCondition(cudaGraphConditionalHandle then_handle,
cudaGraphConditionalHandle else_handle,
bool* predicate) {
if (*predicate) {
cudaGraphSetConditional(then_handle, 1);
cudaGraphSetConditional(else_handle, 0);
} else {
cudaGraphSetConditional(then_handle, 0);
cudaGraphSetConditional(else_handle, 1);
}
}

__global__ void SetCaseCondition(
cudaGraphConditionalHandle h0, cudaGraphConditionalHandle h1,
cudaGraphConditionalHandle h2, cudaGraphConditionalHandle h3,
cudaGraphConditionalHandle h4, cudaGraphConditionalHandle h5,
cudaGraphConditionalHandle h6, cudaGraphConditionalHandle h7,
int32_t* index, int32_t num_handles) {
// Only handles in [0, num_handles) range are valid.
//
// We can't define a device function with dynamic number of handle arguments,
// so we always pass 8 handles, but only some of them are valid. Size 8 picked
// as a reasonable (but random) upper bound for what we see in XLA uses.
std::array<cudaGraphConditionalHandle, 8> handles = {h0, h1, h2, h3,
h4, h5, h6, h7};

// If branch index is out of range activate the last valid handle.
int32_t branch_index = *index;
if (branch_index < 0 || branch_index >= num_handles) {
branch_index = num_handles - 1;
}

for (int32_t i = 0; i < num_handles; ++i) {
if (branch_index == i) {
cudaGraphSetConditional(handles[i], 1);
} else {
cudaGraphSetConditional(handles[i], 0);
}
}
}

__global__ void SetForCondition(cudaGraphConditionalHandle handle,
int32_t* loop_index, int32_t num_iterations) {
if (*loop_index < num_iterations) {
cudaGraphSetConditional(handle, 1);
} else {
cudaGraphSetConditional(handle, 0);
}
*loop_index += 1;
}

__global__ void NoOp() {}

void* GetSetIfConditionKernel() {
return reinterpret_cast<void*>(&cuda::SetIfCondition);
}

void* GetSetIfElseConditionKernel() {
return reinterpret_cast<void*>(&SetIfElseCondition);
}

void* GetSetCaseConditionKernel() {
return reinterpret_cast<void*>(&SetCaseCondition);
}

void* GetSetForConditionKernel() {
return reinterpret_cast<void*>(&SetForCondition);
}

void* GetSetWhileConditionKernel() {
// While condition kernel is the same as an `If` with a single branch.
return reinterpret_cast<void*>(&SetIfCondition);
}

void* GetNoOpKernel() { return reinterpret_cast<void*>(&NoOp); }

#endif

} // namespace stream_executor::cuda

0 comments on commit a03a49b

Please sign in to comment.