Skip to content

Commit

Permalink
Fix smoothquant ignore, Fix typing, Add glm mappings (#1015)
Browse files Browse the repository at this point in the history
## Purpose ##
* Fix regex targets not being ignored
* Fix pydantic type checking to allow lists to be used instead of tuples
* Add ChatGLM mappings (which are the same as the bloom mappings)

## Issues ##
* Fixes #105
* Fixes (partially) #886
* Related to #1003

## Testing ##
<details><summary>glm.py</summary>

```python3
import requests
from PIL import Image
from io import BytesIO

from transformers import AutoProcessor
from llmcompressor.transformers import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.modifiers.smoothquant.utils import BLOOM_SMOOTHQUANT_MAPPINGS
from datasets import load_dataset
from llmcompressor.transformers.tracing import ChatGLMForConditionalGeneration

from llmcompressor.transformers.utils.data_collator import glm_data_collator

MODEL_ID = "THUDM/glm-4v-9b"
model = ChatGLMForConditionalGeneration.from_pretrained(
    MODEL_ID, device_map="auto", torch_dtype="auto", trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)

NUM_CALIBRATION_SAMPLES = 1 #512
MAX_SEQUENCE_LENGTH = 2048

ds = load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V", split=f"train[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)

def preprocess(example):
    url_part = "/".join(example["image"].split("/")[1:])
    url = f"http://images.cocodataset.org/{url_part}"
    response = requests.get(url)
    response.raise_for_status()
    image = Image.open(BytesIO(response.content)).convert('RGB')

    return processor.apply_chat_template(
        [
            {
                "role": "user",
                "image": image,
                "content": example["conversations"][0]["value"],
            }
        ],
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True,
    )

ds = ds.map(preprocess, remove_columns=ds.column_names)

# Configure the quantization algorithms
recipe = [
    SmoothQuantModifier(
        smoothing_strength=0.8,
        mappings=[
            [["re:.*query_key_value"], "re:.*input_layernorm"],
            [["re:.*dense_h_to_4h"], "re:.*post_attention_layernorm"],
        ],
        ignore=["transformer.output_layer", "re:transformer.vision.*"]
    ),
    #GPTQModifier(
    #    targets="Linear",
    #    scheme="W8A8",
    #    sequential_targets=["GLMBlock"],
    #    ignore=["transformer.output_layer", "re:transformer.vision.*"],
    #),
]

# Apply quantization
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    trust_remote_code_model=True,
    data_collator=glm_data_collator,
)
```
</details>

Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Jan 15, 2025
1 parent 58800af commit 4087d9d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
12 changes: 8 additions & 4 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from compressed_tensors.utils.offload import is_module_offloaded
Expand All @@ -14,7 +14,11 @@
)
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.pytorch.module import get_layers, get_matching_layer
from llmcompressor.utils.pytorch.module import (
get_layers,
get_matching_layer,
match_targets,
)

MINIMUM_SMOOTHING_SCALE = 1e-5

Expand Down Expand Up @@ -95,7 +99,7 @@ class SmoothQuantModifier(Modifier):
"""

smoothing_strength: float = 0.5
mappings: Optional[List[Tuple]] = None
mappings: Optional[List[Union[Tuple, List]]] = None
ignore: Optional[List[str]] = None
num_calibration_steps: Optional[int] = None
calibration_function: Optional[Callable] = None
Expand Down Expand Up @@ -176,7 +180,7 @@ def _resolve_mappings(self, model: Module) -> List:
for to_balance, to_smooth in self.mappings:
to_smooth_layers = get_layers(to_smooth, model)
for layer_name, smooth_layer in to_smooth_layers.items():
if layer_name not in self.ignore:
if not match_targets(layer_name, self.ignore)[0]:
balance_layers = []
for balance_suffix in to_balance:
# find the submodule that matches the activation layer
Expand Down
5 changes: 3 additions & 2 deletions src/llmcompressor/modifiers/smoothquant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
smooth_layers="re:.*post_attention_layernorm",
),
]
MIXTRAL_MAPPINGS: List[LayerMap] = [
MIXTRAL_SMOOTHQUANT_MAPPINGS: List[LayerMap] = [
LayerMap(
balance_layers=["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
smooth_layers="re:.*input_layernorm",
Expand All @@ -49,10 +49,11 @@
# Add more mappings here
MAPPINGS_REGISTRY: Dict[str, List[LayerMap]] = {
"LlamaForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"MixtralForCausalLM": MIXTRAL_MAPPINGS,
"MixtralForCausalLM": MIXTRAL_SMOOTHQUANT_MAPPINGS,
"MistralForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"Qwen2ForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"BloomForCausalLM": BLOOM_SMOOTHQUANT_MAPPINGS,
"ChatGLMForConditionalGeneration": BLOOM_SMOOTHQUANT_MAPPINGS,
}


Expand Down

0 comments on commit 4087d9d

Please sign in to comment.