Skip to content

Commit

Permalink
force the same dtype when comparing FA3 and cuDNN FP8
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Sep 18, 2024
1 parent de3db0a commit 19e7f87
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
2 changes: 2 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,8 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
try:
if a.dtype != b.dtype:
a = a.to(b.dtype)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
Expand Down
40 changes: 30 additions & 10 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4828,6 +4828,10 @@ def __init__(
self.attention_type = attention_type
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic
self.logger = logging.getLogger("FlashAttention")
self.logger.setLevel(_log_level)
if not self.logger.hasHandlers():
self.logger.addHandler(_stream_handler)

def forward(
self,
Expand Down Expand Up @@ -5067,16 +5071,32 @@ def convert_to_torch_float8(tensor, dtype):
convert_to_torch_float8(x, torch_dtype)
for x in [query_layer, key_layer, value_layer]
)
output, _ = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
deterministic=self.deterministic,
**fa_optional_forward_kwargs_fp8,
)
try:
output, _ = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
deterministic=self.deterministic,
**fa_optional_forward_kwargs_fp8,
)
except TypeError:
self.logger.debug(
"Running with default q, k, v descales, i.e. 1s. To enable custom "
"descales, please install flashattn-hopper (FA3) with this PR: "
"https://github.com/Dao-AILab/flash-attention/pull/1210."
)
output, _ = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
deterministic=self.deterministic,
)
if fp8 and fp8_meta["recipe"].fp8_mha:
output = cast_to_fp8(
output,
Expand Down

0 comments on commit 19e7f87

Please sign in to comment.