Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Optimize PReLU
Browse files Browse the repository at this point in the history
  • Loading branch information
mejai1206 committed Sep 25, 2023
1 parent 60707a8 commit 9c55c8e
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 11 deletions.
4 changes: 2 additions & 2 deletions trident/kernel/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def forward(
y_stride: tl.int32,
x_stride: tl.int32,
eps: tl.float32,
size_along_dim: tl.int32,
output_y_size: tl.int32,
output_x_size: tl.int32,
size_along_dim: tl.constexpr,
dtype: tl.constexpr,
block_size: tl.constexpr,
require_boundary_check: tl.constexpr,
Expand Down Expand Up @@ -154,9 +154,9 @@ def backward(
z_stride: tl.int32,
y_stride: tl.int32,
x_stride: tl.int32,
size_along_dim: tl.int32,
output_y_size: tl.int32,
output_x_size: tl.int32,
size_along_dim: tl.constexpr,
dtype: tl.constexpr,
block_size: tl.constexpr,
require_boundary_check: tl.constexpr,
Expand Down
51 changes: 44 additions & 7 deletions trident/kernel/prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def prelu_configs():
class PReLU:
@staticmethod
@util.autotune(prelu_configs(), ["x_size"])
@triton.heuristics(
{
"require_y_boundary_check": lambda args: args["y_size"] % args["y_block_size"],
"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"],
}
)
@triton.jit
def forward(
output_ptr: tl.tensor,
Expand All @@ -45,6 +51,8 @@ def forward(
dtype: tl.constexpr,
y_block_size: tl.constexpr,
x_block_size: tl.constexpr,
require_y_boundary_check: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
pid = tl.program_id(0)
num_y_blocks = tl.cdiv(y_size, y_block_size)
Expand Down Expand Up @@ -82,13 +90,31 @@ def forward(
order=(1, 0),
)

input = tl.load(input_block_ptr, boundary_check=(1, 2))
weight = tl.load(weight_block_ptr, boundary_check=(0,))
if require_y_boundary_check | require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1, 2))
else:
input = tl.load(input_block_ptr)

if require_y_boundary_check:
weight = tl.load(weight_block_ptr, boundary_check=(0,))
else:
weight = tl.load(weight_block_ptr)

output = language.math.LeakyReLU.forward(input, weight)
tl.store(output_block_ptr, output.to(dtype), boundary_check=(1, 2))

if require_y_boundary_check | require_x_boundary_check:
tl.store(output_block_ptr, output.to(dtype), boundary_check=(1, 2))
else:
tl.store(output_block_ptr, output.to(dtype))

@staticmethod
@util.autotune(prelu_configs(), ["x_size"])
@triton.heuristics(
{
"require_y_boundary_check": lambda args: args["y_size"] % args["y_block_size"],
"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"],
}
)
@triton.jit
def backward(
grad_input_ptr: tl.tensor,
Expand All @@ -105,6 +131,8 @@ def backward(
dtype: tl.constexpr,
y_block_size: tl.constexpr,
x_block_size: tl.constexpr,
require_y_boundary_check: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
pid = tl.program_id(0)
num_y_blocks = tl.cdiv(y_size, y_block_size)
Expand Down Expand Up @@ -157,11 +185,20 @@ def backward(
block_shape=(y_block_size, 1),
order=(1, 0),
)
if require_y_boundary_check | require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1, 2))
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1, 2))
else:
input = tl.load(input_block_ptr)
grad_output = tl.load(grad_output_block_ptr)

input = tl.load(input_block_ptr, boundary_check=(1, 2))
weight = tl.load(weight_block_ptr)
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1, 2))
grad_input = language.math.LeakyReLU.backward(grad_output, input, weight)
grad_weight = grad_output * tl.where(input > 0, 0, input)
tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1, 2))
tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype), boundary_check=(1, 2))

if require_y_boundary_check | require_x_boundary_check:
tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1, 2))
tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype), boundary_check=(1, 2))
else:
tl.store(grad_input_block_ptr, grad_input.to(dtype))
tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype))
4 changes: 2 additions & 2 deletions trident/operation/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def grid(meta):
y_stride,
x_stride,
eps,
size_along_dim,
output_y_size,
output_x_size,
size_along_dim,
util.dtype(x1.dtype),
)
util.pop_trace()
Expand Down Expand Up @@ -107,9 +107,9 @@ def grid(meta):
z_stride,
y_stride,
x_stride,
size_along_dim,
output_y_size,
output_x_size,
size_along_dim,
util.dtype(x1.dtype),
)
util.pop_trace()
Expand Down

0 comments on commit 9c55c8e

Please sign in to comment.