From 6dfaec348780c6153a4cfd03a01972a291d67f82 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 23 Dec 2024 19:52:21 -1000 Subject: [PATCH] make style for https://github.com/huggingface/diffusers/pull/10368 (#10370) * fix bug for torch.uint1-7 not support in torch<2.6 * up --------- Co-authored-by: baymax591 --- .../quantizers/torchao/torchao_quantizer.py | 39 +++++++++++-------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 25cd4ad448e7..5770e32c909e 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -23,7 +23,7 @@ from packaging import version -from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging +from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging from ..base import DiffusersQuantizer @@ -35,21 +35,28 @@ import torch import torch.nn as nn - SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( - # At the moment, only int8 is supported for integer quantization dtypes. - # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future - # to support more quantization methods, such as intx_weight_only. - torch.int8, - torch.float8_e4m3fn, - torch.float8_e5m2, - torch.uint1, - torch.uint2, - torch.uint3, - torch.uint4, - torch.uint5, - torch.uint6, - torch.uint7, - ) + if is_torch_version(">=", "2.5"): + SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( + # At the moment, only int8 is supported for integer quantization dtypes. + # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future + # to support more quantization methods, such as intx_weight_only. + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, + ) + else: + SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + ) if is_torchao_available(): from torchao.quantization import quantize_