Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Optimize CosineSimiliarity
Browse files Browse the repository at this point in the history
  • Loading branch information
mejai1206 authored Sep 20, 2023
1 parent 4af4b37 commit c4d102a
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 57 deletions.
46 changes: 30 additions & 16 deletions trident/kernel/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def cosine_similarity_configs():
class CosineSimilarity:
@staticmethod
@util.autotune(cosine_similarity_configs(), ["y_size", "x_size"])
@triton.heuristics(
{
"require_boundary_check": lambda args: args["size_along_dim"] % args["block_size"],
}
)
@triton.jit
def forward(
output_ptr: tl.tensor,
Expand All @@ -49,6 +54,7 @@ def forward(
output_x_size: tl.int32,
dtype: tl.constexpr,
block_size: tl.constexpr,
require_boundary_check: tl.constexpr,
):
pid = tl.program_id(0)

Expand Down Expand Up @@ -101,8 +107,12 @@ def forward(
numerator_accumulation = tl.zeros((block_size, 1, 1), tl.float32)

for _ in range(0, size_along_dim, block_size):
x1 = tl.load(x1_block_ptr, boundary_check=(0,), padding_option="zero")
x2 = tl.load(x2_block_ptr, boundary_check=(0,), padding_option="zero")
if require_boundary_check:
x1 = tl.load(x1_block_ptr, boundary_check=(0,), padding_option="zero")
x2 = tl.load(x2_block_ptr, boundary_check=(0,), padding_option="zero")
else:
x1 = tl.load(x1_block_ptr)
x2 = tl.load(x2_block_ptr)

denominator_accumulation1 += x1 * x1
denominator_accumulation2 += x2 * x2
Expand All @@ -124,6 +134,11 @@ def forward(

@staticmethod
@util.autotune(cosine_similarity_configs(), ["y_size", "x_size"])
@triton.heuristics(
{
"require_boundary_check": lambda args: args["size_along_dim"] % args["block_size"],
}
)
@triton.jit
def backward(
grad_x1_ptr: tl.tensor,
Expand All @@ -144,6 +159,7 @@ def backward(
output_x_size: tl.int32,
dtype: tl.constexpr,
block_size: tl.constexpr,
require_boundary_check: tl.constexpr,
):
pid = tl.program_id(0)
num_output_y = pid // output_x_size
Expand Down Expand Up @@ -207,8 +223,12 @@ def backward(
)

for _ in range(0, size_along_dim, block_size):
x1 = tl.load(x1_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32)
x2 = tl.load(x2_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32)
if require_boundary_check:
x1 = tl.load(x1_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32)
x2 = tl.load(x2_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32)
else:
x1 = tl.load(x1_block_ptr)
x2 = tl.load(x2_block_ptr)

denominator = tl.load(denominator_block_ptr)
numerator = tl.load(numerator_block_ptr)
Expand All @@ -232,18 +252,12 @@ def backward(
grad_x1 = (grad_sqrt1 * 2 * x1) + (grad_to_dot * x2)
grad_x2 = (grad_sqrt2 * 2 * x2) + (grad_to_dot * x1)

tl.store(
grad_x1_block_ptr,
grad_x1.to(dtype),
mask=None,
boundary_check=(0,),
)
tl.store(
grad_x2_block_ptr,
grad_x2.to(dtype),
mask=None,
boundary_check=(0,),
)
if require_boundary_check:
tl.store(grad_x1_block_ptr, grad_x1.to(dtype), boundary_check=(0,))
tl.store(grad_x2_block_ptr, grad_x2.to(dtype), boundary_check=(0,))
else:
tl.store(grad_x1_block_ptr, grad_x1.to(dtype))
tl.store(grad_x2_block_ptr, grad_x2.to(dtype))

x1_block_ptr = tl.advance(x1_block_ptr, (block_size, 0, 0))
x2_block_ptr = tl.advance(x2_block_ptr, (block_size, 0, 0))
Expand Down
38 changes: 19 additions & 19 deletions trident/kernel/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ class Linear:
@util.autotune(linear_configs([16, 64, 128], [32, 64, 128], [32, 64]), ["m_size", "n_size", "k_size"])
@triton.heuristics(
{
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0,
"require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"] == 0,
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0,
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"],
"require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"],
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"],
}
)
@triton.jit
Expand Down Expand Up @@ -184,18 +184,18 @@ def forward(
block_shape=(m_block_size, n_block_size),
order=(1, 0),
)
if require_m_boundary_check & require_n_boundary_check:
tl.store(output_block_ptr, output)
else:
if require_m_boundary_check | require_n_boundary_check:
tl.store(output_block_ptr, output, boundary_check=(0, 1))
else:
tl.store(output_block_ptr, output)

@staticmethod
@util.autotune(linear_backward_configs([64, 128], [32, 64], [32, 64, 128]), ["m_size", "n_size", "k_size"])
@triton.heuristics(
{
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0,
"require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"] == 0,
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0,
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"],
"require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"],
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"],
}
)
@triton.jit
Expand Down Expand Up @@ -259,10 +259,10 @@ def backward(
order=(1, 0),
)

if require_m_boundary_check & require_k_boundary_check:
tl.store(grad_input_block_ptr, grad_input)
else:
if require_m_boundary_check | require_k_boundary_check:
tl.store(grad_input_block_ptr, grad_input, boundary_check=(0, 1))
else:
tl.store(grad_input_block_ptr, grad_input)

@staticmethod
@util.autotune(
Expand All @@ -271,9 +271,9 @@ def backward(
)
@triton.heuristics(
{
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0,
"require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"] == 0,
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0,
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"],
"require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"],
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"],
}
)
@triton.jit
Expand Down Expand Up @@ -336,14 +336,14 @@ def backward_weight(
order=(1, 0),
)

