-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Mayank Mishra <[email protected]>
- Loading branch information
1 parent
ee79499
commit ed19747
Showing
43 changed files
with
783 additions
and
741 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
16 changes: 1 addition & 15 deletions
16
cute_kernels/kernels/add/add_scalar/cuda_implementation/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1 @@ | ||
import torch | ||
|
||
from .....constants import LIBRARY_NAME | ||
from .....jit import cpp_jit | ||
from .....utils import cute_op | ||
|
||
|
||
_KERNEL_NAME = "add_scalar_forward_cuda" | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) | ||
@cpp_jit(_KERNEL_NAME) | ||
def add_scalar_forward_cuda( | ||
x: torch.Tensor, y: float, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int | ||
) -> None: ... | ||
from .forward import add_scalar_forward_cuda |
File renamed without changes.
15 changes: 15 additions & 0 deletions
15
cute_kernels/kernels/add/add_scalar/cuda_implementation/forward.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch | ||
|
||
from .....constants import LIBRARY_NAME | ||
from .....jit import cpp_jit | ||
from .....utils import cute_op | ||
|
||
|
||
_KERNEL_NAME = "add_scalar_forward_cuda" | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) | ||
@cpp_jit(_KERNEL_NAME) | ||
def add_scalar_forward_cuda( | ||
x: torch.Tensor, y: float, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int | ||
) -> None: ... |
21 changes: 1 addition & 20 deletions
21
cute_kernels/kernels/add/add_scalar/triton_implementation/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1 @@ | ||
import torch | ||
|
||
from .....constants import LIBRARY_NAME | ||
from .....math import ceil_divide | ||
from .....utils import cute_op | ||
from .kernels_forward import _add_scalar_forward_triton_kernel | ||
|
||
|
||
_KERNEL_NAME = "add_scalar_forward_triton" | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) | ||
def add_scalar_forward_triton(x: torch.Tensor, y: float, output: torch.Tensor, BLOCK_SIZE: int) -> None: | ||
num_elements = x.numel() | ||
num_programs = ceil_divide(num_elements, BLOCK_SIZE) | ||
|
||
with torch.device(x.device): | ||
_add_scalar_forward_triton_kernel[(num_programs,)]( | ||
x_ptr=x, y=y, output_ptr=output, num_elements=num_elements, BLOCK_SIZE=BLOCK_SIZE | ||
) | ||
from .forward import add_scalar_forward_triton |
34 changes: 34 additions & 0 deletions
34
cute_kernels/kernels/add/add_scalar/triton_implementation/forward.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
from .....constants import LIBRARY_NAME | ||
from .....math import ceil_divide | ||
from .....utils import cute_op | ||
|
||
|
||
_KERNEL_NAME = "add_scalar_forward_triton" | ||
|
||
|
||
@triton.jit | ||
def _add_scalar_forward_triton_kernel(x_ptr, y, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): | ||
pid = tl.program_id(axis=0) | ||
|
||
indices = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
mask = indices < num_elements | ||
|
||
x = tl.load(x_ptr + indices, mask=mask) | ||
output = x + y | ||
|
||
tl.store(output_ptr + indices, output, mask=mask) | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) | ||
def add_scalar_forward_triton(x: torch.Tensor, y: float, output: torch.Tensor, BLOCK_SIZE: int) -> None: | ||
num_elements = x.numel() | ||
num_programs = ceil_divide(num_elements, BLOCK_SIZE) | ||
|
||
with torch.device(x.device): | ||
_add_scalar_forward_triton_kernel[(num_programs,)]( | ||
x_ptr=x, y=y, output_ptr=output, num_elements=num_elements, BLOCK_SIZE=BLOCK_SIZE | ||
) |
15 changes: 0 additions & 15 deletions
15
cute_kernels/kernels/add/add_scalar/triton_implementation/kernels_forward.py
This file was deleted.
Oops, something went wrong.
16 changes: 1 addition & 15 deletions
16
cute_kernels/kernels/add/add_tensor/cuda_implementation/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1 @@ | ||
import torch | ||
|
||
from .....constants import LIBRARY_NAME | ||
from .....jit import cpp_jit | ||
from .....utils import cute_op | ||
|
||
|
||
_KERNEL_NAME = "add_tensor_forward_cuda" | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) | ||
@cpp_jit(_KERNEL_NAME) | ||
def add_tensor_forward_cuda( | ||
x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int | ||
) -> None: ... | ||
from .forward import add_tensor_forward_cuda |
File renamed without changes.
15 changes: 15 additions & 0 deletions
15
cute_kernels/kernels/add/add_tensor/cuda_implementation/forward.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch | ||
|
||
from .....constants import LIBRARY_NAME | ||
from .....jit import cpp_jit | ||
from .....utils import cute_op | ||
|
||
|
||
_KERNEL_NAME = "add_tensor_forward_cuda" | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) | ||
@cpp_jit(_KERNEL_NAME) | ||
def add_tensor_forward_cuda( | ||
x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int | ||
) -> None: ... |
20 changes: 1 addition & 19 deletions
20
cute_kernels/kernels/add/add_tensor/triton_implementation/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1 @@ | ||
import torch | ||
|
||
from .....constants import LIBRARY_NAME | ||
from .....math import ceil_divide | ||
from .....utils import cute_op | ||
from .kernels_forward import _add_tensor_forward_triton_kernel | ||
|
||
|
||
_KERNEL_NAME = "add_tensor_forward_triton" | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) | ||
def add_tensor_forward_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, BLOCK_SIZE: int) -> None: | ||
num_elements = x.numel() | ||
num_programs = ceil_divide(num_elements, BLOCK_SIZE) | ||
|
||
_add_tensor_forward_triton_kernel[(num_programs,)]( | ||
x_ptr=x, y_ptr=y, output_ptr=output, num_elements=num_elements, BLOCK_SIZE=BLOCK_SIZE | ||
) | ||
from .forward import add_tensor_forward_triton |
35 changes: 35 additions & 0 deletions
35
cute_kernels/kernels/add/add_tensor/triton_implementation/forward.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
from .....constants import LIBRARY_NAME | ||
from .....math import ceil_divide | ||
from .....utils import cute_op | ||
|
||
|
||
_KERNEL_NAME = "add_tensor_forward_triton" | ||
|
||
|
||
@triton.jit | ||
def _add_tensor_forward_triton_kernel(x_ptr, y_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): | ||
pid = tl.program_id(axis=0) | ||
|
||
indices = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
mask = indices < num_elements | ||
|
||
x = tl.load(x_ptr + indices, mask=mask) | ||
y = tl.load(y_ptr + indices, mask=mask) | ||
|
||
output = x + y | ||
|
||
tl.store(output_ptr + indices, output, mask=mask) | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) | ||
def add_tensor_forward_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, BLOCK_SIZE: int) -> None: | ||
num_elements = x.numel() | ||
num_programs = ceil_divide(num_elements, BLOCK_SIZE) | ||
|
||
_add_tensor_forward_triton_kernel[(num_programs,)]( | ||
x_ptr=x, y_ptr=y, output_ptr=output, num_elements=num_elements, BLOCK_SIZE=BLOCK_SIZE | ||
) |
17 changes: 0 additions & 17 deletions
17
cute_kernels/kernels/add/add_tensor/triton_implementation/kernels_forward.py
This file was deleted.
Oops, something went wrong.
16 changes: 1 addition & 15 deletions
16
cute_kernels/kernels/contiguous_count/cuda_implementation/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1 @@ | ||
import torch | ||
|
||
from ....constants import LIBRARY_NAME | ||
from ....jit import cpp_jit | ||
from ....utils import cute_op | ||
|
||
|
||
_KERNEL_NAME = "contiguous_count_cuda" | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) | ||
@cpp_jit(_KERNEL_NAME) | ||
def contiguous_count_cuda( | ||
x: torch.Tensor, output: torch.Tensor, sm_count: int, thread_block_cluster_size: int, size: int, BLOCK_SIZE: int | ||
) -> None: ... | ||
from .forward import contiguous_count_cuda |
File renamed without changes.
15 changes: 15 additions & 0 deletions
15
cute_kernels/kernels/contiguous_count/cuda_implementation/forward.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch | ||
|
||
from ....constants import LIBRARY_NAME | ||
from ....jit import cpp_jit | ||
from ....utils import cute_op | ||
|
||
|
||
_KERNEL_NAME = "contiguous_count_cuda" | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) | ||
@cpp_jit(_KERNEL_NAME) | ||
def contiguous_count_cuda( | ||
x: torch.Tensor, output: torch.Tensor, sm_count: int, thread_block_cluster_size: int, size: int, BLOCK_SIZE: int | ||
) -> None: ... |
30 changes: 1 addition & 29 deletions
30
cute_kernels/kernels/contiguous_count/triton_implementation/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1 @@ | ||
import torch | ||
|
||
from ....constants import LIBRARY_NAME | ||
from ....math import ceil_divide | ||
from ....utils import cute_op, get_sm_count | ||
from .kernels_forward import _contiguous_count_triton_kernel | ||
|
||
|
||
_KERNEL_NAME = "contiguous_count_triton" | ||
|
||
|
||
@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) | ||
def contiguous_count_triton( | ||
x: torch.Tensor, output: torch.Tensor, size: int, BLOCK_SIZE: int, BLOCK_SIZE_C: int | ||
) -> None: | ||
B = x.numel() | ||
|
||
sm_count = get_sm_count(x.device) | ||
num_programs = min(sm_count, ceil_divide(B, BLOCK_SIZE)) | ||
|
||
with torch.device(x.device): | ||
_contiguous_count_triton_kernel[(num_programs,)]( | ||
x_ptr=x, | ||
output_ptr=output, | ||
B=B, | ||
C=size, | ||
BLOCK_SIZE_B=BLOCK_SIZE, | ||
BLOCK_SIZE_C=BLOCK_SIZE_C, | ||
) | ||
from .forward import contiguous_count_triton |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.