Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Optimize Var #148

Merged
merged 1 commit into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion trident/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def sum(input: torch.Tensor, dim: Optional[Union[int, Tuple[int, ...]]] = None):
return operation.Sum.apply(input, dim)


def var(input, dim, correction=1):
def var(input: torch.Tensor, dim: Optional[Union[int, Tuple[int, ...]]] = None, correction: int = 1):
"""
Returns the variance along the specified dimension in an input.

Expand Down
4 changes: 3 additions & 1 deletion trident/kernel/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def forward(
x_block_size: tl.constexpr,
):
y_offset = tl.program_id(0)
output = language.Mean.forward(input_ptr, y_size, x_size, y_stride, x_stride, y_offset, dtype, x_block_size)
output = language.Mean.forward(
input_ptr, y_size, x_size, y_stride, x_stride, y_offset, dtype, x_block_size, True
)
output_block_ptr = tl.make_block_ptr(
output_ptr,
shape=(y_size,),
Expand Down
33 changes: 28 additions & 5 deletions trident/kernel/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def var_configs():
class Var:
@staticmethod
@util.autotune(var_configs(), ["x_size"])
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def forward(
output_ptr: tl.tensor,
Expand All @@ -41,11 +42,10 @@ def forward(
correction: tl.constexpr,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
y_offset = tl.program_id(0)
output, mean = language.VarMean.forward(
input_ptr, y_size, x_size, y_stride, x_stride, y_offset, correction, dtype, x_block_size, True
)

output_block_ptr = tl.make_block_ptr(
output_ptr,
shape=(y_size,),
Expand All @@ -54,10 +54,24 @@ def forward(
block_shape=(1,),
order=(0,),
)

output, mean = language.VarMean.forward(
input_ptr,
y_size,
x_size,
y_stride,
x_stride,
y_offset,
correction,
dtype,
x_block_size,
require_x_boundary_check,
)
tl.store(output_block_ptr, output)

@staticmethod
@util.autotune(var_configs(), ["x_size"])
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def backward(
grad_input_ptr: tl.tensor,
Expand All @@ -70,6 +84,7 @@ def backward(
correction: tl.constexpr,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
pid = tl.program_id(0)
num_x_blocks = tl.cdiv(x_size, x_block_size)
Expand All @@ -85,7 +100,10 @@ def backward(
block_shape=(1, x_block_size),
order=(1, 0),
)
mean = language.Mean.forward(input_ptr, y_size, x_size, y_stride, x_stride, y_offset, dtype, x_block_size)

mean = language.Mean.forward(
input_ptr, y_size, x_size, y_stride, x_stride, y_offset, dtype, x_block_size, require_x_boundary_check
)
grad_input = language.Var.backward(
grad_output_ptr,
input_ptr,
Expand All @@ -99,5 +117,10 @@ def backward(
correction,
dtype,
x_block_size,
require_x_boundary_check,
)
tl.store(grad_input_block_ptr, grad_input, boundary_check=(1,))

if require_x_boundary_check:
tl.store(grad_input_block_ptr, grad_input, boundary_check=(1,))
else:
tl.store(grad_input_block_ptr, grad_input)
1 change: 1 addition & 0 deletions trident/kernel/var_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def backward(
correction,
dtype,
x_block_size,
require_x_boundary_check,
)

if require_x_boundary_check:
Expand Down
42 changes: 33 additions & 9 deletions trident/language/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,39 @@ class Mean:
@triton.jit
def forward(
input_ptr: tl.tensor,
y_size: int,
x_size: int,
y_stride: int,
x_stride: int,
y_offset: int,
y_size: tl.int32,
x_size: tl.int32,
y_stride: tl.int32,
x_stride: tl.int32,
y_offset: tl.int32,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
sum = language.Sum.forward(input_ptr, y_size, x_size, y_stride, x_stride, y_offset, dtype, x_block_size, True)
mean = sum / x_size
input_block_ptr = tl.make_block_ptr(
input_ptr,
shape=(y_size, x_size),
strides=(y_stride, x_stride),
offsets=(y_offset, 0),
block_shape=(1, x_block_size),
order=(1, 0),
)

sum = tl.zeros((1, x_block_size), dtype)

return mean.to(dtype)
for _ in range(0, x_size, x_block_size):
if require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1,), padding_option="zero")
else:
input = tl.load(input_block_ptr)

sum += input
input_block_ptr = tl.advance(input_block_ptr, (0, x_block_size))

sum = tl.sum(sum, 1)
output = sum / x_size

return output.to(dtype)

@staticmethod
@triton.jit
Expand All @@ -54,5 +75,8 @@ def backward(
block_shape=(1, 1),
order=(0, 1),
)

grad_output = tl.load(grad_output_block_ptr)
return tl.broadcast_to(grad_output * 1.0 / x_size, (1, x_block_size)).to(dtype)
grad_input = tl.broadcast_to(grad_output * 1.0 / x_size, (1, x_block_size))

return grad_input.to(dtype)
26 changes: 20 additions & 6 deletions trident/language/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def forward(
correction: tl.constexpr,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
input_block_ptr = tl.make_block_ptr(
input_ptr,
Expand All @@ -39,12 +40,18 @@ def forward(
block_shape=(1, x_block_size),
order=(1, 0),
)

output = tl.zeros((1, x_block_size), tl.float32)

for block_offset in range(0, x_size, x_block_size):
input = tl.load(input_block_ptr, boundary_check=(1,))
mask = (tl.arange(0, x_block_size) + block_offset) < x_size
centered_mean = tl.where(mask, input - mean, 0.0)
if require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1,))
condition = tl.arange(0, x_block_size) + block_offset < x_size
centered_mean = tl.where(condition, input - mean, 0.0)
else:
input = tl.load(input_block_ptr)
centered_mean = input - mean

output += centered_mean * centered_mean
input_block_ptr = tl.advance(input_block_ptr, (0, x_block_size))

Expand All @@ -67,6 +74,7 @@ def backward(
correction: tl.constexpr,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
grad_output_block_ptr = tl.make_block_ptr(
grad_output_ptr,
Expand All @@ -84,10 +92,16 @@ def backward(
block_shape=(1, x_block_size),
order=(1, 0),
)

if require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1,))
condition = tl.arange(0, x_block_size) + x_offset < x_size
centered_mean = tl.where(condition[None, :], input - mean, 0.0)
else:
input = tl.load(input_block_ptr)
centered_mean = input - mean

grad_output = tl.load(grad_output_block_ptr)
input = tl.load(input_block_ptr, boundary_check=(1,))
condition = tl.arange(0, x_block_size) + x_offset < x_size
centered_mean = tl.where(condition[None, :], input - mean, 0.0)
grad_input = grad_output * 2 * centered_mean / (x_size - correction)

return grad_input.to(dtype)
6 changes: 3 additions & 3 deletions trident/operation/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def backward(ctx: Any, *grad_outputs: Any):
return grad_input, None, None

@staticmethod
def __forward(input: torch.Tensor, dim: int, correction: int):
def __forward(input: torch.Tensor, dim: torch.int32, correction: torch.int32):
factory_kwargs = {"device": input.device, "dtype": input.dtype}
y_size, x_size, y_stride, x_stride = util.size_and_stride(input, dim)
output = torch.empty(y_size, **factory_kwargs)
Expand All @@ -64,14 +64,14 @@ def grid(meta):
y_stride,
x_stride,
correction,
util.dtype(input.dtype),
util.dtype(output.dtype),
)
util.pop_trace()

return output

@staticmethod
def __backward(grad_output: torch.Tensor, input: torch.Tensor, dim: int, correction: int):
def __backward(grad_output: torch.Tensor, input: torch.Tensor, dim: torch.int32, correction: torch.int32):
y_size, x_size, y_stride, x_stride = util.size_and_stride(input, dim)
grad_input = torch.zeros_like(input)

Expand Down