Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ray] Add Support for Disaggregating VAE and DiT #422

Merged
merged 5 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions examples/ray/ray_flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from xfuser import xFuserArgs
from xfuser.ray.pipeline.pipeline_utils import RayDiffusionPipeline
from xfuser.config import FlexibleArgumentParser
from xfuser.model_executor.pipelines import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline
from xfuser.model_executor.pipelines import xFuserFluxPipeline

def main():
os.environ["MASTER_ADDR"] = "localhost"
Expand Down Expand Up @@ -50,14 +50,17 @@ def main():
print(f"elapsed time:{elapsed_time}")
if not os.path.exists("results"):
os.mkdir("results")
# output is a list of results from each worker, we take the last one
for i, image in enumerate(output[-1].images):
image.save(
f"./results/{model_name}_result_{i}.png"
)
print(
f"image {i} saved to ./results/{model_name}_result_{i}.png"
)

for i, images in enumerate(output):
if images is not None:
image = images[0]
image.save(
f"./results/{model_name}_result_{i}.png"
)
print(
f"image {i} saved to ./results/{model_name}_result_{i}.png"
)
break


if __name__ == "__main__":
Expand Down
9 changes: 6 additions & 3 deletions examples/ray/ray_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ mkdir -p ./results
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"


N_GPUS=2
N_GPUS=3 # world size
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1"

VAE_PARALLEL_SIZE=1
# CFG_ARGS="--use_cfg_parallel"

# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance.
Expand All @@ -49,7 +49,8 @@ PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1
# Use this flag to quantize the T5 text encoder, which could reduce the memory usage and have no effect on the result quality.
# QUANTIZE_FLAG="--use_fp8_t5_encoder"

export CUDA_VISIBLE_DEVICES=0,1
# It is necessary to set CUDA_VISIBLE_DEVICES for the ray driver and workers.
export CUDA_VISIBLE_DEVICES=4,5,6,7

python ./examples/ray/$SCRIPT \
--model $MODEL_ID \
Expand All @@ -66,3 +67,5 @@ $CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG \
$QUANTIZE_FLAG \
--use_parallel_vae \
--vae_parallel_size $VAE_PARALLEL_SIZE
50 changes: 25 additions & 25 deletions examples/ray/ray_sd3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,7 @@
from xfuser import xFuserArgs
from xfuser.ray.pipeline.pipeline_utils import RayDiffusionPipeline
from xfuser.config import FlexibleArgumentParser
from xfuser.model_executor.pipelines import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline
import time
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserStableDiffusion3Pipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
is_dp_last_group,
get_data_parallel_rank,
get_runtime_state,
)
from xfuser.core.distributed.parallel_state import get_data_parallel_world_size
from xfuser.model_executor.pipelines import xFuserStableDiffusion3Pipeline


def main():
Expand All @@ -32,7 +18,19 @@ def main():
engine_config, input_config = engine_args.create_config()
model_name = engine_config.model_config.model.split("/")[-1]
PipelineClass = xFuserStableDiffusion3Pipeline
text_encoder_3 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16)

# equal to
# text_encoder_3 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16)
# but load encoder in worker
encoder_kwargs = {
'text_encoder_3': {
'model_class': T5EncoderModel,
'pretrained_model_name_or_path': engine_config.model_config.model,
'subfolder': 'text_encoder_3',
'torch_dtype': torch.float16
},
}

if args.use_fp8_t5_encoder:
from optimum.quanto import freeze, qfloat8, quantize
print(f"rank {local_rank} quantizing text encoder 2")
Expand All @@ -44,7 +42,7 @@ def main():
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.float16,
text_encoder_3=text_encoder_3,
**encoder_kwargs
)
pipe.prepare_run(input_config)

Expand All @@ -63,14 +61,16 @@ def main():
print(f"elapsed time:{elapsed_time}")
if not os.path.exists("results"):
os.mkdir("results")
# output is a list of results from each worker, we take the last one
for i, image in enumerate(output[-1].images):
image.save(
f"./results/{model_name}_result_{i}.png"
)
print(
f"image {i} saved to ./results/{model_name}_result_{i}.png"
)
for i, images in enumerate(output):
if images is not None:
image = images[0]
image.save(
f"./results/{model_name}_result_{i}.png"
)
print(
f"image {i} saved to ./results/{model_name}_result_{i}.png"
)
break


