Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pretrained Model Reload + SparseGPT Support #31

Merged
merged 3 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/compressed_tensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

SPARSITY_CONFIG_NAME = "sparsity_config"
QUANTIZATION_CONFIG_NAME = "sparseml_quantization_config"
1 change: 1 addition & 0 deletions src/compressed_tensors/compressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@

from .base import ModelCompressor
from .dense import DenseCompressor
from .helpers import infer_compressor_from_model_config
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
2 changes: 2 additions & 0 deletions src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from compressed_tensors.base import SPARSITY_CONFIG_NAME
from compressed_tensors.config import CompressionConfig
from compressed_tensors.registry import RegistryMixin
from compressed_tensors.utils import get_safetensors_folder
from torch import Tensor
from torch.nn import Module, Parameter
from tqdm import tqdm
Expand Down Expand Up @@ -62,6 +63,7 @@ def overwrite_weights(self, model_path: str, model: Module):
:param model_path: path to compressed weights
:param model: pytorch model to load decompressed weights into
"""
model_path = get_safetensors_folder(model_path)
dense_gen = self.decompress(model_path)
for name, data in tqdm(dense_gen, desc="Decompressing model"):
# loading the decompressed weights into the model
Expand Down
70 changes: 69 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import re
from collections import OrderedDict
from typing import Iterable, Optional
from typing import Dict, Iterable, Optional

from compressed_tensors.quantization.lifecycle.calibration import (
set_module_for_calibration,
Expand All @@ -28,14 +28,60 @@
QuantizationStatus,
)
from compressed_tensors.quantization.utils import iter_named_leaf_modules
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
from torch.nn import Module


__all__ = [
"load_pretrained_quantization",
"apply_quantization_config",
"apply_quantization_status",
]

from compressed_tensors.quantization.utils.helpers import is_module_quantized
from compressed_tensors.utils.safetensors_load import get_quantization_state_dict


def load_pretrained_quantization(model: Module, model_name_or_path: str):
"""
Loads the quantization parameters (scale and zero point) from model_name_or_path to
a model that has already been initialized with a quantization config

:param model: model to load pretrained quantization parameters to
:param model_name_or_path: Hugging Face stub or local folder containing a quantized
model, which is used to load quantization parameters
"""
model_path = get_safetensors_folder(model_name_or_path)
state_dict = get_quantization_state_dict(model_path)

for name, submodule in iter_named_leaf_modules(model):
if not is_module_quantized(submodule):
continue
if submodule.quantization_scheme.weights is not None:
base_name = "weight"
_load_quant_args_from_state_dict(
base_name=base_name,
module_name=name,
module=submodule,
state_dict=state_dict,
)
if submodule.quantization_scheme.input_activations is not None:
base_name = "input"
_load_quant_args_from_state_dict(
base_name=base_name,
module_name=name,
module=submodule,
state_dict=state_dict,
)
if submodule.quantization_scheme.output_activations is not None:
base_name = "output"
_load_quant_args_from_state_dict(
base_name=base_name,
module_name=name,
module=submodule,
state_dict=state_dict,
)


def apply_quantization_config(model: Module, config: QuantizationConfig):
"""
Expand Down Expand Up @@ -103,3 +149,25 @@ def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]:
elif target == value:
return target
return None


def _load_quant_args_from_state_dict(
base_name: str, module_name: str, module: Module, state_dict: Dict
):
"""
Loads scale and zero point from a state_dict into the specified module

:param base_name: quantization target, one of: weights, input_activations or
output_activations
:param module_name: pytorch module name to look up in state_dict
:module: pytorch module associated with module_name
:state_dict: state_dict to search for matching quantization parameters
"""
scale_name = f"{base_name}_scale"
zp_name = f"{base_name}_zero_point"
device = next(module.parameters()).device

scale = getattr(module, scale_name)
zp = getattr(module, zp_name)
scale.data = state_dict[f"{module_name}.{scale_name}"].to(device)
zp.data = state_dict[f"{module_name}.{zp_name}"].to(device)
5 changes: 5 additions & 0 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def wrapped_forward(self, *args, **kwargs):

if scheme.weights is not None:
# calibrate and (fake) quantize weights when applicable
unquantized_weight = self.weight.data.clone()
self.weight.data = _maybe_calibrate_or_quantize(
module, self.weight, "weight", scheme.weights
)
Expand All @@ -97,6 +98,10 @@ def wrapped_forward(self, *args, **kwargs):
module, output, "output", scheme.output_activations
)

# restore back to unquantized_value
if scheme.weights is not None:
self.weight.data = unquantized_weight

return output

# bind wrapped forward to module class so reference to `self` is correct
Expand Down
17 changes: 17 additions & 0 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from enum import Enum
from typing import Dict, List, Optional

from compressed_tensors.base import QUANTIZATION_CONFIG_NAME
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import (
calculate_compression_ratio,
Expand All @@ -24,6 +25,7 @@
)
from pydantic import BaseModel, Field
from torch.nn import Module
from transformers import AutoConfig


__all__ = [
Expand Down Expand Up @@ -98,6 +100,21 @@ class QuantizationConfig(BaseModel):
global_compression_ratio: Optional[float] = None
ignore: Optional[List[str]] = Field(default_factory=list)

@staticmethod
def from_model_config(model_name_or_path) -> "QuantizationConfig":
"""
Given a path to a model config, extract a quantization config if it exists

:param pretrained_model_name_or_path: path to model config on disk or HF hub
:return: instantiated QuantizationConfig if config contains a quant config
"""
config = AutoConfig.from_pretrained(model_name_or_path)
quantization_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
if quantization_config is None:
return None

return QuantizationConfig.parse_obj(quantization_config)

@staticmethod
def from_pretrained(model: Module) -> "QuantizationConfig":
"""
Expand Down
1 change: 0 additions & 1 deletion src/compressed_tensors/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,4 @@
# limitations under the License.
# flake8: noqa

from .helpers import *
from .safetensors_load import *
32 changes: 31 additions & 1 deletion src/compressed_tensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import struct
from typing import Dict, List, Optional

from safetensors import safe_open
from torch import Tensor
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file


Expand All @@ -28,6 +30,7 @@
"merge_names",
"get_weight_mappings",
"get_nested_weight_mappings",
"get_quantization_state_dict",
]


Expand All @@ -45,7 +48,7 @@ def get_safetensors_folder(
"""
if os.path.exists(pretrained_model_name_or_path):
# argument is a path to a local folder
return pretrained_model_name_or_path
return os.path.abspath(pretrained_model_name_or_path)

safetensors_path = cached_file(
pretrained_model_name_or_path,
Expand Down Expand Up @@ -194,3 +197,30 @@ def get_nested_weight_mappings(
nested_weight_mappings[dense_param][param_name] = weight_mappings[key]

return nested_weight_mappings


def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
weight_mappings = get_weight_mappings(model_path)
state_dict = {}
for weight_name, safe_path in weight_mappings.items():
if not _is_quantization_weight(weight_name):
continue
with safe_open(safe_path, framework="pt", device="cpu") as f:
state_dict[weight_name] = f.get_tensor(weight_name)

return state_dict


def _is_quantization_weight(name: str) -> bool:
"""
Checks is a parameter name is associated with a quantization parameter

:param name: parameter name to check
:return: True if parameter name is a quantization parameter, else False
"""
if name.endswith("_scale"):
return True
if name.endswith("zero_point"):
return True

return False
Loading