Skip to content

Commit

Permalink
CUDA GEMM naive (#138)
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 authored Jan 22, 2025
1 parent 5300a56 commit b0a534e
Show file tree
Hide file tree
Showing 36 changed files with 754 additions and 192 deletions.
9 changes: 9 additions & 0 deletions cute_kernels/cpp_registry.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,12 @@
- kernels/swiglu/cuda_implementation/forward.cu
- kernels/swiglu/cuda_implementation/backward.cu
build_path: swiglu

- functions:
- naive_gemm_cuda
- no_tile_quantization_cuda
sources:
- kernels/gemm/cuda_implementation/ops.cpp
- kernels/gemm/cuda_implementation/naive.cu
- kernels/gemm/cuda_implementation/no_tile_quantization.cu
build_path: gemm
48 changes: 48 additions & 0 deletions cute_kernels/include/dtypes/alias.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#pragma once

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>

#define AT_DISPATCH_CASE_CUSTOM_FLOAT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)

#define AT_DISPATCH_CUSTOM_FLOAT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_CUSTOM_FLOAT_TYPES(__VA_ARGS__))

#define AT_DISPATCH_CASE_CUSTOM_INT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)

#define AT_DISPATCH_CUSTOM_INT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_CUSTOM_INT_TYPES(__VA_ARGS__))

// define dtype aliases
using fp64 = double;
using fp64_2 = double2;
using fp64_4 = double4;

using fp32 = float;
using fp32_2 = float2;
using fp32_4 = float4;

using fp16 = half;
using fp16_2 = half2;

using bf16 = __nv_bfloat16;
using bf16_2 = __nv_bfloat162;

using int64 = long;
using uint64 = ulong;
using uint64_2 = ulong2;

using int32 = int;
using uint32 = uint;
using uint32_2 = uint2;
using uint32_4 = uint4;

using int16 = short;
using uint16 = ushort;
1 change: 1 addition & 0 deletions cute_kernels/include/dtypes/bf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cuda_runtime.h>
#include <torch/extension.h>

#include "alias.h"
#include "common.h"

template <>
Expand Down
42 changes: 1 addition & 41 deletions cute_kernels/include/dtypes/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,7 @@
#include <cuda_runtime.h>
#include <torch/extension.h>

#define AT_DISPATCH_CASE_CUSTOM_FLOAT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)

#define AT_DISPATCH_CUSTOM_FLOAT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_CUSTOM_FLOAT_TYPES(__VA_ARGS__))

#define AT_DISPATCH_CASE_CUSTOM_INT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)

#define AT_DISPATCH_CUSTOM_INT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_CUSTOM_INT_TYPES(__VA_ARGS__))

// define dtype aliases
using fp64 = double;
using fp64_2 = double2;
using fp64_4 = double4;

using fp32 = float;
using fp32_2 = float2;
using fp32_4 = float4;

using fp16 = half;
using fp16_2 = half2;

using bf16 = __nv_bfloat16;
using bf16_2 = __nv_bfloat162;

using int64 = long;
using uint64 = ulong;
using uint64_2 = ulong2;

using int32 = int;
using uint32 = uint;
using uint32_2 = uint2;
using uint32_4 = uint4;

using int16 = short;
using uint16 = ushort;
#include "alias.h"

inline __device__ std::tuple<uint16, uint16> split_fp32_into_16_bits(const fp32 &value) {
uint32 left_right_int = __float_as_uint(value);
Expand Down
1 change: 1 addition & 0 deletions cute_kernels/include/dtypes/fp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cuda_runtime.h>
#include <torch/extension.h>

#include "alias.h"
#include "common.h"

template <>
Expand Down
1 change: 1 addition & 0 deletions cute_kernels/include/dtypes/fp32.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cuda_runtime.h>
#include <torch/extension.h>

#include "alias.h"
#include "common.h"

template <>
Expand Down
1 change: 1 addition & 0 deletions cute_kernels/include/dtypes/fp64.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cuda_runtime.h>
#include <torch/extension.h>

#include "alias.h"
#include "common.h"

template <>
Expand Down
1 change: 1 addition & 0 deletions cute_kernels/include/dtypes/uint32.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cuda_runtime.h>
#include <torch/extension.h>

#include "alias.h"
#include "common.h"

template <>
Expand Down
12 changes: 8 additions & 4 deletions cute_kernels/include/threads.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,22 @@ inline __device__ uint32 get_threads_per_block() { return blockDim.x * blockDim.

inline __device__ uint32 get_num_blocks() { return gridDim.x * gridDim.y * gridDim.z; }

inline __device__ uint32 get_block_id() {
return gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x;
}
inline __device__ uint32 get_block_id() { return gridDim.x * (gridDim.y * blockIdx.z + blockIdx.y) + blockIdx.x; }

inline __device__ uint32 get_local_thread_id() {
return blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x;
return blockDim.x * (blockDim.y * threadIdx.z + threadIdx.y) + threadIdx.x;
}

inline __device__ uint64 get_global_thread_id() {
return get_threads_per_block() * get_block_id() + get_local_thread_id();
}

inline __device__ uint64 get_thread_id_along_axis(const uint32 &block_size,
const uint32 &block_id,
const uint32 &thread_id) {
return block_size * block_id + thread_id;
}

inline __host__ int get_max_thread_blocks(const int &sm_count, const int &thread_block_cluster_size) {
int max_num_blocks = sm_count;
if (max_num_blocks % thread_block_cluster_size != 0) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,13 @@
from .forward import add_scalar_cuda
import torch

from .....constants import LIBRARY_NAME
from .....jit import cpp_jit
from .....utils import cute_op


_KERNEL_NAME = "add_scalar_cuda"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
@cpp_jit(_KERNEL_NAME)
def add_scalar_cuda(x: torch.Tensor, y: float, output: torch.Tensor, BLOCK_SIZE: int) -> None: ...
13 changes: 0 additions & 13 deletions cute_kernels/kernels/add/add_scalar/cuda_implementation/forward.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1 +1,13 @@
from .forward import add_tensor_cuda
import torch

from .....constants import LIBRARY_NAME
from .....jit import cpp_jit
from .....utils import cute_op


_KERNEL_NAME = "add_tensor_cuda"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
@cpp_jit(_KERNEL_NAME)
def add_tensor_cuda(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, BLOCK_SIZE: int) -> None: ...
13 changes: 0 additions & 13 deletions cute_kernels/kernels/add/add_tensor/cuda_implementation/forward.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1 +1,15 @@
from .forward import continuous_count_cuda
import torch

from ....constants import LIBRARY_NAME
from ....jit import cpp_jit
from ....utils import cute_op


_KERNEL_NAME = "continuous_count_cuda"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
@cpp_jit(_KERNEL_NAME)
def continuous_count_cuda(
x: torch.Tensor, output: torch.Tensor, sm_count: int, thread_block_cluster_size: int, size: int, BLOCK_SIZE: int
) -> None: ...

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1 +1,21 @@
from .forward import continuous_count_and_sort_cuda
import torch

from ....constants import LIBRARY_NAME
from ....jit import cpp_jit
from ....utils import cute_op


_KERNEL_NAME = "continuous_count_and_sort_cuda"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"count_output", "sorted_output", "argsort_output"})
@cpp_jit(_KERNEL_NAME)
def continuous_count_and_sort_cuda(
x: torch.Tensor,
count_output: torch.Tensor,
sorted_output: torch.Tensor,
argsort_output: torch.Tensor,
sm_count: int,
size: int,
BLOCK_SIZE: int,
) -> None: ...

This file was deleted.

Loading

0 comments on commit b0a534e

Please sign in to comment.