Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Optimize PReLU
Browse files Browse the repository at this point in the history
  • Loading branch information
mejai1206 authored Sep 25, 2023
1 parent 633effd commit cef626c
Showing 1 changed file with 44 additions and 7 deletions.
51 changes: 44 additions & 7 deletions trident/kernel/prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def prelu_configs():
class PReLU:
@staticmethod
@util.autotune(prelu_configs(), ["x_size"])
@triton.heuristics(
{
"require_y_boundary_check": lambda args: args["y_size"] % args["y_block_size"],
"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"],
}
)
@triton.jit
def forward(
output_ptr: tl.tensor,
Expand All @@ -45,6 +51,8 @@ def forward(
dtype: tl.constexpr,
y_block_size: tl.constexpr,
x_block_size: tl.constexpr,
require_y_boundary_check: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
pid = tl.program_id(0)
num_y_blocks = tl.cdiv(y_size, y_block_size)
Expand Down Expand Up @@ -82,13 +90,31 @@ def forward(
order=(1, 0),
)

input = tl.load(input_block_ptr, boundary_check=(1, 2))
weight = tl.load(weight_block_ptr, boundary_check=(0,))
if require_y_boundary_check | require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1, 2))
else:
input = tl.load(input_block_ptr)

if require_y_boundary_check:
weight = tl.load(weight_block_ptr, boundary_check=(0,))
else:
weight = tl.load(weight_block_ptr)

output = language.math.LeakyReLU.forward(input, weight)
tl.store(output_block_ptr, output.to(dtype), boundary_check=(1, 2))

if require_y_boundary_check | require_x_boundary_check:
tl.store(output_block_ptr, output.to(dtype), boundary_check=(1, 2))
else:
tl.store(output_block_ptr, output.to(dtype))

@staticmethod
@util.autotune(prelu_configs(), ["x_size"])
@triton.heuristics(
{
"require_y_boundary_check": lambda args: args["y_size"] % args["y_block_size"],
"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"],
}
)
@triton.jit
def backward(
grad_input_ptr: tl.tensor,
Expand All @@ -105,6 +131,8 @@ def backward(
dtype: tl.constexpr,
y_block_size: tl.constexpr,
x_block_size: tl.constexpr,
require_y_boundary_check: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
pid = tl.program_id(0)
num_y_blocks = tl.cdiv(y_size, y_block_size)
Expand Down Expand Up @@ -157,11 +185,20 @@ def backward(
block_shape=(y_block_size, 1),
order=(1, 0),
)
if require_y_boundary_check | require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1, 2))
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1, 2))
else:
input = tl.load(input_block_ptr)
grad_output = tl.load(grad_output_block_ptr)

input = tl.load(input_block_ptr, boundary_check=(1, 2))
weight = tl.load(weight_block_ptr)
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1, 2))
grad_input = language.math.LeakyReLU.backward(grad_output, input, weight)
grad_weight = grad_output * tl.where(input > 0, 0, input)
tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1, 2))
tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype), boundary_check=(1, 2))

if require_y_boundary_check | require_x_boundary_check:
tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1, 2))
tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype), boundary_check=(1, 2))
else:
tl.store(grad_input_block_ptr, grad_input.to(dtype))
tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype))

0 comments on commit cef626c

Please sign in to comment.