Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Optimize Softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
daemyung authored Sep 20, 2023
1 parent 9cb516a commit 4af4b37
Showing 1 changed file with 47 additions and 12 deletions.
59 changes: 47 additions & 12 deletions trident/kernel/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def softmax_configs():
class Softmax:
@staticmethod
@util.autotune(softmax_configs(), ["x_size"])
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def forward(
output_ptr: tl.tensor,
Expand All @@ -39,6 +40,7 @@ def forward(
x_stride: tl.int32,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
y_offset = tl.program_id(0)

Expand All @@ -63,10 +65,15 @@ def forward(
sum = tl.zeros((1, x_block_size), tl.float32)

for x_offset in range(0, x_size, x_block_size):
input = tl.load(input_block_ptr, boundary_check=(1,))
condition = tl.arange(0, x_block_size) + x_offset < x_size
input = tl.where(condition, input, -float("inf"))
peak = tl.where(condition, tl.maximum(max, input), 0)
if require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1,))
condition = tl.arange(0, x_block_size) + x_offset < x_size
input = tl.where(condition, input, -float("inf"))
peak = tl.where(condition, tl.maximum(max, input), 0)
else:
input = tl.load(input_block_ptr)
peak = tl.maximum(max, input)

sum = sum * tl.math.fast_expf(max - peak) + tl.math.fast_expf(input - peak)
max = peak
input_block_ptr = tl.advance(input_block_ptr, (0, x_block_size))
Expand All @@ -83,14 +90,24 @@ def forward(
)

for x_offset in range(0, x_size, x_block_size):
input = tl.load(input_block_ptr, boundary_check=(1,))
if require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1,))
else:
input = tl.load(input_block_ptr)

output = tl.math.fast_expf(input - max) / sum
tl.store(output_block_ptr, output.to(dtype), boundary_check=(1,))

if require_x_boundary_check:
tl.store(output_block_ptr, output.to(dtype), boundary_check=(1,))
else:
tl.store(output_block_ptr, output.to(dtype))

output_block_ptr = tl.advance(output_block_ptr, (0, x_block_size))
input_block_ptr = tl.advance(input_block_ptr, (0, x_block_size))

@staticmethod
@util.autotune(softmax_configs(), ["x_size"])
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def backward(
grad_input_ptr: tl.tensor,
Expand All @@ -103,6 +120,7 @@ def backward(
x_stride: tl.int32,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
y_offset = tl.program_id(0)

Expand Down Expand Up @@ -140,17 +158,28 @@ def backward(
)

for x_offset in range(0, x_size, x_block_size):
output = tl.load(output_block_ptr, boundary_check=(1,))
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,))
delta = tl.load(delta_block_ptr, boundary_check=(0,))
if require_x_boundary_check:
output = tl.load(output_block_ptr, boundary_check=(1,))
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,))
else:
output = tl.load(output_block_ptr)
grad_output = tl.load(grad_output_block_ptr)

delta = tl.load(delta_block_ptr)
grad_input = output * (grad_output - delta)
tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1,))

if require_x_boundary_check:
tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1,))
else:
tl.store(grad_input_block_ptr, grad_input.to(dtype))

grad_input_block_ptr = tl.advance(grad_input_block_ptr, (0, x_block_size))
grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, x_block_size))
output_block_ptr = tl.advance(output_block_ptr, (0, x_block_size))

@staticmethod
@util.autotune(softmax_configs(), ["x_size"])
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def backward_delta(
delta_ptr: tl.tensor,
Expand All @@ -162,6 +191,7 @@ def backward_delta(
x_stride: tl.int32,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
y_offset = tl.program_id(0)

Expand Down Expand Up @@ -193,8 +223,13 @@ def backward_delta(
delta = tl.zeros((1, x_block_size), dtype)

for _ in range(0, x_size, x_block_size):
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,), padding_option="zero")
output = tl.load(output_block_ptr, boundary_check=(1,))
if require_x_boundary_check:
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,), padding_option="zero")
output = tl.load(output_block_ptr, boundary_check=(1,))
else:
grad_output = tl.load(grad_output_block_ptr)
output = tl.load(output_block_ptr)

delta += grad_output * output
output_block_ptr = tl.advance(output_block_ptr, (0, x_block_size))
grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, x_block_size))
Expand Down

0 comments on commit 4af4b37

Please sign in to comment.