diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index b69aed6648..7b21b997cd 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -13,7 +13,7 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py -NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py +NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index c33cd1ab16..e6a36104ad 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1319,6 +1319,8 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item())) logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item())) try: + if a.dtype != b.dtype: + a = a.to(b.dtype) torch.testing.assert_close(a, b, atol=atol, rtol=rtol) except Exception as e: logging.debug(e) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b4e04fc685..7aa8f0def1 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -85,6 +85,16 @@ from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo from transformer_engine.pytorch.graph import is_graph_capturing +# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 +_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) +# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 +_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) +_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL +_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} +_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2] +_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") +_stream_handler = logging.StreamHandler() +_stream_handler.setFormatter(_formatter) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) @@ -100,29 +110,31 @@ _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") _flash_attn_3_plus = False _use_flash_attn_3 = False +_flash_attn_3_installation_steps = """\ +(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" +(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` +(3) mkdir -p $python_path/flashattn_hopper +(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" try: _flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper")) - _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1") + _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.9") + _flash_attn_3_0_0_beta = _flash_attn_3_plus and _flash_attn_v3_version < PkgVersion("3.0.0") except PackageNotFoundError: if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN: - warnings.warn( - "To use flash-attn v3, please use the following commands to install: \n" - """(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" - """(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" - """(3) mkdir -p $python_path/flashattn_hopper \n""" - """(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" + fa3_logger = logging.getLogger() + fa3_logger.setLevel(_log_level) + if not fa3_logger.hasHandlers(): + fa3_logger.addHandler(_stream_handler) + fa3_logger.debug( + "To use flash-attn v3, please follow these steps to install the flashattn-hopper " + "package: \n%s", + _flash_attn_3_installation_steps, ) else: from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flashattn_hopper.flash_attn_interface import ( flash_attn_varlen_func as flash_attn_varlen_func_v3, ) - from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import - _flash_attn_forward as _flash_attn_forward_v3, - ) - from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import - _flash_attn_backward as _flash_attn_backward_v3, - ) _use_flash_attn_3 = True @@ -132,18 +144,6 @@ from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd - -# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 -_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) -# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 -_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2] -_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") -_stream_handler = logging.StreamHandler() -_stream_handler.setFormatter(_formatter) - _attention_backends = { "attention_params": None, "use_flash_attention": None, @@ -348,7 +348,7 @@ def get_attention_backend( use_fused_attention = False # Filter: Compute capability - global _flash_attn_3_plus, _use_flash_attn_3 + global _use_flash_attn_3 if device_compute_capability < (8, 0): if use_flash_attention: logger.debug("Disabling FlashAttention as it requires compute capability sm80+") @@ -357,7 +357,7 @@ def get_attention_backend( logger.debug("Disabling FusedAttention as it requires compute capability sm80+") use_fused_attention = False if device_compute_capability < (9, 0): - if use_flash_attention and _flash_attn_3_plus: + if use_flash_attention and _use_flash_attn_3: logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+") _use_flash_attn_3 = False @@ -438,10 +438,9 @@ def get_attention_backend( use_flash_attention = False # Filter: Dropout - if attention_dropout != 0.0 and use_flash_attention: - if _flash_attn_3_plus and _use_flash_attn_3: - logger.debug("Disabling FlashAttention 3 for dropout") - _use_flash_attn_3 = False + if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3: + logger.debug("Disabling FlashAttention 3 for dropout") + _use_flash_attn_3 = False # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends @@ -461,7 +460,7 @@ def get_attention_backend( ) use_unfused_attention = False if context_parallel and use_flash_attention: - if _flash_attn_3_plus and _use_flash_attn_3: + if _use_flash_attn_3: logger.debug("Disabling FlashAttention 3 for context parallelism") _use_flash_attn_3 = False if fp8 and fp8_meta["recipe"].fp8_dpa: @@ -556,7 +555,7 @@ def get_attention_backend( use_fused_attention = False if ( use_flash_attention - and _flash_attn_3_plus + and _use_flash_attn_3 and attn_mask_type in ["causal", "padding_causal"] and max_seqlen_q != max_seqlen_kv ): @@ -590,6 +589,15 @@ def get_attention_backend( "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" ) use_flash_attention = False + if ( + use_flash_attention + and _use_flash_attn_3 + and fp8 + and fp8_meta["recipe"].fp8_dpa + and "padding" in attn_mask_type + ): + logger.debug("Disabling FlashAttention 3 for FP8 and padding masks") + _use_flash_attn_3 = False # Filter: Sliding window attention # backend | window_size | diagonal alignment @@ -633,15 +641,6 @@ def get_attention_backend( attn_mask_type, ) use_fused_attention = False - if ( - use_flash_attention - and (window_size[0] != -1 or window_size[1] not in [-1, 0]) - and _flash_attn_3_plus - ): - logger.debug( - "Disabling FlashAttention 3 as it does not support sliding window attention" - ) - _use_flash_attn_3 = False if ( use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]) @@ -662,12 +661,12 @@ def get_attention_backend( # UnfusedDotProductAttention | no_bias, pre/post_scale_bias | # | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias if use_flash_attention and core_attention_bias_type == "alibi": - if _flash_attn_3_plus and _use_flash_attn_3: + if _use_flash_attn_3: logger.debug("Disabling FlashAttention 3 for ALiBi") _use_flash_attn_3 = False - if not _flash_attn_2_4_plus: - logger.debug("Disabling FlashAttention for ALiBi") - use_flash_attention = False + if not _use_flash_attn_3 and not _flash_attn_2_4_plus: + logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") + use_flash_attention = False if use_flash_attention and ( core_attention_bias_type not in ["no_bias", "alibi"] @@ -827,10 +826,6 @@ def get_attention_backend( "for performance reasons" ) use_flash_attention = False - - # Select FusedAttention for FP8 - # FA3 uses default scaling factors (i.e. 1) in FP8 execution, while FusedAttention takes - # scaling factors from `fp8_meta` and offers more accurate quantization/de-quantization if ( use_flash_attention and use_fused_attention @@ -838,8 +833,8 @@ def get_attention_backend( and _use_flash_attn_3 ): logger.debug( - "Disabling FlashAttention 3 to give FusedAttention preference as FusedAttention " - "supports more accurate scaling factors in FP8 execution" + "Disabling FlashAttention 3 to give FusedAttention preference for performance reasons " + "in FP8 execution" ) use_flash_attention = False @@ -4963,6 +4958,10 @@ def __init__( self.attention_type = attention_type self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic + self.logger = logging.getLogger("FlashAttention") + self.logger.setLevel(_log_level) + if not self.logger.hasHandlers(): + self.logger.addHandler(_stream_handler) def forward( self, @@ -5033,6 +5032,10 @@ def forward( x.transpose(0, 1) for x in (query_layer._data, key_layer._data, value_layer._data) ] + query_layer, key_layer, value_layer = [ + Float8Tensor.make_like(x, data=x._data) + for x in (query_layer, key_layer, value_layer) + ] if context_parallel: query_layer._data, key_layer._data, value_layer._data = [ x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) @@ -5168,33 +5171,62 @@ def forward( fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_kv) if _use_flash_attn_3: + fa_3_optional_forward_kwargs = {} + fa_3_optional_forward_kwargs["window_size"] = window_size + fa_3_optional_forward_kwargs["deterministic"] = self.deterministic if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) activation_dtype = query_layer.dtype torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) + + def convert_to_torch_float8(tensor, dtype): + out = torch.Tensor().to(device=tensor.device, dtype=dtype) + out.set_( + tensor._data.untyped_storage(), + tensor._data.storage_offset(), + tensor._data.shape, + tensor._data.stride(), + ) + return out + if fp8_meta["recipe"].fp8_mha: assert all( isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] ), "q/k/v must be Float8Tensors for FP8 MHA." fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv - query_layer, key_layer, value_layer = ( - x.to(activation_dtype).to(torch_dtype) - for x in [query_layer, key_layer, value_layer] - ) else: query_layer, key_layer, value_layer = ( - x.to(torch_dtype) for x in [query_layer, key_layer, value_layer] + Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward) + for x in [query_layer, key_layer, value_layer] ) - output, _ = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - deterministic=self.deterministic, - ) + fa_3_optional_forward_kwargs["descale_q"] = query_layer._scale_inv + fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv + fa_3_optional_forward_kwargs["descale_v"] = value_layer._scale_inv + query_layer, key_layer, value_layer = ( + convert_to_torch_float8(x, torch_dtype) + for x in [query_layer, key_layer, value_layer] + ) + try: + output, _ = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + **fa_3_optional_forward_kwargs, + ) + except TypeError as e: + if _flash_attn_3_0_0_beta: + e.args = ( + e.args[0] + + ". Please update your FlashAttention 3 (beta) installation as it " + + "may have added more supported arguments to its API. \n" + + _flash_attn_3_installation_steps, + ) + e.args[1:] + raise + if fp8 and fp8_meta["recipe"].fp8_mha: output = cast_to_fp8( output, @@ -5228,8 +5260,12 @@ def forward( if qkv_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) if fp8 and fp8_meta["recipe"].fp8_mha: - output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d() - output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) + output = Float8Tensor.make_like( + output, + data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) + .transpose(0, 1) + .contiguous(), + ) else: output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1) elif qkv_format == "bshd":