diff --git a/benchmarks/_models/eval_hf_models.py b/benchmarks/_models/eval_hf_models.py new file mode 100644 index 0000000000..2bca1fe5f0 --- /dev/null +++ b/benchmarks/_models/eval_hf_models.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import itertools +import subprocess + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +from benchmarks.microbenchmarks.utils import string_to_config +from torchao.quantization import * # noqa: F401, F403 +from torchao.quantization.utils import _lm_eval_available + + +def quantize_model_and_save(model_id, quant_config, output_dir="results"): + """Quantize the model and save it to the output directory.""" + print("Quantizing model with config: ", quant_config) + if quant_config is None: + quantization_config = None + else: + quantization_config = TorchAoConfig(quant_type=quant_config) + quantized_model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.bfloat16, + quantization_config=quantization_config, + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + quantized_model.save_pretrained(output_dir, safe_serialization=False) + tokenizer.save_pretrained(output_dir, safe_serialization=False) + return quantized_model, tokenizer + + +def run_lm_eval(model_dir, tasks_list=["hellaswag"], device="cuda:0", batch_size=8): + """Run the lm_eval command using subprocess.""" + tasks_str = ",".join(tasks_list) + command = [ + "lm_eval", + "--model", + "hf", + "--model_args", + f"pretrained={model_dir}", + "--tasks", + f"{tasks_str}", + "--device", + f"{device}", + "--batch_size", + f"{batch_size}", + ] + subprocess.run(command, check=True) + + +def get_model_size_in_bytes(model, ignore_embeddings=False): + """ + Returns the model size in bytes. The option to ignore embeddings + is useful for models with disproportionately large embeddings compared + to other model parameters that get quantized/sparsified. + """ + + def flat_size(tensor): + if hasattr(tensor, "__tensor_flatten__"): + size = 0 + # 0th element is a list of attributes that + # hold tensors + for attr_name in tensor.__tensor_flatten__()[0]: + sub_tensor = getattr(tensor, attr_name) + size += flat_size(sub_tensor) + return size + else: + return tensor.numel() * tensor.element_size() + + model_size = 0 + for _, child in model.named_children(): + if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings): + for p in itertools.chain( + child.parameters(recurse=False), child.buffers(recurse=False) + ): + model_size += flat_size(p) + model_size += get_model_size_in_bytes(child, ignore_embeddings) + return model_size + + +def run( + model_id, + quantization, + tasks, + device, + batch_size, + model_output_dir, +): + print(f"Running model {model_id} with quantization {quantization}") + model_name = model_id.split("/")[-1] + model_output_dir = f"quantized_model/{model_name}-{quantization}" + quant_config = string_to_config(quantization, None) + quantized_model, tokenizer = quantize_model_and_save( + model_id, quant_config=quant_config, output_dir=model_output_dir + ) + print("Compiling model ....") + quantized_model = torch.compile( + quantized_model, + mode="reduce-overhead", + fullgraph=True, + ) + run_lm_eval( + model_output_dir, tasks_list=tasks, device=device, batch_size=batch_size + ) + model_size = get_model_size_in_bytes(quantized_model, ignore_embeddings=True) / 1e9 + print(f"Model size: {model_size:.2f} GB") + + +if __name__ == "__main__": + if not _lm_eval_available: + print( + "lm_eval is required to run this script. Please install it using pip install lm-eval." + ) + exit(0) + + # Set up argument parser + parser = argparse.ArgumentParser( + description="Quantize a model and evaluate its throughput." + ) + parser.add_argument( + "--model_id", + type=str, + default="meta-llama/Llama-3.1-8B", + help="The model ID to use.", + ) + parser.add_argument( + "--quantization", + type=str, + default=None, + help="The quantization method to use.", + ) + parser.add_argument( + "--tasks", + nargs="+", + type=str, + default=["wikitext"], + help="List of lm-eluther tasks to evaluate usage: --tasks task1 task2", + ) + parser.add_argument( + "--device", type=str, default="cuda:0", help="Device to run the model on." + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size for lm_eval." + ) + parser.add_argument( + "--prompt", + type=str, + default="What are we having for dinner?", + help="Prompt for model throughput evaluation.", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=10, + help="Max new tokens to generate for throughput evaluation.", + ) + parser.add_argument( + "--num_runs", + type=int, + default=5, + help="Number of runs to average over for throughput evaluation.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="quantized_models", + help="Output directory for quantized model.", + ) + args = parser.parse_args() + + # Use parsed arguments + run( + model_id=args.model_id, + quantization=args.quantization, + tasks=args.tasks, + device=args.device, + batch_size=args.batch_size, + model_output_dir=args.output_dir, + ) diff --git a/benchmarks/_models/eval_hf_models.sh b/benchmarks/_models/eval_hf_models.sh new file mode 100644 index 0000000000..14feef7505 --- /dev/null +++ b/benchmarks/_models/eval_hf_models.sh @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +# For llama3.1-8B +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-tensor --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8wo --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4wo-128 --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8wo --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8dq --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization gemlitewo-128 --tasks wikitext hellaswag + + +# For llama3.2-3B +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-row --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-tensor --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8wo --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4wo-128 --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8wo --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8dq --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization gemlitewo-128 --tasks wikitext hellaswag diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index f591ec3669..f6e450226b 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -18,6 +18,7 @@ Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, + GemliteUIntXWeightOnlyConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, @@ -287,6 +288,15 @@ def string_to_config( else: granularity = PerTensor() return Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + if "gemlitewo" in quantization: + group_size = int(quantization.split("-")[1]) + assert group_size in [ + 32, + 64, + 128, + 256, + ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" + return GemliteUIntXWeightOnlyConfig(group_size=group_size) return None diff --git a/torchao/_models/README.md b/torchao/_models/README.md index 074adf884c..3157844a3b 100644 --- a/torchao/_models/README.md +++ b/torchao/_models/README.md @@ -1,3 +1,42 @@ +# TODO: Add info for _models here + +## Eval on Llama 3.1 8B and Llama 3.2 3B + +We use lm-eval tasks for evaluating TorchAO Quantization APIs on HuggingFace models. The results are in the table below: + +| Model Name | Quantization Technique | Acc |Acc Norm| Word perplexity| Throughput (tokens/sec)| Model Size (GB) | +|---------------|------------------------|-------|--------|----------------|------------------------|-------------------| +| Llama 3.1 8B | None | 60.01 | 78.84 | 7.33 | 44.95 | 15.01 | +| Llama 3.1 8B | int4wo-128 | 58.10 | 77.06 | 8.25 | 33.95 | 4.76 | +| Llama 3.1 8B | int8wo | 59.92 | 78.95 | 7.34 | 28.65 | 8.04 | +| Llama 3.1 8B | int8dq | 60.01 | 78.82 | 7.45 | 4.75 | 8.03 | +| Llama 3.1 8B | float8wo | 59.83 | 78.61 | 7.37 | 17.84 | 8.03 | +| Llama 3.1 8B | float8dq (PerRow) | 59.86 | 78.57 | 7.41 | 10.96 | 8.04 | +| Llama 3.1 8B | float8dq (PerTensor) | 59.95 | 78.66 | 7.42 | 10.63 | 8.03 | +| Llama 3.1 8B | gemlite (gp=128) | 58.48 | 77.34 | 8.07 | 14.42 | 4.76 | + +| Model Name | Quantization Technique | Acc |Acc Norm| Word perplexity| Throughput (tokens/sec)| Model Size (GB) | +|---------------|------------------------|-------|--------|----------------|------------------------|-------------------| +| Llama 3.2 3B | None | 55.27 | 73.70 | 9.26 | 53.08 | 6.43 | +| Llama 3.2 3B | int4wo-128 | 53.13 | 71.31 | 10.36 | 36.36 | 2.29 | +| Llama 3.2 3B | int8wo | 55.15 | 73.44 | 9.28 | 36.30 | 3.61 | +| Llama 3.2 3B | int8dq | 55.00 | 73.29 | 9.43 | 5.45 | 3.61 | +| Llama 3.2 3B | float8wo | 55.18 | 73.58 | 9.31 | 28.95 | 3.61 | +| Llama 3.2 3B | float8dq (PerRow) | 55.18 | 73.37 | 9.33 | 12.56 | 3.61 | +| Llama 3.2 3B | float8dq (PerTensor) | 55.16 | 73.53 | 9.35 | 12.21 | 3.61 | +| Llama 3.2 3B | gemlite (gp=128) | 53.71 | 71.99 | 10.05 | 16.52 | 2.29 | + +To generate the above results run: +``` +sh benchmarks/_models/eval_hf_models.sh +``` + +To run lm-eval for a different hf-model with AO quantization technique, run: +``` +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext hellaswag +``` +Replace model id, quantization and tasks with your desired values Please refer to ([HuggingFace <-> TorchAO](https://huggingface.co/docs/transformers/main/en//quantization/torchao)) integration docs for more details about the supported quantization techniques. + ## SAM2 sam2 is a fork of https://github.com/facebookresearch/sam2 at commit c2ec8e14a185632b0a5d8b161928ceb50197eddc