Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sharded models when TorchAO quantization is enabled #10256

Merged
merged 10 commits into from
Dec 20, 2024

Conversation

a-r-r-o-w
Copy link
Member

When passing a device_map, whether as a finegrained dict or simply "auto", we cannot use the torchao quantization method currently.

Error
Traceback (most recent call last):
  File "/home/aryan/work/diffusers/dump5.py", line 91, in <module>
    transformer = FluxTransformer2DModel.from_pretrained(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/modeling_utils.py", line 920, in from_pretrained
    accelerate.load_checkpoint_and_dispatch(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/accelerate/big_modeling.py", line 613, in load_checkpoint_and_dispatch
    load_checkpoint_in_model(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 1690, in load_checkpoint_in_model
    if os.path.isfile(checkpoint):
  File "/home/aryan/.pyenv/versions/3.10.14/lib/python3.10/genericpath.py", line 30, in isfile
    st = os.stat(path)
TypeError: stat: path should be string, bytes, os.PathLike or integer, not dict

This is a quick workaround to throw a cleaner error message. The reason why we error out is because:

Related:

Accelerate also provides load_checkpoint_in_model which might be usable here since we are working with a state dict here. Until we can figure out the best way to support this, let's raise a clean error. We can tackle in #10013 and work on refactoring too. model_file does not make sense as a variable name either when holding a state dict, which caused some confusions during debugging.

The missing case was found by @DN6, thanks! This was not detected by our fast/slow tests or me during testing device_map related changes because I was using unsharded single safetensors file for both SDXL and Flux. If the state dict is unsharded, device_map should work just fine whether you pass a string like auto or balanced, or if you pass a finegrained dict.

@a-r-r-o-w a-r-r-o-w requested review from DN6 and yiyixuxu December 17, 2024 06:12
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -803,6 +803,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
subfolder=subfolder or "",
)
if hf_quantizer is not None:
is_torchao_quantization_method = quantization_config.quant_method == QuantizationMethod.TORCHAO
Copy link
Collaborator

@yiyixuxu yiyixuxu Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we consolidate with this bnb check (remove the bnb check and extend this check for any quantization method)

is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"

this should not specific to any quantisation method, no? I run this test, for non-sharded checkpoint, both works for shared checkpoint, both throw same error

from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig, BitsAndBytesConfig
import torch

sharded_model_id = "black-forest-labs/Flux.1-Dev"
single_model_path = "/raid/yiyi/flux_model_single"
dtype = torch.bfloat16

# create a non-sharded checkpoint
# transformer = FluxTransformer2DModel.from_pretrained(
#     model_id,
#     subfolder="transformer",
#     torch_dtype=dtype,
# )
# transformer.save_pretrained(single_model_path, max_shard_size="100GB")

torch_ao_quantization_config = TorchAoConfig("int8wo")
bnb_quantization_config = BitsAndBytesConfig(load_in_8bit=True)

print(f" testing non-sharded checkpoint")
transformer = FluxTransformer2DModel.from_pretrained(
    single_model_path,
    quantization_config=torch_ao_quantization_config,
    device_map="auto",
    torch_dtype=dtype,
)

print(f"torchao hf_device_map: {transformer.hf_device_map}")

transformer = FluxTransformer2DModel.from_pretrained(
    single_model_path, 
    quantization_config=bnb_quantization_config,
    device_map="auto",
    torch_dtype=dtype,
)
print(f"bnb hf_device_map: {transformer.hf_device_map}")


print(f" testing sharded checkpoint")
## sharded checkpoint
try:
    transformer = FluxTransformer2DModel.from_pretrained(
        sharded_model_id, 
        subfolder="transformer",
        quantization_config=torch_ao_quantization_config,
        device_map="auto",
        torch_dtype=dtype,
    )
    print(f"torchao: {transformer.hf_device_map}")
except Exception as e:
    print(f"error: {e}")

try:
    transformer = FluxTransformer2DModel.from_pretrained(
        sharded_model_id,
        subfolder="transformer",
        quantization_config=bnb_quantization_config,
        device_map="auto",
        torch_dtype=dtype,
)
    print(f"bnb hf_device_map: {transformer.hf_device_map}")
except Exception as e:
    print(f"error: {e}")

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think non-sharded works for both, no? non-sharded checkpoint only seems to work torchao at the moment. These are my results:

method/shard sharded non-sharded
torchao fails works
bnb fails fails

I tried with your code as well and get the following error when using BnB with unsharded on this branch:

NotImplementedError: Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future.

Whatever the automatic infer of device_map thing is, we are still unable to pass device_map manually when state dict is sharded/unsharded, so I would put it in same bucket as failing.

Consolidating the checks together sounds good. Will update

@a-r-r-o-w a-r-r-o-w changed the title Raise an error if using TorchAO quantizer when using device_map with sharded checkpoint Add support for sharded models when TorchAO quantization is enabled Dec 18, 2024
@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Dec 18, 2024

@yiyixuxu As discussed, this PR now supports fully using both sharded/non-sharded checkpoints with the full glory of device_map, when using torchao as quantizer. For BnB, we will have to revisit in a separate PR.

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

model_id = "black-forest-labs/Flux.1-Dev"
dtype = torch.bfloat16

quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=dtype,
    device_map="auto",
    cache_dir="/raid/.cache/huggingface"
)
if hasattr(transformer, "hf_device_map"):
    print(transformer.hf_device_map)

pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16, device_map="balanced", cache_dir="/raid/.cache/huggingface")
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0, generator=torch.Generator().manual_seed(42)).images[0]
image.save("output.png")

