Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Optimize Sum
Browse files Browse the repository at this point in the history
  • Loading branch information
danny jang committed Sep 23, 2023
1 parent a70a73e commit ee77440
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 88 deletions.
72 changes: 0 additions & 72 deletions tests/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,75 +55,3 @@ def test_sum(y_size, x_size, dim, device, dtype):
output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype


@pytest.mark.parametrize("dim", [0, 1])
def test_sum_issue1(dim, device):
factory_kwargs = {"device": device, "dtype": torch.float16}
input = torch.tensor(
[
[
30.6875,
-40.4375,
-29.1719,
81.1875,
23.3125,
3.6348,
6.0508,
-100.5000,
-6.0273,
11.6562,
],
[
21.5469,
11.3438,
14.0000,
33.7188,
13.4844,
-18.0938,
27.5156,
-29.0625,
-1.7559,
20.8594,
],
[
28.6406,
-30.1094,
22.6406,
-35.8750,
3.5410,
-66.1250,
15.6016,
-22.4375,
50.0625,
39.6562,
],
[
5.3281,
-75.1875,
-13.3828,
-39.9688,
-59.9062,
14.7812,
-23.0625,
-3.4336,
-34.8125,
32.7812,
],
[
20.1406,
-33.4375,
-50.3438,
-25.2812,
69.6250,
2.2090,
18.9062,
16.3750,
-7.9922,
27.1562,
],
],
**factory_kwargs,
)

assert util.equal(torch.sum(input, dim), trident.function.sum(input, dim))
4 changes: 2 additions & 2 deletions trident/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


import math
from typing import Optional
from typing import Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -255,7 +255,7 @@ def softmax(input: torch.Tensor, dim: int = None):
return operation.Softmax.apply(input, dim)


def sum(input, dim):
def sum(input: torch.Tensor, dim: Optional[Union[int, Tuple[int, ...]]] = None):
"""
Returns the sum along the specified dimension in an input.
Expand Down
15 changes: 13 additions & 2 deletions trident/kernel/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def sum_configs():
class Sum:
@staticmethod
@util.autotune(sum_configs(), ["x_size"])
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def forward(
output_ptr: tl.tensor,
Expand All @@ -40,9 +41,13 @@ def forward(
x_stride: tl.int32,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
y_offset = tl.program_id(0)
output = language.Sum.forward(input_ptr, y_size, x_size, y_stride, x_stride, y_offset, x_block_size, dtype)

output = language.Sum.forward(
input_ptr, y_size, x_size, y_stride, x_stride, y_offset, dtype, x_block_size, require_x_boundary_check
)
output_block_ptr = tl.make_block_ptr(
output_ptr,
shape=(y_size,),
Expand All @@ -55,6 +60,7 @@ def forward(

@staticmethod
@util.autotune(sum_configs(), ["x_size"])
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def backward(
grad_input_ptr: tl.tensor,
Expand All @@ -64,6 +70,7 @@ def backward(
y_stride: tl.int32,
x_stride: tl.int32,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
y_offset = tl.program_id(0)
grad_input_block_ptr = tl.make_block_ptr(
Expand All @@ -77,5 +84,9 @@ def backward(
grad_input = language.Sum.backward(grad_output_ptr, y_size, y_offset, x_block_size)

for x_offset in range(0, x_size, x_block_size):
tl.store(grad_input_block_ptr, grad_input, boundary_check=(1,))
if require_x_boundary_check:
tl.store(grad_input_block_ptr, grad_input, boundary_check=(1,))
else:
tl.store(grad_input_block_ptr, grad_input)

grad_input_block_ptr = tl.advance(grad_input_block_ptr, (0, x_block_size))
1 change: 0 additions & 1 deletion trident/language/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import triton.language as tl

dim = [tl.constexpr(i) for i in range(3)]
zero = tl.constexpr(0)
e = tl.constexpr(2.71828182846)
eps = tl.constexpr(2.220446049250313e-16)
Expand Down
2 changes: 1 addition & 1 deletion trident/language/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def forward(
dtype: tl.constexpr,
x_block_size: tl.constexpr,
):
sum = language.Sum.forward(input_ptr, y_size, x_size, y_stride, x_stride, y_offset, x_block_size, dtype)
sum = language.Sum.forward(input_ptr, y_size, x_size, y_stride, x_stride, y_offset, dtype, x_block_size, True)
mean = sum / x_size

return mean.to(dtype)
Expand Down
27 changes: 17 additions & 10 deletions trident/language/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ class Sum:
@triton.jit
def forward(
input_ptr: tl.tensor,
y_size: int,
x_size: int,
y_stride: int,
x_stride: int,
y_offset: int,
x_block_size: tl.constexpr,
y_size: tl.int32,
x_size: tl.int32,
y_stride: tl.int32,
x_stride: tl.int32,
y_offset: tl.int32,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
input_block_ptr = tl.make_block_ptr(
input_ptr,
Expand All @@ -37,10 +38,15 @@ def forward(
block_shape=(1, x_block_size),
order=(1, 0),
)
sum = tl.zeros((1, x_block_size), tl.float32)

sum = tl.zeros((1, x_block_size), dtype)

for _ in range(0, x_size, x_block_size):
input = tl.load(input_block_ptr, boundary_check=(1,), padding_option="zero").to(tl.float32)
if require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1,), padding_option="zero")
else:
input = tl.load(input_block_ptr)

sum += input
input_block_ptr = tl.advance(input_block_ptr, (0, x_block_size))

Expand All @@ -50,8 +56,8 @@ def forward(
@triton.jit
def backward(
grad_output_ptr: tl.tensor,
y_size: int,
y_offset: int,
y_size: tl.int32,
y_offset: tl.int32,
x_block_size: tl.constexpr,
):
grad_output_block_ptr = tl.make_block_ptr(
Expand All @@ -62,6 +68,7 @@ def backward(
block_shape=(1, 1),
order=(0, 1),
)

grad_output = tl.load(grad_output_block_ptr)
grad_input = tl.broadcast_to(grad_output, (1, x_block_size))

Expand Down

0 comments on commit ee77440

Please sign in to comment.