From 92187b8bffeb4dcff608d6b018b0431773a43715 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 26 Dec 2024 15:43:51 +0800 Subject: [PATCH] remove dependency on flash_attn (#410) --- setup.py | 6 +- xfuser/config/config.py | 4 - xfuser/core/fast_attention/attn_layer.py | 8 +- xfuser/core/long_ctx_attention/__init__.py | 2 - .../long_ctx_attention/hybrid/attn_layer.py | 7 + .../ring/ring_flash_attn.py | 10 +- .../long_ctx_attention/ulysses/__init__.py | 5 - .../long_ctx_attention/ulysses/attn_layer.py | 168 ------------------ .../layers/attention_processor.py | 43 ++--- .../transformers/consisid_transformer_3d.py | 6 +- .../models/transformers/register.py | 2 + .../pipelines/pipeline_consisid.py | 18 +- 12 files changed, 66 insertions(+), 213 deletions(-) delete mode 100644 xfuser/core/long_ctx_attention/ulysses/__init__.py delete mode 100644 xfuser/core/long_ctx_attention/ulysses/attn_layer.py diff --git a/setup.py b/setup.py index 87a5d7fe..c933a467 100644 --- a/setup.py +++ b/setup.py @@ -32,19 +32,19 @@ def get_cuda_version(): "sentencepiece>=0.1.99", "beautifulsoup4>=4.12.3", "distvae", - "yunchang>=0.3.0", + "yunchang>=0.6.0", "pytest", "flask", "opencv-python", "imageio", "imageio-ffmpeg", "optimum-quanto", - "flash_attn>=2.6.3", "ray" ], extras_require={ "diffusers": [ - "diffusers>=0.31.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux + "diffusers>=0.32.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux + "flash_attn>=2.6.3", ] }, url="https://github.com/xdit-project/xDiT.", diff --git a/xfuser/config/config.py b/xfuser/config/config.py index 801949ea..ce1c87f8 100644 --- a/xfuser/config/config.py +++ b/xfuser/config/config.py @@ -130,10 +130,6 @@ def __post_init__(self): f"sp_degree is {self.sp_degree}, please set it " f"to 1 or install 'yunchang' to use it" ) - if not HAS_FLASH_ATTN and self.ring_degree > 1: - raise ValueError( - f"Flash attention not found. Ring attention not available. Please set ring_degree to 1" - ) @dataclass diff --git a/xfuser/core/fast_attention/attn_layer.py b/xfuser/core/fast_attention/attn_layer.py index e82c4bfa..0d306d40 100644 --- a/xfuser/core/fast_attention/attn_layer.py +++ b/xfuser/core/fast_attention/attn_layer.py @@ -7,7 +7,12 @@ from diffusers.models.attention_processor import Attention from typing import Optional import torch.nn.functional as F -import flash_attn + +try: + import flash_attn +except ImportError: + flash_attn = None + from enum import Flag, auto from .fast_attn_state import get_fast_attn_window_size @@ -165,6 +170,7 @@ def __call__( is_causal=False, ).transpose(1, 2) elif method.has(FastAttnMethod.FULL_ATTN): + assert flash_attn is not None, f"FlashAttention is not available, please install flash_attn" all_hidden_states = flash_attn.flash_attn_func(query, key, value) if need_compute_residual: # Compute the full-window attention residual diff --git a/xfuser/core/long_ctx_attention/__init__.py b/xfuser/core/long_ctx_attention/__init__.py index 2fe9d667..cdca17a0 100644 --- a/xfuser/core/long_ctx_attention/__init__.py +++ b/xfuser/core/long_ctx_attention/__init__.py @@ -1,7 +1,5 @@ from .hybrid import xFuserLongContextAttention -from .ulysses import xFuserUlyssesAttention __all__ = [ "xFuserLongContextAttention", - "xFuserUlyssesAttention", ] diff --git a/xfuser/core/long_ctx_attention/hybrid/attn_layer.py b/xfuser/core/long_ctx_attention/hybrid/attn_layer.py index d4fadf78..a459630a 100644 --- a/xfuser/core/long_ctx_attention/hybrid/attn_layer.py +++ b/xfuser/core/long_ctx_attention/hybrid/attn_layer.py @@ -3,6 +3,11 @@ import torch.distributed from yunchang import LongContextAttention +try: + from yunchang.kernels import AttnType +except ImportError: + raise ImportError("Please install yunchang 0.6.0 or later") + from yunchang.comm.all_to_all import SeqAllToAll4D from xfuser.logger import init_logger @@ -21,6 +26,7 @@ def __init__( ring_impl_type: str = "basic", use_pack_qkv: bool = False, use_kv_cache: bool = False, + attn_type: AttnType = AttnType.FA, ) -> None: """ Arguments: @@ -35,6 +41,7 @@ def __init__( gather_idx=gather_idx, ring_impl_type=ring_impl_type, use_pack_qkv=use_pack_qkv, + attn_type = attn_type, ) self.use_kv_cache = use_kv_cache if ( diff --git a/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py b/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py index 9e8b116d..a4e8a501 100644 --- a/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py +++ b/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py @@ -1,11 +1,16 @@ import torch -import flash_attn -from flash_attn.flash_attn_interface import _flash_attn_forward + from xfuser.core.long_ctx_attention import xFuserLongContextAttention from xfuser.core.cache_manager.cache_manager import get_cache_manager from yunchang.ring.utils import RingComm, update_out_and_lse from yunchang.ring.ring_flash_attn import RingFlashAttnFunc +try: + import flash_attn + from flash_attn.flash_attn_interface import _flash_attn_forward +except ImportError: + flash_attn = None + _flash_attn_forward = None def xdit_ring_flash_attn_forward( process_group, @@ -80,6 +85,7 @@ def xdit_ring_flash_attn_forward( key, value = k, v if not causal or step <= comm.rank: + assert flash_attn is not None, f"FlashAttention is not available, please install flash_attn" if flash_attn.__version__ <= "2.6.3": block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( q, diff --git a/xfuser/core/long_ctx_attention/ulysses/__init__.py b/xfuser/core/long_ctx_attention/ulysses/__init__.py deleted file mode 100644 index 5b11ebdf..00000000 --- a/xfuser/core/long_ctx_attention/ulysses/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .attn_layer import xFuserUlyssesAttention - -__all__ = [ - "xFuserUlyssesAttention", -] diff --git a/xfuser/core/long_ctx_attention/ulysses/attn_layer.py b/xfuser/core/long_ctx_attention/ulysses/attn_layer.py deleted file mode 100644 index ff8788f1..00000000 --- a/xfuser/core/long_ctx_attention/ulysses/attn_layer.py +++ /dev/null @@ -1,168 +0,0 @@ -from typing import Any - -import torch -import torch.distributed as dist -from torch import Tensor - -from xfuser.core.cache_manager.cache_manager import get_cache_manager -from yunchang import UlyssesAttention -from yunchang.globals import PROCESS_GROUP -from yunchang.comm.all_to_all import SeqAllToAll4D -try: - # yunchang > 0.4.0 - from yunchang.kernels.attention import torch_attn -except: - print(f"detect you are not use the latest yunchang. Please install yunchang>=0.4.0") - try: - from yunchang.ulysses.attn_layer import torch_attn - except: - raise ImportError(f"yunchang import torch_attn error") - - -class xFuserUlyssesAttention(UlyssesAttention): - def __init__( - self, - scatter_idx: int = 2, - gather_idx: int = 1, - use_fa: bool = True, - use_kv_cache: bool = True, - ) -> None: - - super(UlyssesAttention, self).__init__() - self.ulysses_pg = PROCESS_GROUP.ULYSSES_PG - - self.scatter_idx = scatter_idx - self.gather_idx = gather_idx - self.use_fa = use_fa - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - gpu_name = torch.cuda.get_device_name(device) - if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name: - self.use_fa = False - self.use_kv_cache = use_kv_cache - - if self.use_fa: - from flash_attn import flash_attn_func - - self.fn = flash_attn_func - else: - self.fn = torch_attn - - def forward( - self, - attn, - query: Tensor, - key: Tensor, - value: Tensor, - *, - joint_tensor_query=None, - joint_tensor_key=None, - joint_tensor_value=None, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - joint_strategy="none", - ) -> Tensor: - """forward - - Arguments: - query (Tensor): query input to the layer - key (Tensor): key input to the layer - value (Tensor): value input to the layer - args: other args - - Returns: - * output (Tensor): context output - """ - if ( - joint_tensor_key is not None - and joint_tensor_value is not None - and joint_tensor_query is not None - ): - if joint_strategy == "rear": - query = torch.cat([query, joint_tensor_query], dim=1) - elif joint_strategy == "front": - query = torch.cat([joint_tensor_query, query], dim=1) - elif joint_strategy == "none": - raise ValueError( - f"joint_strategy: {joint_strategy} not supported when joint tensors is not None." - ) - else: - raise ValueError(f"joint_strategy: {joint_strategy} not supported.") - ulysses_world_size = torch.distributed.get_world_size(self.ulysses_pg) - ulysses_rank = torch.distributed.get_rank(self.ulysses_pg) - attn_heads_per_ulysses_rank = ( - joint_tensor_key.shape[-2] // ulysses_world_size - ) - joint_tensor_key = joint_tensor_key[ - ..., - attn_heads_per_ulysses_rank - * ulysses_rank : attn_heads_per_ulysses_rank - * (ulysses_rank + 1), - :, - ] - joint_tensor_value = joint_tensor_value[ - ..., - attn_heads_per_ulysses_rank - * ulysses_rank : attn_heads_per_ulysses_rank - * (ulysses_rank + 1), - :, - ] - - # TODO Merge three alltoall calls into one - # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! - # in shape : e.g., [s/p:h:] - # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) - - # scatter 2, gather 1 - q = SeqAllToAll4D.apply( - self.ulysses_pg, query, self.scatter_idx, self.gather_idx - ) - k = SeqAllToAll4D.apply(self.ulysses_pg, key, self.scatter_idx, self.gather_idx) - v = SeqAllToAll4D.apply( - self.ulysses_pg, value, self.scatter_idx, self.gather_idx - ) - - if self.use_kv_cache: - k, v = get_cache_manager().update_and_get_kv_cache( - new_kv=[k, v], - layer=attn, - slice_dim=1, - layer_type="attn", - ) - - if joint_strategy != "none": - if joint_strategy == "rear": - k = torch.cat([k, joint_tensor_key], dim=1) - v = torch.cat([v, joint_tensor_value], dim=1) - - elif joint_strategy == "front": - k = torch.cat([joint_tensor_key, k], dim=1) - v = torch.cat([joint_tensor_value, v], dim=1) - - context_layer = self.fn( - q, - k, - v, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=return_attn_probs, - ) - - if isinstance(context_layer, tuple): - context_layer = context_layer[0] - - # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) - # scatter 1, gather 2 - output = SeqAllToAll4D.apply( - self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx - ) - - # out e.g., [s/p::h] - return output diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index b7e6e5ad..2deaca63 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -187,7 +187,6 @@ def __init__(self): if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: from xfuser.core.long_ctx_attention import ( xFuserLongContextAttention, - xFuserUlyssesAttention, ) if HAS_FLASH_ATTN: @@ -196,9 +195,11 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache ) else: - self.hybrid_seq_parallel_attn = xFuserUlyssesAttention( - use_fa=False, + from yunchang.kernels import AttnType + assert yunchang.__version__ >= "0.6.0" + self.hybrid_seq_parallel_attn = xFuserLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache, + attn_type=AttnType.TORCH, ) else: self.hybrid_seq_parallel_attn = None @@ -395,7 +396,6 @@ def __init__(self): if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: from xfuser.core.long_ctx_attention import ( xFuserLongContextAttention, - xFuserUlyssesAttention, ) if HAS_FLASH_ATTN: @@ -403,9 +403,10 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache ) else: - self.hybrid_seq_parallel_attn = xFuserUlyssesAttention( - use_fa=False, + from yunchang.kernels import AttnType + self.hybrid_seq_parallel_attn = xFuserLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache, + attn_type=AttnType.TORCH, ) if get_fast_attn_enable(): @@ -588,7 +589,6 @@ def __init__(self): if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: from xfuser.core.long_ctx_attention import ( xFuserLongContextAttention, - xFuserUlyssesAttention, ) if HAS_FLASH_ATTN: @@ -596,9 +596,10 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache ) else: - self.hybrid_seq_parallel_attn = xFuserUlyssesAttention( - use_fa=False, + from yunchang.kernels import AttnType + self.hybrid_seq_parallel_attn = xFuserLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache, + attn_type=AttnType.TORCH, ) def __call__( @@ -789,7 +790,6 @@ def __init__(self): if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: from xfuser.core.long_ctx_attention import ( xFuserLongContextAttention, - xFuserUlyssesAttention, ) if HAS_FLASH_ATTN: @@ -797,9 +797,10 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache ) else: - self.hybrid_seq_parallel_attn = xFuserUlyssesAttention( - use_fa=False, + from yunchang.kernels import AttnType + self.hybrid_seq_parallel_attn = xFuserLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache, + attn_type=AttnType.TORCH, ) else: self.hybrid_seq_parallel_attn = None @@ -991,7 +992,6 @@ def __init__(self): if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: from xfuser.core.long_ctx_attention import ( xFuserLongContextAttention, - xFuserUlyssesAttention, ) if HAS_FLASH_ATTN: @@ -999,9 +999,10 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache ) else: - self.hybrid_seq_parallel_attn = xFuserUlyssesAttention( - use_fa=False, + from yunchang.kernels import AttnType + self.hybrid_seq_parallel_attn = xFuserLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache, + attn_type=AttnType.TORCH, ) else: self.hybrid_seq_parallel_attn = None @@ -1168,7 +1169,6 @@ def __init__(self): if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: from xfuser.core.long_ctx_attention import ( xFuserLongContextAttention, - xFuserUlyssesAttention, ) if HAS_FLASH_ATTN: @@ -1176,9 +1176,10 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache ) else: - self.hybrid_seq_parallel_attn = xFuserUlyssesAttention( - use_fa=False, + from yunchang.kernels import AttnType + self.hybrid_seq_parallel_attn = xFuserLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache, + attn_type=AttnType.TORCH, ) else: self.hybrid_seq_parallel_attn = None @@ -1340,7 +1341,6 @@ def __init__(self): if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: from xfuser.core.long_ctx_attention import ( xFuserLongContextAttention, - xFuserUlyssesAttention, ) if HAS_FLASH_ATTN: @@ -1348,9 +1348,10 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache ) else: - self.hybrid_seq_parallel_attn = xFuserUlyssesAttention( - use_fa=False, + from yunchang.kernels import AttnType + self.hybrid_seq_parallel_attn = xFuserLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache, + attn_type=AttnType.TORCH, ) else: self.hybrid_seq_parallel_attn = None diff --git a/xfuser/model_executor/models/transformers/consisid_transformer_3d.py b/xfuser/model_executor/models/transformers/consisid_transformer_3d.py index 75346334..16f7227b 100644 --- a/xfuser/model_executor/models/transformers/consisid_transformer_3d.py +++ b/xfuser/model_executor/models/transformers/consisid_transformer_3d.py @@ -5,7 +5,11 @@ from diffusers.models.embeddings import CogVideoXPatchEmbed -from diffusers.models import ConsisIDTransformer3DModel +try: + from diffusers.models import ConsisIDTransformer3DModel +except ImportError: + ConsisIDTransformer3DModel = None + from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers diff --git a/xfuser/model_executor/models/transformers/register.py b/xfuser/model_executor/models/transformers/register.py index 33c45319..6e0a4790 100644 --- a/xfuser/model_executor/models/transformers/register.py +++ b/xfuser/model_executor/models/transformers/register.py @@ -40,6 +40,8 @@ def get_wrapper(cls, transformer: nn.Module) -> xFuserTransformerBaseWrapper: origin_transformer_class, wrapper_class, ) in cls._XFUSER_TRANSFORMER_MAPPING.items(): + if origin_transformer_class is None: + continue if isinstance(transformer, origin_transformer_class): if ( candidate is None diff --git a/xfuser/model_executor/pipelines/pipeline_consisid.py b/xfuser/model_executor/pipelines/pipeline_consisid.py index 44ce0604..06344e2d 100644 --- a/xfuser/model_executor/pipelines/pipeline_consisid.py +++ b/xfuser/model_executor/pipelines/pipeline_consisid.py @@ -3,15 +3,21 @@ import torch import torch.distributed -from diffusers import ConsisIDPipeline -from diffusers.pipelines.consisid.pipeline_consisid import ( - ConsisIDPipelineOutput, - retrieve_timesteps, -) + from diffusers.schedulers import CogVideoXDPMScheduler from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.image_processor import PipelineImageInput +try: + from diffusers import ConsisIDPipeline + + from diffusers.pipelines.consisid.pipeline_consisid import ( + ConsisIDPipelineOutput, + retrieve_timesteps, + ) +except ImportError: + ConsisIDPipeline = None + import math import cv2 import PIL @@ -124,7 +130,7 @@ def __call__( id_vit_hidden: Optional[torch.Tensor] = None, id_cond: Optional[torch.Tensor] = None, kps_cond: Optional[torch.Tensor] = None, - ) -> Union[ConsisIDPipelineOutput, Tuple]: + ) -> Union['ConsisIDPipelineOutput', Tuple]: """ Function invoked when calling the pipeline for generation.