diff --git a/trident/kernel/cosine_similarity.py b/trident/kernel/cosine_similarity.py index 1a88626..ce0939c 100644 --- a/trident/kernel/cosine_similarity.py +++ b/trident/kernel/cosine_similarity.py @@ -49,9 +49,9 @@ def forward( y_stride: tl.int32, x_stride: tl.int32, eps: tl.float32, - size_along_dim: tl.int32, output_y_size: tl.int32, output_x_size: tl.int32, + size_along_dim: tl.constexpr, dtype: tl.constexpr, block_size: tl.constexpr, require_boundary_check: tl.constexpr, @@ -154,9 +154,9 @@ def backward( z_stride: tl.int32, y_stride: tl.int32, x_stride: tl.int32, - size_along_dim: tl.int32, output_y_size: tl.int32, output_x_size: tl.int32, + size_along_dim: tl.constexpr, dtype: tl.constexpr, block_size: tl.constexpr, require_boundary_check: tl.constexpr, diff --git a/trident/kernel/prelu.py b/trident/kernel/prelu.py index b72dad5..5f66463 100644 --- a/trident/kernel/prelu.py +++ b/trident/kernel/prelu.py @@ -31,6 +31,12 @@ def prelu_configs(): class PReLU: @staticmethod @util.autotune(prelu_configs(), ["x_size"]) + @triton.heuristics( + { + "require_y_boundary_check": lambda args: args["y_size"] % args["y_block_size"], + "require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"], + } + ) @triton.jit def forward( output_ptr: tl.tensor, @@ -45,6 +51,8 @@ def forward( dtype: tl.constexpr, y_block_size: tl.constexpr, x_block_size: tl.constexpr, + require_y_boundary_check: tl.constexpr, + require_x_boundary_check: tl.constexpr, ): pid = tl.program_id(0) num_y_blocks = tl.cdiv(y_size, y_block_size) @@ -82,13 +90,31 @@ def forward( order=(1, 0), ) - input = tl.load(input_block_ptr, boundary_check=(1, 2)) - weight = tl.load(weight_block_ptr, boundary_check=(0,)) + if require_y_boundary_check | require_x_boundary_check: + input = tl.load(input_block_ptr, boundary_check=(1, 2)) + else: + input = tl.load(input_block_ptr) + + if require_y_boundary_check: + weight = tl.load(weight_block_ptr, boundary_check=(0,)) + else: + weight = tl.load(weight_block_ptr) + output = language.math.LeakyReLU.forward(input, weight) - tl.store(output_block_ptr, output.to(dtype), boundary_check=(1, 2)) + + if require_y_boundary_check | require_x_boundary_check: + tl.store(output_block_ptr, output.to(dtype), boundary_check=(1, 2)) + else: + tl.store(output_block_ptr, output.to(dtype)) @staticmethod @util.autotune(prelu_configs(), ["x_size"]) + @triton.heuristics( + { + "require_y_boundary_check": lambda args: args["y_size"] % args["y_block_size"], + "require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"], + } + ) @triton.jit def backward( grad_input_ptr: tl.tensor, @@ -105,6 +131,8 @@ def backward( dtype: tl.constexpr, y_block_size: tl.constexpr, x_block_size: tl.constexpr, + require_y_boundary_check: tl.constexpr, + require_x_boundary_check: tl.constexpr, ): pid = tl.program_id(0) num_y_blocks = tl.cdiv(y_size, y_block_size) @@ -157,11 +185,20 @@ def backward( block_shape=(y_block_size, 1), order=(1, 0), ) + if require_y_boundary_check | require_x_boundary_check: + input = tl.load(input_block_ptr, boundary_check=(1, 2)) + grad_output = tl.load(grad_output_block_ptr, boundary_check=(1, 2)) + else: + input = tl.load(input_block_ptr) + grad_output = tl.load(grad_output_block_ptr) - input = tl.load(input_block_ptr, boundary_check=(1, 2)) weight = tl.load(weight_block_ptr) - grad_output = tl.load(grad_output_block_ptr, boundary_check=(1, 2)) grad_input = language.math.LeakyReLU.backward(grad_output, input, weight) grad_weight = grad_output * tl.where(input > 0, 0, input) - tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1, 2)) - tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype), boundary_check=(1, 2)) + + if require_y_boundary_check | require_x_boundary_check: + tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1, 2)) + tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype), boundary_check=(1, 2)) + else: + tl.store(grad_input_block_ptr, grad_input.to(dtype)) + tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype)) diff --git a/trident/operation/cosine_similarity.py b/trident/operation/cosine_similarity.py index 91bc31e..b517b5d 100644 --- a/trident/operation/cosine_similarity.py +++ b/trident/operation/cosine_similarity.py @@ -72,9 +72,9 @@ def grid(meta): y_stride, x_stride, eps, - size_along_dim, output_y_size, output_x_size, + size_along_dim, util.dtype(x1.dtype), ) util.pop_trace() @@ -107,9 +107,9 @@ def grid(meta): z_stride, y_stride, x_stride, - size_along_dim, output_y_size, output_x_size, + size_along_dim, util.dtype(x1.dtype), ) util.pop_trace()