Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix smoothquant ignore, Fix typing, Add glm mappings #1015

Merged
merged 4 commits into from
Jan 10, 2025

Conversation

kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Dec 27, 2024

Purpose

  • Fix regex targets not being ignored
  • Fix pydantic type checking to allow lists to be used instead of tuples
  • Add ChatGLM mappings (which are the same as the bloom mappings)

Issues

Testing

glm.py
import requests
from PIL import Image
from io import BytesIO

from transformers import AutoProcessor
from llmcompressor.transformers import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.modifiers.smoothquant.utils import BLOOM_SMOOTHQUANT_MAPPINGS
from datasets import load_dataset
from llmcompressor.transformers.tracing import ChatGLMForConditionalGeneration

from llmcompressor.transformers.utils.data_collator import glm_data_collator

MODEL_ID = "THUDM/glm-4v-9b"
model = ChatGLMForConditionalGeneration.from_pretrained(
    MODEL_ID, device_map="auto", torch_dtype="auto", trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)

NUM_CALIBRATION_SAMPLES = 1 #512
MAX_SEQUENCE_LENGTH = 2048

ds = load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V", split=f"train[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)

def preprocess(example):
    url_part = "/".join(example["image"].split("/")[1:])
    url = f"http://images.cocodataset.org/{url_part}"
    response = requests.get(url)
    response.raise_for_status()
    image = Image.open(BytesIO(response.content)).convert('RGB')

    return processor.apply_chat_template(
        [
            {
                "role": "user",
                "image": image,
                "content": example["conversations"][0]["value"],
            }
        ],
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True,
    )


ds = ds.map(preprocess, remove_columns=ds.column_names)

# Configure the quantization algorithms
recipe = [
    SmoothQuantModifier(
        smoothing_strength=0.8,
        mappings=[
            [["re:.*query_key_value"], "re:.*input_layernorm"],
            [["re:.*dense_h_to_4h"], "re:.*post_attention_layernorm"],
        ],
        ignore=["transformer.output_layer", "re:transformer.vision.*"]
    ),
    #GPTQModifier(
    #    targets="Linear",
    #    scheme="W8A8",
    #    sequential_targets=["GLMBlock"],
    #    ignore=["transformer.output_layer", "re:transformer.vision.*"],
    #),
]

# Apply quantization
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    trust_remote_code_model=True,
    data_collator=glm_data_collator,
)

Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

@kylesayrs kylesayrs requested a review from rahul-tuli January 1, 2025 19:27
Copy link
Collaborator

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx!

@kylesayrs kylesayrs self-assigned this Jan 4, 2025
@dsikka dsikka merged commit 4d06685 into main Jan 10, 2025
6 of 7 checks passed
@dsikka dsikka deleted the kylesayrs/smoothquant-ignore-glm branch January 10, 2025 20:27
kylesayrs added a commit that referenced this pull request Jan 15, 2025
## Purpose ##
* Fix regex targets not being ignored
* Fix pydantic type checking to allow lists to be used instead of tuples
* Add ChatGLM mappings (which are the same as the bloom mappings)

## Issues ##
* Fixes #105
* Fixes (partially) #886
* Related to #1003

## Testing ##
<details><summary>glm.py</summary>

```python3
import requests
from PIL import Image
from io import BytesIO

from transformers import AutoProcessor
from llmcompressor.transformers import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.modifiers.smoothquant.utils import BLOOM_SMOOTHQUANT_MAPPINGS
from datasets import load_dataset
from llmcompressor.transformers.tracing import ChatGLMForConditionalGeneration

from llmcompressor.transformers.utils.data_collator import glm_data_collator

MODEL_ID = "THUDM/glm-4v-9b"
model = ChatGLMForConditionalGeneration.from_pretrained(
    MODEL_ID, device_map="auto", torch_dtype="auto", trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)

NUM_CALIBRATION_SAMPLES = 1 #512
MAX_SEQUENCE_LENGTH = 2048

ds = load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V", split=f"train[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)

def preprocess(example):
    url_part = "/".join(example["image"].split("/")[1:])
    url = f"http://images.cocodataset.org/{url_part}"
    response = requests.get(url)
    response.raise_for_status()
    image = Image.open(BytesIO(response.content)).convert('RGB')

    return processor.apply_chat_template(
        [
            {
                "role": "user",
                "image": image,
                "content": example["conversations"][0]["value"],
            }
        ],
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True,
    )

ds = ds.map(preprocess, remove_columns=ds.column_names)

# Configure the quantization algorithms
recipe = [
    SmoothQuantModifier(
        smoothing_strength=0.8,
        mappings=[
            [["re:.*query_key_value"], "re:.*input_layernorm"],
            [["re:.*dense_h_to_4h"], "re:.*post_attention_layernorm"],
        ],
        ignore=["transformer.output_layer", "re:transformer.vision.*"]
    ),
    #GPTQModifier(
    #    targets="Linear",
    #    scheme="W8A8",
    #    sequential_targets=["GLMBlock"],
    #    ignore=["transformer.output_layer", "re:transformer.vision.*"],
    #),
]

# Apply quantization
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    trust_remote_code_model=True,
    data_collator=glm_data_collator,
)
```
</details>

Signed-off-by: Kyle Sayers <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Yaml parsing fails with a custom mapping provided to SmoothQuantModifier recipe
4 participants