-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for sharded models when TorchAO quantization is enabled #10256
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -803,6 +803,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
subfolder=subfolder or "", | |||
) | |||
if hf_quantizer is not None: | |||
is_torchao_quantization_method = quantization_config.quant_method == QuantizationMethod.TORCHAO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we consolidate with this bnb check (remove the bnb check and extend this check for any quantization method)
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes" |
this should not specific to any quantisation method, no? I run this test, for non-sharded checkpoint, both works for shared checkpoint, both throw same error
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig, BitsAndBytesConfig
import torch
sharded_model_id = "black-forest-labs/Flux.1-Dev"
single_model_path = "/raid/yiyi/flux_model_single"
dtype = torch.bfloat16
# create a non-sharded checkpoint
# transformer = FluxTransformer2DModel.from_pretrained(
# model_id,
# subfolder="transformer",
# torch_dtype=dtype,
# )
# transformer.save_pretrained(single_model_path, max_shard_size="100GB")
torch_ao_quantization_config = TorchAoConfig("int8wo")
bnb_quantization_config = BitsAndBytesConfig(load_in_8bit=True)
print(f" testing non-sharded checkpoint")
transformer = FluxTransformer2DModel.from_pretrained(
single_model_path,
quantization_config=torch_ao_quantization_config,
device_map="auto",
torch_dtype=dtype,
)
print(f"torchao hf_device_map: {transformer.hf_device_map}")
transformer = FluxTransformer2DModel.from_pretrained(
single_model_path,
quantization_config=bnb_quantization_config,
device_map="auto",
torch_dtype=dtype,
)
print(f"bnb hf_device_map: {transformer.hf_device_map}")
print(f" testing sharded checkpoint")
## sharded checkpoint
try:
transformer = FluxTransformer2DModel.from_pretrained(
sharded_model_id,
subfolder="transformer",
quantization_config=torch_ao_quantization_config,
device_map="auto",
torch_dtype=dtype,
)
print(f"torchao: {transformer.hf_device_map}")
except Exception as e:
print(f"error: {e}")
try:
transformer = FluxTransformer2DModel.from_pretrained(
sharded_model_id,
subfolder="transformer",
quantization_config=bnb_quantization_config,
device_map="auto",
torch_dtype=dtype,
)
print(f"bnb hf_device_map: {transformer.hf_device_map}")
except Exception as e:
print(f"error: {e}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think non-sharded works for both, no? non-sharded checkpoint only seems to work torchao at the moment. These are my results:
method/shard | sharded | non-sharded |
---|---|---|
torchao | fails | works |
bnb | fails | fails |
I tried with your code as well and get the following error when using BnB with unsharded on this branch:
NotImplementedError: Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future.
Whatever the automatic infer of device_map
thing is, we are still unable to pass device_map manually when state dict is sharded/unsharded, so I would put it in same bucket as failing.
Consolidating the checks together sounds good. Will update
@yiyixuxu As discussed, this PR now supports fully using both sharded/non-sharded checkpoints with the full glory of import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
model_id = "black-forest-labs/Flux.1-Dev"
dtype = torch.bfloat16
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=dtype,
device_map="auto",
cache_dir="/raid/.cache/huggingface"
)
if hasattr(transformer, "hf_device_map"):
print(transformer.hf_device_map)
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16, device_map="balanced", cache_dir="/raid/.cache/huggingface")
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0, generator=torch.Generator().manual_seed(42)).images[0]
image.save("output.png") In the code, notice that there is no call to
So, one has to either:
logs{'pos_embed': 0, 'time_text_embed': 0, 'context_embedder': 0, 'x_embedder': 0, 'transformer_blocks.0': 0, 'transformer_blocks.1': 0, 'transformer_blocks.2': 0, 'transformer_blocks.3': 0, 'transformer_blocks.4': 0, 'transformer_blocks.5': 0, 'transformer_blocks.6': 0, 'transformer_blocks.7': 0, 'transformer_blocks.8': 0, 'transformer_blocks.9': 0, 'transformer_blocks.10': 0, 'transformer_blocks.11': 0, 'transformer_blocks.12': 0, 'transformer_blocks.13': 1, 'transformer_blocks.14': 1, 'transformer_blocks.15': 1, 'transformer_blocks.16': 1, 'transformer_blocks.17': 1, 'transformer_blocks.18': 1, 'single_transformer_blocks.0': 1, 'single_transformer_blocks.1': 1, 'single_transformer_blocks.2': 1, 'single_transformer_blocks.3': 1, 'single_transformer_blocks.4': 1, 'single_transformer_blocks.5': 1, 'single_transformer_blocks.6': 1, 'single_transformer_blocks.7': 1, 'single_transformer_blocks.8': 1, 'single_transformer_blocks.9': 1, 'single_transformer_blocks.10': 1, 'single_transformer_blocks.11': 1, 'single_transformer_blocks.12': 1, 'single_transformer_blocks.13': 1, 'single_transformer_blocks.14': 1, 'single_transformer_blocks.15': 1, 'single_transformer_blocks.16': 1, 'single_transformer_blocks.17': 1, 'single_transformer_blocks.18': 1, 'single_transformer_blocks.19': 1, 'single_transformer_blocks.20': 1, 'single_transformer_blocks.21': 2, 'single_transformer_blocks.22': 2, 'single_transformer_blocks.23': 2, 'single_transformer_blocks.24': 2, 'single_transformer_blocks.25': 2, 'single_transformer_blocks.26': 2, 'single_transformer_blocks.27': 2, 'single_transformer_blocks.28': 2, 'single_transformer_blocks.29': 2, 'single_transformer_blocks.30': 2, 'single_transformer_blocks.31': 2, 'single_transformer_blocks.32': 2, 'single_transformer_blocks.33': 2, 'single_transformer_blocks.34': 2, 'single_transformer_blocks.35': 2, 'single_transformer_blocks.36': 2, 'single_transformer_blocks.37': 2, 'norm_out': 2, 'proj_out': 2}
Loading pipeline components...: 29%|████████████████████ | 2/7 [00:02<00:05, 1.05s/it]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████| 2/2 [00:34<00:00, 17.46s/it]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████| 7/7 [00:39<00:00, 5.65s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:23<00:00, 1.28it/s] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for looking into this!
you need to update this to not just bnb
|
@yiyixuxu Added an error with possible suggestions. Could you give this another look? |
@@ -420,6 +422,12 @@ def module_is_offloaded(module): | |||
raise ValueError( | |||
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation." | |||
) | |||
elif pipeline_has_torchao: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif pipeline_has_torchao: | |
elif pipeline_is_sequentially_offloaded and pipeline_has_torchao: |
@@ -388,6 +389,7 @@ def to(self, *args, **kwargs): | |||
|
|||
device = device or device_arg | |||
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items()) | |||
pipeline_has_torchao = any(_check_torchao_status(module) for _, module in self.components.items()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to skip the error message for now?
the issue of not being able use pipe.to("cuda")
when the pipe has a device mapped module, has nothing to do with torchAO, if we do not throw an error here and then we make sure not to move the module with device_map later here, it would work
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): |
some refactor is needed, though, and I don't think needs to be done in this PR
…uggingface#10256) * add sharded + device_map check
…10256) * add sharded + device_map check
When passing a device_map, whether as a finegrained dict or simply
"auto"
, we cannot use the torchao quantization method currently.Error
This is a quick workaround to throw a cleaner error message. The reason why we error out is because:
diffusers/src/diffusers/models/modeling_utils.py
Line 806 in 7ca64fd
Here, we merge the sharded checkpoints because hf_quantizer is not None and set
is_sharded
to False. This causesmodel_file
to be a state dict instead of a string.diffusers/src/diffusers/models/modeling_utils.py
Line 914 in 7ca64fd
Accelerate expects a file path when
load_checkpoint_and_dispatch
is called, but we try to pass a state dict.Related:
Accelerate also provides
load_checkpoint_in_model
which might be usable here since we are working with a state dict here. Until we can figure out the best way to support this, let's raise a clean error. We can tackle in #10013 and work on refactoring too.model_file
does not make sense as a variable name either when holding a state dict, which caused some confusions during debugging.The missing case was found by @DN6, thanks! This was not detected by our fast/slow tests or me during testing
device_map
related changes because I was using unsharded single safetensors file for both SDXL and Flux. If the state dict is unsharded,device_map
should work just fine whether you pass a string likeauto
orbalanced
, or if you pass a finegrained dict.