In the code, notice that there is no call to pipe.to("cuda"). This would end up failing because pipeline_is_sequentially_offloaded is set to True, since accelerate adds a AlignDevicesHook to the transformer here.

So, one has to either:

  • Move the other components (text encoder, vae) to GPU manually.
  • Use a device_map on the pipeline as well. Since balanced is the only mode supported, the above example uses that
logs
{'pos_embed': 0, 'time_text_embed': 0, 'context_embedder': 0, 'x_embedder': 0, 'transformer_blocks.0': 0, 'transformer_blocks.1': 0, 'transformer_blocks.2': 0, 'transformer_blocks.3': 0, 'transformer_blocks.4': 0, 'transformer_blocks.5': 0, 'transformer_blocks.6': 0, 'transformer_blocks.7': 0, 'transformer_blocks.8': 0, 'transformer_blocks.9': 0, 'transformer_blocks.10': 0, 'transformer_blocks.11': 0, 'transformer_blocks.12': 0, 'transformer_blocks.13': 1, 'transformer_blocks.14': 1, 'transformer_blocks.15': 1, 'transformer_blocks.16': 1, 'transformer_blocks.17': 1, 'transformer_blocks.18': 1, 'single_transformer_blocks.0': 1, 'single_transformer_blocks.1': 1, 'single_transformer_blocks.2': 1, 'single_transformer_blocks.3': 1, 'single_transformer_blocks.4': 1, 'single_transformer_blocks.5': 1, 'single_transformer_blocks.6': 1, 'single_transformer_blocks.7': 1, 'single_transformer_blocks.8': 1, 'single_transformer_blocks.9': 1, 'single_transformer_blocks.10': 1, 'single_transformer_blocks.11': 1, 'single_transformer_blocks.12': 1, 'single_transformer_blocks.13': 1, 'single_transformer_blocks.14': 1, 'single_transformer_blocks.15': 1, 'single_transformer_blocks.16': 1, 'single_transformer_blocks.17': 1, 'single_transformer_blocks.18': 1, 'single_transformer_blocks.19': 1, 'single_transformer_blocks.20': 1, 'single_transformer_blocks.21': 2, 'single_transformer_blocks.22': 2, 'single_transformer_blocks.23': 2, 'single_transformer_blocks.24': 2, 'single_transformer_blocks.25': 2, 'single_transformer_blocks.26': 2, 'single_transformer_blocks.27': 2, 'single_transformer_blocks.28': 2, 'single_transformer_blocks.29': 2, 'single_transformer_blocks.30': 2, 'single_transformer_blocks.31': 2, 'single_transformer_blocks.32': 2, 'single_transformer_blocks.33': 2, 'single_transformer_blocks.34': 2, 'single_transformer_blocks.35': 2, 'single_transformer_blocks.36': 2, 'single_transformer_blocks.37': 2, 'norm_out': 2, 'proj_out': 2}
Loading pipeline components...:  29%|████████████████████                                                  | 2/7 [00:02<00:05,  1.05s/it]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████| 2/2 [00:34<00:00, 17.46s/it]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████| 7/7 [00:39<00:00,  5.65s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:23<00:00,  1.28it/s]

image

@a-r-r-o-w a-r-r-o-w requested a review from yiyixuxu December 18, 2024 02:12
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for looking into this!

@yiyixuxu
Copy link
Collaborator

@a-r-r-o-w

In the code, notice that there is no call to pipe.to("cuda"). This would end up failing because pipeline_is_sequentially_offloaded is set to True, since accelerate adds a AlignDevicesHook to the transformer here.

you need to update this to not just bnb

if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:

@a-r-r-o-w
Copy link
Member Author

@yiyixuxu Added an error with possible suggestions. Could you give this another look?

@@ -420,6 +422,12 @@ def module_is_offloaded(module):
raise ValueError(
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
)
elif pipeline_has_torchao:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif pipeline_has_torchao:
elif pipeline_is_sequentially_offloaded and pipeline_has_torchao:

@@ -388,6 +389,7 @@ def to(self, *args, **kwargs):

device = device or device_arg
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
pipeline_has_torchao = any(_check_torchao_status(module) for _, module in self.components.items())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to skip the error message for now?

the issue of not being able use pipe.to("cuda") when the pipe has a device mapped module, has nothing to do with torchAO, if we do not throw an error here and then we make sure not to move the module with device_map later here, it would work

if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):

some refactor is needed, though, and I don't think needs to be done in this PR

@yiyixuxu yiyixuxu added the roadmap Add to current release roadmap label Dec 20, 2024
@yiyixuxu yiyixuxu merged commit 41ba8c0 into main Dec 20, 2024
14 of 15 checks passed
@yiyixuxu yiyixuxu deleted the torchao-error-on-sharded-device-map branch December 20, 2024 01:42
Foundsheep pushed a commit to Foundsheep/diffusers that referenced this pull request Dec 23, 2024
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants