Skip to content

[Transforms] Apply, serialize, deserialize #276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: transform_arg_support
Choose a base branch
from

Conversation

dsikka
Copy link
Collaborator

@dsikka dsikka commented Mar 11, 2025

Summary

  • Add support to apply transforms to models during quantization, specifically targeting layer weights for now

Process

  • Includes processing the provided transforms config, generating a transform_data object to attach to each layer indicating runtime args, and attaching each of the individual transforms to the model layers.
  • Specifically, as multiple transforms can be applied to a particular parameter, to differentiate between them, an index value is attached to the parameter name, resulting in transform parameters having the following name convention:{parameter_type}_transform_{idx} e.g. "weight_transform_0" or "input_activation_transform_0" to try and match the convention of weights, input_activations, and output_activations.
  • TransformData includes a dictionary of all the transforms-relevant runtime data, and is attached to the layer as "transform_data" . The keys of the dictionary correspond to the transform parameter names. Note: in the future, if we decide to add another layer to infer runtime/call args on the fly, we can potentially remove TransformData but that is an optimization we can talk about in the future.

Apply

  • Utils have also been added to apply the transforms to the weights, when applying QDQ. This functionality will be further extended and likely removed from within the forward method as support for activation transforms is added . This is currently being handled by apply_transforms_to_parameter and apply_inverse_transforms_to_parameter which sandwich QDQ

Serialize/Deserialize

  • Serialization currently does not compress the transforms and saves them to disk uncompressed (we will either fuse these in or compress them in a follow-up). The quantization_config is also extended with a transforms_config in config.json
  • For deserialization, an additional load_transforms function has been added to load the parameters from disk and add the relevant runtime information. However, this requires the above transformers PR indicated above

Examples:

# Apply a transform config to a model

from compressed_tensors.quantization import process_transforms_config

targets = ["Linear"]
module_targets = [ModuleTarget.WEIGHT]
linear_layer_args = TransformationArgs(
    targets=targets, module_targets=module_targets
)

scheme = TransformationScheme(
    transform_type="hadamard",
    groups=[linear_layer_args],
    transform_creation_args={"size": 512},
)
config = TransformationConfig(
    transform_groups={
        "transform_0": scheme,
    }
)

model = torch.nn.Linear(512, 512)

model = process_transforms_config(model=model, transforms_config=config)

# Once processed, the model will have the following parameters:
>> model.weight_transform_0
Parameter containing:
        tensor([[ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
                [ 1., -1.,  1.,  ..., -1.,  1., -1.],
                [ 1.,  1., -1.,  ...,  1., -1., -1.],
                ...,
                [ 1., -1.,  1.,  ..., -1.,  1., -1.],
                [ 1.,  1., -1.,  ...,  1., -1., -1.],
                [ 1., -1., -1.,  ..., -1., -1.,  1.]], dtype=torch.bfloat16)

>> model.transform_data
TransformData(data={'weight_transform_0': 
      { 
          'call_args': defaultdict(),
          'type': "hadamard"
      }
  }

# Apply the transform. 
model(some_dummy_data)

Testing

Dependencies:

@dsikka dsikka changed the title add apply, serialize, deserialize support [Transforms] Apply, serialize, deserialize Mar 11, 2025
@dsikka dsikka marked this pull request as ready for review March 11, 2025 21:53
@@ -104,8 +132,92 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
)


def process_transforms_config(
transforms_config: TransformationConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another example where I think consistency in naming convention will help users. Something like this?

Suggested change
transforms_config: TransformationConfig,
transform_config: TransformConfig,

@dsikka dsikka force-pushed the transform_arg_support branch from 358075b to fadaaf8 Compare March 22, 2025 21:02
@dsikka dsikka force-pushed the transform_apply_support branch from e7cdea4 to 063d62d Compare March 22, 2025 21:06
@dsikka dsikka force-pushed the transform_arg_support branch from fadaaf8 to 86e805d Compare March 23, 2025 02:28
@dsikka dsikka force-pushed the transform_apply_support branch from 28c7af5 to 579337d Compare March 23, 2025 02:35
state_dict[weight_name] = f.get_tensor(weight_name)

for name, submodule in iter_named_leaf_modules(model):
transform_data = getattr(submodule, "transform_data", None)
Copy link
Contributor

@kylesayrs kylesayrs Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TransformData includes a dictionary of all the transforms-relevant runtime data, and is attached to the layer as "transform_data"

To me, it seems like this information is essentially duplicating the information of the scheme/args? Why not attach the scheme/args to the module, rather than creating a new abstraction?

@@ -104,8 +137,94 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
)


def process_transforms_config(
Copy link
Contributor

@brian-dellabetta brian-dellabetta Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed that this is actually modifying model registering transforms to it. The name process_transforms_config is pretty innocuous for something like this, given it's modifying model significantly. Consider renaming to add_transforms_to_model or something more explicit?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants