Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docs for adding new models #403

Merged
merged 26 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d631586
split text in MM-DiT
xibosun Nov 22, 2024
d48b78b
check the use_parallel_vae flag for CogVideo
xibosun Nov 22, 2024
d35e495
fix dimensions in all_gather
xibosun Nov 27, 2024
faafcd1
optimizations on H100
xibosun Nov 27, 2024
a41b7c1
support optimized USP in Flux
xibosun Nov 28, 2024
55b8711
do not split text if undivisible by sp_degree
xibosun Nov 28, 2024
726f402
polish optimized USP
xibosun Nov 28, 2024
61b4b90
update diffusers versio in setup.py
xibosun Nov 28, 2024
b8b0b10
merge to main
xibosun Nov 28, 2024
d43176d
fix bugs
xibosun Nov 29, 2024
5d7c886
unify USP interface
xibosun Nov 29, 2024
3c17dea
optimized USP in CogVideo
xibosun Nov 29, 2024
cc1f2da
use optimized USP in cogvideo
xibosun Dec 3, 2024
73a071b
add CogVideoX1.5-5B performance on H20 and L20
xibosun Dec 3, 2024
7a687df
merge upstream main
xibosun Dec 3, 2024
8d6de01
rename files and update docs
xibosun Dec 3, 2024
c2fdea2
Merge remote-tracking branch 'upstream/main' into text_slice
xibosun Dec 5, 2024
0172a69
decouple retime state from USP
xibosun Dec 5, 2024
d45155f
add doc for adding new models
xibosun Dec 17, 2024
e5e21c7
Merge remote-tracking branch 'upstream/main' into text_slice
xibosun Dec 17, 2024
3fd6809
fix typos
xibosun Dec 17, 2024
a5318fb
add docs for adding models
xibosun Dec 18, 2024
a88cf5b
fix docs for adding models
xibosun Dec 18, 2024
57db9de
fix docs for adding new models
xibosun Dec 18, 2024
564a483
fix docs for adding models
xibosun Dec 19, 2024
01a68a2
add figure to illustrate USP
xibosun Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions docs/developer/adding_models/adding_model_cfg.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Parallelize new models with CFG parallelism provided by xDiT

This tutorial focuses on utilizing CFG parallelism in the context of the CogVideoX text-to-video model. It provides step-by-step instructions on how to apply CFG parallelism to a new DiT model.

The diffusion process involves receiving Gaussian noise as input, iteratively predicting and denoising using the *CogVideoX Transformer*, and generating the output video. This process, typically executed on a single GPU within `diffusers`, is outlined in the following figure.

<div align="center">
<img src="https://raw.githubusercontent.com/xdit-project/xdit_assets/main/developer/single-gpu-cfg.png"
alt="single-gpu.png">
</div>

The Transformer's input comprises timesteps, a text sequence, and an image sequence. CogVideoX employs classifier-free guidance (CFG) to enhance video quality. During each iteration, the model not only feeds the timesteps, text sequence, and image sequence into the transformer but also generates an empty text sequence. This, along with the original timesteps and image sequence, is forwarded to the transformer, enabling the model to combine the two outputs for noise prediction at the iteration's end. Consequently, when a single prompt is passed to the model, the timesteps, text sequence, and image sequence each have a batch size of 2.

CFG parallelism, depicted in the following figure, leverages 2 GPUs to process the two batches. At the beginning of each iteration, CFG parallelism splits the input tensor by the batch dimension, distributes each part to a GPU. At the end of the iteration, the two GPUs communicate through the `all_gather` primitive.


<div align="center">
<img src="https://raw.githubusercontent.com/xdit-project/xdit_assets/main/developer/multiple-gpus-cfg.png"
alt="multiple-gpus.png">
</div>

Note that, for DiT models with no CFG functionality, such as Flux and HunyuanVideo, CFG parallelism cannot be applied.

