Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Apr 19, 2024
1 parent 4ac3f5f commit b614d9f
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 9 deletions.
99 changes: 99 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1 +1,100 @@
# compressed-tensors

This repository extends a [safetensors](https://github.com/huggingface/safetensors) format to efficiently store sparse and/or quantized tensors on disk. `compressed-tensors` format supports multiple compression types to minimize the disk space and facilitate the tensor manipulation.

## Motivation

### Reduce disk space by saving sparse tensors in a compressed format

The compressed format stores the data much more efficiently by taking advantage of two properties of tensors:

- Sparse tensors -> due to a large number of entries that are equal to zero.
- Quantized -> due to their low precision representation.


### Introduce an elegant interface to save/load compressed tensors

The library provides the user with the ability to compress/decompress tensors. The properties of tensors are defined by human-readable configs, allowing the users to understand the compression format at a quick glance.

## Installation

### Pip

```bash
pip install compressed-tensors
```

### From source

```bash
git clone https://github.com/neuralmagic/compressed-tensors
cd compressed-tensors
pip install -e .
```

## Getting started

### Saving

The function `save_compressed` returns an optional `compression_config` (if compression has been applied). It can be used to inspect the applied compression.

```python
from compressed_tensors import save_compressed
from torch import Tensor

tensors: Dict[str, Tensor] = ...
compression_config: Dict = save_compressed(tensors, "model.safetensors")


```

### Loading

```python
from compressed_tensors import load_compressed
from torch import Tensor

tensors: Dict[str, Tensor] = load_compressed("model.safetensors", device="cpu")
```

## Benefits
TODO

## SafeTensors File Format

For each parameter in the uncompressed state_dict, we store the following attributes needed for decompression in the compressed state_dict:

- Compressed tensor
- Bitmask
- Uncompressed shape
- Row offsets

```python
# Dense
{
PARAM_NAME: uncompressed_tensor
}

# Compressed
{
PARAM_NAME.compressed: compressed_tensor, # 1d tensor
PARAM_NAME.bitmask: value, # 2d bitmask tensor (nrows x (ncols / 8))
PARAM_NAME.shape: value, # Uncompressed shape tensor
PARAM_NAME.row_offsets: value # 1d offsets tensor
}
```

The library provides pathways to automatically add the config information to the HF config file.

```json
// config.json
{
"sparsity_config": {
"format": "sparse_bitmask", // "dense_sparsity" for the original tensor format

// Informational
"sparsity_structure": "unstructured", // Or 2:4, 8:16, etc.
"global_sparsity": "0.5"
}
}
```
2 changes: 1 addition & 1 deletion src/compressed_tensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

SPARSITY_CONFIG_NAME = "sparsity_config"
CONFIG_NAME = "compression_config"
4 changes: 2 additions & 2 deletions src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import operator
from typing import Dict, Generator, Tuple

from compressed_tensors.base import SPARSITY_CONFIG_NAME
from compressed_tensors.base import CONFIG_NAME
from compressed_tensors.config import CompressionConfig
from compressed_tensors.registry import RegistryMixin
from torch import Tensor
Expand Down Expand Up @@ -70,4 +70,4 @@ def overwrite_weights(self, model_path: str, model: Module):
data_old = operator.attrgetter(name)(model)
data_old.data = data_new.data

setattr(model, SPARSITY_CONFIG_NAME, self.config)
setattr(model, CONFIG_NAME, self.config)
56 changes: 51 additions & 5 deletions src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Dict, Optional, Union

from typing import Optional

from compressed_tensors.base import SPARSITY_CONFIG_NAME
from compressed_tensors.base import CONFIG_NAME
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import CompressionConfig
from safetensors.torch import save_file
from torch import Tensor
from transformers import AutoConfig


__all__ = ["infer_compressor_from_model_config"]
__all__ = ["infer_compressor_from_model_config", "load_compressed", "save_compressed"]


def infer_compressor_from_model_config(
Expand All @@ -35,11 +37,55 @@ def infer_compressor_from_model_config(
:return: matching compressor if config contains a sparsity config
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None)
sparsity_config = getattr(config, CONFIG_NAME, None)
if sparsity_config is None:
return None

format = sparsity_config.get("format")
sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
return compressor


def save_compressed(
tensors: Dict[str, Tensor],
save_path: Union[str, Path],
compression_config: Optional[CompressionConfig] = None,
) -> Optional[CompressionConfig]:
"""
Save compressed tensors to disk. If tensors are not compressed,
save them as is.
:param tensors: dictionary of tensors to compress
:param save_path: path to save compressed tensors
:param compression_config: compression config to use for compressing tensors.
Can be either inferred from tensors or provided explicitly
:return: compression config, if tensors were compressed - None otherwise
"""
if tensors is None or len(tensors) == 0:
raise ValueError("No tensors or empty tensors provided to compress")

# create compression config if not provided
# TODO: Not implemented, need to get this in ASAP
# compression_config = compression_config or infer_compression_config(tensors)

if compression_config is None:
# no compression applied
save_file(tensors, save_path)
return None

# compress
compression_format = compression_config.format
compressor = ModelCompressor.load_from_registry(
compression_format, config=compression_config
)
# save compressed tensors
compressed_tensors = compressor.compress(tensors)
save_file(compressed_tensors, save_path)

# return compression_config as dict
return {CONFIG_NAME: compression_config.model_dump(exclude_unset=True)}


def load_compressed(compressed_tensors: Union[str, Path], device: str):
pass
1 change: 0 additions & 1 deletion tests/quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from compressed_tensors.quantization.lifecycle import apply_quantization_config
from compressed_tensors.quantization.quant_config import (
QuantizationConfig,
Expand Down
79 changes: 79 additions & 0 deletions tests/test_utils/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from compressed_tensors import save_compressed
from compressed_tensors.config import BitmaskConfig


@pytest.fixture
def tensors_and_config_sparse():
tensors = {"tensor_1": torch.Tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])}
expected_config_json = {
"compression_config": {
"format": "sparse_bitmask",
"global_sparsity": (
tensors["tensor_1"].sum() / tensors["tensor_1"].numel()
).item(),
"sparsity_structure": "unstructured",
}
}
return tensors, expected_config_json


@pytest.fixture
def tensors_dense():
tensors = {"tensor_1": torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])}
return tensors


def test_save_compressed_sparse(tmp_path, tensors_and_config_sparse):
tensors, expected_config_json = tensors_and_config_sparse

config_json = save_compressed(
tensors,
compression_config=BitmaskConfig(
format=expected_config_json["compression_config"]["format"],
global_sparsity=expected_config_json["compression_config"][
"global_sparsity"
],
sparsity_structure=expected_config_json["compression_config"][
"sparsity_structure"
],
),
save_path=tmp_path / "model.safetensors",
)
assert (tmp_path / "model.safetensors").exists()
assert config_json == expected_config_json


def test_save_compressed_dense(tmp_path, tensors_dense):
tensors = tensors_dense

config_json = save_compressed(
tensors,
save_path=tmp_path / "model.safetensors",
)
assert (tmp_path / "model.safetensors").exists()
assert config_json is None


def test_save_compressed_empty():
# make sure function raises error
with pytest.raises(Exception):
save_compressed({}, "")

with pytest.raises(Exception):
save_compressed(None, "")

0 comments on commit b614d9f

Please sign in to comment.