Skip to content

Commit b04cf98

Browse files
codreview updates
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 6faebcb commit b04cf98

File tree

2 files changed

+35
-37
lines changed

2 files changed

+35
-37
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ Some of the exciting new features include:
3333
* **DeepSeekV3-style Block Quantization Support**: This allows for more efficient compression of large language models without needing a calibration dataset. Quantize a Qwen3 model to [W8A8](examples/quantization_w8a8_fp8/fp8_block_example.py).
3434
* **Llama4 Quantization Support**: Quantize a Llama4 model to [W4A16](examples/multimodal_vision/llama4_example.py) or [NVFP4](examples/quantization_w4a4_fp4/llama4_example.py). The checkpoint produced can seamlessly run in vLLM.
3535
* **FP4 Quantization - now with MoE and non-uniform support:** Quantize weights and activations to FP4 and seamlessly run the compressed model in vLLM. Model weights and activations are quantized following the NVFP4 [configuration](https://github.com/neuralmagic/compressed-tensors/blob/f5dbfc336b9c9c361b9fe7ae085d5cb0673e56eb/src/compressed_tensors/quantization/quant_scheme.py#L104). See examples of [fp4 activation support](examples/quantization_w4a4_fp4/llama3_example.py), [MoE support](examples/quantization_w4a4_fp4/qwen_30b_a3b.py), and [Non-uniform quantization support](examples/quantization_non_uniform) where some layers are selectively quantized to fp8 for better recovery. You can also mix other quantization schemes, such as int8 and int4.
36-
* **Large Model Support with Sequential Onloading**: As of llm-compressor>=0.6.0, you can now quantize very large language models on a single GPU. Models are broken into disjoint layers which are then onloaded to the GPU one layer at a time. For more information on sequential onloading, see [Big Modeling with Sequential Onloading](examples/big_models_with_sequential_onloading/README.md) as well as the [DeepSeek-R1 Example](examples/quantizing_moe/deepseek_r1_example.py).
37-
* **Axolotl Sparse Finetuning Integration:** Seamlessly finetune sparse LLMs with our Axolotl integration. Learn how to create [fast sparse open-source models with Axolotl and LLM Compressor](https://developers.redhat.com/articles/2025/06/17/axolotl-meets-llm-compressor-fast-sparse-open). See also the [Axolotl integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#llmcompressor).
3836

3937
### Supported Formats
4038
* Activation Quantization: W8A8 (int8 and fp8)

examples/quantization_non_uniform/quantization_multiple_modifiers.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
23
from datasets import load_dataset
34
from transformers import AutoModelForCausalLM, AutoTokenizer
45

@@ -19,8 +20,6 @@ def parse_args():
1920
return parser.parse_args()
2021

2122

22-
args = parse_args()
23-
2423
# Select model and load it.
2524
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
2625
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
@@ -35,10 +34,6 @@ def parse_args():
3534
NUM_CALIBRATION_SAMPLES = 512
3635
MAX_SEQUENCE_LENGTH = 2048
3736

38-
# Load dataset and preprocess.
39-
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
40-
ds = ds.shuffle(seed=42)
41-
4237

4338
def preprocess(example):
4439
return {
@@ -49,9 +44,6 @@ def preprocess(example):
4944
}
5045

5146

52-
ds = ds.map(preprocess)
53-
54-
5547
# Tokenize inputs.
5648
def tokenize(sample):
5749
return tokenizer(
@@ -63,8 +55,6 @@ def tokenize(sample):
6355
)
6456

6557

66-
ds = ds.map(tokenize, remove_columns=ds.column_names)
67-
6858
# Configure the quantization algorithm to run.
6959
# * quantize self_attn layers to W8A8 with GPTQ
7060
# * quantize mlp layers to W4A16 with AWQ
@@ -87,27 +77,37 @@ def tokenize(sample):
8777
),
8878
]
8979

90-
# Apply algorithms.
91-
oneshot(
92-
model=model,
93-
dataset=ds,
94-
recipe=recipe,
95-
max_seq_length=MAX_SEQUENCE_LENGTH,
96-
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
97-
pipeline="independent" if args.independent else "sequential",
98-
)
99-
100-
# Confirm generations of the quantized model look sane.
101-
print("\n\n")
102-
print("========== SAMPLE GENERATION ==============")
103-
dispatch_for_generation(model)
104-
sample = tokenizer("Hello my name is", return_tensors="pt")
105-
sample = {key: value.to(model.device) for key, value in sample.items()}
106-
output = model.generate(**sample, max_new_tokens=100)
107-
print(tokenizer.decode(output[0]))
108-
print("==========================================\n\n")
109-
110-
# Save to disk compressed.
111-
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-gptq-w8a8-self_attn-awq-w4a16-mlp"
112-
model.save_pretrained(SAVE_DIR, save_compressed=True)
113-
tokenizer.save_pretrained(SAVE_DIR)
80+
if __name__ == "__main__":
81+
args = parse_args()
82+
# Load dataset and preprocess.
83+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
84+
ds = ds.shuffle(seed=42)
85+
ds = ds.map(preprocess)
86+
ds = ds.map(tokenize, remove_columns=ds.column_names)
87+
88+
# Apply algorithms.
89+
oneshot(
90+
model=model,
91+
dataset=ds,
92+
recipe=recipe,
93+
max_seq_length=MAX_SEQUENCE_LENGTH,
94+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
95+
pipeline="independent" if args.independent else "sequential",
96+
)
97+
98+
# Confirm generations of the quantized model look sane.
99+
print("\n\n")
100+
print("========== SAMPLE GENERATION ==============")
101+
dispatch_for_generation(model)
102+
sample = tokenizer("Hello my name is", return_tensors="pt")
103+
sample = {key: value.to(model.device) for key, value in sample.items()}
104+
output = model.generate(**sample, max_new_tokens=100)
105+
print(tokenizer.decode(output[0]))
106+
print("==========================================\n\n")
107+
108+
# Save to disk compressed.
109+
SAVE_DIR = (
110+
model_id.rstrip("/").split("/")[-1] + "-gptq-w8a8-self_attn-awq-w4a16-mlp"
111+
)
112+
model.save_pretrained(SAVE_DIR, save_compressed=True)
113+
tokenizer.save_pretrained(SAVE_DIR)

0 commit comments

Comments
 (0)