-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_quantization.py
84 lines (68 loc) · 2.95 KB
/
test_quantization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import math
max_threads = str(min(8, os.cpu_count()))
os.environ['OMP_NUM_THREADS'] = max_threads
os.environ['OPENBLAS_NUM_THREADS'] = max_threads
os.environ['MKL_NUM_THREADS'] = max_threads
os.environ['VECLIB_MAXIMUM_THREADS'] = max_threads
os.environ['NUMEXPR_NUM_THREADS'] = max_threads
os.environ['NUMEXPR_MAX_THREADS'] = max_threads
import tempfile
import unittest
import torch.cuda
from parameterized import parameterized
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM
from auto_gptq.quantization import CHECKPOINT_FORMAT, QUANT_CONFIG_FILENAME, BaseQuantizeConfig
class TestQuantization(unittest.TestCase):
@parameterized.expand([(False,), (True,)])
def test_quantize(self, use_marlin: bool):
pretrained_model_dir = "saibo/llama-1B"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
examples = [
tokenizer(
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
),
tokenizer(
"Today I am in Paris and it is a wonderful day."
),
]
quantize_config = BaseQuantizeConfig(
bits=4,
group_size=128,
desc_act=False,
checkpoint_format=CHECKPOINT_FORMAT.MARLIN if use_marlin else CHECKPOINT_FORMAT.GPTQ,
)
model = AutoGPTQForCausalLM.from_pretrained(
pretrained_model_dir,
quantize_config=quantize_config,
use_flash_attention_2=False,
)
model.quantize(examples)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device="cuda:0", use_marlin=use_marlin)
del model
torch.cuda.empty_cache()
# test compat: 1) with simple dict type 2) is_marlin_format
compat_quantize_config = {
"bits": 4,
"group_size": 128,
"desc_act": False,
"is_marlin_format": use_marlin,
}
model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device="cuda:0", quantize_config=compat_quantize_config)
assert(isinstance(model.quantize_config, BaseQuantizeConfig))
del model
torch.cuda.empty_cache()
# test checkinpoint_format hint to from_quantized()
os.remove(f"{tmpdirname}/{QUANT_CONFIG_FILENAME}")
compat_quantize_config = {
"bits": 4,
"group_size": 128,
"desc_act": False,
}
model = AutoGPTQForCausalLM.from_quantized(tmpdirname, device="cuda:0",
quantize_config=compat_quantize_config,
checkpoint_format=CHECKPOINT_FORMAT.MARLIN if use_marlin else None)
assert (isinstance(model.quantize_config, BaseQuantizeConfig))