Skip to content

Commit

Permalink
simplify script
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Sep 27, 2023
1 parent 6d1ae0e commit 0c53c2f
Showing 1 changed file with 139 additions and 119 deletions.
258 changes: 139 additions & 119 deletions tests/benchmark/benchmark_gptq.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import argparse
import gc
import json
import os
import time

import numpy as np
import torch
from accelerate import init_empty_weights
from memory_tracker import MemoryTracker
from tqdm import tqdm
from transformers import (
Expand All @@ -16,11 +14,11 @@
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
GPTQConfig
)

from optimum.exporters import TasksManager
from optimum.gptq import load_quantized_model

from auto_gptq.utils import Perplexity

def get_parser():
parser = argparse.ArgumentParser()
Expand All @@ -45,13 +43,7 @@ def get_parser():
parser.add_argument(
"--model",
type=str,
help="Model to benchmark (in the non-quantized case), or reference architecture corresponding to the quantized model (GPTQ case)",
)
parser.add_argument(
"--gptq-model",
type=str,
default=None,
help="Path to a local GPTQ model.",
help="Model to benchmark",
)
parser.add_argument(
"--prompt-length",
Expand Down Expand Up @@ -90,6 +82,27 @@ def get_parser():
action="store_true",
help="Disable Exllama kernel, to rather use the AutoGPTQ CUDA (act-order case) or CUDA-old (no act-order case) kernels.",
)
parser.add_argument(
"--disable-exllamav2",
action="store_true",
help="Disable Exllama kernel, to rather use the AutoGPTQ CUDA (act-order case) or CUDA-old (no act-order case) kernels.",
)
parser.add_argument(
"--generate",
action="store_true",
help="Calculate the generate speed (prompt processing + token generation)",
)
parser.add_argument(
"--ppl",
action="store_true",
help="Calculate the perplexity on wikitext2 dataset",
)
parser.add_argument(
"--revision",
default=None,
help="Revision of the model to benchmark",
)

return parser


Expand Down Expand Up @@ -266,7 +279,7 @@ def benchmark_memory(
device = torch.device("cuda:0")
memory_tracker = MemoryTracker()

tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(args.model,revision=args.revision, use_fast=False)

if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
Expand All @@ -288,46 +301,14 @@ def benchmark_memory(
else:
is_decoder = False

act_order = None
bits = None
group_size = None
kernel = None
if args.gptq:
if not args.gptq_model:
raise ValueError("The argument --gptq-model needs to be provided when benchmarking GPTQ.")

with open(os.path.join(args.gptq_model, "quantization_config.json"), "r", encoding="utf-8") as f:
quantize_config_dict = json.load(f)

act_order = quantize_config_dict["desc_act"]
bits = quantize_config_dict["bits"]
group_size = quantize_config_dict["group_size"]

if not args.disable_exllama:
# Exllama kernel can handle both the act-order / no act-order cases.
kernel = "exllama"
elif act_order:
kernel = "autotogptq-cuda"
else:
kernel = "autogptq-cuda-old"

load_start = time.time_ns()
if args.gptq:
with init_empty_weights():
empty_model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.float16)
empty_model.tie_weights()
model = load_quantized_model(
empty_model,
save_folder=args.gptq_model,
state_dict_name="model.safetensors",
device_map="auto",
disable_exllama=args.disable_exllama,
)
quantization_config = GPTQConfig(bits=4, disable_exllama=args.disable_exllama, disable_exllamav2=args.disable_exllamav2)
model = autoclass.from_pretrained(args.model,revision=args.revision, quantization_config=quantization_config, torch_dtype=torch.float16, device_map="auto")
elif args.bitsandbytes:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="fp4", bnb_4bit_compute_dtype=torch.float16
)

