diff --git a/trident/kernel/softmax.py b/trident/kernel/softmax.py index 5704a0f..396089c 100644 --- a/trident/kernel/softmax.py +++ b/trident/kernel/softmax.py @@ -29,6 +29,7 @@ def softmax_configs(): class Softmax: @staticmethod @util.autotune(softmax_configs(), ["x_size"]) + @triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]}) @triton.jit def forward( output_ptr: tl.tensor, @@ -39,6 +40,7 @@ def forward( x_stride: tl.int32, dtype: tl.constexpr, x_block_size: tl.constexpr, + require_x_boundary_check: tl.constexpr, ): y_offset = tl.program_id(0) @@ -63,10 +65,15 @@ def forward( sum = tl.zeros((1, x_block_size), tl.float32) for x_offset in range(0, x_size, x_block_size): - input = tl.load(input_block_ptr, boundary_check=(1,)) - condition = tl.arange(0, x_block_size) + x_offset < x_size - input = tl.where(condition, input, -float("inf")) - peak = tl.where(condition, tl.maximum(max, input), 0) + if require_x_boundary_check: + input = tl.load(input_block_ptr, boundary_check=(1,)) + condition = tl.arange(0, x_block_size) + x_offset < x_size + input = tl.where(condition, input, -float("inf")) + peak = tl.where(condition, tl.maximum(max, input), 0) + else: + input = tl.load(input_block_ptr) + peak = tl.maximum(max, input) + sum = sum * tl.math.fast_expf(max - peak) + tl.math.fast_expf(input - peak) max = peak input_block_ptr = tl.advance(input_block_ptr, (0, x_block_size)) @@ -83,14 +90,24 @@ def forward( ) for x_offset in range(0, x_size, x_block_size): - input = tl.load(input_block_ptr, boundary_check=(1,)) + if require_x_boundary_check: + input = tl.load(input_block_ptr, boundary_check=(1,)) + else: + input = tl.load(input_block_ptr) + output = tl.math.fast_expf(input - max) / sum - tl.store(output_block_ptr, output.to(dtype), boundary_check=(1,)) + + if require_x_boundary_check: + tl.store(output_block_ptr, output.to(dtype), boundary_check=(1,)) + else: + tl.store(output_block_ptr, output.to(dtype)) + output_block_ptr = tl.advance(output_block_ptr, (0, x_block_size)) input_block_ptr = tl.advance(input_block_ptr, (0, x_block_size)) @staticmethod @util.autotune(softmax_configs(), ["x_size"]) + @triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]}) @triton.jit def backward( grad_input_ptr: tl.tensor, @@ -103,6 +120,7 @@ def backward( x_stride: tl.int32, dtype: tl.constexpr, x_block_size: tl.constexpr, + require_x_boundary_check: tl.constexpr, ): y_offset = tl.program_id(0) @@ -140,17 +158,28 @@ def backward( ) for x_offset in range(0, x_size, x_block_size): - output = tl.load(output_block_ptr, boundary_check=(1,)) - grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,)) - delta = tl.load(delta_block_ptr, boundary_check=(0,)) + if require_x_boundary_check: + output = tl.load(output_block_ptr, boundary_check=(1,)) + grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,)) + else: + output = tl.load(output_block_ptr) + grad_output = tl.load(grad_output_block_ptr) + + delta = tl.load(delta_block_ptr) grad_input = output * (grad_output - delta) - tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1,)) + + if require_x_boundary_check: + tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1,)) + else: + tl.store(grad_input_block_ptr, grad_input.to(dtype)) + grad_input_block_ptr = tl.advance(grad_input_block_ptr, (0, x_block_size)) grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, x_block_size)) output_block_ptr = tl.advance(output_block_ptr, (0, x_block_size)) @staticmethod @util.autotune(softmax_configs(), ["x_size"]) + @triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]}) @triton.jit def backward_delta( delta_ptr: tl.tensor, @@ -162,6 +191,7 @@ def backward_delta( x_stride: tl.int32, dtype: tl.constexpr, x_block_size: tl.constexpr, + require_x_boundary_check: tl.constexpr, ): y_offset = tl.program_id(0) @@ -193,8 +223,13 @@ def backward_delta( delta = tl.zeros((1, x_block_size), dtype) for _ in range(0, x_size, x_block_size): - grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,), padding_option="zero") - output = tl.load(output_block_ptr, boundary_check=(1,)) + if require_x_boundary_check: + grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,), padding_option="zero") + output = tl.load(output_block_ptr, boundary_check=(1,)) + else: + grad_output = tl.load(grad_output_block_ptr) + output = tl.load(output_block_ptr) + delta += grad_output * output output_block_ptr = tl.advance(output_block_ptr, (0, x_block_size)) grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, x_block_size))