Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vector load/store fp32 vector addition kernel #11

Merged
merged 10 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
IndentPPDirectives: BeforeHash
4 changes: 2 additions & 2 deletions khd/scattermoe/triton_implementation/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def grid(META):
E=W.size(0),
BLOCK_M=BLOCK_M,
ACC_TYPE=tl.float32,
allow_tf32=True,
allow_tf32=torch.backends.cudnn.allow_tf32,
x_grouped=x_grouped,
y_grouped=y_grouped,
)
Expand Down Expand Up @@ -112,7 +112,7 @@ def grid(META):
K=X.size(-1),
# ACC_TYPE: tl.constexpr,
ACC_TYPE=tl.float32,
allow_tf32=True,
allow_tf32=torch.backends.cudnn.allow_tf32,
)
return DW

Expand Down
46 changes: 41 additions & 5 deletions khd/vector_addition/cuda_implementation/kernels.cu
Original file line number Diff line number Diff line change
@@ -1,23 +1,59 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>

#define BLOCK_SIZE 256
#define NUM_ELEMENTS_PER_THREAD 4 // vectorized load store

template <typename scalar_t>
__global__ void vector_addition_forward_kernel(const scalar_t *x,
const scalar_t *y,
scalar_t *output,
const int num_elements) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num_elements) {
output[index] = x[index] + y[index];

if (std::is_same<scalar_t, float>::value) {
if (index * NUM_ELEMENTS_PER_THREAD < num_elements) {
// float4 is a datatype used for vectorized loads and stores
float4 *x4 = (float4 *)x;
float4 *y4 = (float4 *)y;
float4 *output4 = (float4 *)output;

// tmp is initialized here to avoid doing multiple writes
float4 tmp;
tmp.x = x4[index].x + y4[index].x;
tmp.y = x4[index].y + y4[index].y;
tmp.z = x4[index].z + y4[index].z;
tmp.w = x4[index].w + y4[index].w;

output4[index] = tmp;
}
} else {
if (index < num_elements) {
output[index] = x[index] + y[index];
}
}
}

void vector_addition_forward_kernel_dispatcher(
torch::Tensor x, torch::Tensor y, torch::Tensor output, const int NUM_BLOCKS, const int BLOCK_SIZE) {
torch::Tensor vector_addition_forward_kernel_dispatcher(torch::Tensor x, torch::Tensor y) {
int num_elements = x.numel();

torch::Tensor output = torch::empty_like(x);

AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "vector_addition_forward_kernel", ([&] {
int num_elements_per_thread = 1;
if (std::is_same<scalar_t, float>::value) {
num_elements_per_thread = NUM_ELEMENTS_PER_THREAD;
}

int NUM_BLOCKS =
(num_elements + num_elements_per_thread * BLOCK_SIZE - 1) / (num_elements_per_thread * BLOCK_SIZE);

vector_addition_forward_kernel<scalar_t><<<NUM_BLOCKS, BLOCK_SIZE>>>(
x.data<scalar_t>(), y.data<scalar_t>(), output.data<scalar_t>(), x.numel());
x.data<scalar_t>(), y.data<scalar_t>(), output.data<scalar_t>(), num_elements);
}));

return output;
}
16 changes: 3 additions & 13 deletions khd/vector_addition/cuda_implementation/ops.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include <torch/extension.h>

void vector_addition_forward_kernel_dispatcher(
torch::Tensor x, torch::Tensor y, torch::Tensor output, const int NUM_BLOCKS, const int BLOCK_SIZE);
torch::Tensor vector_addition_forward_kernel_dispatcher(torch::Tensor x, torch::Tensor y);

torch::Tensor vector_addition_forward(torch::Tensor x, torch::Tensor y) {
TORCH_CHECK(x.device().is_cuda(), "tensor x is not on GPU")
Expand All @@ -13,19 +12,10 @@ torch::Tensor vector_addition_forward(torch::Tensor x, torch::Tensor y) {
TORCH_CHECK(x.dim() == 1, "tensor x should be 1 dimensional")
TORCH_CHECK(y.dim() == 1, "tensor y should be 1 dimensional")

int num_elements = x.numel();

TORCH_CHECK(y.numel() == num_elements, "both tensors should have same number of elements");
TORCH_CHECK(x.numel() == y.numel(), "both tensors should have same number of elements");
TORCH_CHECK(x.scalar_type() == y.scalar_type(), "both tensors should have same dtype");

int BLOCK_SIZE = 1024;
int NUM_BLOCKS = (int)ceil((float)num_elements / BLOCK_SIZE);

torch::Tensor output = torch::empty_like(x);

vector_addition_forward_kernel_dispatcher(x, y, output, NUM_BLOCKS, BLOCK_SIZE);

return output;
return vector_addition_forward_kernel_dispatcher(x, y);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Expand Down
37 changes: 37 additions & 0 deletions tools/thoughput.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from time import perf_counter

import torch
from tabulate import tabulate

from khd import vector_addition_cuda, vector_addition_torch, vector_addition_triton


n = 100

headers = ["dtype", "torch", "cuda", "triton"]
kernels = [vector_addition_torch, vector_addition_cuda, vector_addition_triton]

table = []

for dtype in [torch.float16, torch.bfloat16, torch.float32]:
row = [str(dtype)]
for kernel in kernels:
# kernel = torch.compile(kernel)
x = torch.randn(10485760, device=torch.cuda.current_device(), dtype=dtype)
y = torch.randn(10485760, device=torch.cuda.current_device(), dtype=dtype)

for i in range(n):
z = kernel(x, y)

torch.cuda.synchronize()
s = perf_counter()
for i in range(n):
z = kernel(x, y)
torch.cuda.synchronize()
e = perf_counter()

row.append((e - s) / n)
table.append(row)


print(tabulate(table, headers=headers))
Loading