-
Notifications
You must be signed in to change notification settings - Fork 10
[Transforms] Apply, serialize, deserialize #276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: transform_arg_support
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -41,6 +41,9 @@ | |||||
iter_named_leaf_modules, | ||||||
iter_named_quantizable_modules, | ||||||
) | ||||||
from compressed_tensors.transforms import Transforms | ||||||
from compressed_tensors.transforms.transform_config import TransformationConfig | ||||||
from compressed_tensors.transforms.transform_data import TransformData | ||||||
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module | ||||||
from compressed_tensors.utils.offload import update_parameter_data | ||||||
from compressed_tensors.utils.safetensors_load import get_safetensors_folder | ||||||
|
@@ -49,20 +52,50 @@ | |||||
|
||||||
__all__ = [ | ||||||
"load_pretrained_quantization", | ||||||
"load_transforms", | ||||||
"apply_quantization_config", | ||||||
"apply_quantization_status", | ||||||
"find_name_or_class_matches", | ||||||
"expand_target_names", | ||||||
"is_target", | ||||||
"process_transforms_config", | ||||||
] | ||||||
|
||||||
from compressed_tensors.quantization.utils.helpers import is_module_quantized | ||||||
from compressed_tensors.utils.safetensors_load import get_quantization_state_dict | ||||||
from compressed_tensors.utils.safetensors_load import ( | ||||||
get_quantization_state_dict, | ||||||
get_weight_mappings, | ||||||
) | ||||||
from safetensors import safe_open | ||||||
|
||||||
|
||||||
_LOGGER = logging.getLogger(__name__) | ||||||
|
||||||
|
||||||
def load_transforms(model: Module, model_name_or_path: str): | ||||||
model_path = get_safetensors_folder(model_name_or_path) | ||||||
weight_mappings = get_weight_mappings(model_path) | ||||||
|
||||||
state_dict = {} | ||||||
for weight_name, safe_path in weight_mappings.items(): | ||||||
if "transform" in weight_name: | ||||||
with safe_open(safe_path, framework="pt", device="cpu") as f: | ||||||
state_dict[weight_name] = f.get_tensor(weight_name) | ||||||
|
||||||
for name, submodule in iter_named_leaf_modules(model): | ||||||
transform_data = getattr(submodule, "transform_data", None) | ||||||
|
||||||
if transform_data: | ||||||
for transform_name, transform_values in transform_data.data.items(): | ||||||
full_name = f"{name}.{transform_name}" | ||||||
transform_data = state_dict.get(full_name, None) | ||||||
transform = transform_values.get("transform") | ||||||
transform.register_to_module(name=transform_name, module=submodule) | ||||||
transform.update_transform( | ||||||
module=submodule, data=transform_data, name=transform_name | ||||||
) | ||||||
|
||||||
|
||||||
def load_pretrained_quantization(model: Module, model_name_or_path: str): | ||||||
""" | ||||||
Loads the quantization parameters (scale and zero point) from model_name_or_path to | ||||||
|
@@ -104,8 +137,94 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str): | |||||
) | ||||||
|
||||||
|
||||||
def process_transforms_config( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I missed that this is actually modifying |
||||||
transforms_config: TransformationConfig, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another example where I think consistency in naming convention will help users. Something like this?
Suggested change
|
||||||
model: torch.nn.Module, | ||||||
quantization_status: Optional[QuantizationStatus] = QuantizationStatus.INITIALIZED, | ||||||
): | ||||||
for _, group in transforms_config.transform_groups.items(): | ||||||
# Each group/scheme targets one type of transform | ||||||
transform_type = group.transform_type | ||||||
transform_creation_args = group.transform_creation_args | ||||||
|
||||||
# Need a better name - too many groups | ||||||
for transform_arg in group.groups: | ||||||
module_targets = transform_arg.module_targets | ||||||
|
||||||
for name, submodule in model.named_modules(): | ||||||
if len(transform_arg.ignore) > 0: | ||||||
if matches := find_name_or_class_matches( | ||||||
name, submodule, transform_arg.ignore | ||||||
): | ||||||
for match in matches: | ||||||
print("ignoring", match, name) | ||||||
continue # layer matches ignore list, continue | ||||||
|
||||||
targets = find_name_or_class_matches( | ||||||
name, submodule, transform_arg.targets | ||||||
) | ||||||
|
||||||
if targets: | ||||||
# Every layer which matches gets its own transform | ||||||
# Same transform type and args are used however | ||||||
|
||||||
# attach the transform to the submodule | ||||||
# because we can have more than one transform, need to attach some | ||||||
# form of key to fetch | ||||||
# OR we store it in the dictionary, handle cpu-offloading separatly | ||||||
|
||||||
if hasattr(submodule, "transform_data"): | ||||||
idx = submodule.transform_data.idx + 1 | ||||||
else: | ||||||
idx = 0 | ||||||
# only support weight parameters for now, assume one value in | ||||||
# module targets | ||||||
transform_name = f"{module_targets[0]}_transform_{idx}" | ||||||
|
||||||
# create an empty tensor OR create a new transform | ||||||
dtype = getattr(submodule, module_targets[0]).dtype | ||||||
if quantization_status in [ | ||||||
QuantizationStatus.COMPRESSED, | ||||||
QuantizationStatus.FROZEN, | ||||||
]: | ||||||
transform = Transforms.load_from_registry( | ||||||
transform_type, | ||||||
dtype=dtype, | ||||||
empty=True, | ||||||
rahul-tuli marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
**transform_creation_args, | ||||||
) | ||||||
else: | ||||||
transform = Transforms.load_from_registry( | ||||||
transform_type, | ||||||
dtype=dtype, | ||||||
**transform_creation_args, | ||||||
) | ||||||
transform.register_to_module( | ||||||
name=transform_name, module=submodule | ||||||
) | ||||||
|
||||||
# add relevant transform data to the submodule as well | ||||||
data = { | ||||||
transform_name: { | ||||||
"transform": transform, | ||||||
"call_args": transform_arg.call_args, | ||||||
} | ||||||
} | ||||||
|
||||||
if hasattr(submodule, "transform_data"): | ||||||
submodule.transform_data.data.update(data) | ||||||
submodule.transform_data.idx = idx | ||||||
else: | ||||||
transform_data = TransformData(data=OrderedDict(data)) | ||||||
submodule.transform_data = transform_data | ||||||
return model | ||||||
|
||||||
|
||||||
def apply_quantization_config( | ||||||
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False | ||||||
model: Module, | ||||||
config: Union[QuantizationConfig, None], | ||||||
run_compressed: bool = False, | ||||||
transforms_config=None, | ||||||
) -> OrderedDict: | ||||||
""" | ||||||
Initializes the model for quantization in-place based on the given config. | ||||||
|
@@ -184,6 +303,12 @@ def apply_quantization_config( | |||||
f"{set(config.ignore) - set(ignored_submodules)}" | ||||||
) | ||||||
|
||||||
if transforms_config: | ||||||
model.transforms_config = transforms_config | ||||||
model = process_transforms_config( | ||||||
transforms_config, model, config.quantization_status | ||||||
) | ||||||
|
||||||
# apply current quantization status across all targeted layers | ||||||
apply_quantization_status(model, config.quantization_status) | ||||||
return names_to_scheme | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, it seems like this information is essentially duplicating the information of the scheme/args? Why not attach the scheme/args to the module, rather than creating a new abstraction?