if __name__ == "__main__":
Expand Down
19 changes: 14 additions & 5 deletions xfuser/config/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class xFuserArgs:
# ray arguments
use_ray: bool = False
ray_world_size: int = 1
vae_parallel_size: int = 0
# pipefusion parallel
pipefusion_parallel_degree: int = 1
num_pipeline_patch: Optional[int] = None
Expand Down Expand Up @@ -210,6 +211,12 @@ def add_cli_args(parser: FlexibleArgumentParser):
default=1,
help="Tensor parallel degree.",
)
parallel_group.add_argument(
"--vae_parallel_size",
type=int,
default=0,
help="Number of processes for VAE parallelization. 0: no seperate process for VAE, 1: run VAE in a separate process, >1: distribute VAE across multiple processes.",
)
parallel_group.add_argument(
"--split_scheme",
type=str,
Expand Down Expand Up @@ -345,7 +352,7 @@ def create_config(
self.world_size = self.ray_world_size
else:
self.world_size = torch.distributed.get_world_size()

self.dit_world_size = self.world_size - self.vae_parallel_size # FIXME: Lack of scalability
lihuahua123 marked this conversation as resolved.
Show resolved Hide resolved
model_config = ModelConfig(
model=self.model,
download_dir=self.download_dir,
Expand All @@ -366,25 +373,27 @@ def create_config(
dp_config=DataParallelConfig(
dp_degree=self.data_parallel_degree,
use_cfg_parallel=self.use_cfg_parallel,
world_size=self.world_size,
world_size=self.dit_world_size,
),
sp_config=SequenceParallelConfig(
ulysses_degree=self.ulysses_degree,
ring_degree=self.ring_degree,
world_size=self.world_size,
world_size=self.dit_world_size,
),
tp_config=TensorParallelConfig(
tp_degree=self.tensor_parallel_degree,
split_scheme=self.split_scheme,
world_size=self.world_size,
world_size=self.dit_world_size,
),
pp_config=PipeFusionParallelConfig(
pp_degree=self.pipefusion_parallel_degree,
num_pipeline_patch=self.num_pipeline_patch,
attn_layer_num_for_pp=self.attn_layer_num_for_pp,
world_size=self.world_size,
world_size=self.dit_world_size,
),
world_size=self.world_size,
dit_world_size=self.dit_world_size,
vae_parallel_size=self.vae_parallel_size,
)

fast_attn_config = FastAttnConfig(
Expand Down
5 changes: 3 additions & 2 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ class ParallelConfig:
pp_config: PipeFusionParallelConfig
tp_config: TensorParallelConfig
world_size: int = 1 # FIXME: remove this
worker_cls: str = "xfuser.ray.worker.worker.Worker"
dit_world_size: int = 1
vae_parallel_size: int = 1 # 0 means the vae is in the same process with diffusion

def __post_init__(self):
assert self.tp_config is not None, "tp_config must be set"
Expand All @@ -207,7 +208,7 @@ def __post_init__(self):
* self.tp_config.tp_degree
* self.pp_config.pp_degree
)
world_size = self.world_size
world_size = self.dit_world_size
assert parallel_world_size == world_size, (
f"parallel_world_size {parallel_world_size} "
f"must be equal to world_size {self.world_size}"
Expand Down
14 changes: 14 additions & 0 deletions xfuser/core/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
initialize_model_parallel,
model_parallel_is_initialized,
get_tensor_model_parallel_world_size,
get_vae_parallel_group,
get_vae_parallel_rank,
get_vae_parallel_world_size,
get_dit_world_size,
init_vae_group,
init_dit_group,
get_dit_group,
)
from .runtime_state import (
get_runtime_state,
Expand Down Expand Up @@ -58,4 +65,11 @@
"get_runtime_state",
"runtime_state_is_initialized",
"initialize_runtime_state",
"get_dit_world_size",
"get_vae_parallel_group",
"get_vae_parallel_rank",
"get_vae_parallel_world_size",
"init_vae_group",
"init_dit_group",
"get_dit_group",
]
5 changes: 5 additions & 0 deletions xfuser/core/distributed/group_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ def __init__(
else:
self.device = torch.device("cpu")

@property
def size(self):
lihuahua123 marked this conversation as resolved.
Show resolved Hide resolved
"""Return the size of the process group (alias for world_size)"""
return self.world_size

@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
Expand Down
Loading
Loading