Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Miscellaneous fixes for FA3 attention #1174

Merged
merged 29 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bcdc4d1
add qkv descales to FA3
cyanguwa Sep 10, 2024
3ed49d0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2024
1db61e2
fix sbhd shapes
cyanguwa Sep 17, 2024
6a86660
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2024
7da4b6c
Merge branch 'main' into add_descales
cyanguwa Sep 17, 2024
de3db0a
Merge branch 'main' into add_descales
cyanguwa Sep 18, 2024
19e7f87
force the same dtype when comparing FA3 and cuDNN FP8
cyanguwa Sep 18, 2024
bff80b6
Revert "force the same dtype when comparing FA3 and cuDNN FP8"
cyanguwa Sep 18, 2024
68b9b48
force the same dtype when comparing FA3 and cuDNN FP8
cyanguwa Sep 18, 2024
0553a83
add try/except for FA3 when custom qkv descales are not supported
cyanguwa Sep 18, 2024
b73760b
replace FA3 installation warning with a debug logging message
cyanguwa Sep 19, 2024
66cc6f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2024
3269685
fix lint
cyanguwa Sep 19, 2024
39a4e1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2024
5bcc355
Merge branch 'main' into add_descales
cyanguwa Sep 19, 2024
3dbee25
Merge branch 'main' into add_descales
cyanguwa Sep 27, 2024
c01a5b2
Merge branch 'NVIDIA:main' into add_descales
cyanguwa Oct 1, 2024
12dc8a9
Merge branch 'main' into add_descales
cyanguwa Oct 3, 2024
336a452
remove unused imports
cyanguwa Oct 3, 2024
2e140c5
avoid varlen_func for FP8 and improve messaging
cyanguwa Oct 3, 2024
1138edf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2024
4095be8
add SWA support for FA3
cyanguwa Oct 3, 2024
8e2bcc2
fix lint
cyanguwa Oct 3, 2024
7bf4936
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2024
b765f3d
change preference reason for FP8 logic
cyanguwa Oct 6, 2024
a4030e8
minor fixes
cyanguwa Oct 7, 2024
569532a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2024
e907ad7
Merge branch 'main' into add_descales
cyanguwa Oct 7, 2024
f006a25
minor fix
cyanguwa Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
Expand Down
2 changes: 2 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,8 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
try:
if a.dtype != b.dtype:
a = a.to(b.dtype)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
Expand Down
172 changes: 104 additions & 68 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing

# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)

