Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Optimize RMSNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
danny.jang authored Sep 21, 2023
1 parent c4d102a commit 11421bf
Showing 1 changed file with 68 additions and 14 deletions.
82 changes: 68 additions & 14 deletions trident/kernel/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

class RMSNorm:
@staticmethod
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def forward(
output_ptr: tl.tensor,
Expand All @@ -27,12 +28,13 @@ def forward(
x_size: tl.int32,
y_stride: tl.int32,
x_stride: tl.int32,
partial_size: tl.int32,
partial_size: tl.constexpr,
weight_ptr: tl.tensor,
bias_ptr: tl.tensor,
eps: tl.float32,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
y_offset = tl.program_id(0)

Expand Down Expand Up @@ -69,11 +71,25 @@ def forward(
order=(0,),
)

input = tl.load(input_block_ptr, boundary_check=(1,))
partial_input = tl.where(tl.arange(0, x_block_size) < partial_size, input, 0)
if require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1,))
else:
input = tl.load(input_block_ptr)

if x_block_size != partial_size:
condition = tl.arange(0, x_block_size) < partial_size
partial_input = tl.where(condition, input, 0)
else:
partial_input = input

rms = tl.math.sqrt(tl.sum(partial_input * partial_input / partial_size, 1))
norm = input / (rms + eps)
weight = tl.load(weight_block_ptr, boundary_check=(0,))

if require_x_boundary_check:
weight = tl.load(weight_block_ptr, boundary_check=(0,))
else:
weight = tl.load(weight_block_ptr)

output = norm * weight

if bias_ptr is not None:
Expand All @@ -85,13 +101,23 @@ def forward(
block_shape=(1, x_block_size),
order=(1, 0),
)
bias = tl.load(bias_block_ptr, boundary_check=(1,))

if require_x_boundary_check:
bias = tl.load(bias_block_ptr, boundary_check=(1,))
else:
bias = tl.load(bias_block_ptr)

output += bias

tl.store(rms_block_ptr, rms.to(dtype))
tl.store(output_block_ptr, output.to(dtype), boundary_check=(1,))

if require_x_boundary_check:
tl.store(output_block_ptr, output.to(dtype), boundary_check=(1,))
else:
tl.store(output_block_ptr, output.to(dtype))

@staticmethod
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def backward(
grad_input_ptr: tl.tensor,
Expand All @@ -103,11 +129,12 @@ def backward(
y_stride: tl.int32,
x_stride: tl.int32,
rms_ptr: tl.tensor,
partial_size: tl.int32,
partial_size: tl.constexpr,
weight_ptr: tl.tensor,
eps: tl.float32,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
y_offset = tl.program_id(0)

Expand Down Expand Up @@ -160,19 +187,46 @@ def backward(
order=(0,),
)

grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,))
input = tl.load(input_block_ptr, boundary_check=(1,))
if require_x_boundary_check:
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,))
input = tl.load(input_block_ptr, boundary_check=(1,))
else:
grad_output = tl.load(grad_output_block_ptr)
input = tl.load(input_block_ptr)

rms = tl.load(rms_block_ptr)
weight = tl.load(weight_block_ptr, boundary_check=(0,))

if require_x_boundary_check:
weight = tl.load(weight_block_ptr, boundary_check=(0,))
else:
weight = tl.load(weight_block_ptr)

grad_norm = grad_output * weight
norm = input / (rms + eps)
grad_weight = grad_output * norm
tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype), boundary_check=(1,))

if require_x_boundary_check:
tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype), boundary_check=(1,))
else:
tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype))

grad_rms = grad_norm * -input / (rms * rms + eps)
grad_rms = tl.where(tl.arange(0, x_block_size) < x_size, grad_rms, 0.0)

if require_x_boundary_check:
condition = tl.arange(0, x_block_size) < x_size
grad_rms = tl.where(condition, grad_rms, 0.0)

grad_rms = tl.sum(grad_rms, 1)
grad_mean_square = grad_rms / (2 * rms)
grad_partial_input = 2 * input * grad_mean_square / partial_size
grad_partial_input = tl.where(tl.arange(0, x_block_size) < partial_size, grad_partial_input, 0)

if x_block_size != partial_size:
condition = tl.arange(0, x_block_size) < partial_size
grad_partial_input = tl.where(condition, grad_partial_input, 0)

grad_input = (grad_norm / (rms + eps)) + grad_partial_input
tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1,))

if require_x_boundary_check:
tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1,))
else:
tl.store(grad_input_block_ptr, grad_input.to(dtype))

0 comments on commit 11421bf

Please sign in to comment.