Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sharded models when TorchAO quantization is enabled #10256

Merged
merged 10 commits into from
Dec 20, 2024
2 changes: 1 addition & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder or "",
)
if hf_quantizer is not None:
if hf_quantizer is not None and is_bnb_quantization_method:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False
Expand Down
70 changes: 47 additions & 23 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,14 @@ def test_int4wo_quant_bfloat16_conversion(self):
self.assertEqual(weight.quant_max, 15)
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))

def test_offload(self):
def test_device_map(self):
"""
Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies
that the device map is correctly set (in the `hf_device_map` attribute of the model).
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
correctly set (in the `hf_device_map` attribute of the model).
"""

device_map_offload = {
custom_device_map_dict = {
"time_text_embed": torch_device,
"context_embedder": torch_device,
"x_embedder": torch_device,
Expand All @@ -293,27 +294,50 @@ def test_offload(self):
"norm_out": torch_device,
"proj_out": "cpu",
}
device_maps = ["auto", custom_device_map_dict]

inputs = self.get_dummy_tensor_inputs(torch_device)

with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map_offload,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)

self.assertTrue(quantized_model.hf_device_map == device_map_offload)

output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()

expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])

for device_map in device_maps:
device_map_to_compare = {"": 0} if device_map == "auto" else device_map

# Test non-sharded model
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)

self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)

output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))

# Test sharded model
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)

self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)

output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()

self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))

def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
Expand Down
Loading