forked from xdit-project/xDiT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
parallel.py
54 lines (47 loc) · 2.04 KB
/
parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
from pathlib import Path
from xfuser.config.config import InputConfig
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
from xfuser.config import EngineConfig
from xfuser.core.distributed.parallel_state import (
get_data_parallel_rank,
get_data_parallel_world_size,
is_dp_last_group,
)
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.logger import init_logger
from xfuser.model_executor.pipelines.base_pipeline import xFuserPipelineBaseWrapper
from xfuser.model_executor.pipelines.register import xFuserPipelineWrapperRegister
logger = init_logger(__name__)
class xDiTParallel:
def __init__(self, pipe, engine_config: EngineConfig, input_config: InputConfig):
xfuser_pipe_wrapper = xFuserPipelineWrapperRegister.get_class(pipe)
self.pipe = xfuser_pipe_wrapper(pipeline=pipe, engine_config=engine_config)
self.config = engine_config
self.pipe.prepare_run(input_config)
def __call__(
self,
*args,
**kwargs,
):
self.result = self.pipe(*args, **kwargs)
return self.result
def save(self, directory: str, prefix: str):
dp_rank = get_data_parallel_rank()
parallel_info = (
f"dp{self.config.parallel_config.dp_degree}_cfg{self.config.parallel_config.cfg_degree}_"
f"ulysses{self.config.parallel_config.ulysses_degree}_ring{self.config.parallel_config.ring_degree}_"
f"pp{self.config.parallel_config.pp_degree}_patch{self.config.parallel_config.pp_config.num_pipeline_patch}"
)
if is_dp_last_group():
path = Path(f"{directory}")
path.mkdir(mode=755, parents=True, exist_ok=True)
path = path / f"{prefix}_result_{parallel_info}_dprank{dp_rank}"
for i, image in enumerate(self.result.images):
image.save(f"{str(path)}_image{i}.png")
print(f"{str(path)}_image{i}.png")
def __del__(self):
get_runtime_state().destory_distributed_env()