diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index d110dece53..14456010b4 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1319,6 +1319,8 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item())) logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item())) try: + if a.dtype != b.dtype: + a = a.to(b.dtype) torch.testing.assert_close(a, b, atol=atol, rtol=rtol) except Exception as e: logging.debug(e) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 817f4bb62e..4dd70ade4e 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4828,6 +4828,10 @@ def __init__( self.attention_type = attention_type self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic + self.logger = logging.getLogger("FlashAttention") + self.logger.setLevel(_log_level) + if not self.logger.hasHandlers(): + self.logger.addHandler(_stream_handler) def forward( self, @@ -5067,16 +5071,32 @@ def convert_to_torch_float8(tensor, dtype): convert_to_torch_float8(x, torch_dtype) for x in [query_layer, key_layer, value_layer] ) - output, _ = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - deterministic=self.deterministic, - **fa_optional_forward_kwargs_fp8, - ) + try: + output, _ = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + deterministic=self.deterministic, + **fa_optional_forward_kwargs_fp8, + ) + except TypeError: + self.logger.debug( + "Running with default q, k, v descales, i.e. 1s. To enable custom " + "descales, please install flashattn-hopper (FA3) with this PR: " + "https://github.com/Dao-AILab/flash-attention/pull/1210." + ) + output, _ = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + deterministic=self.deterministic, + ) if fp8 and fp8_meta["recipe"].fp8_mha: output = cast_to_fp8( output,