From d4787e23db3871ce922c0feb4e6ad021076d8a9e Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Thu, 25 Apr 2024 15:45:10 +0200 Subject: [PATCH] [Release 0.3.0] Basic Readme and user-facing pathways (#30) * initial commit * is this a version problem * or wrong find_packages logic * all_right * initial commit * add load_compress func * More tests (loading dense tensors) * simplify UX * cosmetic changes * finishing the PR * finalize the PR * Update src/compressed_tensors/compressors/sparse_bitmask.py * disable ipynb test --- README.md | 81 ++++++ examples/bitmask_compression.ipynb | 252 ++++++++++++++++++ makefile | 6 +- setup.py | 2 +- src/compressed_tensors/README.md | 162 ----------- .../compressors/__init__.py | 7 +- src/compressed_tensors/compressors/base.py | 16 +- src/compressed_tensors/compressors/dense.py | 7 +- src/compressed_tensors/compressors/helpers.py | 124 ++++++++- .../compressors/sparse_bitmask.py | 19 +- src/compressed_tensors/config/base.py | 8 +- src/compressed_tensors/config/dense.py | 6 +- .../config/sparse_bitmask.py | 6 +- .../utils/safetensors_load.py | 51 ++-- .../test_bitmask_compression_ipynb.py | 31 +++ tests/test_registry.py | 10 +- tests/test_utils/test_helpers.py | 151 +++++++++++ 17 files changed, 725 insertions(+), 214 deletions(-) create mode 100644 examples/bitmask_compression.ipynb delete mode 100644 src/compressed_tensors/README.md create mode 100644 tests/test_examples/test_bitmask_compression_ipynb.py create mode 100644 tests/test_utils/test_helpers.py diff --git a/README.md b/README.md index 05fa83a3..361a68f9 100644 --- a/README.md +++ b/README.md @@ -1 +1,82 @@ # compressed-tensors + +This repository extends a [safetensors](https://github.com/huggingface/safetensors) format to efficiently store sparse and/or quantized tensors on disk. `compressed-tensors` format supports multiple compression types to minimize the disk space and facilitate the tensor manipulation. + +## Motivation + +### Reduce disk space by saving sparse tensors in a compressed format + +The compressed format stores the data much more efficiently by taking advantage of two properties of tensors: + +- Sparse tensors -> due to a large number of entries that are equal to zero. +- Quantized -> due to their low precision representation. + +### Introduce an elegant interface to save/load compressed tensors + +The library provides the user with the ability to compress/decompress tensors. The properties of tensors are defined by human-readable configs, allowing the users to understand the compression format at a quick glance. + +## Installation + +### Pip + +```bash +pip install compressed-tensors +``` + +### From source + +```bash +git clone https://github.com/neuralmagic/compressed-tensors +cd compressed-tensors +pip install -e . +``` + +## Getting started + +### Saving/Loading Compressed Tensors (Bitmask Compression) + +The function `save_compressed` uses the `compression_format` argument to apply compression to tensors. +The function `load_compressed` reverses the process: converts the compressed weights on disk to decompressed weights in device memory. + +```python +from compressed_tensors import save_compressed, load_compressed, BitmaskConfig +from torch import Tensor +from typing import Dict + +# the example BitmaskConfig method efficiently compresses +# tensors with large number of zero entries +compression_config = BitmaskConfig() + +tensors: Dict[str, Tensor] = {"tensor_1": Tensor( + [[0.0, 0.0, 0.0], + [1.0, 1.0, 1.0]] +)} +# compress tensors using BitmaskConfig compression format (save them efficiently on disk) +save_compressed(tensors, "model.safetensors", compression_format=compression_config.format) + +# decompress tensors (load_compressed returns a generator for memory efficiency) +decompressed_tensors = {} +for tensor_name, tensor in load_compressed("model.safetensors", compression_config = compression_config): + decompressed_tensors[tensor_name] = tensor +``` + +## Saving/Loading Compressed Models (Bitmask Compression) + +We can apply bitmask compression to a whole model. For more detailed example see `example` directory. +```python +from compressed_tensors import save_compressed_model, load_compressed, BitmaskConfig +from transformers import AutoModelForCausalLM + +model_name = "neuralmagic/llama2.c-stories110M-pruned50" +model = AutoModelForCausalLM.from_pretrained(model_name) + +original_state_dict = model.state_dict() + +compression_config = BitmaskConfig() + +# save compressed model weights +save_compressed_model(model, "compressed_model.safetensors", compression_format=compression_config.format) + +# load compressed model weights (`dict` turns generator into a dictionary) +state_dict = dict(load_compressed("compressed_model.safetensors", compression_config)) +``` diff --git a/examples/bitmask_compression.ipynb b/examples/bitmask_compression.ipynb new file mode 100644 index 00000000..995629c4 --- /dev/null +++ b/examples/bitmask_compression.ipynb @@ -0,0 +1,252 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Bitmask Compression Example ##\n", + "\n", + "Bitmask compression allows for storing sparse tensors efficiently on the disk. \n", + "\n", + "Instead of storing each zero element represented as an actual number, we use bitmask to indicate which tensor entries correspond to zero elements. This approach is useful when the matrix is mostly zero values, as it saves space by not wastefully storing those zeros explicitly.\n", + "\n", + "The example below shows how to save and load sparse tensors using bitmask compression. It also demonstrates the benefits of the bitmask compression over \"dense\" representation, and finally, introduces the enhanced `safetensors` file format for storing sparse weights." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import os\n", + "from safetensors import safe_open\n", + "from safetensors.torch import save_model\n", + "from compressed_tensors import save_compressed_model, load_compressed, BitmaskConfig\n", + "from transformers import AutoModelForCausalLM" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LlamaForCausalLM(\n", + " (model): LlamaModel(\n", + " (embed_tokens): Embedding(32000, 768)\n", + " (layers): ModuleList(\n", + " (0-11): 12 x LlamaDecoderLayer(\n", + " (self_attn): LlamaSdpaAttention(\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (o_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (rotary_emb): LlamaRotaryEmbedding()\n", + " )\n", + " (mlp): LlamaMLP(\n", + " (gate_proj): Linear(in_features=768, out_features=2048, bias=False)\n", + " (up_proj): Linear(in_features=768, out_features=2048, bias=False)\n", + " (down_proj): Linear(in_features=2048, out_features=768, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): LlamaRMSNorm()\n", + " (post_attention_layernorm): LlamaRMSNorm()\n", + " )\n", + " )\n", + " (norm): LlamaRMSNorm()\n", + " )\n", + " (lm_head): Linear(in_features=768, out_features=32000, bias=False)\n", + ")" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# load a tiny, pruned llama2 model\n", + "model_name = \"neuralmagic/llama2.c-stories110M-pruned50\"\n", + "model = AutoModelForCausalLM.from_pretrained(model_name)\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The example layer model.layers.0.self_attn.q_proj.weight has sparsity 0.50%\n" + ] + } + ], + "source": [ + "# most of the weights of the model are pruned to 50% (except for few layers such as lm_head or embeddings)\n", + "state_dict = model.state_dict()\n", + "state_dict.keys()\n", + "example_layer = \"model.layers.0.self_attn.q_proj.weight\"\n", + "print(f\"The example layer {example_layer} has sparsity {torch.sum(state_dict[example_layer] == 0).item() / state_dict[example_layer].numel():.2f}%\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The model is 31.67% sparse overall\n" + ] + } + ], + "source": [ + "# we can inspect to total sparisity of the state_dict\n", + "total_num_parameters = 0\n", + "total_num_zero_parameters = 0\n", + "for key in state_dict:\n", + " total_num_parameters += state_dict[key].numel()\n", + " total_num_zero_parameters += state_dict[key].eq(0).sum().item()\n", + "print(f\"The model is {total_num_zero_parameters/total_num_parameters*100:.2f}% sparse overall\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Compressing model: 100%|██████████| 111/111 [00:06<00:00, 17.92it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Size of the model's weights on disk using safetensors: 417.83 MB\n", + "Size of the model's weights on disk using compressed-tensors: 366.82 MB\n", + "The compression ratio is x1.14\n" + ] + } + ], + "source": [ + "# let's save the model on disk using safetensors and compressed-tensors and compare the size on disk\n", + "\n", + "## save the model using safetensors ##\n", + "save_model(model, \"model.safetensors\")\n", + "size_on_disk_mb = os.path.getsize('model.safetensors') / 1024 / 1024\n", + "\n", + "## save the model using compressed-tensors ##\n", + "save_compressed_model(model, \"compressed_model.safetensors\", compression_format=\"sparse-bitmask\")\n", + "compressed_size_on_disk_mb = os.path.getsize('compressed_model.safetensors') / 1024 / 1024\n", + "\n", + "print(f\"Size of the model's weights on disk using safetensors: {size_on_disk_mb:.2f} MB\")\n", + "print(f\"Size of the model's weights on disk using compressed-tensors: {compressed_size_on_disk_mb:.2f} MB\")\n", + "print(\"The compression ratio is x{:.2f}\".format(size_on_disk_mb / compressed_size_on_disk_mb))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Storing weights with around 30% of zero entries requires significantly less disk space when using `compressed-tensors`. The compression ratio improves radically for more sparse models. \n", + "\n", + "We can load back the `state_dict` from the compressed and uncompressed representation on disk and confirm, that they represent same tensors in memory." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Once loaded, the state_dicts from safetensors and compressed-tensors are equal: True\n" + ] + } + ], + "source": [ + "# load the safetensor and the compressed-tensor and show that they have the same representation\n", + "\n", + "## load the uncompressed safetensors to memory ##\n", + "state_dict_1 = {}\n", + "with safe_open('model.safetensors', framework=\"pt\") as f:\n", + " for key in f.keys():\n", + " state_dict_1[key] = f.get_tensor(key)\n", + "\n", + "## load the compressed-tensors to memory ##\n", + "config = BitmaskConfig() # we need to specify the method for decompression\n", + "state_dict_2 = dict(load_compressed(\"compressed_model.safetensors\", config)) # load_compressed returns a generator, we convert it to a dict\n", + "\n", + "tensors_equal = all(torch.equal(state_dict_1[key], state_dict_2[key]) for key in state_dict_1)\n", + "\n", + "print(f\"Once loaded, the state_dicts from safetensors and compressed-tensors are equal: {tensors_equal}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### SafeTensors File Format\n", + "\n", + "The reason why the introduced bitmask compression is much more efficient, is imbibing the information about the compression in the header of the `.safetensors` file.\n", + "For each parameter in the uncompressed `state_dict`, we store the following attributes needed for decompression in the compressed `state_dict`:\n", + "\n", + "* Compressed tensor\n", + "* Bitmask\n", + "* Uncompressed shape\n", + "* Row offsets\n", + "\n", + "```bash\n", + "# Dense\n", + "{\n", + " PARAM_NAME: uncompressed_tensor\n", + "}\n", + "\n", + "# Compressed\n", + "{\n", + " PARAM_NAME.compressed: compressed_tensor, # 1d tensor\n", + " PARAM_NAME.bitmask: value, # 2d bitmask tensor (nrows x (ncols / 8))\n", + " PARAM_NAME.shape: value, # Uncompressed shape tensor\n", + " PARAM_NAME.row_offsets: value # 1d offsets tensor\n", + "}\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/makefile b/makefile index 435a37b9..255514f9 100644 --- a/makefile +++ b/makefile @@ -1,4 +1,4 @@ -BUILDDIR := $(PWD) + PYCHECKDIRS := src tests PYCHECKGLOBS := 'src/**/*.py' 'tests/**/*.py' 'utils/**/*.py' 'examples/**/*.py' setup.py # run checks on all files for the repo @@ -23,6 +23,10 @@ test: @echo "Running python tests"; pytest tests; +# creates wheel file +build: + python3 setup.py sdist bdist_wheel $(BUILD_ARGS) + # clean package clean: @echo "Cleaning up"; diff --git a/setup.py b/setup.py index 225d7b8d..c6e4b380 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def _setup_install_requires() -> List: return ["torch>=1.7.0", "transformers<=4.40", "pydantic<2.7"] def _setup_extras() -> Dict: - return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0",]} + return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0", "nbconvert>=7.16.3"]} setup( name="compressed-tensors", diff --git a/src/compressed_tensors/README.md b/src/compressed_tensors/README.md deleted file mode 100644 index 5b1c8ece..00000000 --- a/src/compressed_tensors/README.md +++ /dev/null @@ -1,162 +0,0 @@ -# Save/Load Compressed SafeTensors - -## Motivation - -* Reduce disk space by saving in a compressed format for sparse models. Models in this compressed format will be loaded by vLLM for more efficient inference -* Set up the save/load architecture such that we can easily expand to additional compression formats in the future. The config should be human readable so users can understand the compression format at a quick glance - -## SafeTensors File Format - -For each parameter in the uncompressed state_dict, we store the following attributes -needed for decompression in the compressed state_dict: - -* compressed tensor -* bitmask -* uncompressed shape -* row offsets - -```python -# dense -{ - PARAM_NAME: uncompressed_tensor -} - -# compressed -{ - PARAM_NAME.compressed: compressed_tensor # 1d tensor - PARAM_NAME.bitmask: value # 2d bitmask tensor (nrows x (ncols / 8)) - PARAM_NAME.shape: value # uncompressed shape tensor - PARAM_NAME.row_offsets: value # 1d offsets tensor -} -``` - -Config information gets stored in the HF config file -```json -// config.json -{ - "sparsity_config": { - "format": "sparse_bitmask", // "dense_sparsity" for original tensor format - - // informational - "sparsity_structure": "unstructured", // or 2:4, 8:16 etc... - "global_sparsity": "0.5" - } -} -``` - -## Saving/Loading Interface - -Loading in a compressed model requires no interface changes - -```python -from sparseml.transformers.utils import SparseAutoModelForCausalLM - -# should contain model.safetensors or model.safetensors.index.json -model_path = "/PATH/TO/COMPRESSED_MODEL" - -model = SparseAutoModelForCausalLM.from_pretrained( - model_name_or_path=model_path, - **model_kwargs, -) -``` - -Saving a compressed model with an explicitly provided compression config. The config -is saved to the model's `config.json` file. **Note:** the model must have been -initialized with SparseAutoModelForCausalLM.from_pretrained() - -```python -from compressed_tensors import BitmaskConfig - -output_dir = "/PATH/TO/SAVE/COMPRESSED_MODEL" -sparsity_config = BitmaskConfig() - -model.save_pretrained( - save_directory=output_dir, - sparsity_config=sparsity_config, -) -``` - -Saving a compressed model, inferring the config from the model attributes - -```python -model.save_pretrained( - save_directory=output_dir, - save_compressed=True -) -``` - -Saving a model in the dense format. If the model has at least 5% global sparsity a -sparsity config will still be included in `config.json` with format `dense_sparsity` - -```python -model.save_pretrained( - save_directory=output_dir -) -``` - -Saving a model in the dense format, bypassing the sparsity config calculation. When the -`skip_compression_stats` flag is set, no sparsity config will be written to -`config.json` - -```python -model.save_pretrained( - save_directory=output_dir - skip_compression_stats=True -) -``` - -## Enable Compression During One-Shot and Sparse Finetunining -Models that are saved in a supported compressed format on disk will automatically be -decompressed when loaded as input to `sparseml.transformers.oneshot` or -`sparseml.transformers.train` - -To enable compression on save after oneshot or finetuning simply add the -`save_compressed=True` argument to `sparseml.transformers.oneshot` or -`sparseml.transformers.train` - -```python -from sparseml.transformers import train - -train( - save_compressed=True, - model="neuralmagic/TinyLlama-1.1B-Chat-v1.0-pruned2.4", - recipe=RECIPE, - dataset=DATASET -) -``` - - -## Example Code - -Loads a 60% sparse model, compresses it using the inferred bitmask compression, then -reloads the compressed model. - -```python -from sparseml.transformers import SparseAutoModelForCausalLM -from sparseml.utils.pytorch.utils import measure_cuda_memory -import torch - -MODEL_PATH = "zoo:llama2-7b-open_platypus_orca_llama2_pretrain-pruned60" -OUTPUT_PATH = "./test_compress_output" -RECIPE = "zoo:llama2-7b-open_platypus_orca_llama2_pretrain-pruned60" - -torch.cuda.set_device(0) -with measure_cuda_memory() as m: - model = SparseAutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="cuda:0") -print(f"Load dense model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") - -sparsity_config = getattr(model,"sparsity_config", None) -print(f"Sparsity config before compression: {sparsity_config}") -with measure_cuda_memory() as m: - model.save_pretrained(OUTPUT_PATH, save_compressed=True) -print(f"Save compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") - -torch.cuda.set_device(1) -with measure_cuda_memory() as m: - model_again = SparseAutoModelForCausalLM.from_pretrained( - OUTPUT_PATH, device_map="cuda:1" - ) -print(f"Load compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") -sparsity_config = getattr(model_again,"sparsity_config", None) -print(f"Sparsity config after compression: {sparsity_config}") -``` diff --git a/src/compressed_tensors/compressors/__init__.py b/src/compressed_tensors/compressors/__init__.py index 50d569e4..c93f1346 100644 --- a/src/compressed_tensors/compressors/__init__.py +++ b/src/compressed_tensors/compressors/__init__.py @@ -16,5 +16,10 @@ from .base import ModelCompressor from .dense import DenseCompressor -from .helpers import infer_compressor_from_model_config +from .helpers import ( + infer_compressor_from_model_config, + load_compressed, + save_compressed, + save_compressed_model, +) from .sparse_bitmask import BitmaskCompressor, BitmaskTensor diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index 5ef34076..98abe61c 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -13,7 +13,7 @@ # limitations under the License. import operator -from typing import Dict, Generator, Tuple +from typing import Dict, Generator, Optional, Tuple from compressed_tensors.base import SPARSITY_CONFIG_NAME from compressed_tensors.config import CompressionConfig @@ -34,7 +34,7 @@ class ModelCompressor(RegistryMixin): :param config: config specifying compression parameters """ - def __init__(self, config: CompressionConfig): + def __init__(self, config: Optional[CompressionConfig] = None): self.config = config def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: @@ -46,12 +46,16 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: """ raise NotImplementedError() - def decompress(self, model_path: str) -> Generator[Tuple[str, Tensor], None, None]: + def decompress( + self, path_to_model_or_tensors: str + ) -> Generator[Tuple[str, Tensor], None, None]: """ - Reads a compressed state dict located at model_path and returns a - generator for sequentially decompressing back to a dense state dict + Reads a compressed state dict located at path_to_model_or_tensors + and returns a generator for sequentially decompressing back to a + dense state dict - :param model_path: path to compressed safetensors model + :param model_path: path to compressed safetensors model (directory with + one or more safetensors files) or compressed tensors file :return: compressed state dict """ raise NotImplementedError() diff --git a/src/compressed_tensors/compressors/dense.py b/src/compressed_tensors/compressors/dense.py index 6e8785bc..c9a1c00c 100644 --- a/src/compressed_tensors/compressors/dense.py +++ b/src/compressed_tensors/compressors/dense.py @@ -15,10 +15,11 @@ from typing import Dict, Generator, Tuple from compressed_tensors.compressors import ModelCompressor +from compressed_tensors.config import CompressionFormat from torch import Tensor -@ModelCompressor.register(name="dense_sparsity") +@ModelCompressor.register(name=CompressionFormat.dense_sparsity.value) class DenseCompressor(ModelCompressor): """ Identity compressor for dense models, returns the original state_dict @@ -27,5 +28,7 @@ class DenseCompressor(ModelCompressor): def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: return model_state - def decompress(self, model_path: str) -> Generator[Tuple[str, Tensor], None, None]: + def decompress( + self, path_to_model_or_tensors: str, device: str + ) -> Generator[Tuple[str, Tensor], None, None]: return iter([]) diff --git a/src/compressed_tensors/compressors/helpers.py b/src/compressed_tensors/compressors/helpers.py index ac9ed229..1ba75636 100644 --- a/src/compressed_tensors/compressors/helpers.py +++ b/src/compressed_tensors/compressors/helpers.py @@ -12,16 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path +from typing import Dict, Generator, Optional, Tuple, Union -from typing import Optional - +import torch from compressed_tensors.base import SPARSITY_CONFIG_NAME from compressed_tensors.compressors import ModelCompressor -from compressed_tensors.config import CompressionConfig +from compressed_tensors.config import CompressionConfig, CompressionFormat +from compressed_tensors.utils.safetensors_load import get_weight_mappings +from safetensors import safe_open +from safetensors.torch import save_file +from torch import Tensor from transformers import AutoConfig -__all__ = ["infer_compressor_from_model_config"] +__all__ = [ + "infer_compressor_from_model_config", + "load_compressed", + "save_compressed", + "save_compressed_model", +] def infer_compressor_from_model_config( @@ -43,3 +53,109 @@ def infer_compressor_from_model_config( sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config) compressor = ModelCompressor.load_from_registry(format, config=sparsity_config) return compressor + + +def save_compressed( + tensors: Dict[str, Tensor], + save_path: Union[str, Path], + compression_format: Optional[CompressionFormat] = None, +): + """ + Save compressed tensors to disk. If tensors are not compressed, + save them as is. + + :param tensors: dictionary of tensors to compress + :param save_path: path to save compressed tensors + :param compression_format: compression format used for the tensors + :return: compression config, if tensors were compressed - None otherwise + """ + if tensors is None or len(tensors) == 0: + raise ValueError("No tensors or empty tensors provided to compress") + + # if no compression_format specified, default to `dense_sparsity` + compression_format = compression_format or CompressionFormat.dense_sparsity.value + + if not ( + compression_format in ModelCompressor.registered_names() + or compression_format in ModelCompressor.registered_aliases() + ): + raise ValueError( + f"Unknown compression format: {compression_format}. " + f"Must be one of {set(ModelCompressor.registered_names() + ModelCompressor.registered_aliases())}" # noqa E501 + ) + + # compress + compressor = ModelCompressor.load_from_registry(compression_format) + # save compressed tensors + compressed_tensors = compressor.compress(tensors) + save_file(compressed_tensors, save_path) + + +def load_compressed( + compressed_tensors: Union[str, Path], + compression_config: CompressionConfig = None, + device: Optional[str] = "cpu", +) -> Generator[Tuple[str, Tensor], None, None]: + """ + Load compressed tensors from disk. + If tensors are not compressed, load them as is. + + :param compressed_tensors: path to compressed tensors. + This can be a path to a file or a directory containing + one or multiple safetensor files (if multiple - in the format + assumed by huggingface) + :param compression_config: compression config to use for decompressing tensors. + :param device: device to move tensors to. If None, tensors are loaded on CPU. + :param return_dict: if True, return a dictionary of decompressed tensors + :return a generator that yields the name and tensor of the decompressed tensor + """ + if compressed_tensors is None or not Path(compressed_tensors).exists(): + raise ValueError("No compressed tensors provided to load") + + if ( + compression_config is None + or compression_config.format == CompressionFormat.dense_sparsity.value + ): + # if no compression_config specified, or `dense_sparsity` format specified, + # assume tensors are not compressed on disk + weight_mappings = get_weight_mappings(compressed_tensors) + for weight_name, file_with_weight_name in weight_mappings.items(): + with safe_open(file_with_weight_name, framework="pt", device=device) as f: + weight = f.get_tensor(weight_name) + yield weight_name, weight + else: + # decompress tensors + compression_format = compression_config.format + compressor = ModelCompressor.load_from_registry( + compression_format, config=compression_config + ) + yield from compressor.decompress(compressed_tensors, device=device) + + +def save_compressed_model( + model: torch.nn.Module, + filename: str, + compression_format: Optional[CompressionFormat] = None, + force_contiguous: bool = True, +): + """ + Wrapper around safetensors `save_model` helper function, which allows for + saving compressed model to disk. + + Note: The model is assumed to have a + state_dict with unique entries + + :param model: model to save on disk + :param filename: filename location to save the file + :param compression_format: compression format used for the model + :param force_contiguous: forcing the state_dict to be saved as contiguous tensors + """ + state_dict = model.state_dict() + if force_contiguous: + state_dict = {k: v.contiguous() for k, v in state_dict.items()} + try: + save_compressed(state_dict, filename, compression_format=compression_format) + except ValueError as e: + msg = str(e) + msg += " Or use save_compressed_model(..., force_contiguous=True), read the docs for potential caveats." # noqa E501 + raise ValueError(msg) diff --git a/src/compressed_tensors/compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_bitmask.py index f6f03f0b..dec359c3 100644 --- a/src/compressed_tensors/compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_bitmask.py @@ -18,6 +18,7 @@ import numpy import torch from compressed_tensors.compressors import ModelCompressor +from compressed_tensors.config import CompressionFormat from compressed_tensors.utils import get_nested_weight_mappings, merge_names from safetensors import safe_open from torch import Tensor @@ -36,7 +37,7 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -@ModelCompressor.register(name="sparse_bitmask") +@ModelCompressor.register(name=CompressionFormat.sparse_bitmask.value) class BitmaskCompressor(ModelCompressor): """ Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d @@ -70,22 +71,26 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: return compressed_dict - def decompress(self, model_path: str) -> Generator[Tuple[str, Tensor], None, None]: + def decompress( + self, path_to_model_or_tensors: str, device: str = "cpu" + ) -> Generator[Tuple[str, Tensor], None, None]: """ - Reads a bitmask compressed state dict located at model_path and returns a - generator for sequentially decompressing back to a dense state dict + Reads a bitmask compressed state dict located at path_to_model_or_tensors + and returns a generator for sequentially decompressing back to a dense state dict - :param model_path: path to compressed safetensors model + :param model_path: path to compressed safetensors model (directory with + one or more safetensors files) or compressed tensors file + :param device: device to load decompressed weights onto :return: iterator for generating decompressed weights """ weight_mappings = get_nested_weight_mappings( - model_path, self.COMPRESSION_PARAM_NAMES + path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES ) for weight_name in weight_mappings.keys(): weight_data = {} for param_name, safe_path in weight_mappings[weight_name].items(): full_name = merge_names(weight_name, param_name) - with safe_open(safe_path, framework="pt", device="cpu") as f: + with safe_open(safe_path, framework="pt", device=device) as f: weight_data[param_name] = f.get_tensor(full_name) data = BitmaskTensor(**weight_data) decompressed = data.decompress() diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index f58b11f8..96778995 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -12,13 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from typing import Optional from compressed_tensors.registry import RegistryMixin from pydantic import BaseModel -__all__ = ["CompressionConfig"] +__all__ = ["CompressionConfig", "CompressionFormat"] + + +class CompressionFormat(Enum): + dense_sparsity = "dense-sparsity" + sparse_bitmask = "sparse-bitmask" class CompressionConfig(RegistryMixin, BaseModel): diff --git a/src/compressed_tensors/config/dense.py b/src/compressed_tensors/config/dense.py index aa23220c..0a18309e 100644 --- a/src/compressed_tensors/config/dense.py +++ b/src/compressed_tensors/config/dense.py @@ -14,13 +14,13 @@ from typing import Optional -from compressed_tensors.config import CompressionConfig +from compressed_tensors.config import CompressionConfig, CompressionFormat __all__ = ["DenseSparsityConfig"] -@CompressionConfig.register(name="dense_sparsity") +@CompressionConfig.register(name=CompressionFormat.dense_sparsity.value) class DenseSparsityConfig(CompressionConfig): """ Identity configuration for storing a sparse model in @@ -31,6 +31,6 @@ class DenseSparsityConfig(CompressionConfig): "unstructured", "2:4", "8:16" etc """ - format: str = "dense_sparsity" + format: str = CompressionFormat.dense_sparsity.value global_sparsity: Optional[float] = 0.0 sparsity_structure: Optional[str] = "unstructured" diff --git a/src/compressed_tensors/config/sparse_bitmask.py b/src/compressed_tensors/config/sparse_bitmask.py index 9b9cf211..9d2015f3 100644 --- a/src/compressed_tensors/config/sparse_bitmask.py +++ b/src/compressed_tensors/config/sparse_bitmask.py @@ -14,13 +14,13 @@ from typing import Optional -from compressed_tensors.config.base import CompressionConfig +from compressed_tensors.config import CompressionConfig, CompressionFormat __all__ = ["BitmaskConfig"] -@CompressionConfig.register(name="sparse_bitmask") +@CompressionConfig.register(name=CompressionFormat.sparse_bitmask.value) class BitmaskConfig(CompressionConfig): """ Configuration for storing a sparse model using @@ -31,6 +31,6 @@ class BitmaskConfig(CompressionConfig): "unstructured", "2:4", "8:16" etc """ - format: str = "sparse_bitmask" + format: str = CompressionFormat.sparse_bitmask.value global_sparsity: Optional[float] = 0.0 sparsity_structure: Optional[str] = "unstructured" diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index 7a9973dc..ee8e6ddd 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -120,7 +120,7 @@ def merge_names(parent_name: str, child_name: str) -> str: return parent_name + "." + child_name -def get_weight_mappings(model_path: str) -> Dict[str, str]: +def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]: """ Takes a path to a state dict saved in safetensors format and returns a mapping from parameterized layer name to file location. @@ -134,31 +134,42 @@ def get_weight_mappings(model_path: str) -> Dict[str, str]: This generalizes to cases where the model is split into multiple safetensors files - :param model_path: path to safetensors state dict, must contain either a single - safetensors file or multiple files with an index + :param path_to_model_or_tensors: path to directory that contains + safetensors (must contain either a single file or multiple files with an index), + or a path to a single safetensors file :return: mapping of parameterized layer name to file location """ - safetensors_path = os.path.join(model_path, SAFE_WEIGHTS_NAME) - index_path = os.path.join(model_path, SAFE_WEIGHTS_INDEX_NAME) - if os.path.exists(safetensors_path): + + if os.path.isfile(path_to_model_or_tensors): # we have a single safetensors file to read - header = get_safetensors_header(safetensors_path) + header = get_safetensors_header(path_to_model_or_tensors) for key in header.keys(): - header[key] = SAFE_WEIGHTS_NAME + header[key] = path_to_model_or_tensors header.pop("__metadata__", None) - elif os.path.exists(index_path): - # we have multiple safetensors file, read from index - with open(index_path, "r", encoding="utf-8") as f: - index = json.load(f) - header = index["weight_map"] else: - raise ValueError( - f"Could not find a safetensors weight or index file at {model_path}" - ) - - # convert weight locations to full paths - for key, value in header.items(): - header[key] = os.path.join(model_path, value) + # we have a directory with multiple safetensors files + safetensors_path = os.path.join(path_to_model_or_tensors, SAFE_WEIGHTS_NAME) + index_path = os.path.join(path_to_model_or_tensors, SAFE_WEIGHTS_INDEX_NAME) + if os.path.exists(safetensors_path): + # we have a single safetensors file to read + header = get_safetensors_header(safetensors_path) + for key in header.keys(): + header[key] = SAFE_WEIGHTS_NAME + header.pop("__metadata__", None) + elif os.path.exists(index_path): + # we have multiple safetensors file, read from index + with open(index_path, "r", encoding="utf-8") as f: + index = json.load(f) + header = index["weight_map"] + else: + raise ValueError( + "Could not find a safetensors weight " + f"or index file at {path_to_model_or_tensors}" + ) + + # convert weight locations to full paths + for key, value in header.items(): + header[key] = os.path.join(path_to_model_or_tensors, value) return header diff --git a/tests/test_examples/test_bitmask_compression_ipynb.py b/tests/test_examples/test_bitmask_compression_ipynb.py new file mode 100644 index 00000000..a3ee9d2a --- /dev/null +++ b/tests/test_examples/test_bitmask_compression_ipynb.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nbformat +import pytest +from nbconvert.preprocessors import ExecutePreprocessor + + +@pytest.mark.skip( + reason="GHA not setup yet to run those tests. The test should work locally" +) +@pytest.mark.parametrize("notebook", ["examples/bitmask_compression.ipynb"]) +def test_notebook_exec(notebook): + with open(notebook) as f: + nb = nbformat.read(f, as_version=4) + ep = ExecutePreprocessor(timeout=600, kernel_name="python3") + try: + assert ep.preprocess(nb) is not None, f"Got empty notebook for {notebook}" + except Exception: + assert False, f"Failed executing {notebook}" diff --git a/tests/test_registry.py b/tests/test_registry.py index a183d77d..ffe66b85 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -17,6 +17,7 @@ BitmaskCompressor, BitmaskConfig, CompressionConfig, + CompressionFormat, DenseCompressor, DenseSparsityConfig, ModelCompressor, @@ -26,8 +27,8 @@ @pytest.mark.parametrize( "name,type", [ - ["sparse_bitmask", BitmaskConfig], - ["dense_sparsity", DenseSparsityConfig], + [CompressionFormat.sparse_bitmask.value, BitmaskConfig], + [CompressionFormat.dense_sparsity.value, DenseSparsityConfig], ], ) def test_configs(name, type): @@ -38,7 +39,10 @@ def test_configs(name, type): @pytest.mark.parametrize( "name,type", - [["sparse_bitmask", BitmaskCompressor], ["dense_sparsity", DenseCompressor]], + [ + [CompressionFormat.sparse_bitmask.value, BitmaskCompressor], + [CompressionFormat.dense_sparsity.value, DenseCompressor], + ], ) def test_compressors(name, type): compressor = ModelCompressor.load_from_registry( diff --git a/tests/test_utils/test_helpers.py b/tests/test_utils/test_helpers.py new file mode 100644 index 00000000..7ae0799d --- /dev/null +++ b/tests/test_utils/test_helpers.py @@ -0,0 +1,151 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import torch +from compressed_tensors import load_compressed, save_compressed, save_compressed_model +from compressed_tensors.config import BitmaskConfig +from safetensors.torch import save_model +from transformers import AutoModelForCausalLM + + +@pytest.fixture +def tensors(): + tensors = {"tensor_1": torch.Tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])} + return tensors + + +@pytest.fixture +def llama_model(tmp_path): + model_name = "neuralmagic/llama2.c-stories110M-pruned50" + model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=tmp_path) + yield model + + +def test_save_compressed_sparse_bitmask(tmp_path, tensors): + save_compressed( + tensors, + compression_format="sparse-bitmask", + save_path=tmp_path / "model.safetensors", + ) + assert (tmp_path / "model.safetensors").exists() + + +def test_save_compressed_dense_sparsity(tmp_path, tensors): + save_compressed( + tensors, + compression_format="dense-sparsity", + save_path=tmp_path / "model.safetensors", + ) + assert (tmp_path / "model.safetensors").exists() + + +def test_save_compressed_no_compression(tmp_path, tensors): + save_compressed( + tensors, + save_path=tmp_path / "model.safetensors", + ) + assert (tmp_path / "model.safetensors").exists() + + +def test_save_compressed_error(tmp_path): + with pytest.raises(Exception): + save_compressed({}, "") + + with pytest.raises(Exception): + save_compressed(None, "") + + with pytest.raises(Exception): + save_compressed( + tensors, + compression_format="this_is_not_a_valid_format", + save_path=tmp_path / "model.safetensors", + ) + + +def test_load_compressed_sparse_bitmask(tmp_path, tensors): + save_compressed( + tensors, + compression_format="sparse-bitmask", + save_path=tmp_path / "model.safetensors", + ) + compression_config = BitmaskConfig( + format="sparse-bitmask", + ) + loaded_tensors = dict( + load_compressed(tmp_path / "model.safetensors", compression_config) + ) + for key in tensors: + assert torch.allclose(tensors[key], loaded_tensors[key]) + + +def test_load_compressed_dense_sparsity(tmp_path, tensors): + save_compressed( + tensors, + compression_format="dense-sparsity", + save_path=tmp_path / "model.safetensors", + ) + save_compressed( + tensors, + save_path=tmp_path / "model_.safetensors", + ) + + loaded_tensors = dict(load_compressed(tmp_path / "model.safetensors")) + loaded_tensors_ = dict(load_compressed(tmp_path / "model_.safetensors")) + # loaded_tensors should be equal to loaded_tensors_ + for key in tensors: + assert torch.allclose(loaded_tensors[key], loaded_tensors_[key]) + + +def test_load_compressed_sharded(tmp_path, llama_model): + sharded_model_path = tmp_path / "shared_model" + llama_model.save_pretrained(sharded_model_path, max_shard_size="2MB") + # make sure that model is shared on disk + assert len(os.listdir(sharded_model_path)) > 1 + loaded_state_dict = dict(load_compressed(sharded_model_path)) + for key, value in llama_model.state_dict().items(): + if key == "lm_head.weight": + # lm_head doesn't have separate weights. + # It shares its weight tensor with the token embedding layer. + continue + assert torch.allclose(value, loaded_state_dict[key]) + + +def test_save_compressed_model(tmp_path, llama_model): + path_to_uncompressed = tmp_path / "model_uncompressed.safetensors" + path_to_compressed = tmp_path / "model_compressed.safetensors" + + # save uncompressed model + save_model(llama_model, path_to_uncompressed) + size_uncompressed_kb = path_to_uncompressed.stat().st_size / 1024 + + # save compressed model + save_compressed_model( + llama_model, path_to_compressed, compression_format="sparse-bitmask" + ) + size_compressed_kb = path_to_compressed.stat().st_size / 1024 + + # compare that the are the same after loading + state_dict_1 = dict(load_compressed(path_to_uncompressed)) + state_dict_2 = dict( + load_compressed(path_to_compressed, BitmaskConfig(format="sparse-bitmask")) + ) + assert all( + torch.allclose(state_dict_1[key], state_dict_2[key]) for key in state_dict_1 + ) + # make sure that compressed model is smaller + # than uncompressed by roughly 1.14 (value established empirically) + assert pytest.approx(size_uncompressed_kb / size_compressed_kb, 0.01) == 1.14