-
Notifications
You must be signed in to change notification settings - Fork 23
/
autograd_4bit.py
150 lines (123 loc) · 5.24 KB
/
autograd_4bit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import matmul_utils_4bit as mm4b
import torch
import torch.nn as nn
import time
import math
from safetensors import safe_open
import numpy as np
class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod
def forward(ctx, x, qweight, scales, zeros, groupsize=-1):
ctx.save_for_backward(qweight, scales, zeros, torch.from_numpy(np.array([groupsize])).cuda())
if groupsize == -1:
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
else:
output = mm4b._matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize)
output = output.clone()
return output
@staticmethod
def backward(ctx, grad_output):
qweight, scales, zeros, groupsize = ctx.saved_tensors
groupsize = groupsize.cpu().numpy()[0]
if groupsize == -1:
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
else:
grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True)
return grad, None, None, None, None
# Assumes layer is perfectly divisible into 256 * 256 blocks
class Autograd4bitQuantLinear(nn.Module):
def __init__(self, infeatures, outfeatures, groupsize=-1):
super().__init__()
bits = 4
self.in_features = infeatures
self.out_features = outfeatures
self.bits = bits
self.groupsize = groupsize
if groupsize == -1:
self.register_buffer('zeros', torch.empty((outfeatures, 1)))
self.register_buffer('scales', torch.empty((outfeatures, 1)))
else:
self.register_buffer('qzeros',
torch.empty((math.ceil(infeatures/groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
)
self.register_buffer('scales', torch.empty((math.ceil(infeatures/groupsize),outfeatures)))
self.register_buffer('bias', torch.empty(outfeatures))
self.register_buffer(
'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)
)
def forward(self, x):
if torch.is_grad_enabled():
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales,
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
out += self.bias
else:
out = mm4b.matmul4bit(x, self.qweight, self.scales,
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
out += self.bias
return out
def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1):
if isinstance(module, Autograd4bitQuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
setattr(
module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize)
)
for name1, child in module.named_children():
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize)
def model_to_half(model):
model.half()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
if m.groupsize == -1:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
print('Converted as Half.')
def model_to_float(model):
model.float()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
if m.groupsize == -1:
m.zeros = m.zeros.float()
m.scales = m.scales.float()
m.bias = m.bias.float()
print('Converted as Float.')
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(
child, layers=layers, name=name + '.' + name1 if name != '' else name1
))
return res
def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048):
import accelerate
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
print("Loading Model ...")
t0 = time.time()
with accelerate.init_empty_weights():
config = LlamaConfig.from_pretrained(config_path)
model = LlamaForCausalLM(config)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize)
model = accelerate.load_checkpoint_and_dispatch(
model=model,
checkpoint=model_path,
device_map=device_map,
no_split_module_classes=["LlamaDecoderLayer"]
)
model.seqlen = seqlen
if half:
model_to_half(model)
tokenizer = LlamaTokenizer.from_pretrained(config_path)
tokenizer.truncation_side = 'left'
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer