Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Optimize CosineSimiliarity #141

Merged
merged 1 commit into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions trident/kernel/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def cosine_similarity_configs():
class CosineSimilarity:
@staticmethod
@util.autotune(cosine_similarity_configs(), ["y_size", "x_size"])
@triton.heuristics(
{
"require_boundary_check": lambda args: args["size_along_dim"] % args["block_size"],
}
)
@triton.jit
def forward(
output_ptr: tl.tensor,
Expand All @@ -49,6 +54,7 @@ def forward(
output_x_size: tl.int32,
dtype: tl.constexpr,
block_size: tl.constexpr,
require_boundary_check: tl.constexpr,
):
pid = tl.program_id(0)

Expand Down Expand Up @@ -101,8 +107,12 @@ def forward(
numerator_accumulation = tl.zeros((block_size, 1, 1), tl.float32)

for _ in range(0, size_along_dim, block_size):
x1 = tl.load(x1_block_ptr, boundary_check=(0,), padding_option="zero")
x2 = tl.load(x2_block_ptr, boundary_check=(0,), padding_option="zero")
if require_boundary_check:
x1 = tl.load(x1_block_ptr, boundary_check=(0,), padding_option="zero")
x2 = tl.load(x2_block_ptr, boundary_check=(0,), padding_option="zero")
else:
x1 = tl.load(x1_block_ptr)
x2 = tl.load(x2_block_ptr)

denominator_accumulation1 += x1 * x1
denominator_accumulation2 += x2 * x2
Expand All @@ -124,6 +134,11 @@ def forward(

@staticmethod
@util.autotune(cosine_similarity_configs(), ["y_size", "x_size"])
@triton.heuristics(
{
"require_boundary_check": lambda args: args["size_along_dim"] % args["block_size"],
}
)
@triton.jit
def backward(
grad_x1_ptr: tl.tensor,
Expand All @@ -144,6 +159,7 @@ def backward(
output_x_size: tl.int32,
dtype: tl.constexpr,
block_size: tl.constexpr,
require_boundary_check: tl.constexpr,
):
pid = tl.program_id(0)
num_output_y = pid // output_x_size
Expand Down Expand Up @@ -207,8 +223,12 @@ def backward(
)

for _ in range(0, size_along_dim, block_size):
x1 = tl.load(x1_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32)
x2 = tl.load(x2_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32)
if require_boundary_check:
x1 = tl.load(x1_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32)
x2 = tl.load(x2_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32)
else:
x1 = tl.load(x1_block_ptr)
x2 = tl.load(x2_block_ptr)

denominator = tl.load(denominator_block_ptr)
numerator = tl.load(numerator_block_ptr)
Expand All @@ -232,18 +252,12 @@ def backward(
grad_x1 = (grad_sqrt1 * 2 * x1) + (grad_to_dot * x2)
grad_x2 = (grad_sqrt2 * 2 * x2) + (grad_to_dot * x1)

tl.store(
grad_x1_block_ptr,
grad_x1.to(dtype),
mask=None,
boundary_check=(0,),
)
tl.store(
grad_x2_block_ptr,
grad_x2.to(dtype),
mask=None,
boundary_check=(0,),
)
if require_boundary_check:
tl.store(grad_x1_block_ptr, grad_x1.to(dtype), boundary_check=(0,))
tl.store(grad_x2_block_ptr, grad_x2.to(dtype), boundary_check=(0,))
else:
tl.store(grad_x1_block_ptr, grad_x1.to(dtype))
tl.store(grad_x2_block_ptr, grad_x2.to(dtype))

x1_block_ptr = tl.advance(x1_block_ptr, (block_size, 0, 0))
x2_block_ptr = tl.advance(x2_block_ptr, (block_size, 0, 0))
Expand Down
38 changes: 19 additions & 19 deletions trident/kernel/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ class Linear:
@util.autotune(linear_configs([16, 64, 128], [32, 64, 128], [32, 64]), ["m_size", "n_size", "k_size"])
@triton.heuristics(
{
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0,
"require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"] == 0,
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0,
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"],
"require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"],
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"],
}
)
@triton.jit
Expand Down Expand Up @@ -184,18 +184,18 @@ def forward(
block_shape=(m_block_size, n_block_size),
order=(1, 0),
)
if require_m_boundary_check & require_n_boundary_check:
tl.store(output_block_ptr, output)
else:
if require_m_boundary_check | require_n_boundary_check:
tl.store(output_block_ptr, output, boundary_check=(0, 1))
else:
tl.store(output_block_ptr, output)

@staticmethod
@util.autotune(linear_backward_configs([64, 128], [32, 64], [32, 64, 128]), ["m_size", "n_size", "k_size"])
@triton.heuristics(
{
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0,
"require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"] == 0,
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0,
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"],
"require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"],
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"],
}
)
@triton.jit
Expand Down Expand Up @@ -259,10 +259,10 @@ def backward(
order=(1, 0),
)

if require_m_boundary_check & require_k_boundary_check:
tl.store(grad_input_block_ptr, grad_input)
else:
if require_m_boundary_check | require_k_boundary_check:
tl.store(grad_input_block_ptr, grad_input, boundary_check=(0, 1))
else:
tl.store(grad_input_block_ptr, grad_input)

@staticmethod
@util.autotune(
Expand All @@ -271,9 +271,9 @@ def backward(
)
@triton.heuristics(
{
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0,
"require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"] == 0,
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0,
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"],
"require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"],
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"],
}
)
@triton.jit
Expand Down Expand Up @@ -336,14 +336,14 @@ def backward_weight(
order=(1, 0),
)

