Skip to content

Commit 01acf5f

Browse files
a-r-r-o-wsayakpaul
authored andcommitted
Bump minimum TorchAO version to 0.7.0 (#10293)
* bump min torchao version to 0.7.0 * update
1 parent 2bc919f commit 01acf5f

File tree

3 files changed

+52
-51
lines changed

3 files changed

+52
-51
lines changed

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ def validate_environment(self, *args, **kwargs):
9393
raise ImportError(
9494
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
9595
)
96+
torchao_version = version.parse(importlib.metadata.version("torch"))
97+
if torchao_version < version.parse("0.7.0"):
98+
raise RuntimeError(
99+
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
100+
)
96101

97102
self.offload = False
98103

src/diffusers/utils/testing_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,11 +490,11 @@ def decorator(test_case):
490490
return decorator
491491

492492

493-
def require_torchao_version_greater(torchao_version):
493+
def require_torchao_version_greater_or_equal(torchao_version):
494494
def decorator(test_case):
495495
correct_torchao_version = is_torchao_available() and version.parse(
496496
version.parse(importlib.metadata.version("torchao")).base_version
497-
) > version.parse(torchao_version)
497+
) >= version.parse(torchao_version)
498498
return unittest.skipUnless(
499499
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
500500
)(test_case)

tests/quantization/torchao/test_torchao.py

Lines changed: 45 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
nightly,
3737
require_torch,
3838
require_torch_gpu,
39-
require_torchao_version_greater,
39+
require_torchao_version_greater_or_equal,
4040
slow,
4141
torch_device,
4242
)
@@ -74,13 +74,13 @@ def forward(self, input, *args, **kwargs):
7474

7575
if is_torchao_available():
7676
from torchao.dtypes import AffineQuantizedTensor
77-
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
7877
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
78+
from torchao.utils import get_model_size_in_bytes
7979

8080

8181
@require_torch
8282
@require_torch_gpu
83-
@require_torchao_version_greater("0.6.0")
83+
@require_torchao_version_greater_or_equal("0.7.0")
8484
class TorchAoConfigTest(unittest.TestCase):
8585
def test_to_dict(self):
8686
"""
@@ -125,7 +125,7 @@ def test_repr(self):
125125
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
126126
@require_torch
127127
@require_torch_gpu
128-
@require_torchao_version_greater("0.6.0")
128+
@require_torchao_version_greater_or_equal("0.7.0")
129129
class TorchAoTest(unittest.TestCase):
130130
def tearDown(self):
131131
gc.collect()
@@ -139,11 +139,13 @@ def get_dummy_components(self, quantization_config: TorchAoConfig):
139139
quantization_config=quantization_config,
140140
torch_dtype=torch.bfloat16,
141141
)
142-
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
143-
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
142+
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
143+
text_encoder_2 = T5EncoderModel.from_pretrained(
144+
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
145+
)
144146
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
145147
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
146-
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
148+
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
147149
scheduler = FlowMatchEulerDiscreteScheduler()
148150

149151
return {
@@ -212,7 +214,7 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
212214
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]):
213215
components = self.get_dummy_components(quantization_config)
214216
pipe = FluxPipeline(**components)
215-
pipe.to(device=torch_device, dtype=torch.bfloat16)
217+
pipe.to(device=torch_device)
216218

217219
inputs = self.get_dummy_inputs(torch_device)
218220
output = pipe(**inputs)[0]
@@ -276,7 +278,6 @@ def test_int4wo_quant_bfloat16_conversion(self):
276278
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
277279
self.assertEqual(weight.quant_min, 0)
278280
self.assertEqual(weight.quant_max, 15)
279-
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
280281

281282
def test_device_map(self):
282283
"""
@@ -341,21 +342,33 @@ def test_device_map(self):
341342

342343
def test_modules_to_not_convert(self):
343344
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
344-
quantized_model = FluxTransformer2DModel.from_pretrained(
345+
quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained(
345346
"hf-internal-testing/tiny-flux-pipe",
346347
subfolder="transformer",
347348
quantization_config=quantization_config,
348349
torch_dtype=torch.bfloat16,
349350
)
350351

351-
unquantized_layer = quantized_model.transformer_blocks[0].ff.net[2]
352+
unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2]
352353
self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear))
353354
self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor))
354355
self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16)
355356

356-
quantized_layer = quantized_model.proj_out
357+
quantized_layer = quantized_model_with_not_convert.proj_out
357358
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
358-
self.assertEqual(quantized_layer.weight.layout_tensor.data.dtype, torch.int8)
359+
360+
quantization_config = TorchAoConfig("int8_weight_only")
361+
quantized_model = FluxTransformer2DModel.from_pretrained(
362+
"hf-internal-testing/tiny-flux-pipe",
363+
subfolder="transformer",
364+
quantization_config=quantization_config,
365+
torch_dtype=torch.bfloat16,
366+
)
367+
368+
size_quantized_with_not_convert = get_model_size_in_bytes(quantized_model_with_not_convert)
369+
size_quantized = get_model_size_in_bytes(quantized_model)
370+
371+
self.assertTrue(size_quantized < size_quantized_with_not_convert)
359372