To accelerate CogVideoX inference using CFG parallelism, two modifications to the original diffusion process are required. Firstly, the xDiT environment should be initialized at the beginning of the program. This requires several function such as `init_distributed_environment`, `initialize_model_parallel`, and `get_world_group` provided by xDiT. Secondly, in `diffusers`, the CogVideoX model is encapsulated within the `CogVideoXTransformer3DModel` class located at [diffusers/models/transformers/cogvideox_transformer_3d.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/cogvideox_transformer_3d.py), and it is reqired to split and merge seqeunces before and after the `forward` function of `CogVideoXTransformer3DModel`.

## 1. Initialization

Begin by setting up the distributed environment with the following code snippet:

```python
from xfuser.core.distributed import init_distributed_environment
dist.init_process_group("nccl")
init_distributed_environment(
rank=dist.get_rank(),
world_size=dist.get_world_size()
)
# Ensure world size is 2 for CFG parallelism
```

Specify the level of CFG parallelism:

```python
from xfuser.core.distributed import initialize_model_parallel
initialize_model_parallel(
classifier_free_guidance_degree=2,
)
```

Ensure the model checkpoint is loaded on all GPUs by copying the pipe from the CPU to each GPU:

```python
from xfuser.core.distributed import get_world_group
local_rank = get_world_group().local_rank
device = torch.device(f"cuda:{local_rank}")
pipe.to(device)
```


## 2. Splitting and Merging Sequences

The `forward` function of `CogVideoXTransformer3DModel` orchestrates the inference process for a single step iteration, outlined below:

```python
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
...
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
)
```

To parallelize the inference process, we utilize `parallelize_transformer` on `pipe`. Within this function, a `new_forward` function is introduced with identical input and output parameters as the original function. The `new_forward` function performs the following steps:

- Splits the timesteps, text sequence, and image sequence based on the batch size dimension, allocating each batch to a GPU.
- Executes the original forward process on the two GPUs.
- Merges the predicted noise using all_gather.

The code snippet below demonstrates the utilization of `@functools.wraps` to decorate the new_forward function, ensuring that essential details such as the function name, docstring, and argument list are inherited from original_forward. As forward is a method of a class object, the `__get__` function is employed to set transformer as the initial argument for new_forward, subsequently assigning new_forward to transformer.forward.

```python
def parallelize_transformer(pipe: DiffusionPipeline):
transformer = pipe.transformer
original_forward = transformer.forward


# definition of the new forward
@functools.wraps(transformer.__class__.forward)
def new_forward(...)

new_forward = new_forward.__get__(transformer)
transformer.forward = new_forward

parallelize_transformer(pipe)
```
The input parameters, timestep, hidden_state, and encoder_hidden_states, represent the timesteps, the input video sequence, and the input text sequence, respectively. These tensors require division. Their shapes are outlined below:

- timesteps (batch_size)
- hidden_state (batch_size, temporal_length, channels, height, width)
- encoder_hidden_states (batch_size, text_length, hidden_state)

where the batch size is 2. xDiT provides helper functions for CFG parallelism, offering functionalities such as `get_classifier_free_guidance_rank()` and `get_classifier_free_guidance_world_size()` to access the number of GPUs and their respective ranks. The `get_cfg_group()` function facilitates CFG parallelism, incorporating an `all_gather()` operation to merge sequences after `forward`. The new forward function is outlined as follows:

```python
@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: torch.LongTensor = None,
timestep_cond: Optional[torch.Tensor] = None,
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
):
# Step 1: split tensors
timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(), dim=0)[get_classifier_free_guidance_rank()]
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_classifier_free_guidance_world_size(), dim=0)[get_classifier_free_guidance_rank()]

# Step 2: perform the original forward
output = original_forward(
hidden_states,
encoder_hidden_states,
timestep=timestep,
timestep_cond=timestep_cond,
ofs=ofs,
image_rotary_emb=image_rotary_emb,
**kwargs,
)

return_dict = not isinstance(output, tuple)
sample = output[0]
# Step 3: merge the output from two GPUs
sample = get_cfg_group().all_gather(sample, dim=0)

if return_dict:
return output.__class__(sample, *output[1:])

return (sample, *output[1:])
```