if require_n_boundary_check & require_k_boundary_check:
tl.store(grad_weight_staging_block_ptr, grad_weight)
else:
if require_n_boundary_check | require_k_boundary_check:
tl.store(grad_weight_staging_block_ptr, grad_weight, boundary_check=(0, 1))
else:
tl.store(grad_weight_staging_block_ptr, grad_weight)

@staticmethod
@util.autotune(linear_configs_for_backward_bias(), ["m_size", "n_size"])
@triton.heuristics({"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0})
@triton.heuristics({"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"]})
@triton.jit
def backward_bias(
grad_bias_staging_ptr: tl.tensor,
Expand Down
44 changes: 22 additions & 22 deletions trident/language/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ def forward(
output = tl.zeros((m_block_size, n_block_size), dtype)

for k_offset in range(0, k_size, k_block_size):
if require_k_boundary_check & require_m_boundary_check:
input = tl.load(input_block_ptr)
else:
if require_k_boundary_check | require_m_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(0, 1), padding_option="zero")

if require_k_boundary_check & require_n_boundary_check:
weight = tl.load(weight_block_ptr)
else:
input = tl.load(input_block_ptr)

if require_k_boundary_check | require_n_boundary_check:
weight = tl.load(weight_block_ptr, boundary_check=(0, 1), padding_option="zero")
else:
weight = tl.load(weight_block_ptr)

output += language.dot(input, weight, use_accelerator, dtype)
input_block_ptr = tl.advance(input_block_ptr, (0, k_block_size))
Expand All @@ -86,9 +86,9 @@ def forward(
order=(0,),
)
if require_n_boundary_check:
bias = tl.load(bias_block_ptr)
else:
bias = tl.load(bias_block_ptr, boundary_check=(0,), padding_option="zero")
else:
bias = tl.load(bias_block_ptr)

output += bias

Expand Down Expand Up @@ -134,15 +134,15 @@ def backward(
grad_input = tl.zeros((m_block_size, k_block_size), dtype)

for _ in range(0, n_size, n_block_size):
if require_n_boundary_check & require_m_boundary_check:
grad_output = tl.load(grad_output_block_ptr)
else:
if require_n_boundary_check | require_m_boundary_check:
grad_output = tl.load(grad_output_block_ptr, boundary_check=(0, 1), padding_option="zero")

if require_n_boundary_check & require_k_boundary_check:
weight = tl.load(weight_block_ptr)
else:
grad_output = tl.load(grad_output_block_ptr)

if require_n_boundary_check | require_k_boundary_check:
weight = tl.load(weight_block_ptr, boundary_check=(0, 1), padding_option="zero")
else:
weight = tl.load(weight_block_ptr)

grad_input += language.dot(grad_output, weight, use_accelerator, dtype)
grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, n_block_size))
Expand Down Expand Up @@ -190,15 +190,15 @@ def backward_weight(
grad_weight = tl.zeros((n_block_size, k_block_size), dtype)

for _ in range(0, m_size, m_block_size):
if require_m_boundary_check & require_n_boundary_check:
grad_output = tl.load(grad_output_block_ptr)
else:
if require_m_boundary_check | require_n_boundary_check:
grad_output = tl.load(grad_output_block_ptr, boundary_check=(0, 1), padding_option="zero")

if require_m_boundary_check & require_k_boundary_check:
input = tl.load(input_block_ptr)
else:
grad_output = tl.load(grad_output_block_ptr)

if require_m_boundary_check | require_k_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(0, 1), padding_option="zero")
else:
input = tl.load(input_block_ptr)

grad_weight += language.dot(grad_output, input, use_accelerator, dtype)
grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, m_block_size))
Expand Down Expand Up @@ -229,9 +229,9 @@ def backward_bias(

for m_offset in range(0, m_size, m_block_size):
if require_m_boundary_check:
grad_output = tl.load(grad_output_block_ptr)
else:
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,), padding_option="zero")
else:
grad_output = tl.load(grad_output_block_ptr)

grad_bias += grad_output
grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, m_block_size))
Expand Down