Skip to content

Commit

Permalink
Swiglu triton (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
mayank31398 authored Sep 9, 2024
1 parent 4901695 commit 477ea68
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 0 deletions.
1 change: 1 addition & 0 deletions khd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .kernel_registry import KernelRegistry
from .scattermoe import MoE_Torch, MoE_Triton
from .swiglu import swiglu_torch, swiglu_triton
from .vector_addition import vector_addition_cuda, vector_addition_torch, vector_addition_triton
2 changes: 2 additions & 0 deletions khd/swiglu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .torch_implementation import swiglu_torch
from .triton_implementation import swiglu_triton
6 changes: 6 additions & 0 deletions khd/swiglu/torch_implementation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch
import torch.nn.functional as F


def swiglu_torch(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
return up * F.silu(gate)
54 changes: 54 additions & 0 deletions khd/swiglu/triton_implementation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import triton

from .kernels import swiglu_backward_triton_kernel, swiglu_forward_triton_kernel


class _Swiglu_Triton(torch.autograd.Function):
@staticmethod
def forward(ctx, gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
assert gate.is_cuda, "tensor gate is not on GPU"
assert up.is_cuda, "tensor up is not on GPU"

output = torch.empty_like(gate)

ctx.save_for_backward(gate, up)

original_shape = gate.size()
gate = gate.view(-1)
up = up.view(-1)

assert gate.numel() == up.numel(), "both tensors should have same number of elements"
assert gate.type() == up.type(), "both tensors should have same dtype"

num_elements = gate.numel()
grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)

swiglu_forward_triton_kernel[grid](gate, up, output, num_elements, BLOCK_SIZE=1024)

output = output.view(original_shape)

return output

@staticmethod
def backward(ctx, output_grad: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
gate, up = ctx.saved_tensors

original_shape = gate.size()
gate = gate.view(-1)
up = up.view(-1)

num_elements = output_grad.numel()
grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)

# the kernel uses the gate and up tensors to store the gradients in-place for memory savings
swiglu_backward_triton_kernel[grid](gate, up, output_grad, num_elements, BLOCK_SIZE=1024)

gate = gate.view(original_shape)
up = up.view(original_shape)

return gate, up


def swiglu_triton(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
return _Swiglu_Triton.apply(gate, up)
45 changes: 45 additions & 0 deletions khd/swiglu/triton_implementation/kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import triton
import triton.language as tl


@triton.jit
def swiglu_forward_triton_kernel(gate_ptr, up_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)

block_start = pid * BLOCK_SIZE
block_indices = block_start + tl.arange(0, BLOCK_SIZE)

mask = block_indices < num_elements

gate = tl.load(gate_ptr + block_indices, mask=mask).to(tl.float32)
up = tl.load(up_ptr + block_indices, mask=mask)

gate_sigmoid = tl.sigmoid(gate)
gate_silu = gate * gate_sigmoid

output = up * gate_silu

tl.store(output_ptr + block_indices, output, mask=mask)


@triton.jit
def swiglu_backward_triton_kernel(gate_ptr, up_ptr, output_grad_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)

block_start = pid * BLOCK_SIZE
block_indices = block_start + tl.arange(0, BLOCK_SIZE)

mask = block_indices < num_elements

gate = tl.load(gate_ptr + block_indices, mask=mask).to(tl.float32)
up = tl.load(up_ptr + block_indices, mask=mask)
output_grad = tl.load(output_grad_ptr + block_indices, mask=mask)

gate_sigmoid = tl.sigmoid(gate)
gate_silu = gate * gate_sigmoid

up_grad = output_grad * gate_silu
gate_grad = output_grad * up * (gate_sigmoid + gate_silu * (1 - gate_sigmoid))

tl.store(gate_ptr + block_indices, gate_grad, mask=mask)
tl.store(up_ptr + block_indices, up_grad, mask=mask)
51 changes: 51 additions & 0 deletions tests/swiglu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Callable

import torch
from parameterized import parameterized

from khd import swiglu_torch, swiglu_triton

from .test_commons import TestCommons


class SwigluTest(TestCommons):
@parameterized.expand(
TestCommons.make_args_matrix(
TestCommons.get_2d_tensor_sizes(), [torch.device("cuda")], TestCommons.get_dtypes()
)
)
def test_swiglu_triton_forward(self, size: tuple[int], device: torch.device, dtype: torch.dtype) -> None:
self._test_swiglu_forward(size, device, dtype, swiglu_triton)

def _test_swiglu_forward(self, size: int, device: torch.device, dtype: torch.dtype, function: Callable) -> None:
x = torch.randn(size, device=device, dtype=dtype)
y = torch.randn(size, device=device, dtype=dtype)

z_kernel = function(x, y)
z_expected = swiglu_torch(x, y)

self.assert_equal_tensors(z_kernel, z_expected, False, atol_float32=5e-6, rtol_float32=0)

@parameterized.expand(
TestCommons.make_args_matrix(
TestCommons.get_2d_tensor_sizes(), [torch.device("cuda")], TestCommons.get_dtypes()
)
)
def test_swiglu_triton_backward(self, size: tuple[int], device: torch.device, dtype: torch.dtype) -> None:
self._test_swiglu_backward(size, device, dtype, swiglu_triton)

def _test_swiglu_backward(self, size: int, device: torch.device, dtype: torch.dtype, function: Callable) -> None:
x_kernel = torch.randn(size, device=device, dtype=dtype, requires_grad=True)
y_kernel = torch.randn(size, device=device, dtype=dtype, requires_grad=True)

x_expected = x_kernel.clone().detach().requires_grad_()
y_expected = y_kernel.clone().detach().requires_grad_()

z_kernel = function(x_kernel, y_kernel)
z_expected = swiglu_torch(x_expected, y_expected)

z_kernel.mean().backward()
z_expected.mean().backward()

self.assert_equal_tensors(x_kernel.grad, x_expected.grad, False, atol_float32=5e-6, rtol_float32=0)
self.assert_equal_tensors(y_kernel.grad, y_expected.grad, False, atol_float32=5e-6, rtol_float32=0)
13 changes: 13 additions & 0 deletions tests/test_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ def get_1d_tensor_sizes() -> list[tuple[int]]:
sizes.add(3000 + random.randint(-1000, 1000))
return sizes

@staticmethod
def get_2d_tensor_sizes() -> list[tuple[int]]:
sizes = set()
# powers of 2
for i in range(15):
start = 2**i
for j in range(10):
sizes.add((start + j, start + j))
# not powers of 2
for _ in range(50):
sizes.add((3000 + random.randint(-1000, 1000), 3000 + random.randint(-1000, 1000)))
return sizes

def make_args_matrix(*args_lists) -> list[Any]:
return [p for p in product(*args_lists)]

Expand Down

0 comments on commit 477ea68

Please sign in to comment.