Skip to content

Commit

Permalink
add load_compress func
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Apr 19, 2024
1 parent b614d9f commit e394eb2
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 45 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ from torch import Tensor

tensors: Dict[str, Tensor] = ...
compression_config: Dict = save_compressed(tensors, "model.safetensors")


```

### Loading
Expand Down
12 changes: 8 additions & 4 deletions src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,16 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
raise NotImplementedError()

def decompress(self, model_path: str) -> Generator[Tuple[str, Tensor], None, None]:
def decompress(
self, path_to_model_or_tensors: str
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Reads a compressed state dict located at model_path and returns a
generator for sequentially decompressing back to a dense state dict
Reads a compressed state dict located at path_to_model_or_tensors
and returns a generator for sequentially decompressing back to a
dense state dict
:param model_path: path to compressed safetensors model
:param model_path: path to compressed safetensors model (directory with
one or more safetensors files) or compressed tensors file
:return: compressed state dict
"""
raise NotImplementedError()
Expand Down
4 changes: 3 additions & 1 deletion src/compressed_tensors/compressors/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,7 @@ class DenseCompressor(ModelCompressor):
def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
return model_state

def decompress(self, model_path: str) -> Generator[Tuple[str, Tensor], None, None]:
def decompress(
self, path_to_model_or_tensors: str
) -> Generator[Tuple[str, Tensor], None, None]:
return iter([])
16 changes: 10 additions & 6 deletions src/compressed_tensors/compressors/sparse_bitmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,26 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:

return compressed_dict

def decompress(self, model_path: str) -> Generator[Tuple[str, Tensor], None, None]:
def decompress(
self, path_to_model_or_tensors: str, device: str = "cpu"
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Reads a bitmask compressed state dict located at model_path and returns a
generator for sequentially decompressing back to a dense state dict
Reads a bitmask compressed state dict located at path_to_model_or_tensors
and returns agenerator for sequentially decompressing back to a dense state dict
:param model_path: path to compressed safetensors model
:param model_path: path to compressed safetensors model (directory with
one or more safetensors files) or compressed tensors file
:param device: device to load decompressed weights onto
:return: iterator for generating decompressed weights
"""
weight_mappings = get_nested_weight_mappings(
model_path, self.COMPRESSION_PARAM_NAMES
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
)
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, safe_path in weight_mappings[weight_name].items():
full_name = merge_names(weight_name, param_name)
with safe_open(safe_path, framework="pt", device="cpu") as f:
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)
data = BitmaskTensor(**weight_data)
decompressed = data.decompress()
Expand Down
40 changes: 38 additions & 2 deletions src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from compressed_tensors.base import CONFIG_NAME
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import CompressionConfig
from safetensors import safe_open
from safetensors.torch import save_file
from torch import Tensor
from transformers import AutoConfig
Expand Down Expand Up @@ -87,5 +88,40 @@ def save_compressed(
return {CONFIG_NAME: compression_config.model_dump(exclude_unset=True)}


def load_compressed(compressed_tensors: Union[str, Path], device: str):
pass
def load_compressed(
compressed_tensors: Union[str, Path],
compression_config: Optional[CompressionConfig] = None,
device: Optional[str] = "cpu",
) -> Dict[str, Tensor]:
"""
Load compressed tensors from disk. If tensors are not compressed,
load them as is.
:param compressed_tensors: path to compressed tensors
:param compression_config: compression config to use for decompressing tensors.
Can be either inferred from tensors or provided explicitly.
:param device: device to move tensors to. If None, tensors are loaded on CPU.
:return decompressed tensors
"""

if compressed_tensors is None or not Path(compressed_tensors).exists():
raise ValueError("No compressed tensors provided to load")

# create compression config if not provided
# TODO: Not implemented, need to get this in ASAP
# compression_config = compression_config or infer_compression_config(tensors)

if compression_config is None:
# no compression applied
tensors = {}
with safe_open(compressed_tensors, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
return tensors

# decompress
compression_format = compression_config.format
compressor = ModelCompressor.load_from_registry(
compression_format, config=compression_config
)
return dict(compressor.decompress(compressed_tensors))
51 changes: 31 additions & 20 deletions src/compressed_tensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def merge_names(parent_name: str, child_name: str) -> str:
return parent_name + "." + child_name


def get_weight_mappings(model_path: str) -> Dict[str, str]:
def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]:
"""
Takes a path to a state dict saved in safetensors format and returns a mapping
from parameterized layer name to file location.
Expand All @@ -131,31 +131,42 @@ def get_weight_mappings(model_path: str) -> Dict[str, str]:
This generalizes to cases where the model is split into multiple safetensors files
:param model_path: path to safetensors state dict, must contain either a single
safetensors file or multiple files with an index
:param path_to_model_or_tensors: path to directory that contains
safetensors (must contain either a single file or multiple files with an index),
or a path to a single safetensors file
:return: mapping of parameterized layer name to file location
"""
safetensors_path = os.path.join(model_path, SAFE_WEIGHTS_NAME)
index_path = os.path.join(model_path, SAFE_WEIGHTS_INDEX_NAME)
if os.path.exists(safetensors_path):

if os.path.isfile(path_to_model_or_tensors):
# we have a single safetensors file to read
header = get_safetensors_header(safetensors_path)
header = get_safetensors_header(path_to_model_or_tensors)
for key in header.keys():
header[key] = SAFE_WEIGHTS_NAME
header[key] = path_to_model_or_tensors
header.pop("__metadata__", None)
elif os.path.exists(index_path):
# we have multiple safetensors file, read from index
with open(index_path, "r", encoding="utf-8") as f:
index = json.load(f)
header = index["weight_map"]
else:
raise ValueError(
f"Could not find a safetensors weight or index file at {model_path}"
)

# convert weight locations to full paths
for key, value in header.items():
header[key] = os.path.join(model_path, value)
# we have a directory with multiple safetensors files
safetensors_path = os.path.join(path_to_model_or_tensors, SAFE_WEIGHTS_NAME)
index_path = os.path.join(path_to_model_or_tensors, SAFE_WEIGHTS_INDEX_NAME)
if os.path.exists(safetensors_path):
# we have a single safetensors file to read
header = get_safetensors_header(safetensors_path)
for key in header.keys():
header[key] = SAFE_WEIGHTS_NAME
header.pop("__metadata__", None)
elif os.path.exists(index_path):
# we have multiple safetensors file, read from index
with open(index_path, "r", encoding="utf-8") as f:
index = json.load(f)
header = index["weight_map"]
else:
raise ValueError(
"Could not find a safetensors weight "
f"or index file at {path_to_model_or_tensors}"
)

# convert weight locations to full paths
for key, value in header.items():
header[key] = os.path.join(path_to_model_or_tensors, value)

return header

Expand Down
1 change: 1 addition & 0 deletions tests/quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from compressed_tensors.quantization.lifecycle import apply_quantization_config
from compressed_tensors.quantization.quant_config import (
QuantizationConfig,
Expand Down
25 changes: 15 additions & 10 deletions tests/test_utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest
import torch
from compressed_tensors import save_compressed
from compressed_tensors import load_compressed, save_compressed
from compressed_tensors.config import BitmaskConfig


Expand Down Expand Up @@ -44,15 +44,7 @@ def test_save_compressed_sparse(tmp_path, tensors_and_config_sparse):

config_json = save_compressed(
tensors,
compression_config=BitmaskConfig(
format=expected_config_json["compression_config"]["format"],
global_sparsity=expected_config_json["compression_config"][
"global_sparsity"
],
sparsity_structure=expected_config_json["compression_config"][
"sparsity_structure"
],
),
compression_config=BitmaskConfig(**expected_config_json["compression_config"]),
save_path=tmp_path / "model.safetensors",
)
assert (tmp_path / "model.safetensors").exists()
Expand All @@ -77,3 +69,16 @@ def test_save_compressed_empty():

with pytest.raises(Exception):
save_compressed(None, "")


def test_load_compressed_sparse(tmp_path, tensors_and_config_sparse):
tensors, expected_config_json = tensors_and_config_sparse
compression_config = BitmaskConfig(**expected_config_json["compression_config"])
save_compressed(
tensors,
compression_config=compression_config,
save_path=tmp_path / "model.safetensors",
)
loaded_tensors = load_compressed(tmp_path / "model.safetensors", compression_config)
for key in tensors:
assert torch.allclose(tensors[key], loaded_tensors[key])

0 comments on commit e394eb2

Please sign in to comment.