Skip to content

Commit

Permalink
upgrade yunchang version to 0.3.0 and flash_attn to 2.6.0 (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Aug 27, 2024
1 parent 4838960 commit cfb3417
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 49 deletions.
20 changes: 12 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ The overview of xDiT is shown as follows.

<h2 id="updates">📢 Updates</h2>


* 🎉**August 26, 2024**: We apply torch.compile and [onediff](https://github.com/siliconflow/onediff) nexfort backend to accelerate GPU kernels speed.
* 🎉**August 9, 2024**: Support Latte sequence parallel version. The inference scripts are [examples/latte_example](examples/latte_example.py).
* 🎉**August 8, 2024**: Support Flux sequence parallel version. The inference scripts are [examples/flux_example](examples/flux_example.py).
Expand All @@ -91,6 +90,8 @@ The overview of xDiT is shown as follows.

<h2 id="support-dits">🎯 Supported DiTs</h2>

<div align="center">

| Model Name | CFG | SP | PipeFusion |
| --- | --- | --- | --- |
| [🎬 Latte](https://huggingface.co/maxin-cn/Latte-1) || ✔️ ||
Expand All @@ -100,6 +101,8 @@ The overview of xDiT is shown as follows.
| [🟢 PixArt-alpha](https://huggingface.co/PixArt-alpha/PixArt-alpha) | ✔️ | ✔️ | ✔️ |
| [🟠 Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) | ✔️ | ✔️ | ✔️ |

</div>

### Supported by legacy version only:

- [🔴 DiT-XL](https://huggingface.co/facebook/DiT-XL-2-256)
Expand Down Expand Up @@ -163,13 +166,15 @@ Runtime Options:
--warmup_steps WARMUP_STEPS
Warmup steps in generation.
--use_parallel_vae
--use_torch_compile Enable torch.compile to accelerate inference in a single card
--seed SEED Random seed for operations.
--output_type OUTPUT_TYPE
Output type of the pipeline.
--enable_sequential_cpu_offload
Offloading the weights to the CPU.

Parallel Processing Options:
--do_classifier_free_guidance
--use_split_batch Use split batch in classifier_free_guidance. cfg_degree will be 2 if set
--use_cfg_parallel Use split batch in classifier_free_guidance. cfg_degree will be 2 if set
--data_parallel_degree DATA_PARALLEL_DEGREE
Data parallel degree.
--ulysses_degree ULYSSES_DEGREE
Expand Down Expand Up @@ -241,7 +246,6 @@ The (<span style="color: red;">xDiT</span>) highlights the methods first propose
<img src="assets/methods/xdit_method.png" alt="xdit methods">
</div>
The communication and memory costs associated with the aforementioned intra-image parallelism, except for the CFG and DP (they are inter-image parallel), in DiTs are detailed in the table below. (* denotes that communication can be overlapped with computation.)
As we can see, PipeFusion and Sequence Parallel achieve lowest communication cost on different scales and hardware configurations, making them suitable foundational components for a hybrid approach.
Expand Down Expand Up @@ -333,10 +337,10 @@ We also welcome developers to join and contribute more features and models to th
}
@article{fang2024unified,
title={USP: a Unified Sequence Parallelism Approach for Long Context Generative AI},
author={Fang, Jiarui and Zhao, Shangchun},
journal={arXiv preprint arXiv:2405.07719},
year={2024}
title={USP: a Unified Sequence Parallelism Approach for Long Context Generative AI},
author={Fang, Jiarui and Zhao, Shangchun},
journal={arXiv preprint arXiv:2405.07719},
year={2024}
}
```
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
"accelerate==0.33.0",
"beautifulsoup4>=4.12.3",
"distvae",
"yunchang==0.2",
"yunchang==0.3",
"flash_attn>=2.6.3",
],
url="https://github.com/xdit-project/xDiT.",
description="xDiT: A Scalable Inference Engine for Diffusion Transformers (DiTs) on multi-GPU Clusters",
Expand Down
2 changes: 1 addition & 1 deletion xfuser/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3"
__version__ = "0.3.1"
1 change: 1 addition & 0 deletions xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def ring_flash_attn_forward(
softmax_scale,
causal=causal and step == 0,
window_size=window_size,
softcap=0.0,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
)
Expand Down
77 changes: 38 additions & 39 deletions xfuser/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,49 +20,35 @@


environment_variables: Dict[str, Callable[[], Any]] = {

# ================== Runtime Env Vars ==================

# used in distributed environment to determine the master address
'MASTER_ADDR':
lambda: os.getenv('MASTER_ADDR', ""),

"MASTER_ADDR": lambda: os.getenv("MASTER_ADDR", ""),
# used in distributed environment to manually set the communication port
'MASTER_PORT':
lambda: int(os.getenv('MASTER_PORT', '0'))
if 'MASTER_PORT' in os.environ else None,

"MASTER_PORT": lambda: (
int(os.getenv("MASTER_PORT", "0")) if "MASTER_PORT" in os.environ else None
),
# path to cudatoolkit home directory, under which should be bin, include,
# and lib directories.
"CUDA_HOME":
lambda: os.environ.get("CUDA_HOME", None),

"CUDA_HOME": lambda: os.environ.get("CUDA_HOME", None),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK":
lambda: int(os.environ.get("LOCAL_RANK", "0")),

"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),
# used to control the visible devices in the distributed setting
"CUDA_VISIBLE_DEVICES":
lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),

"CUDA_VISIBLE_DEVICES": lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),
# this is used for configuring the default logging level
"XDIT_LOGGING_LEVEL":
lambda: os.getenv("XDIT_LOGGING_LEVEL", "INFO"),
"XDIT_LOGGING_LEVEL": lambda: os.getenv("XDIT_LOGGING_LEVEL", "INFO"),
}

variables: Dict[str, Callable[[], Any]] = {

# ================== Other Vars ==================

# used in version checking
'CUDA_VERSION':
lambda: version.parse(torch.version.cuda),

'TORCH_VERSION':
lambda: version.parse(version.parse(torch.__version__).base_version),
"CUDA_VERSION": lambda: version.parse(torch.version.cuda),
"TORCH_VERSION": lambda: version.parse(
version.parse(torch.__version__).base_version
),
}


class PackagesEnvChecker:
_instance = None

Expand All @@ -74,11 +60,10 @@ def __new__(cls):

def initialize(self):
self.packages_info = {
'has_flash_attn': self.check_flash_attn(),
'has_long_ctx_attn': self.check_long_ctx_attn(),
'diffusers_version': self.check_diffusers_version(),
}

"has_flash_attn": self.check_flash_attn(),
"has_long_ctx_attn": self.check_long_ctx_attn(),
"diffusers_version": self.check_diffusers_version(),
}

def check_flash_attn(self):
try:
Expand All @@ -88,10 +73,16 @@ def check_flash_attn(self):
return False
else:
from flash_attn import flash_attn_func
from flash_attn import __version__

if __version__ < "2.6.0":
raise ImportError(f"install flash_attn >= 2.6.0")
return True
except ImportError:
logger.warning(f'Flash Attention library "flash_attn" not found, '
f'using pytorch attention implementation')
logger.warning(
f'Flash Attention library "flash_attn" not found, '
f"using pytorch attention implementation"
)
return False

def check_long_ctx_attn(self):
Expand All @@ -103,21 +94,29 @@ def check_long_ctx_attn(self):
LongContextAttention,
LongContextAttentionQKVPacked,
)

return True
except ImportError:
logger.warning(f'Ring Flash Attention library "yunchang" not found, '
f'using pytorch attention implementation')
logger.warning(
f'Ring Flash Attention library "yunchang" not found, '
f"using pytorch attention implementation"
)
return False

def check_diffusers_version(self):
if version.parse(version.parse(diffusers.__version__).base_version) < version.parse("0.30.0"):
raise RuntimeError(f"Diffusers version: {version.parse(version.parse(diffusers.__version__).base_version)} is not supported,"
f"please upgrade to version > 0.30.0")
if version.parse(
version.parse(diffusers.__version__).base_version
) < version.parse("0.30.0"):
raise RuntimeError(
f"Diffusers version: {version.parse(version.parse(diffusers.__version__).base_version)} is not supported,"
f"please upgrade to version > 0.30.0"
)
return version.parse(version.parse(diffusers.__version__).base_version)

def get_packages_info(self):
return self.packages_info


PACKAGES_CHECKER = PackagesEnvChecker()


Expand Down

0 comments on commit cfb3417

Please sign in to comment.