From 8079ddca40034c59dfb99bde09ceeed1a674ee23 Mon Sep 17 00:00:00 2001 From: "mejai.p" Date: Wed, 20 Sep 2023 17:39:09 +0900 Subject: [PATCH] Optimize CosineSimiliarity --- trident/kernel/cosine_similarity.py | 46 +++++++++++++++++++---------- trident/kernel/linear.py | 38 ++++++++++++------------ trident/language/linear.py | 44 +++++++++++++-------------- 3 files changed, 71 insertions(+), 57 deletions(-) diff --git a/trident/kernel/cosine_similarity.py b/trident/kernel/cosine_similarity.py index 48197862..1a88626a 100644 --- a/trident/kernel/cosine_similarity.py +++ b/trident/kernel/cosine_similarity.py @@ -30,6 +30,11 @@ def cosine_similarity_configs(): class CosineSimilarity: @staticmethod @util.autotune(cosine_similarity_configs(), ["y_size", "x_size"]) + @triton.heuristics( + { + "require_boundary_check": lambda args: args["size_along_dim"] % args["block_size"], + } + ) @triton.jit def forward( output_ptr: tl.tensor, @@ -49,6 +54,7 @@ def forward( output_x_size: tl.int32, dtype: tl.constexpr, block_size: tl.constexpr, + require_boundary_check: tl.constexpr, ): pid = tl.program_id(0) @@ -101,8 +107,12 @@ def forward( numerator_accumulation = tl.zeros((block_size, 1, 1), tl.float32) for _ in range(0, size_along_dim, block_size): - x1 = tl.load(x1_block_ptr, boundary_check=(0,), padding_option="zero") - x2 = tl.load(x2_block_ptr, boundary_check=(0,), padding_option="zero") + if require_boundary_check: + x1 = tl.load(x1_block_ptr, boundary_check=(0,), padding_option="zero") + x2 = tl.load(x2_block_ptr, boundary_check=(0,), padding_option="zero") + else: + x1 = tl.load(x1_block_ptr) + x2 = tl.load(x2_block_ptr) denominator_accumulation1 += x1 * x1 denominator_accumulation2 += x2 * x2 @@ -124,6 +134,11 @@ def forward( @staticmethod @util.autotune(cosine_similarity_configs(), ["y_size", "x_size"]) + @triton.heuristics( + { + "require_boundary_check": lambda args: args["size_along_dim"] % args["block_size"], + } + ) @triton.jit def backward( grad_x1_ptr: tl.tensor, @@ -144,6 +159,7 @@ def backward( output_x_size: tl.int32, dtype: tl.constexpr, block_size: tl.constexpr, + require_boundary_check: tl.constexpr, ): pid = tl.program_id(0) num_output_y = pid // output_x_size @@ -207,8 +223,12 @@ def backward( ) for _ in range(0, size_along_dim, block_size): - x1 = tl.load(x1_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32) - x2 = tl.load(x2_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32) + if require_boundary_check: + x1 = tl.load(x1_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32) + x2 = tl.load(x2_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32) + else: + x1 = tl.load(x1_block_ptr) + x2 = tl.load(x2_block_ptr) denominator = tl.load(denominator_block_ptr) numerator = tl.load(numerator_block_ptr) @@ -232,18 +252,12 @@ def backward( grad_x1 = (grad_sqrt1 * 2 * x1) + (grad_to_dot * x2) grad_x2 = (grad_sqrt2 * 2 * x2) + (grad_to_dot * x1) - tl.store( - grad_x1_block_ptr, - grad_x1.to(dtype), - mask=None, - boundary_check=(0,), - ) - tl.store( - grad_x2_block_ptr, - grad_x2.to(dtype), - mask=None, - boundary_check=(0,), - ) + if require_boundary_check: + tl.store(grad_x1_block_ptr, grad_x1.to(dtype), boundary_check=(0,)) + tl.store(grad_x2_block_ptr, grad_x2.to(dtype), boundary_check=(0,)) + else: + tl.store(grad_x1_block_ptr, grad_x1.to(dtype)) + tl.store(grad_x2_block_ptr, grad_x2.to(dtype)) x1_block_ptr = tl.advance(x1_block_ptr, (block_size, 0, 0)) x2_block_ptr = tl.advance(x2_block_ptr, (block_size, 0, 0)) diff --git a/trident/kernel/linear.py b/trident/kernel/linear.py index 4b727794..39ea3a79 100644 --- a/trident/kernel/linear.py +++ b/trident/kernel/linear.py @@ -114,9 +114,9 @@ class Linear: @util.autotune(linear_configs([16, 64, 128], [32, 64, 128], [32, 64]), ["m_size", "n_size", "k_size"]) @triton.heuristics( { - "require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0, - "require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"] == 0, - "require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0, + "require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"], + "require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"], + "require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"], } ) @triton.jit @@ -184,18 +184,18 @@ def forward( block_shape=(m_block_size, n_block_size), order=(1, 0), ) - if require_m_boundary_check & require_n_boundary_check: - tl.store(output_block_ptr, output) - else: + if require_m_boundary_check | require_n_boundary_check: tl.store(output_block_ptr, output, boundary_check=(0, 1)) + else: + tl.store(output_block_ptr, output) @staticmethod @util.autotune(linear_backward_configs([64, 128], [32, 64], [32, 64, 128]), ["m_size", "n_size", "k_size"]) @triton.heuristics( { - "require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0, - "require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"] == 0, - "require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0, + "require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"], + "require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"], + "require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"], } ) @triton.jit @@ -259,10 +259,10 @@ def backward( order=(1, 0), ) - if require_m_boundary_check & require_k_boundary_check: - tl.store(grad_input_block_ptr, grad_input) - else: + if require_m_boundary_check | require_k_boundary_check: tl.store(grad_input_block_ptr, grad_input, boundary_check=(0, 1)) + else: + tl.store(grad_input_block_ptr, grad_input) @staticmethod @util.autotune( @@ -271,9 +271,9 @@ def backward( ) @triton.heuristics( { - "require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0, - "require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"] == 0, - "require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0, + "require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"], + "require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"], + "require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"], } ) @triton.jit @@ -336,14 +336,14 @@ def backward_weight( order=(1, 0), ) - if require_n_boundary_check & require_k_boundary_check: - tl.store(grad_weight_staging_block_ptr, grad_weight) - else: + if require_n_boundary_check | require_k_boundary_check: tl.store(grad_weight_staging_block_ptr, grad_weight, boundary_check=(0, 1)) + else: + tl.store(grad_weight_staging_block_ptr, grad_weight) @staticmethod @util.autotune(linear_configs_for_backward_bias(), ["m_size", "n_size"]) - @triton.heuristics({"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0}) + @triton.heuristics({"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"]}) @triton.jit def backward_bias( grad_bias_staging_ptr: tl.tensor, diff --git a/trident/language/linear.py b/trident/language/linear.py index 95efd85b..c58e7795 100644 --- a/trident/language/linear.py +++ b/trident/language/linear.py @@ -62,15 +62,15 @@ def forward( output = tl.zeros((m_block_size, n_block_size), dtype) for k_offset in range(0, k_size, k_block_size): - if require_k_boundary_check & require_m_boundary_check: - input = tl.load(input_block_ptr) - else: + if require_k_boundary_check | require_m_boundary_check: input = tl.load(input_block_ptr, boundary_check=(0, 1), padding_option="zero") - - if require_k_boundary_check & require_n_boundary_check: - weight = tl.load(weight_block_ptr) else: + input = tl.load(input_block_ptr) + + if require_k_boundary_check | require_n_boundary_check: weight = tl.load(weight_block_ptr, boundary_check=(0, 1), padding_option="zero") + else: + weight = tl.load(weight_block_ptr) output += language.dot(input, weight, use_accelerator, dtype) input_block_ptr = tl.advance(input_block_ptr, (0, k_block_size)) @@ -86,9 +86,9 @@ def forward( order=(0,), ) if require_n_boundary_check: - bias = tl.load(bias_block_ptr) - else: bias = tl.load(bias_block_ptr, boundary_check=(0,), padding_option="zero") + else: + bias = tl.load(bias_block_ptr) output += bias @@ -134,15 +134,15 @@ def backward( grad_input = tl.zeros((m_block_size, k_block_size), dtype) for _ in range(0, n_size, n_block_size): - if require_n_boundary_check & require_m_boundary_check: - grad_output = tl.load(grad_output_block_ptr) - else: + if require_n_boundary_check | require_m_boundary_check: grad_output = tl.load(grad_output_block_ptr, boundary_check=(0, 1), padding_option="zero") - - if require_n_boundary_check & require_k_boundary_check: - weight = tl.load(weight_block_ptr) else: + grad_output = tl.load(grad_output_block_ptr) + + if require_n_boundary_check | require_k_boundary_check: weight = tl.load(weight_block_ptr, boundary_check=(0, 1), padding_option="zero") + else: + weight = tl.load(weight_block_ptr) grad_input += language.dot(grad_output, weight, use_accelerator, dtype) grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, n_block_size)) @@ -190,15 +190,15 @@ def backward_weight( grad_weight = tl.zeros((n_block_size, k_block_size), dtype) for _ in range(0, m_size, m_block_size): - if require_m_boundary_check & require_n_boundary_check: - grad_output = tl.load(grad_output_block_ptr) - else: + if require_m_boundary_check | require_n_boundary_check: grad_output = tl.load(grad_output_block_ptr, boundary_check=(0, 1), padding_option="zero") - - if require_m_boundary_check & require_k_boundary_check: - input = tl.load(input_block_ptr) else: + grad_output = tl.load(grad_output_block_ptr) + + if require_m_boundary_check | require_k_boundary_check: input = tl.load(input_block_ptr, boundary_check=(0, 1), padding_option="zero") + else: + input = tl.load(input_block_ptr) grad_weight += language.dot(grad_output, input, use_accelerator, dtype) grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, m_block_size)) @@ -229,9 +229,9 @@ def backward_bias( for m_offset in range(0, m_size, m_block_size): if require_m_boundary_check: - grad_output = tl.load(grad_output_block_ptr) - else: grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,), padding_option="zero") + else: + grad_output = tl.load(grad_output_block_ptr) grad_bias += grad_output grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, m_block_size))