From cfb34179c196004de3b1b74322eb49ce55154c42 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 27 Aug 2024 14:13:12 +0800 Subject: [PATCH] upgrade yunchang version to 0.3.0 and flash_attn to 2.6.0 (#234) --- README.md | 20 +++-- setup.py | 3 +- xfuser/__version__.py | 2 +- .../ring/ring_flash_attn.py | 1 + xfuser/envs.py | 77 +++++++++---------- 5 files changed, 54 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 646941f4..b346b7b1 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,6 @@ The overview of xDiT is shown as follows.

📢 Updates

- * 🎉**August 26, 2024**: We apply torch.compile and [onediff](https://github.com/siliconflow/onediff) nexfort backend to accelerate GPU kernels speed. * 🎉**August 9, 2024**: Support Latte sequence parallel version. The inference scripts are [examples/latte_example](examples/latte_example.py). * 🎉**August 8, 2024**: Support Flux sequence parallel version. The inference scripts are [examples/flux_example](examples/flux_example.py). @@ -91,6 +90,8 @@ The overview of xDiT is shown as follows.

🎯 Supported DiTs

+
+ | Model Name | CFG | SP | PipeFusion | | --- | --- | --- | --- | | [🎬 Latte](https://huggingface.co/maxin-cn/Latte-1) | ❎ | ✔️ | ❎ | @@ -100,6 +101,8 @@ The overview of xDiT is shown as follows. | [🟢 PixArt-alpha](https://huggingface.co/PixArt-alpha/PixArt-alpha) | ✔️ | ✔️ | ✔️ | | [🟠 Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) | ✔️ | ✔️ | ✔️ | +
+ ### Supported by legacy version only: - [🔴 DiT-XL](https://huggingface.co/facebook/DiT-XL-2-256) @@ -163,13 +166,15 @@ Runtime Options: --warmup_steps WARMUP_STEPS Warmup steps in generation. --use_parallel_vae + --use_torch_compile Enable torch.compile to accelerate inference in a single card --seed SEED Random seed for operations. --output_type OUTPUT_TYPE Output type of the pipeline. + --enable_sequential_cpu_offload + Offloading the weights to the CPU. Parallel Processing Options: - --do_classifier_free_guidance - --use_split_batch Use split batch in classifier_free_guidance. cfg_degree will be 2 if set + --use_cfg_parallel Use split batch in classifier_free_guidance. cfg_degree will be 2 if set --data_parallel_degree DATA_PARALLEL_DEGREE Data parallel degree. --ulysses_degree ULYSSES_DEGREE @@ -241,7 +246,6 @@ The (xDiT) highlights the methods first propose xdit methods - The communication and memory costs associated with the aforementioned intra-image parallelism, except for the CFG and DP (they are inter-image parallel), in DiTs are detailed in the table below. (* denotes that communication can be overlapped with computation.) As we can see, PipeFusion and Sequence Parallel achieve lowest communication cost on different scales and hardware configurations, making them suitable foundational components for a hybrid approach. @@ -333,10 +337,10 @@ We also welcome developers to join and contribute more features and models to th } @article{fang2024unified, - title={USP: a Unified Sequence Parallelism Approach for Long Context Generative AI}, - author={Fang, Jiarui and Zhao, Shangchun}, - journal={arXiv preprint arXiv:2405.07719}, - year={2024} + title={USP: a Unified Sequence Parallelism Approach for Long Context Generative AI}, + author={Fang, Jiarui and Zhao, Shangchun}, + journal={arXiv preprint arXiv:2405.07719}, + year={2024} } ``` diff --git a/setup.py b/setup.py index 0b4d4769..97289f0e 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,8 @@ "accelerate==0.33.0", "beautifulsoup4>=4.12.3", "distvae", - "yunchang==0.2", + "yunchang==0.3", + "flash_attn>=2.6.3", ], url="https://github.com/xdit-project/xDiT.", description="xDiT: A Scalable Inference Engine for Diffusion Transformers (DiTs) on multi-GPU Clusters", diff --git a/xfuser/__version__.py b/xfuser/__version__.py index 6a35e852..260c070a 100644 --- a/xfuser/__version__.py +++ b/xfuser/__version__.py @@ -1 +1 @@ -__version__ = "0.3" +__version__ = "0.3.1" diff --git a/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py b/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py index 6f4d91a9..141b4c83 100644 --- a/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py +++ b/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py @@ -81,6 +81,7 @@ def ring_flash_attn_forward( softmax_scale, causal=causal and step == 0, window_size=window_size, + softcap=0.0, alibi_slopes=alibi_slopes, return_softmax=True and dropout_p > 0, ) diff --git a/xfuser/envs.py b/xfuser/envs.py index 568efdc6..2a7e1797 100644 --- a/xfuser/envs.py +++ b/xfuser/envs.py @@ -20,49 +20,35 @@ environment_variables: Dict[str, Callable[[], Any]] = { - # ================== Runtime Env Vars ================== - # used in distributed environment to determine the master address - 'MASTER_ADDR': - lambda: os.getenv('MASTER_ADDR', ""), - + "MASTER_ADDR": lambda: os.getenv("MASTER_ADDR", ""), # used in distributed environment to manually set the communication port - 'MASTER_PORT': - lambda: int(os.getenv('MASTER_PORT', '0')) - if 'MASTER_PORT' in os.environ else None, - + "MASTER_PORT": lambda: ( + int(os.getenv("MASTER_PORT", "0")) if "MASTER_PORT" in os.environ else None + ), # path to cudatoolkit home directory, under which should be bin, include, # and lib directories. - "CUDA_HOME": - lambda: os.environ.get("CUDA_HOME", None), - + "CUDA_HOME": lambda: os.environ.get("CUDA_HOME", None), # local rank of the process in the distributed setting, used to determine # the GPU device id - "LOCAL_RANK": - lambda: int(os.environ.get("LOCAL_RANK", "0")), - + "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), # used to control the visible devices in the distributed setting - "CUDA_VISIBLE_DEVICES": - lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), - + "CUDA_VISIBLE_DEVICES": lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), # this is used for configuring the default logging level - "XDIT_LOGGING_LEVEL": - lambda: os.getenv("XDIT_LOGGING_LEVEL", "INFO"), + "XDIT_LOGGING_LEVEL": lambda: os.getenv("XDIT_LOGGING_LEVEL", "INFO"), } variables: Dict[str, Callable[[], Any]] = { - # ================== Other Vars ================== - # used in version checking - 'CUDA_VERSION': - lambda: version.parse(torch.version.cuda), - - 'TORCH_VERSION': - lambda: version.parse(version.parse(torch.__version__).base_version), + "CUDA_VERSION": lambda: version.parse(torch.version.cuda), + "TORCH_VERSION": lambda: version.parse( + version.parse(torch.__version__).base_version + ), } + class PackagesEnvChecker: _instance = None @@ -74,11 +60,10 @@ def __new__(cls): def initialize(self): self.packages_info = { - 'has_flash_attn': self.check_flash_attn(), - 'has_long_ctx_attn': self.check_long_ctx_attn(), - 'diffusers_version': self.check_diffusers_version(), - } - + "has_flash_attn": self.check_flash_attn(), + "has_long_ctx_attn": self.check_long_ctx_attn(), + "diffusers_version": self.check_diffusers_version(), + } def check_flash_attn(self): try: @@ -88,10 +73,16 @@ def check_flash_attn(self): return False else: from flash_attn import flash_attn_func + from flash_attn import __version__ + + if __version__ < "2.6.0": + raise ImportError(f"install flash_attn >= 2.6.0") return True except ImportError: - logger.warning(f'Flash Attention library "flash_attn" not found, ' - f'using pytorch attention implementation') + logger.warning( + f'Flash Attention library "flash_attn" not found, ' + f"using pytorch attention implementation" + ) return False def check_long_ctx_attn(self): @@ -103,21 +94,29 @@ def check_long_ctx_attn(self): LongContextAttention, LongContextAttentionQKVPacked, ) + return True except ImportError: - logger.warning(f'Ring Flash Attention library "yunchang" not found, ' - f'using pytorch attention implementation') + logger.warning( + f'Ring Flash Attention library "yunchang" not found, ' + f"using pytorch attention implementation" + ) return False def check_diffusers_version(self): - if version.parse(version.parse(diffusers.__version__).base_version) < version.parse("0.30.0"): - raise RuntimeError(f"Diffusers version: {version.parse(version.parse(diffusers.__version__).base_version)} is not supported," - f"please upgrade to version > 0.30.0") + if version.parse( + version.parse(diffusers.__version__).base_version + ) < version.parse("0.30.0"): + raise RuntimeError( + f"Diffusers version: {version.parse(version.parse(diffusers.__version__).base_version)} is not supported," + f"please upgrade to version > 0.30.0" + ) return version.parse(version.parse(diffusers.__version__).base_version) def get_packages_info(self): return self.packages_info + PACKAGES_CHECKER = PackagesEnvChecker()