diff --git a/test/quantization/quantize_/test_int4_groupwise_preshuffle.py b/test/quantization/quantize_/int4/test_int4_groupwise_preshuffle_tensor.py similarity index 52% rename from test/quantization/quantize_/test_int4_groupwise_preshuffle.py rename to test/quantization/quantize_/int4/test_int4_groupwise_preshuffle_tensor.py index 9bfe6dffdb..f120d4500b 100644 --- a/test/quantization/quantize_/test_int4_groupwise_preshuffle.py +++ b/test/quantization/quantize_/int4/test_int4_groupwise_preshuffle_tensor.py @@ -4,14 +4,18 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import tempfile import unittest import torch from torch.testing._internal.common_utils import ( TestCase, + instantiate_parametrized_tests, + parametrize, run_tests, ) +from torchao.float8.config import e4m3_dtype from torchao.quantization import ( FbgemmConfig, quantize_, @@ -23,6 +27,45 @@ is_sm_at_least_90, ) +if TORCH_VERSION_AT_LEAST_2_8: + BF16_ACT_CONFIG = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 128], + preshuffle=True, + ) + + BF16_ACT_BMM_CONFIG = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 1, 128], + preshuffle=True, + ) + + FP8_ACT_CONFIG = FbgemmConfig( + input_dtype=e4m3_dtype, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 128], + preshuffle=True, + ) + + FP8_ACT_BMM_CONFIG = FbgemmConfig( + input_dtype=e4m3_dtype, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 1, 128], + preshuffle=True, + ) + +else: + BF16_ACT_CONFIG = None + BF16_ACT_BMM_CONFIG = None + FP8_ACT_CONFIG = None + FP8_ACT_BMM_CONFIG = None + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -32,33 +75,23 @@ ) class TestInt4GroupwisePreshuffleTensor(TestCase): def setUp(self): - self.config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 128], - preshuffle=True, - ) - self.bmm_config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 1, 128], - preshuffle=True, - ) self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] - def test_linear(self): + @parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG]) + def test_linear(self, config): dtype = torch.bfloat16 device = "cuda" input = torch.randn(1, 128, dtype=dtype, device=device) linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) original = linear(input) - quantize_(linear, self.config) + quantize_(linear, config) quantized = linear(input) self.assertTrue(compute_error(original, quantized) > 20) - def test_bmm(self): + # Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449` + # @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG]) + @parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG]) + def test_bmm(self, bmm_config): class M(torch.nn.Module): def __init__(self, weight): super().__init__() @@ -74,32 +107,46 @@ def forward(self, x): m = M(weight).eval() original = m(input) m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) - quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) + quantize_(m, bmm_config, filter_fn=lambda x, fqn: True) quantized = m(input) self.assertTrue(compute_error(original, quantized) > 18) - def test_to_device(self): + @parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG]) + def test_to_device(self, config): for device in self.GPU_DEVICES: linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) linear.to(device) linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) linear.to(device=device) linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) linear.to(device) - def test_module_path(self): + @parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG]) + def test_module_path(self, config): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) self.assertEqual( str(type(linear.weight)), "", ) + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +instantiate_parametrized_tests(TestInt4GroupwisePreshuffleTensor) + if __name__ == "__main__": run_tests() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 7df6995955..ce16897d76 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2040,6 +2040,8 @@ class FbgemmConfig(AOBaseConfig): weight_dtype (torch.dtype): weight dtype of the kernel output_dtype (torch.dtype): output dtype of the kernel group_size (int): The group size for weight + preshuffle (bool): whether preshuffle the weights or not + activation_dtype_for_int4 (str): the dtype for activation for int4 weight, either bf16 or fp8 """ input_dtype: torch.dtype @@ -2067,7 +2069,9 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: ): if config.preshuffle: weight = Int4GroupwisePreshuffleTensor.from_float( - module.weight, config.block_size + module.weight, + config.block_size, + activation_dtype="bf16", ) else: weight = to_fbgemm_int4( @@ -2077,6 +2081,20 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: module.weight = torch.nn.Parameter(weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module + if ( + (config.input_dtype == e4m3_dtype) + and (config.weight_dtype == torch.int4) + and (config.output_dtype == torch.bfloat16) + ): + if config.preshuffle: + weight = Int4GroupwisePreshuffleTensor.from_float( + module.weight, + config.block_size, + activation_dtype="fp8", + ) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module elif ( (config.input_dtype == e4m3_dtype) and (config.weight_dtype == e4m3_dtype) diff --git a/torchao/quantization/quantize_/__init__.py b/torchao/quantization/quantize_/__init__.py index 049b71631b..235e18bf3c 100644 --- a/torchao/quantization/quantize_/__init__.py +++ b/torchao/quantization/quantize_/__init__.py @@ -1,4 +1,4 @@ -from .int4_groupwise_preshuffle_tensor import ( +from .int4 import ( Int4GroupwisePreshuffleTensor, ) diff --git a/torchao/quantization/quantize_/int4/__init__.py b/torchao/quantization/quantize_/int4/__init__.py new file mode 100644 index 0000000000..6ebbb55f0f --- /dev/null +++ b/torchao/quantization/quantize_/int4/__init__.py @@ -0,0 +1,7 @@ +from .int4_groupwise_preshuffle_tensor import ( + Int4GroupwisePreshuffleTensor, +) + +__all__ = [ + "Int4GroupwisePreshuffleTensor", +] diff --git a/torchao/quantization/quantize_/int4_groupwise_preshuffle_tensor.py b/torchao/quantization/quantize_/int4/int4_groupwise_preshuffle_tensor.py similarity index 60% rename from torchao/quantization/quantize_/int4_groupwise_preshuffle_tensor.py rename to torchao/quantization/quantize_/int4/int4_groupwise_preshuffle_tensor.py index 1313be5128..502397fb15 100644 --- a/torchao/quantization/quantize_/int4_groupwise_preshuffle_tensor.py +++ b/torchao/quantization/quantize_/int4/int4_groupwise_preshuffle_tensor.py @@ -6,7 +6,7 @@ import importlib.util -from typing import List +from typing import List, Optional import torch from torch.utils._python_dispatch import return_and_correct_aliasing @@ -26,8 +26,12 @@ if importlib.util.find_spec("fbgemm_gpu") is None: quantize_int4_preshuffle = None + quantize_fp8_row = None else: - from fbgemm_gpu.experimental.gen_ai.quantize import quantize_int4_preshuffle + from fbgemm_gpu.experimental.gen_ai.quantize import ( + quantize_fp8_row, + quantize_int4_preshuffle, + ) class Int4GroupwisePreshuffleTensor(TorchAOBaseTensor): @@ -36,10 +40,16 @@ class Int4GroupwisePreshuffleTensor(TorchAOBaseTensor): Tensor Attributes: packed_weight: packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed - group_scale: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor - dtype is the same as the original Tensor dtype - group_zero: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor - dtype is the same as the original Tensor dtype + for bf16 activation: + group_scale: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor + dtype is the same as the original Tensor dtype + group_zero: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor + dtype is the same as the original Tensor dtype + for float8 activation: + group_scale: (K/group_size/8, 8, N) for 2D Tensor, (B, K/group_size/8, 8, N) for 3D Tensor + dtype is float8 + row_scale: (N,) for 2D Tensor, (B, N) for 3D Tensor + dtype is the same as the original Tensor dtype Non-Tensor Attributes: group_size: the group size for groupwise quantization @@ -48,8 +58,7 @@ class Int4GroupwisePreshuffleTensor(TorchAOBaseTensor): a 2D tensor, the shape_multiplier will be [1, 2] shape: shape of the original Tensor - Note: - Details for preshuffle for fbgemm kernel: + Note on Details for preshuffle for fbgemm kernel: We use WGMMA instruction for efficient matrix multiplication in H100 Tensor Core. To address a major inefficiency in how WGMMA tiles are loaded into shared memory before @@ -61,13 +70,26 @@ class Int4GroupwisePreshuffleTensor(TorchAOBaseTensor): loads so having to load all four groups is wasteful. We can optimize weight loading by shuffling the order of elements such that all 4 groups are sequential in memory. This allows us to perform a single 64 bit load to move all needed weights for the thread into register memory. + + Note for float8 activation int4 weight kernel: + float8 activation int4 weight kernel doesn't work with zero_point, since it use table lookup approach which + requires symmetric quantization """ - tensor_data_attrs = ["packed_weight", "group_scale", "group_zero"] + tensor_data_attrs = ["packed_weight", "group_scale"] + optional_tensor_data_attr1 = "group_zero" + optional_tensor_data_attr2 = "row_scale" tensor_attributes = ["group_size", "shape_multiplier", "shape"] def __new__( - cls, packed_weight, group_scale, group_zero, group_size, shape_multiplier, shape + cls, + packed_weight, + group_scale, + group_zero, + row_scale, + group_size, + shape_multiplier, + shape, ): kwargs = {} kwargs["device"] = packed_weight.device @@ -77,36 +99,53 @@ def __new__( def __init__( self, - packed_weight, - group_scale, - group_zero, - group_size, - shape_multiplier, - shape, + packed_weight: torch.Tensor, + group_scale: torch.Tensor, + group_zero: Optional[torch.Tensor], + row_scale: Optional[torch.Tensor], + group_size: int, + shape_multiplier: List[int], + shape: List[int], ): self.packed_weight = packed_weight self.group_scale = group_scale self.group_zero = group_zero + self.row_scale = row_scale self.shape_multiplier = shape_multiplier self.group_size = group_size def __tensor_flatten__(self): - return self.tensor_data_attrs, [ - getattr(self, attr) for attr in self.tensor_attributes - ] + if getattr(self, self.optional_tensor_data_attr1) is None: + assert getattr(self, self.optional_tensor_data_attr2) is not None + return self.tensor_data_attrs + [self.optional_tensor_data_attr2], [ + getattr(self, attr) for attr in self.tensor_attributes + ] + else: + assert getattr(self, self.optional_tensor_data_attr1) is not None + return self.tensor_data_attrs + [self.optional_tensor_data_attr1], [ + getattr(self, attr) for attr in self.tensor_attributes + ] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): + tensors = [tensor_data_dict[name] for name in cls.tensor_data_attrs] + tensors.append(tensor_data_dict.get(cls.optional_tensor_data_attr1, None)) + tensors.append(tensor_data_dict.get(cls.optional_tensor_data_attr2, None)) return cls( - *[tensor_data_dict[name] for name in cls.tensor_data_attrs], + *tensors, *tensor_attributes, ) def _apply_fn_to_data(self, fn): + tensors = [fn(getattr(self, name)) for name in self.tensor_data_attrs] + t1 = getattr(self, self.optional_tensor_data_attr1) + tensors.append(fn(t1) if t1 is not None else None) + t2 = getattr(self, self.optional_tensor_data_attr2) + tensors.append(fn(t2) if t2 is not None else None) return self.__class__( - *[fn(getattr(self, attr)) for attr in self.tensor_data_attrs], + *tensors, *[getattr(self, attr) for attr in self.tensor_attributes], ) @@ -126,7 +165,8 @@ def to(self, *args, **kwargs): return self.__class__( self.packed_weight.to(device), self.group_scale.to(device), - self.group_zero.to(device), + self.group_zero.to(device) if self.group_zero is not None else None, + self.row_scale.to(device) if self.row_scale is not None else None, self.group_size, self.shape_multiplier, self.shape, @@ -137,10 +177,17 @@ def from_float( cls, w: torch.Tensor, block_size: List[int], + activation_dtype: str = "bf16", ): assert len(block_size) == w.ndim, ( f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" ) + + _SUPPORTED_ACT_DTYPES = ["fp8", "bf16"] + assert activation_dtype in _SUPPORTED_ACT_DTYPES, ( + f"activation dtype {activation_dtype} is not supported, supported ones are: {_SUPPORTED_ACT_DTYPES}" + ) + if quantize_int4_preshuffle is None: raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") @@ -149,17 +196,26 @@ def from_float( if w.ndim >= 3: wq, scales = zip( - *[quantize_int4_preshuffle(i.cuda(), dtype="bf16") for i in w] + *[quantize_int4_preshuffle(i.cuda(), dtype=activation_dtype) for i in w] ) wq = torch.stack(wq, dim=0) - group_scale, group_zero = zip(*scales) - group_zero = torch.stack(group_zero, dim=0).contiguous() + group_scale, group_zero_or_row_scale = zip(*scales) + group_zero_or_row_scale = torch.stack( + group_zero_or_row_scale, dim=0 + ).contiguous() group_scale = torch.stack(group_scale, dim=0).contiguous() else: - wq, (group_scale, group_zero) = quantize_int4_preshuffle( - w.cuda(), dtype="bf16" + wq, (group_scale, group_zero_or_row_scale) = quantize_int4_preshuffle( + w.cuda(), dtype=activation_dtype ) + if activation_dtype == "bf16": + group_zero = group_zero_or_row_scale + row_scale = None + else: + group_zero = None + row_scale = group_zero_or_row_scale + shape_multiplier = [1] * wq.ndim shape_multiplier[-1] = 2 @@ -168,6 +224,7 @@ def from_float( packed_weight=wq, group_scale=group_scale, group_zero=group_zero, + row_scale=row_scale, group_size=group_size, shape_multiplier=shape_multiplier, shape=original_shape, @@ -189,21 +246,42 @@ def _(func, types, args, kwargs): wq = weight_tensor.packed_weight.contiguous() group_scale = weight_tensor.group_scale.contiguous() - group_zero = weight_tensor.group_zero.contiguous() - - if input_tensor.dim() == 3: - B, M, _ = input_tensor.shape - _, N, _ = wq.shape - res = torch.empty((B, M, N), device=input_tensor.device, dtype=torch.bfloat16) - for i in range(B): - res[i] = torch.ops.fbgemm.bf16i4bf16_shuffled( - input_tensor[i], wq[i], group_scale[i], group_zero[i] + # bf16 activation + if weight_tensor.group_zero is not None: + group_zero = weight_tensor.group_zero.contiguous() + if input_tensor.ndim == 3 and wq.ndim == 3: + B, M, _ = input_tensor.shape + _, N, _ = wq.shape + res = torch.empty( + (B, M, N), device=input_tensor.device, dtype=torch.bfloat16 + ) + for i in range(B): + res[i] = torch.ops.fbgemm.bf16i4bf16_shuffled( + input_tensor[i], wq[i], group_scale[i], group_zero[i] + ) + else: + # Otherwise run gemm normally. + res = torch.ops.fbgemm.bf16i4bf16_shuffled( + input_tensor, wq, group_scale, group_zero ) else: - # Otherwise run gemm normally. - res = torch.ops.fbgemm.bf16i4bf16_shuffled( - input_tensor, wq, group_scale, group_zero - ) + assert weight_tensor.row_scale is not None + row_scale = weight_tensor.row_scale.contiguous() + xq, x_scale = quantize_fp8_row(input_tensor) + # From: https://github.com/pytorch/FBGEMM/blob/ba8f2b7adb90e096cff8818716f7cc3587030f70/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py#L1654 + if xq.dim() == 3: + B, M, _ = xq.shape + _, N, _ = wq.shape + res = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16) + for i in range(B): + res[i] = torch.ops.fbgemm.f8i4bf16_shuffled( + xq[i], wq[i], x_scale[i], row_scale[i], group_scale[i] + ) + else: + # Otherwise run gemm normally. + res = torch.ops.fbgemm.f8i4bf16_shuffled( + xq, wq, x_scale, row_scale, group_scale + ) res = res.reshape(*orig_input_size[:-1], orig_out_features) if bias is not None: @@ -221,17 +299,27 @@ def _(func, types, args, kwargs): orig_out_features = weight_tensor.shape[-2] assert weight_tensor.shape_multiplier[-1] == 2 - wq = weight_tensor.packed_weight - group_scale = weight_tensor.group_scale - group_zero = weight_tensor.group_zero - # from https://github.com/pytorch/FBGEMM/blob/ba8f2b7adb90e096cff8818716f7cc3587030f70/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py#L1715-L1722 - B, M, _ = input_tensor.shape - _, N, _ = wq.shape - res = torch.empty((B, M, N), device=input_tensor.device, dtype=torch.bfloat16) - for i in range(B): - res[i] = torch.ops.fbgemm.bf16i4bf16_shuffled( - input_tensor[i], wq[i], group_scale[i], group_zero[i] + wq = weight_tensor.packed_weight.contiguous() + group_scale = weight_tensor.group_scale.contiguous() + if weight_tensor.group_zero is not None: + group_zero = weight_tensor.group_zero.contiguous() + res = torch.ops.fbgemm.bf16i4bf16_shuffled_batched( + input_tensor, wq, group_scale, group_zero ) + else: + assert weight_tensor.row_scale is not None + row_scale = weight_tensor.row_scale.contiguous() + xq, x_scale = quantize_fp8_row(input_tensor) + # From: https://github.com/pytorch/FBGEMM/blob/ba8f2b7adb90e096cff8818716f7cc3587030f70/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py#L1654 + assert xq.dim() == 3 + B, M, _ = xq.shape + _, N, _ = wq.shape + res = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16) + for i in range(B): + res[i] = torch.ops.fbgemm.f8i4bf16_shuffled( + xq[i], wq[i], x_scale[i], row_scale[i], group_scale[i] + ) + res = res.reshape(*orig_input_size[:-1], orig_out_features) return res @@ -259,7 +347,16 @@ def _same_metadata( and self.shape == src.shape and self.packed_weight.shape == src.packed_weight.shape and self.group_scale.shape == src.group_scale.shape - and self.group_zero.shape == src.group_zero.shape + and ( + self.group_zero.shape == src.group_zero.shape + if self.group_zero is not None + else src.group_zero is None + ) + and ( + self.row_scale.shape == src.row_scale.shape + if self.row_scale is not None + else src.row_scale is None + ) and self.group_size == src.group_size and self.shape_multiplier == src.shape_multiplier )