-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4901695
commit 477ea68
Showing
7 changed files
with
172 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .kernel_registry import KernelRegistry | ||
from .scattermoe import MoE_Torch, MoE_Triton | ||
from .swiglu import swiglu_torch, swiglu_triton | ||
from .vector_addition import vector_addition_cuda, vector_addition_torch, vector_addition_triton |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .torch_implementation import swiglu_torch | ||
from .triton_implementation import swiglu_triton |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
|
||
def swiglu_torch(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: | ||
return up * F.silu(gate) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import torch | ||
import triton | ||
|
||
from .kernels import swiglu_backward_triton_kernel, swiglu_forward_triton_kernel | ||
|
||
|
||
class _Swiglu_Triton(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: | ||
assert gate.is_cuda, "tensor gate is not on GPU" | ||
assert up.is_cuda, "tensor up is not on GPU" | ||
|
||
output = torch.empty_like(gate) | ||
|
||
ctx.save_for_backward(gate, up) | ||
|
||
original_shape = gate.size() | ||
gate = gate.view(-1) | ||
up = up.view(-1) | ||
|
||
assert gate.numel() == up.numel(), "both tensors should have same number of elements" | ||
assert gate.type() == up.type(), "both tensors should have same dtype" | ||
|
||
num_elements = gate.numel() | ||
grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) | ||
|
||
swiglu_forward_triton_kernel[grid](gate, up, output, num_elements, BLOCK_SIZE=1024) | ||
|
||
output = output.view(original_shape) | ||
|
||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, output_grad: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||
gate, up = ctx.saved_tensors | ||
|
||
original_shape = gate.size() | ||
gate = gate.view(-1) | ||
up = up.view(-1) | ||
|
||
num_elements = output_grad.numel() | ||
grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) | ||
|
||
# the kernel uses the gate and up tensors to store the gradients in-place for memory savings | ||
swiglu_backward_triton_kernel[grid](gate, up, output_grad, num_elements, BLOCK_SIZE=1024) | ||
|
||
gate = gate.view(original_shape) | ||
up = up.view(original_shape) | ||
|
||
return gate, up | ||
|
||
|
||
def swiglu_triton(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: | ||
return _Swiglu_Triton.apply(gate, up) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import triton | ||
import triton.language as tl | ||
|
||
|
||
@triton.jit | ||
def swiglu_forward_triton_kernel(gate_ptr, up_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): | ||
pid = tl.program_id(axis=0) | ||
|
||
block_start = pid * BLOCK_SIZE | ||
block_indices = block_start + tl.arange(0, BLOCK_SIZE) | ||
|
||
mask = block_indices < num_elements | ||
|
||
gate = tl.load(gate_ptr + block_indices, mask=mask).to(tl.float32) | ||
up = tl.load(up_ptr + block_indices, mask=mask) | ||
|
||
gate_sigmoid = tl.sigmoid(gate) | ||
gate_silu = gate * gate_sigmoid | ||
|
||
output = up * gate_silu | ||
|
||
tl.store(output_ptr + block_indices, output, mask=mask) | ||
|
||
|
||
@triton.jit | ||
def swiglu_backward_triton_kernel(gate_ptr, up_ptr, output_grad_ptr, num_elements, BLOCK_SIZE: tl.constexpr): | ||
pid = tl.program_id(axis=0) | ||
|
||
block_start = pid * BLOCK_SIZE | ||
block_indices = block_start + tl.arange(0, BLOCK_SIZE) | ||
|
||
mask = block_indices < num_elements | ||
|
||
gate = tl.load(gate_ptr + block_indices, mask=mask).to(tl.float32) | ||
up = tl.load(up_ptr + block_indices, mask=mask) | ||
output_grad = tl.load(output_grad_ptr + block_indices, mask=mask) | ||
|
||
gate_sigmoid = tl.sigmoid(gate) | ||
gate_silu = gate * gate_sigmoid | ||
|
||
up_grad = output_grad * gate_silu | ||
gate_grad = output_grad * up * (gate_sigmoid + gate_silu * (1 - gate_sigmoid)) | ||
|
||
tl.store(gate_ptr + block_indices, gate_grad, mask=mask) | ||
tl.store(up_ptr + block_indices, up_grad, mask=mask) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from typing import Callable | ||
|
||
import torch | ||
from parameterized import parameterized | ||
|
||
from khd import swiglu_torch, swiglu_triton | ||
|
||
from .test_commons import TestCommons | ||
|
||
|
||
class SwigluTest(TestCommons): | ||
@parameterized.expand( | ||
TestCommons.make_args_matrix( | ||
TestCommons.get_2d_tensor_sizes(), [torch.device("cuda")], TestCommons.get_dtypes() | ||
) | ||
) | ||
def test_swiglu_triton_forward(self, size: tuple[int], device: torch.device, dtype: torch.dtype) -> None: | ||
self._test_swiglu_forward(size, device, dtype, swiglu_triton) | ||
|
||
def _test_swiglu_forward(self, size: int, device: torch.device, dtype: torch.dtype, function: Callable) -> None: | ||
x = torch.randn(size, device=device, dtype=dtype) | ||
y = torch.randn(size, device=device, dtype=dtype) | ||
|
||
z_kernel = function(x, y) | ||
z_expected = swiglu_torch(x, y) | ||
|
||
self.assert_equal_tensors(z_kernel, z_expected, False, atol_float32=5e-6, rtol_float32=0) | ||
|
||
@parameterized.expand( | ||
TestCommons.make_args_matrix( | ||
TestCommons.get_2d_tensor_sizes(), [torch.device("cuda")], TestCommons.get_dtypes() | ||
) | ||
) | ||
def test_swiglu_triton_backward(self, size: tuple[int], device: torch.device, dtype: torch.dtype) -> None: | ||
self._test_swiglu_backward(size, device, dtype, swiglu_triton) | ||
|
||
def _test_swiglu_backward(self, size: int, device: torch.device, dtype: torch.dtype, function: Callable) -> None: | ||
x_kernel = torch.randn(size, device=device, dtype=dtype, requires_grad=True) | ||
y_kernel = torch.randn(size, device=device, dtype=dtype, requires_grad=True) | ||
|
||
x_expected = x_kernel.clone().detach().requires_grad_() | ||
y_expected = y_kernel.clone().detach().requires_grad_() | ||
|
||
z_kernel = function(x_kernel, y_kernel) | ||
z_expected = swiglu_torch(x_expected, y_expected) | ||
|
||
z_kernel.mean().backward() | ||
z_expected.mean().backward() | ||
|
||
self.assert_equal_tensors(x_kernel.grad, x_expected.grad, False, atol_float32=5e-6, rtol_float32=0) | ||
self.assert_equal_tensors(y_kernel.grad, y_expected.grad, False, atol_float32=5e-6, rtol_float32=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters