diff --git a/tests/test_geglu.py b/tests/test_geglu.py index a3376484..b9418984 100644 --- a/tests/test_geglu.py +++ b/tests/test_geglu.py @@ -86,6 +86,9 @@ def train(func): @pytest.mark.parametrize("num_batches, m_size, n_size, k_size", [(1, 16, 16, 16)]) def test_geglu(num_batches, m_size, n_size, k_size, device, dtype): + if dtype is torch.bfloat16: + pytest.skip("Triton has a bug.") + factory_kwargs = {"device": device, "dtype": dtype} input = torch.randn(num_batches, m_size, k_size, **factory_kwargs) x_size = n_size // 2 diff --git a/tests/test_linear.py b/tests/test_linear.py index fa99d612..12ae5db8 100644 --- a/tests/test_linear.py +++ b/tests/test_linear.py @@ -84,6 +84,9 @@ def train(func): @pytest.mark.parametrize("m_size, n_size, k_size", [(32, 32, 32)]) def test_linear(m_size, n_size, k_size, device, dtype): + if dtype is torch.bfloat16: + pytest.skip("Triton has a bug.") + factory_kwargs = {"device": device, "dtype": dtype} input = torch.randn(m_size, k_size, **factory_kwargs, requires_grad=True) weight = torch.randn(n_size, k_size, **factory_kwargs, requires_grad=True) diff --git a/trident/kernel/attention.py b/trident/kernel/attention.py index dddea92f..dcfefe70 100644 --- a/trident/kernel/attention.py +++ b/trident/kernel/attention.py @@ -120,7 +120,7 @@ def forward( score = tl.where(condition, score, float("-inf")) key = tl.load(key_block_ptr) - score += tl.dot(query, key, use_accelerator, dtype) + score += tl.dot(query, key, allow_tf32=use_accelerator, out_dtype=dtype) peak = tl.maximum(max, tl.max(score, 1)) alpha = tl.math.exp2(max - peak) beta = tl.math.exp2(score - peak[:, None]) @@ -128,7 +128,7 @@ def forward( max = peak output *= alpha[:, None].to(dtype) value = tl.load(value_block_ptr) - gamma = tl.dot(beta.to(dtype), value, use_accelerator, dtype) + gamma = tl.dot(beta.to(dtype), value, allow_tf32=use_accelerator, out_dtype=dtype) output += gamma.to(dtype) key_block_ptr = tl.advance(key_block_ptr, (0, y_block_size)) value_block_ptr = tl.advance(value_block_ptr, (y_block_size, 0)) @@ -254,7 +254,7 @@ def backward( else: score = tl.zeros((y_block_size, y_block_size), dtype) - score += tl.dot(query, tl.trans(key), use_accelerator, dtype) + score += tl.dot(query, tl.trans(key), allow_tf32=use_accelerator, out_dtype=dtype) score *= score_scale log2sum = tl.load(log2sum_block_ptr) alpha = tl.math.exp2(score - log2sum[:, None]).to(dtype) @@ -266,14 +266,14 @@ def backward( grad_dropout = tl.where(output > 0.0, dropout_scale, 0.0).to(dtype) grad_output *= grad_dropout - grad_value += tl.dot(tl.trans(alpha), grad_output, use_accelerator, dtype) + grad_value += tl.dot(tl.trans(alpha), grad_output, allow_tf32=use_accelerator, out_dtype=dtype) delta = tl.load(delta_block_ptr) grad_alpha = tl.zeros((y_block_size, y_block_size), dtype) - delta[:, None] - grad_alpha += tl.dot(grad_output, tl.trans(value), use_accelerator, dtype) + grad_alpha += tl.dot(grad_output, tl.trans(value), allow_tf32=use_accelerator, out_dtype=dtype) grad_softmax = (alpha * grad_alpha * softmax_scale).to(dtype) - grad_key += tl.dot(tl.trans(grad_softmax), query, use_accelerator, dtype) + grad_key += tl.dot(tl.trans(grad_softmax), query, allow_tf32=use_accelerator, out_dtype=dtype) grad_query = tl.load(grad_query_block_ptr) - grad_query += tl.dot(grad_softmax, key, use_accelerator, dtype) + grad_query += tl.dot(grad_softmax, key, allow_tf32=use_accelerator, out_dtype=dtype) tl.store(grad_query_block_ptr, grad_query) grad_query_block_ptr = tl.advance(grad_query_block_ptr, (y_block_size, 0)) diff --git a/trident/kernel/geglu.py b/trident/kernel/geglu.py index 660cc26e..ba82a1d6 100644 --- a/trident/kernel/geglu.py +++ b/trident/kernel/geglu.py @@ -18,28 +18,48 @@ from trident import language, util +def num_warps_and_stages_for_geglu(size): + if size >= 2**15: + num_warps = 8 + num_stages = 3 + elif size >= 2**14: + num_warps = 4 + num_stages = 4 + else: + num_warps = 2 + num_stages = 5 + return num_warps, num_stages + + def geglu_configs(): configs = [] - for m_block_size in [32, 64]: - for k_block_size in [32, 64, 128, 256]: - for x_block_size in [32, 64]: - for num_stages in [2, 3]: - config = triton.Config( - { - "m_block_size": m_block_size, - "k_block_size": k_block_size, - "x_block_size": x_block_size, - }, - 2 if k_block_size <= 64 else 4, - num_stages, - ) - configs.append(config) + for k_block_size in [32, 64]: + for m_block_size in [16, 64, 128]: + for x_block_size in [32, 64, 128]: + num_warps, num_stages = num_warps_and_stages_for_geglu(m_block_size * x_block_size) + config = triton.Config( + { + "m_block_size": m_block_size, + "k_block_size": k_block_size, + "x_block_size": x_block_size, + }, + num_warps=num_warps, + num_stages=num_stages, + ) + configs.append(config) return configs class GEGLU: @staticmethod - @util.autotune(configs=geglu_configs(), key=["m_size", "k_size", "x_size"]) + @util.autotune(geglu_configs(), ["m_size", "k_size", "x_size"]) + @triton.heuristics( + { + "require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0, + "require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0, + "require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"] == 0, + } + ) @triton.jit def forward( output_ptr: tl.tensor, @@ -61,6 +81,9 @@ def forward( m_block_size: tl.constexpr, k_block_size: tl.constexpr, x_block_size: tl.constexpr, + require_m_boundary_check: tl.constexpr, + require_k_boundary_check: tl.constexpr, + require_x_boundary_check: tl.constexpr, ): pid = tl.program_id(0) num_m_blocks = tl.cdiv(m_size, m_block_size) @@ -115,6 +138,9 @@ def forward( m_block_size, x_block_size, k_block_size, + require_m_boundary_check, + require_x_boundary_check, + require_k_boundary_check, dtype, ) gate = language.Linear.forward( @@ -134,12 +160,21 @@ def forward( m_block_size, x_block_size, k_block_size, + require_m_boundary_check, + require_x_boundary_check, + require_k_boundary_check, dtype, ) output = state * language.math.GELU.forward(gate) - tl.store(output_block_ptr, output.to(dtype), boundary_check=(0, 1)) - tl.store(state_block_ptr, state.to(dtype), boundary_check=(0, 1)) - tl.store(gate_block_ptr, gate.to(dtype), boundary_check=(0, 1)) + + if require_m_boundary_check & require_x_boundary_check: + tl.store(output_block_ptr, output.to(dtype)) + tl.store(state_block_ptr, state.to(dtype)) + tl.store(gate_block_ptr, gate.to(dtype)) + else: + tl.store(output_block_ptr, output.to(dtype), boundary_check=(0, 1)) + tl.store(state_block_ptr, state.to(dtype), boundary_check=(0, 1)) + tl.store(gate_block_ptr, gate.to(dtype), boundary_check=(0, 1)) @staticmethod @triton.jit diff --git a/trident/kernel/linear.py b/trident/kernel/linear.py index a4b6be82..7a29182b 100644 --- a/trident/kernel/linear.py +++ b/trident/kernel/linear.py @@ -12,49 +12,113 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + import triton import triton.language as tl from trident import language, util -def linear_configs(): +def num_warps_and_stages_for_linear(size): + if size >= 2**15: + num_warps = 8 + num_stages = 3 + elif size >= 2**14: + num_warps = 4 + num_stages = 4 + else: + num_warps = 2 + num_stages = 5 + return num_warps, num_stages + + +def linear_configs(m_block_size_list: List[int], n_block_size_list: List[int], k_block_size_list: List[int]): + configs = [] + for m_block_size in m_block_size_list: + for n_block_size in n_block_size_list: + for k_block_size in k_block_size_list: + num_warps, num_stages = num_warps_and_stages_for_linear(m_block_size * n_block_size) + config = triton.Config( + { + "m_block_size": m_block_size, + "n_block_size": n_block_size, + "k_block_size": k_block_size, + }, + num_warps=num_warps, + num_stages=num_stages, + ) + configs.append(config) + + return configs + + +def linear_backward_configs(m_block_size_list: List[int], n_block_size_list: List[int], k_block_size_list: List[int]): configs = [] - for m_block_size in [16, 32, 64]: - for n_block_size in [32, 64]: - for k_block_size in [32, 64, 128, 256]: - for num_stages in [2, 3]: - config = triton.Config( - { - "m_block_size": m_block_size, - "n_block_size": n_block_size, - "k_block_size": k_block_size, - }, - 2 if k_block_size <= 64 else 4, - num_stages, - ) - configs.append(config) + for m_block_size in m_block_size_list: + for n_block_size in n_block_size_list: + for k_block_size in k_block_size_list: + num_warps, num_stages = num_warps_and_stages_for_linear(m_block_size * k_block_size) + config = triton.Config( + { + "m_block_size": m_block_size, + "n_block_size": n_block_size, + "k_block_size": k_block_size, + }, + num_warps=num_warps, + num_stages=num_stages, + ) + configs.append(config) + + return configs + + +def linear_backward_weight_configs( + m_block_size_list: List[int], n_block_size_list: List[int], k_block_size_list: List[int] +): + configs = [] + for m_block_size in m_block_size_list: + for n_block_size in n_block_size_list: + for k_block_size in k_block_size_list: + num_warps, num_stages = num_warps_and_stages_for_linear(n_block_size * k_block_size) + config = triton.Config( + { + "m_block_size": m_block_size, + "n_block_size": n_block_size, + "k_block_size": k_block_size, + }, + num_warps=num_warps, + num_stages=num_stages, + ) + configs.append(config) + return configs def linear_configs_for_backward_bias(): configs = [] - for m_block_size in [32, 64, 128, 256]: + for m_block_size in [32, 64, 128]: for num_stages in [2, 3]: config = triton.Config( - { - "m_block_size": m_block_size, - }, + {"m_block_size": m_block_size}, 2 if m_block_size <= 64 else 4, num_stages, ) configs.append(config) + return configs class Linear: @staticmethod - @util.autotune(configs=linear_configs(), key=["m_size", "n_size", "k_size"]) + @util.autotune(linear_configs([16, 64, 128], [32, 64, 128], [32, 64]), ["m_size", "n_size", "k_size"]) + @triton.heuristics( + { + "require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0, + "require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"] == 0, + "require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0, + } + ) @triton.jit def forward( output_ptr: tl.tensor, @@ -74,6 +138,9 @@ def forward( m_block_size: tl.constexpr, n_block_size: tl.constexpr, k_block_size: tl.constexpr, + require_m_boundary_check: tl.constexpr, + require_n_boundary_check: tl.constexpr, + require_k_boundary_check: tl.constexpr, ): pid = tl.program_id(0) num_m_blocks = tl.cdiv(m_size, m_block_size) @@ -103,8 +170,12 @@ def forward( m_block_size, n_block_size, k_block_size, + require_m_boundary_check, + require_n_boundary_check, + require_k_boundary_check, dtype, ) + output_block_ptr = tl.make_block_ptr( output_ptr + batch * m_size * n_size, shape=(m_size, n_size), @@ -113,10 +184,20 @@ def forward( block_shape=(m_block_size, n_block_size), order=(1, 0), ) - tl.store(output_block_ptr, output, boundary_check=(0, 1)) + if require_m_boundary_check & require_n_boundary_check: + tl.store(output_block_ptr, output) + else: + tl.store(output_block_ptr, output, boundary_check=(0, 1)) @staticmethod - @util.autotune(configs=linear_configs(), key=["m_size", "n_size", "k_size"]) + @util.autotune(linear_backward_configs([64, 128], [32, 64], [32, 64, 128]), ["m_size", "n_size", "k_size"]) + @triton.heuristics( + { + "require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0, + "require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"] == 0, + "require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0, + } + ) @triton.jit def backward( grad_input_ptr: tl.tensor, @@ -134,6 +215,9 @@ def backward( m_block_size: tl.constexpr, n_block_size: tl.constexpr, k_block_size: tl.constexpr, + require_m_boundary_check: tl.constexpr, + require_n_boundary_check: tl.constexpr, + require_k_boundary_check: tl.constexpr, ): pid = tl.program_id(0) num_m_blocks = tl.cdiv(m_size, m_block_size) @@ -160,8 +244,12 @@ def backward( m_block_size, n_block_size, k_block_size, + require_m_boundary_check, + require_n_boundary_check, + require_k_boundary_check, dtype, ) + grad_input_block_ptr = tl.make_block_ptr( grad_input_ptr + batch * m_size * k_size, shape=(m_size, k_size), @@ -170,10 +258,24 @@ def backward( block_shape=(m_block_size, k_block_size), order=(1, 0), ) - tl.store(grad_input_block_ptr, grad_input, boundary_check=(0, 1)) + + if require_m_boundary_check & require_k_boundary_check: + tl.store(grad_input_block_ptr) + else: + tl.store(grad_input_block_ptr, grad_input, boundary_check=(0, 1)) @staticmethod - @util.autotune(configs=linear_configs(), key=["m_size", "n_size", "k_size"]) + @util.autotune( + linear_backward_weight_configs([32, 64], [64, 128], [32, 64, 128]), + ["m_size", "n_size", "k_size"], + ) + @triton.heuristics( + { + "require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0, + "require_n_boundary_check": lambda args: args["n_size"] % args["n_block_size"] == 0, + "require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0, + } + ) @triton.jit def backward_weight( grad_weight_staging_ptr: tl.tensor, @@ -190,6 +292,9 @@ def backward_weight( m_block_size: tl.constexpr, n_block_size: tl.constexpr, k_block_size: tl.constexpr, + require_m_boundary_check: tl.constexpr, + require_n_boundary_check: tl.constexpr, + require_k_boundary_check: tl.constexpr, ): pid = tl.program_id(0) num_n_blocks = tl.cdiv(n_size, n_block_size) @@ -216,8 +321,12 @@ def backward_weight( m_block_size, n_block_size, k_block_size, + require_m_boundary_check, + require_n_boundary_check, + require_k_boundary_check, dtype, ) + grad_weight_staging_block_ptr = tl.make_block_ptr( grad_weight_staging_ptr + batch * n_size * k_size, shape=(n_size, k_size), @@ -226,10 +335,15 @@ def backward_weight( block_shape=(n_block_size, k_block_size), order=(1, 0), ) - tl.store(grad_weight_staging_block_ptr, grad_weight, boundary_check=(0, 1)) + + if require_n_boundary_check & require_k_boundary_check: + tl.store(grad_weight_staging_block_ptr) + else: + tl.store(grad_weight_staging_block_ptr, grad_weight, boundary_check=(0, 1)) @staticmethod - @util.autotune(configs=linear_configs_for_backward_bias(), key=["m_size", "n_size"]) + @util.autotune(linear_configs_for_backward_bias(), ["m_size", "n_size"]) + @triton.heuristics({"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0}) @triton.jit def backward_bias( grad_bias_staging_ptr: tl.tensor, @@ -238,13 +352,21 @@ def backward_bias( n_size: tl.int32, dtype: tl.constexpr, m_block_size: tl.constexpr, + require_m_boundary_check: tl.constexpr, ): pid = tl.program_id(0) batch = pid // n_size n_offset = pid % n_size grad_bias = language.Linear.backward_bias( - grad_output_ptr + batch * m_size * n_size, m_size, n_size, n_offset, m_block_size, dtype + grad_output_ptr + batch * m_size * n_size, + m_size, + n_size, + n_offset, + m_block_size, + require_m_boundary_check, + dtype, ) + grad_bias_staging_block_ptr = tl.make_block_ptr( grad_bias_staging_ptr + batch * n_size, shape=(n_size,), diff --git a/trident/language/linear.py b/trident/language/linear.py index b07ad3fe..4e0f6f23 100644 --- a/trident/language/linear.py +++ b/trident/language/linear.py @@ -36,6 +36,9 @@ def forward( m_block_size: tl.constexpr, n_block_size: tl.constexpr, k_block_size: tl.constexpr, + require_m_boundary_check: tl.constexpr, + require_n_boundary_check: tl.constexpr, + require_k_boundary_check: tl.constexpr, dtype: tl.constexpr, ): input_block_ptr = tl.make_block_ptr( @@ -57,9 +60,18 @@ def forward( output = tl.zeros((m_block_size, n_block_size), dtype) for k_offset in range(0, k_size, k_block_size): - input = tl.load(input_block_ptr, boundary_check=(0, 1), padding_option="zero") - weight = tl.load(weight_block_ptr, boundary_check=(0, 1), padding_option="zero") - output += tl.dot(input, weight, use_accelerator).to(dtype) + if require_k_boundary_check & require_m_boundary_check: + input = tl.load(input_block_ptr) + else: + input = tl.load(input_block_ptr, boundary_check=(0, 1), padding_option="zero") + + if require_k_boundary_check & require_n_boundary_check: + weight = tl.load(weight_block_ptr) + else: + weight = tl.load(weight_block_ptr, boundary_check=(0, 1), padding_option="zero") + + output += tl.dot(input, weight, allow_tf32=use_accelerator, out_dtype=dtype) + input_block_ptr = tl.advance(input_block_ptr, (0, k_block_size)) weight_block_ptr = tl.advance(weight_block_ptr, (k_block_size, 0)) @@ -72,7 +84,10 @@ def forward( block_shape=(n_block_size,), order=(0,), ) - bias = tl.load(bias_block_ptr, boundary_check=(0,), padding_option="zero") + if require_n_boundary_check: + bias = tl.load(bias_block_ptr) + else: + bias = tl.load(bias_block_ptr, boundary_check=(0,), padding_option="zero") output += bias return output @@ -93,6 +108,9 @@ def backward( m_block_size: tl.constexpr, n_block_size: tl.constexpr, k_block_size: tl.constexpr, + require_m_boundary_check: tl.constexpr, + require_n_boundary_check: tl.constexpr, + require_k_boundary_check: tl.constexpr, dtype: tl.constexpr, ): grad_output_block_ptr = tl.make_block_ptr( @@ -111,16 +129,25 @@ def backward( block_shape=(n_block_size, k_block_size), order=(1, 0), ) - grad_input = tl.zeros((m_block_size, k_block_size), tl.float32) + grad_input = tl.zeros((m_block_size, k_block_size), dtype) + + for _ in range(0, n_size, n_block_size): + if require_n_boundary_check & require_m_boundary_check: + grad_output = tl.load(grad_output_block_ptr) + else: + grad_output = tl.load(grad_output_block_ptr, boundary_check=(0, 1), padding_option="zero") + + if require_n_boundary_check & require_k_boundary_check: + weight = tl.load(weight_block_ptr) + else: + weight = tl.load(weight_block_ptr, boundary_check=(0, 1), padding_option="zero") + + grad_input += tl.dot(grad_output, weight, allow_tf32=use_accelerator, out_dtype=dtype) - for n_offset in range(0, n_size, n_block_size): - grad_output = tl.load(grad_output_block_ptr, boundary_check=(0, 1), padding_option="zero") - weight = tl.load(weight_block_ptr, boundary_check=(0, 1), padding_option="zero") - grad_input += tl.dot(grad_output, weight, use_accelerator) grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, n_block_size)) weight_block_ptr = tl.advance(weight_block_ptr, (n_block_size, 0)) - return grad_input.to(dtype) + return grad_input @staticmethod @triton.jit @@ -138,6 +165,9 @@ def backward_weight( m_block_size: tl.constexpr, n_block_size: tl.constexpr, k_block_size: tl.constexpr, + require_m_boundary_check: tl.constexpr, + require_n_boundary_check: tl.constexpr, + require_k_boundary_check: tl.constexpr, dtype: tl.constexpr, ): grad_output_block_ptr = tl.make_block_ptr( @@ -156,25 +186,35 @@ def backward_weight( block_shape=(m_block_size, k_block_size), order=(1, 0), ) - grad_weight = tl.zeros((n_block_size, k_block_size), tl.float32) + grad_weight = tl.zeros((n_block_size, k_block_size), dtype) + + for _ in range(0, m_size, m_block_size): + if require_m_boundary_check & require_n_boundary_check: + grad_output = tl.load(grad_output_block_ptr) + else: + grad_output = tl.load(grad_output_block_ptr, boundary_check=(0, 1), padding_option="zero") + + if require_m_boundary_check & require_k_boundary_check: + input = tl.load(input_block_ptr) + else: + input = tl.load(input_block_ptr, boundary_check=(0, 1), padding_option="zero") + + grad_weight += tl.dot(grad_output, input, allow_tf32=use_accelerator, out_dtype=dtype) - for m_offset in range(0, m_size, m_block_size): - grad_output = tl.load(grad_output_block_ptr, boundary_check=(0, 1), padding_option="zero") - input = tl.load(input_block_ptr, boundary_check=(0, 1), padding_option="zero") - grad_weight += tl.dot(grad_output, input, use_accelerator) grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, m_block_size)) input_block_ptr = tl.advance(input_block_ptr, (m_block_size, 0)) - return grad_weight.to(dtype) + return grad_weight @staticmethod @triton.jit def backward_bias( grad_output_ptr: tl.tensor, - m_size: int, - n_size: int, - n_offset: int, + m_size: tl.int32, + n_size: tl.int32, + n_offset: tl.int32, m_block_size: tl.constexpr, + require_m_boundary_check: tl.constexpr, dtype: tl.constexpr, ): grad_output_block_ptr = tl.make_block_ptr( @@ -188,7 +228,11 @@ def backward_bias( grad_bias = tl.zeros((1, m_block_size), dtype) for m_offset in range(0, m_size, m_block_size): - grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,), padding_option="zero") + if require_m_boundary_check: + grad_output = tl.load(grad_output_block_ptr) + else: + grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,), padding_option="zero") + grad_bias += grad_output grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, m_block_size))