_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
Expand All @@ -100,29 +110,31 @@
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
_flash_attn_3_plus = False
_use_flash_attn_3 = False
_flash_attn_3_installation_steps = """\
(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(3) mkdir -p $python_path/flashattn_hopper
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
try:
_flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
_flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1")
_flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.9")
_flash_attn_3_0_0_beta = _flash_attn_3_plus and _flash_attn_v3_version < PkgVersion("3.0.0")
except PackageNotFoundError:
if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN:
warnings.warn(
"To use flash-attn v3, please use the following commands to install: \n"
"""(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n"""
"""(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n"""
"""(3) mkdir -p $python_path/flashattn_hopper \n"""
"""(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
fa3_logger = logging.getLogger()
fa3_logger.setLevel(_log_level)
if not fa3_logger.hasHandlers():
fa3_logger.addHandler(_stream_handler)
fa3_logger.debug(
"To use flash-attn v3, please follow these steps to install the flashattn-hopper "
"package: \n%s",
_flash_attn_3_installation_steps,
)
else:
from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flashattn_hopper.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import
_flash_attn_forward as _flash_attn_forward_v3,
)
from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import
_flash_attn_backward as _flash_attn_backward_v3,
)

_use_flash_attn_3 = True

Expand All @@ -132,18 +144,6 @@
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd


# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)

_attention_backends = {
"attention_params": None,
"use_flash_attention": None,
Expand Down Expand Up @@ -348,7 +348,7 @@ def get_attention_backend(
use_fused_attention = False

# Filter: Compute capability
global _flash_attn_3_plus, _use_flash_attn_3
global _use_flash_attn_3
if device_compute_capability < (8, 0):
if use_flash_attention:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
Expand All @@ -357,7 +357,7 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
use_fused_attention = False
if device_compute_capability < (9, 0):
if use_flash_attention and _flash_attn_3_plus:
if use_flash_attention and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
_use_flash_attn_3 = False

Expand Down Expand Up @@ -438,10 +438,9 @@ def get_attention_backend(
use_flash_attention = False

# Filter: Dropout
if attention_dropout != 0.0 and use_flash_attention:
if _flash_attn_3_plus and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for dropout")
_use_flash_attn_3 = False
if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for dropout")
_use_flash_attn_3 = False

# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
Expand All @@ -461,7 +460,7 @@ def get_attention_backend(
)
use_unfused_attention = False
if context_parallel and use_flash_attention:
if _flash_attn_3_plus and _use_flash_attn_3:
if _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for context parallelism")
_use_flash_attn_3 = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
Expand Down Expand Up @@ -556,7 +555,7 @@ def get_attention_backend(
use_fused_attention = False
if (
use_flash_attention
and _flash_attn_3_plus
and _use_flash_attn_3
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
Expand Down Expand Up @@ -590,6 +589,15 @@ def get_attention_backend(
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if (
use_flash_attention
and _use_flash_attn_3
and fp8
and fp8_meta["recipe"].fp8_dpa
and "padding" in attn_mask_type
):
logger.debug("Disabling FlashAttention 3 for FP8 and padding masks")
_use_flash_attn_3 = False

# Filter: Sliding window attention
# backend | window_size | diagonal alignment
Expand Down Expand Up @@ -633,15 +641,6 @@ def get_attention_backend(
attn_mask_type,
)
use_fused_attention = False
if (
use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0])
and _flash_attn_3_plus
):
logger.debug(
"Disabling FlashAttention 3 as it does not support sliding window attention"
)
_use_flash_attn_3 = False
if (
use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0])
Expand All @@ -662,12 +661,12 @@ def get_attention_backend(
# UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
# | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias
if use_flash_attention and core_attention_bias_type == "alibi":
if _flash_attn_3_plus and _use_flash_attn_3:
if _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for ALiBi")
_use_flash_attn_3 = False
if not _flash_attn_2_4_plus:
logger.debug("Disabling FlashAttention for ALiBi")
use_flash_attention = False
if not _use_flash_attn_3 and not _flash_attn_2_4_plus:
logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
use_flash_attention = False

if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"]
Expand Down Expand Up @@ -827,19 +826,15 @@ def get_attention_backend(
"for performance reasons"
)
use_flash_attention = False

# Select FusedAttention for FP8
# FA3 uses default scaling factors (i.e. 1) in FP8 execution, while FusedAttention takes
# scaling factors from `fp8_meta` and offers more accurate quantization/de-quantization
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["FP8"]
and _use_flash_attn_3
):
logger.debug(
"Disabling FlashAttention 3 to give FusedAttention preference as FusedAttention "
"supports more accurate scaling factors in FP8 execution"
"Disabling FlashAttention 3 to give FusedAttention preference for performance reasons "
"in FP8 execution"
)
use_flash_attention = False

Expand Down Expand Up @@ -4963,6 +4958,10 @@ def __init__(
self.attention_type = attention_type
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic
self.logger = logging.getLogger("FlashAttention")
self.logger.setLevel(_log_level)
if not self.logger.hasHandlers():
self.logger.addHandler(_stream_handler)

def forward(
self,
Expand Down Expand Up @@ -5033,6 +5032,10 @@ def forward(
x.transpose(0, 1)
for x in (query_layer._data, key_layer._data, value_layer._data)
]
query_layer, key_layer, value_layer = [
Float8Tensor.make_like(x, data=x._data)
for x in (query_layer, key_layer, value_layer)
]
if context_parallel:
query_layer._data, key_layer._data, value_layer._data = [
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
Expand Down Expand Up @@ -5168,33 +5171,62 @@ def forward(
fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv)
if _use_flash_attn_3:
fa_3_optional_forward_kwargs = {}
fa_3_optional_forward_kwargs["window_size"] = window_size
fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
activation_dtype = query_layer.dtype
torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)

def convert_to_torch_float8(tensor, dtype):
out = torch.Tensor().to(device=tensor.device, dtype=dtype)
out.set_(
tensor._data.untyped_storage(),
tensor._data.storage_offset(),
tensor._data.shape,
tensor._data.stride(),
)
return out

if fp8_meta["recipe"].fp8_mha:
assert all(
isinstance(x, Float8Tensor)
for x in [query_layer, key_layer, value_layer]
), "q/k/v must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv
query_layer, key_layer, value_layer = (
x.to(activation_dtype).to(torch_dtype)
for x in [query_layer, key_layer, value_layer]
)
else:
query_layer, key_layer, value_layer = (
x.to(torch_dtype) for x in [query_layer, key_layer, value_layer]
Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward)
for x in [query_layer, key_layer, value_layer]
)
output, _ = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
deterministic=self.deterministic,
)
fa_3_optional_forward_kwargs["descale_q"] = query_layer._scale_inv
fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv
fa_3_optional_forward_kwargs["descale_v"] = value_layer._scale_inv
query_layer, key_layer, value_layer = (
convert_to_torch_float8(x, torch_dtype)
for x in [query_layer, key_layer, value_layer]
)
try:
output, _ = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
**fa_3_optional_forward_kwargs,
)
except TypeError as e:
if _flash_attn_3_0_0_beta:
e.args = (
e.args[0]
+ ". Please update your FlashAttention 3 (beta) installation as it "
+ "may have added more supported arguments to its API. \n"
+ _flash_attn_3_installation_steps,
) + e.args[1:]
raise

if fp8 and fp8_meta["recipe"].fp8_mha:
output = cast_to_fp8(
output,
Expand Down Expand Up @@ -5228,8 +5260,12 @@ def forward(
if qkv_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd)
if fp8 and fp8_meta["recipe"].fp8_mha:
output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d()
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
output = Float8Tensor.make_like(
output,
data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1)
.contiguous(),
cyanguwa marked this conversation as resolved.
Show resolved Hide resolved
)
else:
output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
elif qkv_format == "bshd":
Expand Down
Loading