diff --git a/torchao/quantization/pt2e/_affine_quantization.py b/torchao/quantization/pt2e/_affine_quantization.py index e02bee03ce..906629bd8a 100644 --- a/torchao/quantization/pt2e/_affine_quantization.py +++ b/torchao/quantization/pt2e/_affine_quantization.py @@ -113,7 +113,8 @@ def _get_reduction_params(block_size, input_size): shape_for_reduction: (3, 3, 5, 2, 10) reduction_dim: [0, 1, 3, 4] """ - assert len(block_size) == len(input_size) + assert block_size == [-1] or len(block_size) == len(input_size) + block_size = [-1] * len(input_size) if block_size == [-1] else block_size shape_for_reduction = [] reduction_dims = [] cur_dim = 0 diff --git a/torchao/quantization/pt2e/observer.py b/torchao/quantization/pt2e/observer.py index f6534308d8..3d09255cd1 100644 --- a/torchao/quantization/pt2e/observer.py +++ b/torchao/quantization/pt2e/observer.py @@ -1793,7 +1793,7 @@ def get_block_size( "Please provide an instance of Granularity, not subclass of it" ) if isinstance(granularity, PerTensor): - return input_shape + return [-1] elif isinstance(granularity, PerAxis): block_size = list(input_shape) block_size[granularity.axis] = 1 diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index d13ac330a0..ca740ef862 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -365,6 +365,9 @@ def _quantize_affine( # torch.uintx dtypes yet if output_dtype in _SUB_BYTE_UINT_BOUNDS: output_dtype = torch.uint8 + if block_size == [-1]: + # per-tensor quantization + block_size = [-1] * input.dim() return _quantize_affine_no_dtype_cast( input, block_size, @@ -520,6 +523,9 @@ def _dequantize_affine( torch.float16, torch.bfloat16, ], f"Unsupported output dtype: {output_dtype}" + if block_size == [-1]: + # per-tensor quantization + block_size = [-1] * input.dim() quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) return _dequantize_affine_no_dtype_check( input, @@ -878,6 +884,9 @@ def _choose_qparams_affine( scale_dtype = input.dtype if eps is None: eps = torch.finfo(input.dtype).eps + if block_size == [-1]: + # per-tensor quantization + block_size = [-1] * input.dim() assert len(block_size) == input.dim(), ( f"Got input dim:{input.dim()}, block_size: {block_size}"