From e384a4ec68a790e55d740512df24b58fb9e52903 Mon Sep 17 00:00:00 2001 From: danny jang Date: Wed, 20 Sep 2023 12:51:02 +0900 Subject: [PATCH] Support bfloat16 in Linear and Attention --- tests/test_attention.py | 3 --- tests/test_geglu.py | 3 --- tests/test_linear.py | 3 --- trident/kernel/attention.py | 15 +++++++-------- trident/language/__init__.py | 1 + trident/language/linear.py | 12 ++++++------ trident/language/standard.py | 24 ++++++++++++++++++++++++ 7 files changed, 38 insertions(+), 23 deletions(-) create mode 100644 trident/language/standard.py diff --git a/tests/test_attention.py b/tests/test_attention.py index 985436e5..e83060a3 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -62,9 +62,6 @@ def train(func): "num_batches, num_heads, y_size, x_size, is_causal", [(1, 1, 1, 16, True), (1, 1, 1, 16, False)] ) def test_attention(num_batches, num_heads, y_size, x_size, is_causal, device, dtype): - if dtype is torch.bfloat16: - pytest.skip("Triton has a bug.") - factory_kwargs = {"device": device, "dtype": dtype} query = torch.rand(num_batches, num_heads, y_size, x_size, **factory_kwargs, requires_grad=True) key = torch.randn_like(query, requires_grad=True) diff --git a/tests/test_geglu.py b/tests/test_geglu.py index 90922997..0232abc2 100644 --- a/tests/test_geglu.py +++ b/tests/test_geglu.py @@ -86,9 +86,6 @@ def train(func): @pytest.mark.parametrize("num_batches, m_size, n_size, k_size", [(1, 16, 16, 16)]) def test_geglu(num_batches, m_size, n_size, k_size, device, dtype): - if dtype is torch.bfloat16: - pytest.skip("Triton has a bug.") - factory_kwargs = {"device": device, "dtype": dtype} x_size = n_size // 2 input = torch.randn(num_batches, m_size, k_size, **factory_kwargs, requires_grad=True) diff --git a/tests/test_linear.py b/tests/test_linear.py index 12ae5db8..fa99d612 100644 --- a/tests/test_linear.py +++ b/tests/test_linear.py @@ -84,9 +84,6 @@ def train(func): @pytest.mark.parametrize("m_size, n_size, k_size", [(32, 32, 32)]) def test_linear(m_size, n_size, k_size, device, dtype): - if dtype is torch.bfloat16: - pytest.skip("Triton has a bug.") - factory_kwargs = {"device": device, "dtype": dtype} input = torch.randn(m_size, k_size, **factory_kwargs, requires_grad=True) weight = torch.randn(n_size, k_size, **factory_kwargs, requires_grad=True) diff --git a/trident/kernel/attention.py b/trident/kernel/attention.py index dcfefe70..3dc108c7 100644 --- a/trident/kernel/attention.py +++ b/trident/kernel/attention.py @@ -120,7 +120,7 @@ def forward( score = tl.where(condition, score, float("-inf")) key = tl.load(key_block_ptr) - score += tl.dot(query, key, allow_tf32=use_accelerator, out_dtype=dtype) + score += language.dot(query, key, use_accelerator, dtype) peak = tl.maximum(max, tl.max(score, 1)) alpha = tl.math.exp2(max - peak) beta = tl.math.exp2(score - peak[:, None]) @@ -128,8 +128,7 @@ def forward( max = peak output *= alpha[:, None].to(dtype) value = tl.load(value_block_ptr) - gamma = tl.dot(beta.to(dtype), value, allow_tf32=use_accelerator, out_dtype=dtype) - output += gamma.to(dtype) + output += language.dot(beta.to(dtype), value, use_accelerator, dtype) key_block_ptr = tl.advance(key_block_ptr, (0, y_block_size)) value_block_ptr = tl.advance(value_block_ptr, (y_block_size, 0)) @@ -254,7 +253,7 @@ def backward( else: score = tl.zeros((y_block_size, y_block_size), dtype) - score += tl.dot(query, tl.trans(key), allow_tf32=use_accelerator, out_dtype=dtype) + score += language.dot(query, tl.trans(key), use_accelerator, dtype) score *= score_scale log2sum = tl.load(log2sum_block_ptr) alpha = tl.math.exp2(score - log2sum[:, None]).to(dtype) @@ -266,14 +265,14 @@ def backward( grad_dropout = tl.where(output > 0.0, dropout_scale, 0.0).to(dtype) grad_output *= grad_dropout - grad_value += tl.dot(tl.trans(alpha), grad_output, allow_tf32=use_accelerator, out_dtype=dtype) + grad_value += language.dot(tl.trans(alpha), grad_output, use_accelerator, dtype) delta = tl.load(delta_block_ptr) grad_alpha = tl.zeros((y_block_size, y_block_size), dtype) - delta[:, None] - grad_alpha += tl.dot(grad_output, tl.trans(value), allow_tf32=use_accelerator, out_dtype=dtype) + grad_alpha += language.dot(grad_output, tl.trans(value), use_accelerator, dtype) grad_softmax = (alpha * grad_alpha * softmax_scale).to(dtype) - grad_key += tl.dot(tl.trans(grad_softmax), query, allow_tf32=use_accelerator, out_dtype=dtype) + grad_key += language.dot(tl.trans(grad_softmax), query, use_accelerator, dtype) grad_query = tl.load(grad_query_block_ptr) - grad_query += tl.dot(grad_softmax, key, allow_tf32=use_accelerator, out_dtype=dtype) + grad_query += language.dot(grad_softmax, key, use_accelerator, dtype) tl.store(grad_query_block_ptr, grad_query) grad_query_block_ptr = tl.advance(grad_query_block_ptr, (y_block_size, 0)) diff --git a/trident/language/__init__.py b/trident/language/__init__.py index 0347f0de..e1ed90b6 100644 --- a/trident/language/__init__.py +++ b/trident/language/__init__.py @@ -17,6 +17,7 @@ from .constant import * from .linear import * from .mean import * +from .standard import * from .sum import * from .var import * from .var_mean import * diff --git a/trident/language/linear.py b/trident/language/linear.py index 4e0f6f23..95efd85b 100644 --- a/trident/language/linear.py +++ b/trident/language/linear.py @@ -15,6 +15,8 @@ import triton import triton.language as tl +from trident import language + class Linear: @staticmethod @@ -70,8 +72,7 @@ def forward( else: weight = tl.load(weight_block_ptr, boundary_check=(0, 1), padding_option="zero") - output += tl.dot(input, weight, allow_tf32=use_accelerator, out_dtype=dtype) - + output += language.dot(input, weight, use_accelerator, dtype) input_block_ptr = tl.advance(input_block_ptr, (0, k_block_size)) weight_block_ptr = tl.advance(weight_block_ptr, (k_block_size, 0)) @@ -88,6 +89,7 @@ def forward( bias = tl.load(bias_block_ptr) else: bias = tl.load(bias_block_ptr, boundary_check=(0,), padding_option="zero") + output += bias return output @@ -142,8 +144,7 @@ def backward( else: weight = tl.load(weight_block_ptr, boundary_check=(0, 1), padding_option="zero") - grad_input += tl.dot(grad_output, weight, allow_tf32=use_accelerator, out_dtype=dtype) - + grad_input += language.dot(grad_output, weight, use_accelerator, dtype) grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, n_block_size)) weight_block_ptr = tl.advance(weight_block_ptr, (n_block_size, 0)) @@ -199,8 +200,7 @@ def backward_weight( else: input = tl.load(input_block_ptr, boundary_check=(0, 1), padding_option="zero") - grad_weight += tl.dot(grad_output, input, allow_tf32=use_accelerator, out_dtype=dtype) - + grad_weight += language.dot(grad_output, input, use_accelerator, dtype) grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, m_block_size)) input_block_ptr = tl.advance(input_block_ptr, (m_block_size, 0)) diff --git a/trident/language/standard.py b/trident/language/standard.py new file mode 100644 index 00000000..55d05a20 --- /dev/null +++ b/trident/language/standard.py @@ -0,0 +1,24 @@ +# Copyright 2023 ⓒ Kakao Brain Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import triton +import triton.language as tl + + +@triton.jit +def dot(input: tl.tensor, other: tl.tensor, use_accelerator: tl.constexpr, dtype: tl.constexpr): + if dtype is tl.bfloat16: + return tl.dot(input, other, allow_tf32=use_accelerator, out_dtype=tl.float16).to(dtype) + else: + return tl.dot(input, other, allow_tf32=use_accelerator, out_dtype=dtype)