Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Support bfloat16 in Linear and Attention
Browse files Browse the repository at this point in the history
  • Loading branch information
daemyung committed Sep 20, 2023
1 parent 300a11e commit e384a4e
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 23 deletions.
3 changes: 0 additions & 3 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ def train(func):
"num_batches, num_heads, y_size, x_size, is_causal", [(1, 1, 1, 16, True), (1, 1, 1, 16, False)]
)
def test_attention(num_batches, num_heads, y_size, x_size, is_causal, device, dtype):
if dtype is torch.bfloat16:
pytest.skip("Triton has a bug.")

factory_kwargs = {"device": device, "dtype": dtype}
query = torch.rand(num_batches, num_heads, y_size, x_size, **factory_kwargs, requires_grad=True)
key = torch.randn_like(query, requires_grad=True)
Expand Down
3 changes: 0 additions & 3 deletions tests/test_geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,6 @@ def train(func):

@pytest.mark.parametrize("num_batches, m_size, n_size, k_size", [(1, 16, 16, 16)])
def test_geglu(num_batches, m_size, n_size, k_size, device, dtype):
if dtype is torch.bfloat16:
pytest.skip("Triton has a bug.")

factory_kwargs = {"device": device, "dtype": dtype}
x_size = n_size // 2
input = torch.randn(num_batches, m_size, k_size, **factory_kwargs, requires_grad=True)
Expand Down
3 changes: 0 additions & 3 deletions tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ def train(func):

@pytest.mark.parametrize("m_size, n_size, k_size", [(32, 32, 32)])
def test_linear(m_size, n_size, k_size, device, dtype):
if dtype is torch.bfloat16:
pytest.skip("Triton has a bug.")

factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(m_size, k_size, **factory_kwargs, requires_grad=True)
weight = torch.randn(n_size, k_size, **factory_kwargs, requires_grad=True)
Expand Down
15 changes: 7 additions & 8 deletions trident/kernel/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,15 @@ def forward(
score = tl.where(condition, score, float("-inf"))

key = tl.load(key_block_ptr)
score += tl.dot(query, key, allow_tf32=use_accelerator, out_dtype=dtype)
score += language.dot(query, key, use_accelerator, dtype)
peak = tl.maximum(max, tl.max(score, 1))
alpha = tl.math.exp2(max - peak)
beta = tl.math.exp2(score - peak[:, None])
sum = sum * alpha + tl.sum(beta, 1)
max = peak
output *= alpha[:, None].to(dtype)
value = tl.load(value_block_ptr)
gamma = tl.dot(beta.to(dtype), value, allow_tf32=use_accelerator, out_dtype=dtype)
output += gamma.to(dtype)
output += language.dot(beta.to(dtype), value, use_accelerator, dtype)
key_block_ptr = tl.advance(key_block_ptr, (0, y_block_size))
value_block_ptr = tl.advance(value_block_ptr, (y_block_size, 0))

Expand Down Expand Up @@ -254,7 +253,7 @@ def backward(
else:
score = tl.zeros((y_block_size, y_block_size), dtype)

score += tl.dot(query, tl.trans(key), allow_tf32=use_accelerator, out_dtype=dtype)
score += language.dot(query, tl.trans(key), use_accelerator, dtype)
score *= score_scale
log2sum = tl.load(log2sum_block_ptr)
alpha = tl.math.exp2(score - log2sum[:, None]).to(dtype)
Expand All @@ -266,14 +265,14 @@ def backward(
grad_dropout = tl.where(output > 0.0, dropout_scale, 0.0).to(dtype)
grad_output *= grad_dropout

grad_value += tl.dot(tl.trans(alpha), grad_output, allow_tf32=use_accelerator, out_dtype=dtype)
grad_value += language.dot(tl.trans(alpha), grad_output, use_accelerator, dtype)
delta = tl.load(delta_block_ptr)
grad_alpha = tl.zeros((y_block_size, y_block_size), dtype) - delta[:, None]
grad_alpha += tl.dot(grad_output, tl.trans(value), allow_tf32=use_accelerator, out_dtype=dtype)
grad_alpha += language.dot(grad_output, tl.trans(value), use_accelerator, dtype)
grad_softmax = (alpha * grad_alpha * softmax_scale).to(dtype)
grad_key += tl.dot(tl.trans(grad_softmax), query, allow_tf32=use_accelerator, out_dtype=dtype)
grad_key += language.dot(tl.trans(grad_softmax), query, use_accelerator, dtype)
grad_query = tl.load(grad_query_block_ptr)
grad_query += tl.dot(grad_softmax, key, allow_tf32=use_accelerator, out_dtype=dtype)
grad_query += language.dot(grad_softmax, key, use_accelerator, dtype)
tl.store(grad_query_block_ptr, grad_query)

grad_query_block_ptr = tl.advance(grad_query_block_ptr, (y_block_size, 0))
Expand Down
1 change: 1 addition & 0 deletions trident/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .constant import *
from .linear import *
from .mean import *
from .standard import *
from .sum import *
from .var import *
from .var_mean import *
12 changes: 6 additions & 6 deletions trident/language/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import triton
import triton.language as tl

from trident import language


class Linear:
@staticmethod
Expand Down Expand Up @@ -70,8 +72,7 @@ def forward(
else:
weight = tl.load(weight_block_ptr, boundary_check=(0, 1), padding_option="zero")

output += tl.dot(input, weight, allow_tf32=use_accelerator, out_dtype=dtype)

output += language.dot(input, weight, use_accelerator, dtype)
input_block_ptr = tl.advance(input_block_ptr, (0, k_block_size))
weight_block_ptr = tl.advance(weight_block_ptr, (k_block_size, 0))

Expand All @@ -88,6 +89,7 @@ def forward(
bias = tl.load(bias_block_ptr)
else:
bias = tl.load(bias_block_ptr, boundary_check=(0,), padding_option="zero")

output += bias

return output
Expand Down Expand Up @@ -142,8 +144,7 @@ def backward(
else:
weight = tl.load(weight_block_ptr, boundary_check=(0, 1), padding_option="zero")

grad_input += tl.dot(grad_output, weight, allow_tf32=use_accelerator, out_dtype=dtype)

grad_input += language.dot(grad_output, weight, use_accelerator, dtype)
grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, n_block_size))
weight_block_ptr = tl.advance(weight_block_ptr, (n_block_size, 0))

Expand Down Expand Up @@ -199,8 +200,7 @@ def backward_weight(
else:
input = tl.load(input_block_ptr, boundary_check=(0, 1), padding_option="zero")

grad_weight += tl.dot(grad_output, input, allow_tf32=use_accelerator, out_dtype=dtype)

grad_weight += language.dot(grad_output, input, use_accelerator, dtype)
grad_output_block_ptr = tl.advance(grad_output_block_ptr, (0, m_block_size))
input_block_ptr = tl.advance(input_block_ptr, (m_block_size, 0))

Expand Down
24 changes: 24 additions & 0 deletions trident/language/standard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2023 ⓒ Kakao Brain Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import triton
import triton.language as tl


@triton.jit
def dot(input: tl.tensor, other: tl.tensor, use_accelerator: tl.constexpr, dtype: tl.constexpr):
if dtype is tl.bfloat16:
return tl.dot(input, other, allow_tf32=use_accelerator, out_dtype=tl.float16).to(dtype)
else:
return tl.dot(input, other, allow_tf32=use_accelerator, out_dtype=dtype)

0 comments on commit e384a4e

Please sign in to comment.