model = autoclass.from_pretrained(
args.model, quantization_config=quantization_config, device_map="auto", torch_dtype=torch.float16
)
Expand All @@ -337,6 +318,27 @@ def benchmark_memory(
torch.cuda.synchronize()
load_end = time.time_ns()

act_order = None
bits = None
group_size = None
kernel = None

if args.gptq:
quantization_config_dict = model.config.quantization_config.to_dict()
act_order = quantization_config_dict["desc_act"]
bits = quantization_config_dict["bits"]
group_size = quantization_config_dict["group_size"]

if not args.disable_exllamav2:
kernel = "exllamav2"
elif not args.disable_exllama:
# Exllama kernel can handle both the act-order / no act-order cases.
kernel = "exllama"
elif act_order:
kernel = "autotogptq-cuda"
else:
kernel = "autogptq-cuda-old"

load_time = (load_end - load_start) * 1e-9
print(f"Model load time: {load_time:.1f} s")

Expand Down Expand Up @@ -364,82 +366,100 @@ def benchmark_memory(
file_name = file_name + "_noquant"
quantization = None

file_name = file_name + ".csv"
output_file = open(file_name, "w")
header = "quantization, act_order, bits, group_size, kernel, num_batches, batch_size, prompt_length, new_tokens, Load time (s), Per-token latency (ms), Throughput (tok/s), Max memory (MB)\n"
output_file.write(header)

latencies = {}
throughputs = {}
all_max_mem = {}
print(
"WARNING: The reported peak memory is only a rough estimate, and can NOT be precisely relied upon to estimate an OOM limit."
)

for batch_size in tqdm(batch_sizes):
for prompt_length in tqdm(prompt_lengths):
for new_token in tqdm(new_tokens):
print(f"---- Running: batch_size={batch_size}, prompt_length={prompt_length}, new_tokens={new_token}")

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

input_ids = torch.randint(1, model.config.vocab_size - 1, size=(batch_size, prompt_length)).to(device)
masks = torch.ones(batch_size, prompt_length, dtype=torch.int32).to(device)
if args.ppl:
output_file = open(file_name + "_perplexity.csv", "w")
header = "quantization, act_order, bits, group_size, kernel, perplexity\n"
output_file.write(header)
ppl = Perplexity(model, tokenizer)
ppl_value = np.mean(ppl.calculate_perplexity())
line = "{},{},{},{},{},{}\n".format(
quantization,
act_order,
bits,
group_size,
kernel,
f"{ppl_value:.2f}",
)
print(header)
print(line)
output_file.write(line)
output_file.close()

if args.generate:
output_file = open(file_name + ".csv", "w")
header = "quantization, act_order, bits, group_size, kernel, num_batches, batch_size, prompt_length, new_tokens, Load time (s), Per-token latency (ms), Throughput (tok/s), Max memory (MB)\n"
output_file.write(header)

latencies = {}
throughputs = {}
all_max_mem = {}
print(
"WARNING: The reported peak memory is only a rough estimate, and can NOT be precisely relied upon to estimate an OOM limit."
)

with torch.no_grad():
max_mem = benchmark_memory(
model,
input_ids,
masks,
args.num_batches,
is_decoder,
new_token,
tokenizer.pad_token_id,
memory_tracker=memory_tracker,
for batch_size in tqdm(batch_sizes):
for prompt_length in tqdm(prompt_lengths):
for new_token in tqdm(new_tokens):
print(f"---- Running: batch_size={batch_size}, prompt_length={prompt_length}, new_tokens={new_token}")

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

input_ids = torch.randint(1, model.config.vocab_size - 1, size=(batch_size, prompt_length)).to(device)
masks = torch.ones(batch_size, prompt_length, dtype=torch.int32).to(device)

with torch.no_grad():
max_mem = benchmark_memory(
model,
input_ids,
masks,
args.num_batches,
is_decoder,
new_token,
tokenizer.pad_token_id,
memory_tracker=memory_tracker,
)

mean_latency = benchmark_latency(
model,
input_ids,
masks,
args.num_batches,
is_decoder,
new_token,
tokenizer.pad_token_id,
memory_tracker=memory_tracker,
)

index = (batch_size, prompt_length, new_token)

per_token_latency = mean_latency / new_token
latencies[index] = per_token_latency

throughput = batch_size / (per_token_latency * 1e-3)
throughputs[index] = throughput
all_max_mem[index] = max_mem

print(
f"Latency per token: {per_token_latency:.3f} ms, throughput: {throughput:.3f} tok/s, peak mem: {max_mem:.2f} MB"
)

mean_latency = benchmark_latency(
model,
input_ids,
masks,
line = "{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
quantization,
act_order,
bits,
group_size,
kernel,
args.num_batches,
is_decoder,
batch_size,
prompt_length,
new_token,
tokenizer.pad_token_id,
memory_tracker=memory_tracker,
f"{load_time:.2f}",
f"{per_token_latency:.2f}",
f"{throughput:.2f}",
f"{max_mem:.2f}",
)

index = (batch_size, prompt_length, new_token)

per_token_latency = mean_latency / new_token
latencies[index] = per_token_latency

throughput = batch_size / (per_token_latency * 1e-3)
throughputs[index] = throughput
all_max_mem[index] = max_mem

print(
f"Latency per token: {per_token_latency:.3f} ms, throughput: {throughput:.3f} tok/s, peak mem: {max_mem:.2f} MB"
)

line = "{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
quantization,
act_order,
bits,
group_size,
kernel,
args.num_batches,
batch_size,
prompt_length,
new_token,
f"{load_time:.2f}",
f"{per_token_latency:.2f}",
f"{throughput:.2f}",
f"{max_mem:.2f}",
)
print(header)
print(line)
output_file.write(line)

output_file.close()
print(header)
print(line)
output_file.write(line)
output_file.close()

0 comments on commit 0c53c2f

Please sign in to comment.