-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
flux-dev oom with 2gpus(each gpu is 24576MiB) #345
Comments
--pipefusion_parallel_degree 2 Your command line is not valid. The parallel degree should be 2 in total. |
@feifeibear when the command is "torchrun --nproc_per_node=2 ./examples/flux_example.py --model ./FLUX.1-dev/ --pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1 --height 512 --width 512 --no_use_resolution_binning --output_type latent --num_inference_steps 28 --warmup_steps 1 --prompt 'brown dog laying on the ground with a metal bowl in front of him.' --use_cfg_parallel --use_parallel_vae" is error with word size is not equal 4; |
you should not use --use_cfg_parallel |
@feifeibear The command does not use --use_cfg_parallel, but it occurs oom error |
I see, your memory is really small. I have a very simple optimization to avoid OOM. We can use FSDP to load the text encoder. We will add a PR for this ASAP. |
@feifeibear Thank you for your quick response.But when I use diffusers to inference with height=width=512, the problem will not occur;The code is: |
@algorithmconquer Hello, could you provide the error log of the oom error? We need to check whether the oom error happend in the model loading process or the inference process. If it happened in the loading process. You could simpiliy quantize the Text Encoder into FP8, which could reduce the max memory use to 17GB without any quality loss. Firstly, install the dependencies by running the following command: Then, you could use the following code to replace the original examples/flux_example.py
|
@Lay2000 Thank you for sharing the code. I was able to implement the inference pipeline for flux-dev in bfloat16 by using model shards with 2gpus(each gpu is 24576MiB). I want to try the inference performance of xdit in the same device and environment(datatype=bfloat16, height=width=1024, 2gpus(each gpu is 24576MiB)). |
@Lay2000 The error log is : |
@algorithmconquer Hello, we attempted to execute the same scripts on two GPUs, each equipped with 24576MiB of vRAM. However, Out of Memory (OOM) issues still arose, as the vRAM might not be sufficient to accommodate the whole BF16 FLUX model. In our latest code, we've add a
|
with the command above, i still got oom on 2 rtx4090. the oom happens in text_encoder loading process. the detail traceback is below: /media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
[rank1]: Traceback (most recent call last):
[rank1]: File "/media/74nvme/research/test.py", line 231, in <module>
[rank1]: main()
[rank1]: File "/media/74nvme/research/test.py", line 186, in main
[rank1]: pipe = pipe.to(f"cuda:{local_rank}")
[rank1]: File "/media/74nvme/research/xDiT/xfuser/model_executor/pipelines/base_pipeline.py", line 117, in to
[rank1]: self.module = self.module.to(*args, **kwargs)
[rank1]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/diffusers/pipelines/pipeline_utils.py", line 461, in to
[rank1]: module.to(device, dtype)
[rank1]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3164, in to
[rank1]: return super().to(*args, **kwargs)
[rank1]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1340, in to
[rank1]: return self._apply(convert)
[rank1]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank1]: module._apply(fn)
[rank1]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank1]: module._apply(fn)
[rank1]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank1]: module._apply(fn)
[rank1]: [Previous line repeated 4 more times]
[rank1]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 927, in _apply
[rank1]: param_applied = fn(param)
[rank1]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1326, in convert
[rank1]: return t.to(
[rank1]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/optimum/quanto/tensor/weights/qbytes.py", line 272, in __torch_function__
[rank1]: return func(*args, **kwargs)
[rank1]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/optimum/quanto/tensor/weights/qbytes.py", line 298, in __torch_dispatch__
[rank1]: return WeightQBytesTensor.create(
[rank1]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/optimum/quanto/tensor/weights/qbytes.py", line 139, in create
[rank1]: return MarlinF8QBytesTensor(qtype, axis, size, stride, data, scale, requires_grad)
[rank1]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/optimum/quanto/tensor/weights/marlin/fp8/qbits.py", line 79, in __init__
[rank1]: data_packed = MarlinF8PackedTensor.pack(data) # pack fp8 data to in32, and apply marlier re-ordering.
[rank1]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/optimum/quanto/tensor/weights/marlin/fp8/packed.py", line 179, in pack
[rank1]: data_int32 = pack_fp8_as_int32(tensor.T) # pack fp8 data to in32.
[rank1]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/optimum/quanto/tensor/weights/marlin/fp8/packed.py", line 46, in pack_fp8_as_int32
[rank1]: packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
[rank1]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 40.00 MiB. GPU 1 has a total capacity of 23.64 GiB of which 23.69 MiB is free. Process 1303781 has 23.61 GiB memory in use. Of the allocated memory 23.13 GiB is allocated by PyTorch, and 44.35 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank0]: Traceback (most recent call last):
[rank0]: File "/media/74nvme/research/test.py", line 231, in <module>
[rank0]: main()
[rank0]: File "/media/74nvme/research/test.py", line 186, in main
[rank0]: pipe = pipe.to(f"cuda:{local_rank}")
[rank0]: File "/media/74nvme/research/xDiT/xfuser/model_executor/pipelines/base_pipeline.py", line 117, in to
[rank0]: self.module = self.module.to(*args, **kwargs)
[rank0]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/diffusers/pipelines/pipeline_utils.py", line 461, in to
[rank0]: module.to(device, dtype)
[rank0]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1340, in to
[rank0]: return self._apply(convert)
[rank0]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank0]: module._apply(fn)
[rank0]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank0]: module._apply(fn)
[rank0]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank0]: module._apply(fn)
[rank0]: [Previous line repeated 1 more time]
[rank0]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 927, in _apply
[rank0]: param_applied = fn(param)
[rank0]: File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1326, in convert
[rank0]: return t.to(
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 72.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 41.69 MiB is free. Process 1303780 has 23.60 GiB memory in use. Of the allocated memory 22.92 GiB is allocated by PyTorch, and 244.90 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank0]:[W103 10:46:26.792933189 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())
W0103 10:47:22.272385 2611841 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2612188 closing signal SIGTERM
E0103 10:47:22.306004 2611841 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 1 (pid: 2612189) of binary: /media/74nvme/software/miniconda3/envs/stable-fast/bin/python
Traceback (most recent call last):
File "/media/74nvme/software/miniconda3/envs/stable-fast/bin/torchrun", line 8, in <module>
sys.exit(main())
File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
return f(*args, **kwargs)
File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
run(args)
File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
elastic_launch(
File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: the xfuser and optimum-quanto version details is below : Name: xfuser
Version: 0.4.0
Summary: xDiT: A Scalable Inference Engine for Diffusion Transformers (DiTs) on multi-GPU Clusters
Home-page: https://github.com/xdit-project/xDiT.
Author: xDiT Team
Author-email: fangjiarui123@gmail.com
License:
Location: /media/74nvme/research/xDiT
Editable project location: /media/74nvme/research/xDiT
Requires: accelerate, beautifulsoup4, distvae, flask, imageio, imageio-ffmpeg, opencv-python, optimum-quanto, pytest, ray, sentencepiece, torch, transformers, yunchang
Required-by:
Name: optimum-quanto
Version: 0.2.6
Summary: A pytorch quantization backend for optimum.
Home-page:
Author: David Corvoysier
Author-email:
License: Apache-2.0
Location: /media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages
Requires: huggingface-hub, ninja, numpy, safetensors, torch
Required-by: xfuser |
The command is:
torchrun --nproc_per_node=2 ./examples/flux_example.py --model ./FLUX.1-dev/ --pipefusion_parallel_degree 1 --ulysses_degree 1 --ring_degree 1 --height 1024 --width 1024 --no_use_resolution_binning --output_type latent --num_inference_steps 28 --warmup_steps 1 --prompt 'brown dog laying on the ground with a metal bowl in front of him.' --use_cfg_parallel --use_parallel_vae
How to solve the problem?
The text was updated successfully, but these errors were encountered: