Skip to content

Commit

Permalink
(linter)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilia-cher committed Dec 4, 2024
1 parent ecdf357 commit 835ec72
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,11 +751,15 @@ def forward(
use_fp8 = True
(q_scale, k_scale, v_scale, p_scale, o_scale) = fp8_scales
float8 = torch.float8_e4m3fnuz

def check_and_convert(t, scale):
finfo = torch.finfo(float8)
descale = 1.0 / scale
return (t * descale).clamp(min=finfo.min, max=finfo.max).to(float8) \
if t.dtype != float8 else t
if t.dtype != float8:
finfo = torch.finfo(float8)
descale = 1.0 / scale
ts = (t * descale).clamp(min=finfo.min, max=finfo.max)
return ts.to(float8)
else:
return t

q = check_and_convert(q, q_scale)
k = check_and_convert(k, k_scale)
Expand Down Expand Up @@ -865,7 +869,7 @@ def check_and_convert(t, scale):
BIAS_TYPE=0 if bias is None else 1,
ENABLE_DROPOUT=False,
RETURN_ENCODED_SOFTMAX=False,
USE_FP8 = use_fp8,
USE_FP8=use_fp8,
)

ctx.grid = grid
Expand Down

0 comments on commit 835ec72

Please sign in to comment.