360373
def test_training(self):
361374
quantization_config = TorchAoConfig("int8_weight_only")
@@ -406,23 +419,6 @@ def test_torch_compile(self):
406419
# Note: Seems to require higher tolerance
407420
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
408421

409-
@staticmethod
410-
def _get_memory_footprint(module):
411-
quantized_param_memory = 0.0
412-
unquantized_param_memory = 0.0
413-
414-
for param in module.parameters():
415-
if param.__class__.__name__ == "AffineQuantizedTensor":
416-
data, scale, zero_point = param.layout_tensor.get_plain()
417-
quantized_param_memory += data.numel() + data.element_size()
418-
quantized_param_memory += scale.numel() + scale.element_size()
419-
quantized_param_memory += zero_point.numel() + zero_point.element_size()
420-
else:
421-
unquantized_param_memory += param.data.numel() * param.data.element_size()
422-
423-
total_memory = quantized_param_memory + unquantized_param_memory
424-
return total_memory, quantized_param_memory, unquantized_param_memory
425-
426422
def test_memory_footprint(self):
427423
r"""
428424
A simple test to check if the model conversion has been done correctly by checking on the
@@ -433,20 +429,18 @@ def test_memory_footprint(self):
433429
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"]
434430
transformer_bf16 = self.get_dummy_components(None)["transformer"]
435431

436-
total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo)
437-
total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint(
438-
transformer_int4wo_gs32
439-
)
440-
total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo)
441-
total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16)
442-
443-
self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16)
444-
# int4wo_gs32 has smaller group size, so more groups -> more scales and zero points
445-
self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32)
446-
# int4 with default group size quantized very few linear layers compared to a smaller group size of 32
447-
self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32)
432+
total_int4wo = get_model_size_in_bytes(transformer_int4wo)
433+
total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32)
434+
total_int8wo = get_model_size_in_bytes(transformer_int8wo)
435+
total_bf16 = get_model_size_in_bytes(transformer_bf16)
436+
437+
# Latter has smaller group size, so more groups -> more scales and zero points
438+
self.assertTrue(total_int4wo < total_int4wo_gs32)
448439
# int8 quantizes more layers compare to int4 with default group size
449-
self.assertTrue(quantized_int8wo < quantized_int4wo)
440+
self.assertTrue(total_int8wo < total_int4wo)
441+
# int4wo does not quantize too many layers because of default group size, but for the layers it does
442+
# there is additional overhead of scales and zero points
443+
self.assertTrue(total_bf16 < total_int4wo)
450444

451445
def test_wrong_config(self):
452446
with self.assertRaises(ValueError):
@@ -456,7 +450,7 @@ def test_wrong_config(self):
456450
# This class is not to be run as a test by itself. See the tests that follow this class
457451
@require_torch
458452
@require_torch_gpu
459-
@require_torchao_version_greater("0.6.0")
453+
@require_torchao_version_greater_or_equal("0.7.0")
460454
class TorchAoSerializationTest(unittest.TestCase):
461455
model_name = "hf-internal-testing/tiny-flux-pipe"
462456
quant_method, quant_method_kwargs = None, None
@@ -565,7 +559,7 @@ class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest):
565559
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
566560
@require_torch
567561
@require_torch_gpu
568-
@require_torchao_version_greater("0.6.0")
562+
@require_torchao_version_greater_or_equal("0.7.0")
569563
@slow
570564
@nightly
571565
class SlowTorchAoTests(unittest.TestCase):
@@ -581,11 +575,13 @@ def get_dummy_components(self, quantization_config: TorchAoConfig):
581575
quantization_config=quantization_config,
582576
torch_dtype=torch.bfloat16,
583577
)
584-
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
585-
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
578+
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
579+
text_encoder_2 = T5EncoderModel.from_pretrained(
580+
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
581+
)
586582
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
587583
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
588-
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
584+
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
589585
scheduler = FlowMatchEulerDiscreteScheduler()
590586

591587
return {
@@ -617,7 +613,7 @@ def get_dummy_inputs(self, device: torch.device, seed: int = 0):
617613

618614
def _test_quant_type(self, quantization_config, expected_slice):
619615
components = self.get_dummy_components(quantization_config)
620-
pipe = FluxPipeline(**components).to(dtype=torch.bfloat16)
616+
pipe = FluxPipeline(**components)
621617
pipe.enable_model_cpu_offload()
622618

623619
inputs = self.get_dummy_inputs(torch_device)

0 commit comments

Comments
 (0)