diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0f9c9203c926..872d4d73d41f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -802,7 +802,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P revision=revision, subfolder=subfolder or "", ) - if hf_quantizer is not None: + if hf_quantizer is not None and is_bnb_quantization_method: model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") is_sharded = False diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 58c1d3613daf..6f9980c006ac 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -278,13 +278,14 @@ def test_int4wo_quant_bfloat16_conversion(self): self.assertEqual(weight.quant_max, 15) self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) - def test_offload(self): + def test_device_map(self): """ - Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies - that the device map is correctly set (in the `hf_device_map` attribute of the model). + Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps. + The custom device map performs cpu/disk offloading as well. Also verifies that the device map is + correctly set (in the `hf_device_map` attribute of the model). """ - device_map_offload = { + custom_device_map_dict = { "time_text_embed": torch_device, "context_embedder": torch_device, "x_embedder": torch_device, @@ -293,27 +294,50 @@ def test_offload(self): "norm_out": torch_device, "proj_out": "cpu", } + device_maps = ["auto", custom_device_map_dict] inputs = self.get_dummy_tensor_inputs(torch_device) - - with tempfile.TemporaryDirectory() as offload_folder: - quantization_config = TorchAoConfig("int4_weight_only", group_size=64) - quantized_model = FluxTransformer2DModel.from_pretrained( - "hf-internal-testing/tiny-flux-pipe", - subfolder="transformer", - quantization_config=quantization_config, - device_map=device_map_offload, - torch_dtype=torch.bfloat16, - offload_folder=offload_folder, - ) - - self.assertTrue(quantized_model.hf_device_map == device_map_offload) - - output = quantized_model(**inputs)[0] - output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - - expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) - self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) + + for device_map in device_maps: + device_map_to_compare = {"": 0} if device_map == "auto" else device_map + + # Test non-sharded model + with tempfile.TemporaryDirectory() as offload_folder: + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + device_map=device_map, + torch_dtype=torch.bfloat16, + offload_folder=offload_folder, + ) + + self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) + + output = quantized_model(**inputs)[0] + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + # Test sharded model + with tempfile.TemporaryDirectory() as offload_folder: + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + quantized_model = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-sharded", + subfolder="transformer", + quantization_config=quantization_config, + device_map=device_map, + torch_dtype=torch.bfloat16, + offload_folder=offload_folder, + ) + + self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) + + output = quantized_model(**inputs)[0] + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_modules_to_not_convert(self): quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])