Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Optimize VarMean
Browse files Browse the repository at this point in the history
  • Loading branch information
danny.jang authored Sep 25, 2023
1 parent e7a61c5 commit bd48fb6
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 44 deletions.
2 changes: 1 addition & 1 deletion trident/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def var(input, dim, correction=1):
return operation.Var.apply(input, dim, correction)


def var_mean(input, dim, correction=1):
def var_mean(input: torch.Tensor, dim: Optional[Union[int, Tuple[int, ...]]] = None, correction: int = 1):
"""
Returns the variance and mean along the specified dimension in an input.
Expand Down
10 changes: 1 addition & 9 deletions trident/kernel/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,7 @@ def forward(
):
y_offset = tl.program_id(0)
output, mean = language.VarMean.forward(
input_ptr,
y_size,
x_size,
y_stride,
x_stride,
y_offset,
correction,
dtype,
x_block_size,
input_ptr, y_size, x_size, y_stride, x_stride, y_offset, correction, dtype, x_block_size, True
)
output_block_ptr = tl.make_block_ptr(
output_ptr,
Expand Down
36 changes: 24 additions & 12 deletions trident/kernel/var_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def var_mean_configs():
class VarMean:
@staticmethod
@util.autotune(var_mean_configs(), ["x_size"])
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def forward(
output_ptr: tl.tensor,
Expand All @@ -42,19 +43,10 @@ def forward(
correction: tl.constexpr,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
y_offset = tl.program_id(0)
output, mean = language.VarMean.forward(
input_ptr,
y_size,
x_size,
y_stride,
x_stride,
y_offset,
correction,
dtype,
x_block_size,
)

output_block_ptr = tl.make_block_ptr(
output_ptr,
shape=(y_size,),
Expand All @@ -71,11 +63,25 @@ def forward(
block_shape=(1,),
order=(0,),
)

output, mean = language.VarMean.forward(
input_ptr,
y_size,
x_size,
y_stride,
x_stride,
y_offset,
correction,
dtype,
x_block_size,
require_x_boundary_check,
)
tl.store(output_block_ptr, output)
tl.store(mean_block_ptr, mean)

@staticmethod
@util.autotune(var_mean_configs(), ["x_size"])
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def backward(
grad_input_ptr: tl.tensor,
Expand All @@ -89,6 +95,7 @@ def backward(
correction: tl.constexpr,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
pid = tl.program_id(0)
num_x_blocks = tl.cdiv(x_size, x_block_size)
Expand All @@ -112,6 +119,7 @@ def backward(
block_shape=(1,),
order=(0,),
)

mean = tl.load(mean_block_ptr)
grad_input = language.Var.backward(
grad_output_ptr,
Expand All @@ -127,4 +135,8 @@ def backward(
dtype,
x_block_size,
)
tl.store(grad_input_block_ptr, grad_input, boundary_check=(1,))

if require_x_boundary_check:
tl.store(grad_input_block_ptr, grad_input, boundary_check=(1,))
else:
tl.store(grad_input_block_ptr, grad_input)
57 changes: 37 additions & 20 deletions trident/language/var_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ class VarMean:
@triton.jit
def forward(
input_ptr: tl.tensor,
y_size: int,
x_size: int,
y_stride: int,
x_stride: int,
y_offset: int,
y_size: tl.int32,
x_size: tl.int32,
y_stride: tl.int32,
x_stride: tl.int32,
y_offset: tl.int32,
correction: tl.constexpr,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
input_block_ptr = tl.make_block_ptr(
input_ptr,
Expand All @@ -40,17 +41,26 @@ def forward(
block_shape=(1, x_block_size),
order=(1, 0),
)

m2 = tl.zeros((1, x_block_size), tl.float32)
count = tl.zeros((1, x_block_size), tl.float32)
mean = tl.zeros((1, x_block_size), tl.float32)

for x_offset in range(0, x_size, x_block_size):
input = tl.load(input_block_ptr, boundary_check=(1,))
condition = tl.arange(0, x_block_size) + x_offset < x_size
delta = tl.where(condition, input - mean, 0.0)
count += tl.where(condition, 1.0, language.eps)
mean += delta / count
m2 += delta * tl.where(condition, input - mean, 0.0)
if require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1,))
condition = tl.arange(0, x_block_size) + x_offset < x_size
delta = tl.where(condition, input - mean, 0.0)
count += tl.where(condition, 1.0, language.eps)
mean += delta / count
m2 += delta * tl.where(condition, input - mean, 0.0)
else:
input = tl.load(input_block_ptr)
delta = input - mean
count += 1
mean += delta / count
m2 += delta * (input - mean)

input_block_ptr = tl.advance(input_block_ptr, (0, x_block_size))

m2, mean, count = tl.reduce((m2, mean, count), 1, language.combine_welford)
Expand All @@ -63,16 +73,17 @@ def forward(
def backward(
grad_output_ptr: tl.tensor,
input_ptr: tl.tensor,
y_size: int,
x_size: int,
y_stride: int,
x_stride: int,
y_offset: int,
x_offset: int,
y_size: tl.int32,
x_size: tl.int32,
y_stride: tl.int32,
x_stride: tl.int32,
y_offset: tl.int32,
x_offset: tl.int32,
mean: tl.tensor,
correction: tl.constexpr,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
grad_output_block_ptr = tl.make_block_ptr(
grad_output_ptr,
Expand All @@ -90,10 +101,16 @@ def backward(
block_shape=(1, x_block_size),
order=(1, 0),
)

if require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1,))
condition = tl.arange(0, x_block_size) + x_offset < x_size
centered_mean = tl.where(condition[None, :], input - mean, 0.0)
else:
input = tl.load(input_block_ptr)
centered_mean = input - mean

grad_output = tl.load(grad_output_block_ptr)
input = tl.load(input_block_ptr, boundary_check=(1,))
condition = tl.arange(0, x_block_size) + x_offset < x_size
centered_mean = tl.where(condition[None, :], input - mean, 0.0)
grad_input = grad_output * 2 * centered_mean / (x_size - correction)

return grad_input.to(dtype)
6 changes: 4 additions & 2 deletions trident/operation/var_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def backward(ctx: Any, *grad_outputs: Any):
return grad_input, None, None

@staticmethod
def __forward(input: torch.Tensor, dim: int, correction: int):
def __forward(input: torch.Tensor, dim: torch.int32, correction: torch.int32):
factory_kwargs = {"device": input.device, "dtype": input.dtype}
y_size, x_size, y_stride, x_stride = util.size_and_stride(input, dim)
output = torch.empty(y_size, **factory_kwargs)
Expand All @@ -73,7 +73,9 @@ def grid(meta):
return output, mean

@staticmethod
def __backward(grad_output: torch.Tensor, input: torch.Tensor, mean: torch.Tensor, dim: int, correction: int):
def __backward(
grad_output: torch.Tensor, input: torch.Tensor, mean: torch.Tensor, dim: torch.int32, correction: torch.int32
):
y_size, x_size, y_stride, x_stride = util.size_and_stride(input, dim)
grad_input = torch.zeros_like(input)

Expand Down

0 comments on commit bd48fb6

Please sign in to comment.