if require_n_boundary_check & require_k_boundary_check:
tl.store(grad_weight_staging_block_ptr, grad_weight)
else:
if require_n_boundary_check | require_k_boundary_check:
tl.store(grad_weight_staging_block_ptr, grad_weight, boundary_check=(0, 1))
else:
tl.store(grad_weight_staging_block_ptr, grad_weight)

@staticmethod
@util.autotune(linear_configs_for_backward_bias(), ["m_size", "n_size"])
@triton.heuristics({"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0})
@triton.heuristics({"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"]})
@triton.jit
def backward_bias(
grad_bias_staging_ptr: tl.tensor,
Expand Down
44 changes: 22 additions & 22 deletions trident/language/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ def forward(
output = tl.zeros((m_block_size, n_block_size), dtype)

for k_offset in range(0, k_size, k_block_size):
if require_k_boundary_check & require_m_boundary_check:
input = tl.load(input_block_ptr)
else:
if require_k_boundary_check | require_m_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(0, 1), padding_option="zero")

if require_k_boundary_check & require_n_boundary_check:
weight = tl.load(weight_block_ptr)
else:
input = tl.load(input_block_ptr)

if require_k_boundary_check | require_n_boundary_check:
weight = tl.load(weight_block_ptr, boundary_check=(0, 1), padding_option="zero")
else:
weight = tl.load(weight_block_ptr)

output += language.dot(input, weight, use_accelerator, dtype)
input_block_ptr = tl.advance(input_block_ptr, (0, k_block_size))
Expand All @@ -86,9 +86,9 @@ def forward(
order=(0,),
)
if require_n_boundary_check:
bias = tl.load(bias_block_ptr)
else:
bias = tl.load(bias_block_ptr, boundary_check=(0,), padding_option="zero")
else:
bias = tl.load(bias_block_ptr)

output += bias

Expand Down Expand Up @@ -134,15 +134,15 @@ def backward(
grad_input = tl.zeros((m_block_size, k_block_size), dtype)

for _ in range(0, n_size, n_block_size):
if require_n_boundary_check & require_m_boundary_check:
grad_output = tl.load(grad_output_block_ptr)
else:
if require_n_boundary_check | require_m_boundary_check:
grad_output = tl.load(grad_output_block_ptr, boundary_check=(0, 1), padding_option="zero")

if require_n_boundary_check & require_k_boundary_check:
weight = tl.load(weight_block_ptr)
else:
grad_output = tl.load(grad_output_block_ptr)

if require_n_boundary_check | require_k_boundary_check:
weight = tl.load(weight_block_ptr, boundary_check=(0, 1), padding_option="zero")
else:
weight = tl.load(weight_block_ptr)

grad_input += language.dot(grad_output, weight, use_accelerator, dtype)
grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, n_block_size))
Expand Down Expand Up @@ -190,15 +190,15 @@ def backward_weight(
grad_weight = tl.zeros((n_block_size, k_block_size), dtype)

for _ in range(0, m_size, m_block_size):
if require_m_boundary_check & require_n_boundary_check:
grad_output = tl.load(grad_output_block_ptr)
else:
if require_m_boundary_check | require_n_boundary_check:
grad_output = tl.load(grad_output_block_ptr, boundary_check=(0, 1), padding_option="zero")

if require_m_boundary_check & require_k_boundary_check:
input = tl.load(input_block_ptr)
else:
grad_output = tl.load(grad_output_block_ptr)

if require_m_boundary_check | require_k_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(0, 1), padding_option="zero")
else:
input = tl.load(input_block_ptr)

grad_weight += language.dot(grad_output, input, use_accelerator, dtype)
grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, m_block_size))
Expand Down Expand Up @@ -229,9 +229,9 @@ def backward_bias(

for m_offset in range(0, m_size, m_block_size):
if require_m_boundary_check:
grad_output = tl.load(grad_output_block_ptr)
else:
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,), padding_option="zero")
else:
grad_output = tl.load(grad_output_block_ptr)

grad_bias += grad_output
grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, m_block_size))
Expand Down

0 comments on commit c4d102a

Please sign in to comment.