From 3a6ccc8839b5cf8336ab334299ca19b1431e3081 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 27 Nov 2024 13:57:29 +0000 Subject: [PATCH 1/4] Add: Sparse24_compressor + tests --- .../sparse_compressors/__init__.py | 1 + .../sparse_compressors/sparse_24.py | 92 +++++++++++++++++++ src/compressed_tensors/config/__init__.py | 1 + src/compressed_tensors/config/base.py | 1 + src/compressed_tensors/config/sparse_24.py | 37 ++++++++ .../utils/semi_structured_conversions.py | 19 +++- .../test_semi_structured_conversions.py | 66 +++++++++++++ 7 files changed, 213 insertions(+), 4 deletions(-) create mode 100644 src/compressed_tensors/compressors/sparse_compressors/sparse_24.py create mode 100644 src/compressed_tensors/config/sparse_24.py create mode 100644 tests/test_utils/test_semi_structured_conversions.py diff --git a/src/compressed_tensors/compressors/sparse_compressors/__init__.py b/src/compressed_tensors/compressors/sparse_compressors/__init__.py index de4fd887..f1b59ad3 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/__init__.py +++ b/src/compressed_tensors/compressors/sparse_compressors/__init__.py @@ -15,4 +15,5 @@ from .base import * from .dense import * +from .sparse_24 import * from .sparse_bitmask import * diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py new file mode 100644 index 00000000..70974e68 --- /dev/null +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py @@ -0,0 +1,92 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict + +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor +from compressed_tensors.config import CompressionFormat, SparsityStructure +from compressed_tensors.utils import ( + merge_names, + sparse_semi_structured_from_dense_cutlass, + sparse_semi_structured_to_dense_cutlass, + tensor_follows_mask_structure, +) +from torch import Tensor + + +@BaseCompressor.register(name=CompressionFormat.sparse_24.value) +class Sparse24Compressor(BaseSparseCompressor): + """ + Compresses a with 2:4 sparsity structure for inference + with sparse 2:4 kernels for float/float16/bfloat16. + https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/semi_structured.py + """ + + COMPRESSION_PARAM_NAMES = ["sparse_24_packed_weight", "meta"] + + @staticmethod + def validate_sparsity_structure(name: str, weight: Tensor) -> bool: + """ + Checks if a tensor fits the required 2:4 sparsity structure + :param name: name of the tensor to check + :param weight: tensor to check for sparsity structure + :return: True if all rows match the 2:4 sparsity structure, raises + ValueError otherwise + """ + + if not tensor_follows_mask_structure( + weight, mask=SparsityStructure.TWO_FOUR.value + ): + raise ValueError( + "Sparse24Compressor is only compatible with weights that have " + f"a 2:4 sparsity structure. Found segments in {name} " + "that do not match the expected structure." + ) + + return True + + def compress_weight(self, name: str, value: Tensor) -> Dict[str, Tensor]: + """ + Compresses a given with 2:4 sparsity structure. + :param name: name of the tensor in state dict of uncompressed model + :param value: 2:4 sparse tensor to compress + :return: dictionary containing the compressed weight and associated + metadata + """ + weight_suffix = ".weight" + if not name.endswith(weight_suffix): + return {} + + prefix = name[: -len(weight_suffix)] + self.validate_sparsity_structure(name=prefix, weight=value) + sparse_24_packed_weight, meta = sparse_semi_structured_from_dense_cutlass( + dense=value + ) + return { + merge_names(name, "sparse_24_packed_weight"): sparse_24_packed_weight.cpu(), + merge_names(name, "meta"): meta.cpu(), + } + + def decompress_weight(self, weight_data): + assert ( + "sparse_24_packed_weight" in weight_data + ), "sparse_24_packed_weight not found in weight_data" + assert "meta" in weight_data, "meta not found in weight_data" + + return sparse_semi_structured_to_dense_cutlass( + sparse=weight_data["sparse_24_packed_weight"], + meta_reordered=weight_data["meta"], + ) diff --git a/src/compressed_tensors/config/__init__.py b/src/compressed_tensors/config/__init__.py index ff83f5af..f021f284 100644 --- a/src/compressed_tensors/config/__init__.py +++ b/src/compressed_tensors/config/__init__.py @@ -15,4 +15,5 @@ # flake8: noqa from .base import * from .dense import * +from .sparse_24 import * from .sparse_bitmask import * diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 79a4fcdd..2d280330 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -26,6 +26,7 @@ class CompressionFormat(Enum): dense = "dense" sparse_bitmask = "sparse-bitmask" + sparse_24 = "sparse-24" int_quantized = "int-quantized" float_quantized = "float-quantized" naive_quantized = "naive-quantized" diff --git a/src/compressed_tensors/config/sparse_24.py b/src/compressed_tensors/config/sparse_24.py new file mode 100644 index 00000000..2a5ed384 --- /dev/null +++ b/src/compressed_tensors/config/sparse_24.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from compressed_tensors.config import ( + CompressionFormat, + SparsityCompressionConfig, + SparsityStructure, +) + + +__all__ = ["Sparse24Config"] + + +@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24.value) +class Sparse24Config(SparsityCompressionConfig): + """ + Configuration for storing a sparse model using 2:4 compression + :param global_sparsity: average sparsity of the entire model + :param sparsity_structure: structure of the sparsity, "2:4" + """ + + format: str = CompressionFormat.sparse_24.value + global_sparsity: Optional[float] = 0.0 + sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value diff --git a/src/compressed_tensors/utils/semi_structured_conversions.py b/src/compressed_tensors/utils/semi_structured_conversions.py index ef318a48..480d1b48 100644 --- a/src/compressed_tensors/utils/semi_structured_conversions.py +++ b/src/compressed_tensors/utils/semi_structured_conversions.py @@ -75,6 +75,7 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device # This function converts dense matrix into sparse semi-structured # representation, producing "compressed" matrix, in the layout used by # CUTLASS backend, and corresponding metadata matrix. +# Modified from https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/_semi_structured_conversions.py#L47 def sparse_semi_structured_from_dense_cutlass(dense): if dense.dim() != 2: raise RuntimeError( @@ -85,7 +86,7 @@ def sparse_semi_structured_from_dense_cutlass(dense): device = dense.device meta_dtype = torch.int8 - if dense.dtype == torch.int8: + if dense.dtype == torch.int8 or dense.dtype == torch.float8_e4m3fn: meta_dtype = torch.int32 elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: meta_dtype = torch.int16 @@ -165,11 +166,15 @@ def sparse_semi_structured_from_dense_cutlass(dense): idxs1 = bit2 | (bit3.to(torch.int64) << 1) if dense.dtype != torch.float: + if dense.dtype == torch.float8_e4m3fn: + dense_4 = dense_4.view(torch.int8) sparse0 = dense_4.gather( -1, idxs0.unsqueeze(-1) ) # type: ignore[possibly-undefined] sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + if dense.dtype == torch.float8_e4m3fn: + sparse = sparse.view(torch.float8_e4m3fn) else: sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( m, k // 2 @@ -213,6 +218,7 @@ def sparse_semi_structured_from_dense_cutlass(dense): # reconstructs dense matrix from a pair of "compressed" matrix, given # in the layout used by CUTLASS backend, and accompanying metadata # matrix. +# Copied from https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/_semi_structured_conversions.py#L180 def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): if sparse.dim() != 2: raise RuntimeError( @@ -298,16 +304,21 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): torch.arange(0, 2 * m * k // ksparse, device=device) * 4 ).view(-1, 1).repeat(1, 2).view(-1) - dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + sparse_dtype = sparse.dtype if sparse.dtype != torch.float8_e4m3fn else torch.int8 + dense = torch.zeros((m * 2 * k,), dtype=sparse_dtype, device=device) if sparse.dtype != torch.float: # dense.scatter_(0, dense_offsets, sparse.view(-1)) - dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + if sparse.dtype == torch.float8_e4m3fn: + dense.scatter_(0, dense_offsets, sparse.view(torch.int8).view(-1)) + else: + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) else: dense.view(torch.half).scatter_( 0, dense_offsets, sparse.view(torch.half).view(-1) ) - return dense.view(m, 2 * k) + result = dense.view(m, 2 * k) + return result.view(sparse.dtype) def mask_creator(tensor): diff --git a/tests/test_utils/test_semi_structured_conversions.py b/tests/test_utils/test_semi_structured_conversions.py new file mode 100644 index 00000000..e74722fb --- /dev/null +++ b/tests/test_utils/test_semi_structured_conversions.py @@ -0,0 +1,66 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from compressed_tensors.utils.semi_structured_conversions import ( + sparse_semi_structured_from_dense_cutlass, + sparse_semi_structured_to_dense_cutlass, +) + + +def supported_dtypes(): + return [torch.int8, torch.float16, torch.bfloat16, torch.float8_e4m3fn] + + +def get_random_mat(M, K, dtype): + rand_tensor_dtype = dtype + if dtype in [torch.int8, torch.float8_e4m3fn]: + rand_tensor_dtype = torch.float16 + mat = torch.rand(M, K, dtype=rand_tensor_dtype).cuda() + mat = mat.masked_fill_(mat == 0, 1) + return mat.to(dtype) + + +def generate_pruned_semi_structured_mat(M, K, dtype): + mask = torch.Tensor([0, 0, 1, 1]).tile((M, K // 4)).bool() + rand_tensor_dtype = dtype + if dtype in [torch.int8, torch.float8_e4m3fn]: + rand_tensor_dtype = torch.float16 + mat = torch.rand(M, K, dtype=rand_tensor_dtype) + mat = mat.masked_fill_(mat == 0, 1) + if dtype == torch.float8_e4m3fn: + # some float8_e4m3fn operations are not supported on CPU + mat = mat.cuda() + mask = mask.cuda() + mat = mat * mask + return mat.to(dtype) + + +@pytest.mark.parametrize("dtype", supported_dtypes()) +def test_inverse_property_from_dense_then_to_dense(dtype): + M, K = 1024, 1024 + dense_matrix = generate_pruned_semi_structured_mat(M, K, dtype) + compressed_matrix, meta = sparse_semi_structured_from_dense_cutlass(dense_matrix) + result = sparse_semi_structured_to_dense_cutlass(compressed_matrix, meta) + + assert ( + dense_matrix.dtype == result.dtype + ), f"Dtype Mis-match: {dense_matrix.dtype} and {result.dtype}" + assert ( + dense_matrix.shape == result.shape + ), f"Shape Mis-match: {dense_matrix.shape} and {result.shape}" + assert torch.equal( + dense_matrix, result + ), f"Failed for dtype: {dense_matrix.dtype} and input: {dense_matrix}" From 8fd469f0f317877e636600eb5b01eee1d7bfef43 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 27 Nov 2024 14:13:47 +0000 Subject: [PATCH 2/4] Run float8 test only if cuda is available and device capability is greater than 90 --- tests/test_utils/test_semi_structured_conversions.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_utils/test_semi_structured_conversions.py b/tests/test_utils/test_semi_structured_conversions.py index e74722fb..eb25b34a 100644 --- a/tests/test_utils/test_semi_structured_conversions.py +++ b/tests/test_utils/test_semi_structured_conversions.py @@ -21,7 +21,12 @@ def supported_dtypes(): - return [torch.int8, torch.float16, torch.bfloat16, torch.float8_e4m3fn] + dtypes = [torch.int8, torch.float16, torch.bfloat16] + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + if major > 9 or (major == 9 and minor >= 0): + dtypes += [torch.float8_e4m3fn] + return dtypes def get_random_mat(M, K, dtype): From b07961c1e97a84a1dc52e15ad6eea1984a9ade62 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 3 Dec 2024 08:12:57 +0000 Subject: [PATCH 3/4] Review comments from @dsikka and @kylesayrs --- .../compressors/sparse_compressors/sparse_24.py | 16 +++++++++++++++- .../utils/semi_structured_conversions.py | 16 ++++++++-------- .../test_semi_structured_conversions.py | 9 +++++---- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py index 70974e68..e2219a8d 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py @@ -30,7 +30,7 @@ @BaseCompressor.register(name=CompressionFormat.sparse_24.value) class Sparse24Compressor(BaseSparseCompressor): """ - Compresses a with 2:4 sparsity structure for inference + Compresses a model with 2:4 sparsity structure for inference with sparse 2:4 kernels for float/float16/bfloat16. https://github.com/pytorch/pytorch/blob/78cf8df4a019e919e8eac5f5d048d8842d4fc692/torch/sparse/semi_structured.py """ @@ -81,6 +81,20 @@ def compress_weight(self, name: str, value: Tensor) -> Dict[str, Tensor]: } def decompress_weight(self, weight_data): + """ + Decompresses the given weight data from its compressed representation to its + dense form. + + The weight_data dictionary must contain the keys 'sparse_24_packed_weight' and + 'meta', which represent the sparse-compressed weight and its associated meta + tensor. + + :param weight_data: A dictionary containing: + - sparse_24_packed_weight: The sparse-compressed representation of + the weight. + - meta: The meta tesnor associated with the compressed weight. + :return: The dense representation of the weight. + """ assert ( "sparse_24_packed_weight" in weight_data ), "sparse_24_packed_weight not found in weight_data" diff --git a/src/compressed_tensors/utils/semi_structured_conversions.py b/src/compressed_tensors/utils/semi_structured_conversions.py index 480d1b48..17ea6ef2 100644 --- a/src/compressed_tensors/utils/semi_structured_conversions.py +++ b/src/compressed_tensors/utils/semi_structured_conversions.py @@ -20,7 +20,7 @@ # limitations under the License. import torch - +from compressed_tensors.quantization import FP8_DTYPE __all__ = [ "sparse_semi_structured_from_dense_cutlass", @@ -85,8 +85,8 @@ def sparse_semi_structured_from_dense_cutlass(dense): m, k = dense.shape device = dense.device - meta_dtype = torch.int8 - if dense.dtype == torch.int8 or dense.dtype == torch.float8_e4m3fn: + meta_dtype = None + if dense.dtype == torch.int8 or dense.dtype == FP8_DTYPE: meta_dtype = torch.int32 elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: meta_dtype = torch.int16 @@ -166,15 +166,15 @@ def sparse_semi_structured_from_dense_cutlass(dense): idxs1 = bit2 | (bit3.to(torch.int64) << 1) if dense.dtype != torch.float: - if dense.dtype == torch.float8_e4m3fn: + if dense.dtype == FP8_DTYPE: dense_4 = dense_4.view(torch.int8) sparse0 = dense_4.gather( -1, idxs0.unsqueeze(-1) ) # type: ignore[possibly-undefined] sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) - if dense.dtype == torch.float8_e4m3fn: - sparse = sparse.view(torch.float8_e4m3fn) + if dense.dtype == FP8_DTYPE: + sparse = sparse.view(FP8_DTYPE) else: sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( m, k // 2 @@ -304,11 +304,11 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): torch.arange(0, 2 * m * k // ksparse, device=device) * 4 ).view(-1, 1).repeat(1, 2).view(-1) - sparse_dtype = sparse.dtype if sparse.dtype != torch.float8_e4m3fn else torch.int8 + sparse_dtype = sparse.dtype if sparse.dtype != FP8_DTYPE else torch.int8 dense = torch.zeros((m * 2 * k,), dtype=sparse_dtype, device=device) if sparse.dtype != torch.float: # dense.scatter_(0, dense_offsets, sparse.view(-1)) - if sparse.dtype == torch.float8_e4m3fn: + if sparse.dtype == FP8_DTYPE: dense.scatter_(0, dense_offsets, sparse.view(torch.int8).view(-1)) else: dense.scatter_(0, dense_offsets, sparse.reshape(-1)) diff --git a/tests/test_utils/test_semi_structured_conversions.py b/tests/test_utils/test_semi_structured_conversions.py index eb25b34a..c2c198c6 100644 --- a/tests/test_utils/test_semi_structured_conversions.py +++ b/tests/test_utils/test_semi_structured_conversions.py @@ -14,6 +14,7 @@ import pytest import torch +from compressed_tensors.quantization import FP8_DTYPE from compressed_tensors.utils.semi_structured_conversions import ( sparse_semi_structured_from_dense_cutlass, sparse_semi_structured_to_dense_cutlass, @@ -25,13 +26,13 @@ def supported_dtypes(): if torch.cuda.is_available(): major, minor = torch.cuda.get_device_capability() if major > 9 or (major == 9 and minor >= 0): - dtypes += [torch.float8_e4m3fn] + dtypes += [FP8_DTYPE] return dtypes def get_random_mat(M, K, dtype): rand_tensor_dtype = dtype - if dtype in [torch.int8, torch.float8_e4m3fn]: + if dtype in [torch.int8, FP8_DTYPE]: rand_tensor_dtype = torch.float16 mat = torch.rand(M, K, dtype=rand_tensor_dtype).cuda() mat = mat.masked_fill_(mat == 0, 1) @@ -41,11 +42,11 @@ def get_random_mat(M, K, dtype): def generate_pruned_semi_structured_mat(M, K, dtype): mask = torch.Tensor([0, 0, 1, 1]).tile((M, K // 4)).bool() rand_tensor_dtype = dtype - if dtype in [torch.int8, torch.float8_e4m3fn]: + if dtype in [torch.int8, FP8_DTYPE]: rand_tensor_dtype = torch.float16 mat = torch.rand(M, K, dtype=rand_tensor_dtype) mat = mat.masked_fill_(mat == 0, 1) - if dtype == torch.float8_e4m3fn: + if dtype == FP8_DTYPE: # some float8_e4m3fn operations are not supported on CPU mat = mat.cuda() mask = mask.cuda() From c6ef4f96d0db7ba4cb5c2a7a17a008d4dc25bbe9 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 3 Dec 2024 16:54:01 +0000 Subject: [PATCH 4/4] remove extra .weight --- .../compressors/sparse_compressors/sparse_24.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py index e2219a8d..040daa50 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24.py @@ -76,8 +76,10 @@ def compress_weight(self, name: str, value: Tensor) -> Dict[str, Tensor]: dense=value ) return { - merge_names(name, "sparse_24_packed_weight"): sparse_24_packed_weight.cpu(), - merge_names(name, "meta"): meta.cpu(), + merge_names( + prefix, "sparse_24_packed_weight" + ): sparse_24_packed_weight.cpu(), + merge_names(prefix, "meta"): meta.cpu(), } def decompress_weight(self, weight_data):