Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Optimize Sum #145

Merged
merged 1 commit into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 4 additions & 73 deletions tests/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,80 +50,11 @@ def test_sum(y_size, x_size, dim, device, dtype):
grad_output = torch.randn(x_size if dim == 0 else y_size, **factory_kwargs)

output = trident.Sum(dim).forward(input)
assert output is not None and output.dtype == dtype

assert output is not None
assert output.dtype == dtype

output.backward(grad_output)

assert input.grad is not None
assert input.grad.dtype == dtype


@pytest.mark.parametrize("dim", [0, 1])
def test_sum_issue1(dim, device):
factory_kwargs = {"device": device, "dtype": torch.float16}
input = torch.tensor(
[
[
30.6875,
-40.4375,
-29.1719,
81.1875,
23.3125,
3.6348,
6.0508,
-100.5000,
-6.0273,
11.6562,
],
[
21.5469,
11.3438,
14.0000,
33.7188,
13.4844,
-18.0938,
27.5156,
-29.0625,
-1.7559,
20.8594,
],
[
28.6406,
-30.1094,
22.6406,
-35.8750,
3.5410,
-66.1250,
15.6016,
-22.4375,
50.0625,
39.6562,
],
[
5.3281,
-75.1875,
-13.3828,
-39.9688,
-59.9062,
14.7812,
-23.0625,
-3.4336,
-34.8125,
32.7812,
],
[
20.1406,
-33.4375,
-50.3438,
-25.2812,
69.6250,
2.2090,
18.9062,
16.3750,
-7.9922,
27.1562,
],
],
**factory_kwargs,
)

assert util.equal(torch.sum(input, dim), trident.function.sum(input, dim))
4 changes: 2 additions & 2 deletions trident/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


import math
from typing import Optional
from typing import Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -255,7 +255,7 @@ def softmax(input: torch.Tensor, dim: int = None):
return operation.Softmax.apply(input, dim)


def sum(input, dim):
def sum(input: torch.Tensor, dim: Optional[Union[int, Tuple[int, ...]]] = None):
"""
Returns the sum along the specified dimension in an input.

Expand Down
15 changes: 13 additions & 2 deletions trident/kernel/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def sum_configs():
class Sum:
@staticmethod
@util.autotune(sum_configs(), ["x_size"])
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def forward(
output_ptr: tl.tensor,
Expand All @@ -40,9 +41,13 @@ def forward(
x_stride: tl.int32,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
y_offset = tl.program_id(0)
output = language.Sum.forward(input_ptr, y_size, x_size, y_stride, x_stride, y_offset, x_block_size, dtype)

output = language.Sum.forward(
input_ptr, y_size, x_size, y_stride, x_stride, y_offset, dtype, x_block_size, require_x_boundary_check
)
output_block_ptr = tl.make_block_ptr(
output_ptr,
shape=(y_size,),
Expand All @@ -55,6 +60,7 @@ def forward(

@staticmethod
@util.autotune(sum_configs(), ["x_size"])
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def backward(
grad_input_ptr: tl.tensor,
Expand All @@ -64,6 +70,7 @@ def backward(
y_stride: tl.int32,
x_stride: tl.int32,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
y_offset = tl.program_id(0)
grad_input_block_ptr = tl.make_block_ptr(
Expand All @@ -77,5 +84,9 @@ def backward(
grad_input = language.Sum.backward(grad_output_ptr, y_size, y_offset, x_block_size)

for x_offset in range(0, x_size, x_block_size):
tl.store(grad_input_block_ptr, grad_input, boundary_check=(1,))
if require_x_boundary_check:
tl.store(grad_input_block_ptr, grad_input, boundary_check=(1,))
else:
tl.store(grad_input_block_ptr, grad_input)

grad_input_block_ptr = tl.advance(grad_input_block_ptr, (0, x_block_size))
1 change: 0 additions & 1 deletion trident/language/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import triton.language as tl

dim = [tl.constexpr(i) for i in range(3)]
zero = tl.constexpr(0)
e = tl.constexpr(2.71828182846)
eps = tl.constexpr(2.220446049250313e-16)
Expand Down
2 changes: 1 addition & 1 deletion trident/language/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def forward(
dtype: tl.constexpr,
x_block_size: tl.constexpr,
):
sum = language.Sum.forward(input_ptr, y_size, x_size, y_stride, x_stride, y_offset, x_block_size, dtype)
sum = language.Sum.forward(input_ptr, y_size, x_size, y_stride, x_stride, y_offset, dtype, x_block_size, True)
mean = sum / x_size

return mean.to(dtype)
Expand Down
27 changes: 17 additions & 10 deletions trident/language/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ class Sum:
@triton.jit
def forward(
input_ptr: tl.tensor,
y_size: int,
x_size: int,
y_stride: int,
x_stride: int,
y_offset: int,
x_block_size: tl.constexpr,
y_size: tl.int32,
x_size: tl.int32,
y_stride: tl.int32,
x_stride: tl.int32,
y_offset: tl.int32,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
input_block_ptr = tl.make_block_ptr(
input_ptr,
Expand All @@ -37,10 +38,15 @@ def forward(
block_shape=(1, x_block_size),
order=(1, 0),
)
sum = tl.zeros((1, x_block_size), tl.float32)

sum = tl.zeros((1, x_block_size), dtype)

for _ in range(0, x_size, x_block_size):
input = tl.load(input_block_ptr, boundary_check=(1,), padding_option="zero").to(tl.float32)
if require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1,), padding_option="zero")
else:
input = tl.load(input_block_ptr)

sum += input
input_block_ptr = tl.advance(input_block_ptr, (0, x_block_size))

Expand All @@ -50,8 +56,8 @@ def forward(
@triton.jit
def backward(
grad_output_ptr: tl.tensor,
y_size: int,
y_offset: int,
y_size: tl.int32,
y_offset: tl.int32,
x_block_size: tl.constexpr,
):
grad_output_block_ptr = tl.make_block_ptr(
Expand All @@ -62,6 +68,7 @@ def backward(
block_shape=(1, 1),
order=(0, 1),
)

grad_output = tl.load(grad_output_block_ptr)
grad_input = tl.broadcast_to(grad_output, (1, x_block_size))

Expand Down