-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Mayank Mishra <[email protected]>
- Loading branch information
1 parent
5300a56
commit b0a534e
Showing
36 changed files
with
754 additions
and
192 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#pragma once | ||
|
||
#include <cuda.h> | ||
#include <cuda_fp16.h> | ||
#include <cuda_runtime.h> | ||
#include <torch/extension.h> | ||
|
||
#define AT_DISPATCH_CASE_CUSTOM_FLOAT_TYPES(...) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) | ||
|
||
#define AT_DISPATCH_CUSTOM_FLOAT_TYPES(TYPE, NAME, ...) \ | ||
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_CUSTOM_FLOAT_TYPES(__VA_ARGS__)) | ||
|
||
#define AT_DISPATCH_CASE_CUSTOM_INT_TYPES(...) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) | ||
|
||
#define AT_DISPATCH_CUSTOM_INT_TYPES(TYPE, NAME, ...) \ | ||
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_CUSTOM_INT_TYPES(__VA_ARGS__)) | ||
|
||
// define dtype aliases | ||
using fp64 = double; | ||
using fp64_2 = double2; | ||
using fp64_4 = double4; | ||
|
||
using fp32 = float; | ||
using fp32_2 = float2; | ||
using fp32_4 = float4; | ||
|
||
using fp16 = half; | ||
using fp16_2 = half2; | ||
|
||
using bf16 = __nv_bfloat16; | ||
using bf16_2 = __nv_bfloat162; | ||
|
||
using int64 = long; | ||
using uint64 = ulong; | ||
using uint64_2 = ulong2; | ||
|
||
using int32 = int; | ||
using uint32 = uint; | ||
using uint32_2 = uint2; | ||
using uint32_4 = uint4; | ||
|
||
using int16 = short; | ||
using uint16 = ushort; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
14 changes: 13 additions & 1 deletion
14
cute_kernels/kernels/add/add_scalar/cuda_implementation/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,13 @@ | ||
from .forward import add_scalar_cuda | ||
import torch | ||
|
||
from .....constants import LIBRARY_NAME | ||
from .....jit import cpp_jit | ||
from .....utils import cute_op | ||
|
||
|
||
_KERNEL_NAME = "add_scalar_cuda" | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) | ||
@cpp_jit(_KERNEL_NAME) | ||
def add_scalar_cuda(x: torch.Tensor, y: float, output: torch.Tensor, BLOCK_SIZE: int) -> None: ... |
13 changes: 0 additions & 13 deletions
13
cute_kernels/kernels/add/add_scalar/cuda_implementation/forward.py
This file was deleted.
Oops, something went wrong.
14 changes: 13 additions & 1 deletion
14
cute_kernels/kernels/add/add_tensor/cuda_implementation/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,13 @@ | ||
from .forward import add_tensor_cuda | ||
import torch | ||
|
||
from .....constants import LIBRARY_NAME | ||
from .....jit import cpp_jit | ||
from .....utils import cute_op | ||
|
||
|
||
_KERNEL_NAME = "add_tensor_cuda" | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) | ||
@cpp_jit(_KERNEL_NAME) | ||
def add_tensor_cuda(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, BLOCK_SIZE: int) -> None: ... |
13 changes: 0 additions & 13 deletions
13
cute_kernels/kernels/add/add_tensor/cuda_implementation/forward.py
This file was deleted.
Oops, something went wrong.
16 changes: 15 additions & 1 deletion
16
cute_kernels/kernels/continuous_count/cuda_implementation/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,15 @@ | ||
from .forward import continuous_count_cuda | ||
import torch | ||
|
||
from ....constants import LIBRARY_NAME | ||
from ....jit import cpp_jit | ||
from ....utils import cute_op | ||
|
||
|
||
_KERNEL_NAME = "continuous_count_cuda" | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) | ||
@cpp_jit(_KERNEL_NAME) | ||
def continuous_count_cuda( | ||
x: torch.Tensor, output: torch.Tensor, sm_count: int, thread_block_cluster_size: int, size: int, BLOCK_SIZE: int | ||
) -> None: ... |
15 changes: 0 additions & 15 deletions
15
cute_kernels/kernels/continuous_count/cuda_implementation/forward.py
This file was deleted.
Oops, something went wrong.
22 changes: 21 additions & 1 deletion
22
cute_kernels/kernels/continuous_count_and_sort/cuda_implementation/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,21 @@ | ||
from .forward import continuous_count_and_sort_cuda | ||
import torch | ||
|
||
from ....constants import LIBRARY_NAME | ||
from ....jit import cpp_jit | ||
from ....utils import cute_op | ||
|
||
|
||
_KERNEL_NAME = "continuous_count_and_sort_cuda" | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"count_output", "sorted_output", "argsort_output"}) | ||
@cpp_jit(_KERNEL_NAME) | ||
def continuous_count_and_sort_cuda( | ||
x: torch.Tensor, | ||
count_output: torch.Tensor, | ||
sorted_output: torch.Tensor, | ||
argsort_output: torch.Tensor, | ||
sm_count: int, | ||
size: int, | ||
BLOCK_SIZE: int, | ||
) -> None: ... |
21 changes: 0 additions & 21 deletions
21
cute_kernels/kernels/continuous_count_and_sort/cuda_implementation/forward.py
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.