A complete example script can be found in [adding_model_cfg.py](adding_model_cfg.py).
103 changes: 103 additions & 0 deletions docs/developer/adding_models/adding_model_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Example for parallelize new models with USP
# run with
# torchrun --nproc_per_node=2 \
# adding_cogvideox.py <cogvideox-checkpoint-path>
import sys
import functools
from typing import List, Optional, Tuple, Union

import time
import torch

from diffusers import DiffusionPipeline, CogVideoXPipeline

import torch.distributed as dist
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
get_world_group,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
get_cfg_group,
)

from diffusers.utils import export_to_video

def parallelize_transformer(pipe: DiffusionPipeline):
transformer = pipe.transformer
original_forward = transformer.forward

@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: torch.LongTensor = None,
timestep_cond: Optional[torch.Tensor] = None,
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
):
timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]

output = original_forward(
hidden_states,
encoder_hidden_states,
timestep=timestep,
timestep_cond=timestep_cond,
ofs=ofs,
image_rotary_emb=image_rotary_emb,
**kwargs,
)

return_dict = not isinstance(output, tuple)
sample = output[0]
sample = get_cfg_group().all_gather(sample, dim=0)
if return_dict:
return output.__class__(sample, *output[1:])
return (sample, *output[1:])

new_forward = new_forward.__get__(transformer)
transformer.forward = new_forward

if __name__ == "__main__":
dist.init_process_group("nccl")
init_distributed_environment(
rank=dist.get_rank(),
world_size=dist.get_world_size()
)
initialize_model_parallel(
classifier_free_guidance_degree=2,
)
pipe = CogVideoXPipeline.from_pretrained(
pretrained_model_name_or_path=sys.argv[1],
torch_dtype=torch.bfloat16,
)
local_rank = get_world_group().local_rank
device = torch.device(f"cuda:{local_rank}")
pipe = pipe.to(device)

pipe.vae.enable_tiling()

parallelize_transformer(pipe)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()

output = pipe(
num_frames=9,
prompt="A little girl is riding a bicycle at high speed. Focused, detailed, realistic.",
num_inference_steps=20,
generator=torch.Generator(device="cuda").manual_seed(42),
).frames[0]

end_time = time.time()
elapsed_time = end_time - start_time

if local_rank == 0:
export_to_video(output, "output.mp4", fps=8)
print(f"epoch time: {elapsed_time:.2f} sec")

dist.destroy_process_group()
24 changes: 24 additions & 0 deletions docs/developer/adding_models/adding_model_cfg_usp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Parallelize new models with CFG parallelism and USP provided by xDiT

The following two tutorials provide detailed instructions on how to implement CFG parallelism and USP (Unified Sequence Parallelism) supported by xDiT for a new DiT model:

[Parallelize new models with CFG parallelism provided by xDiT](adding_model_cfg.md)

[Parallelize new models with USP provided by xDiT](adding_model_usp.md)

[Parallelize new models with USP provided by xDiT (text replica)](adding_model_usp_text_replica.md)

Both parallelization techniques can be concurrently employed. To achieve this, specify the level of parallelization for both CFG parallelism and USP as demonstrated below. The number of GPUs should be twice the product of the degrees of ulysses attention and ring attention:

```python
from xfuser.core.distributed import initialize_model_parallel
initialize_model_parallel(
sequence_parallel_degree=<ring_degree x ulysses_degree>,
ring_degree=<ring_degree>,
ulysses_degree=<ulysses_degree>,
classifier_free_guidance_degree=2,
)
# restriction: dist.get_world_size() == 2 x <ring_degree> x <ulysses_degree>
```

Following this, both CFG parallelism and USP can be simultaneously implemented. For a comprehensive example script showcasing this approach, refer to [adding_model_cfg_usp.py](adding_model_cfg_usp.py).
Loading
Loading