Skip to content

Commit

Permalink
Summary: redoing
Browse files Browse the repository at this point in the history
5bf70c1
in a way that doesn't get reverted

Test Plan: sh run.sh

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 4307b5c7904d59fdbd40a0455687ae90e84585d4
Pull Request resolved: #142
  • Loading branch information
HDCharles committed Mar 20, 2024
1 parent c955dac commit 9875178
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 23 deletions.
4 changes: 2 additions & 2 deletions GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def __init__(
}

# trace model for one input
one_input = [multi.values[0] for multi in inputs]
one_input = [multi.values[0].cpu() for multi in inputs]
exported_model = torch._dynamo.export(
model, aten_graph=True, pre_dispatch=True, tracing_mode="fake"
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
)(*one_input)
super().__init__(exported_model.graph_module)
self.new_state_dict = model.state_dict()
Expand Down
6 changes: 2 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,8 @@ def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]

k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)

return k_out, v_out

Expand Down
32 changes: 15 additions & 17 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_til
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
return weight_int4pack, scales_and_zeros

def _calc_padded_size(k, groupsize=1, innner_k_tiles=1):
from model import find_multiple
return find_multiple(k, 1024)

def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
Expand All @@ -378,29 +381,24 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0

def replace_linear_int4(module, groupsize, inner_k_tiles, padding, use_cuda):
def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
setattr(module, name, WeightOnlyInt4Linear(
child.in_features, child.out_features, bias=False,
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False, use_cuda=use_cuda
))
elif padding:
setattr(module, name, WeightOnlyInt4Linear(
child.in_features, child.out_features, bias=False,
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True, use_cuda=use_cuda
groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda
))
else:
replace_linear_int4(child, groupsize, inner_k_tiles, padding, use_cuda)
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda)


class WeightOnlyInt4QuantHandler:
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True):
self.mod = mod
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding = padding
self.padding_allowed = padding_allowed
assert groupsize in [32, 64, 128, 256]
assert inner_k_tiles in [2, 4, 8]

Expand All @@ -417,7 +415,7 @@ def create_quantized_state_dict(self):

weight = mod.weight.data
if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
if self.padding:
if self.padding_allowed:
from model import find_multiple
import torch.nn.functional as F
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
Expand All @@ -436,7 +434,7 @@ def create_quantized_state_dict(self):
return cur_state_dict

def convert_for_runtime(self, use_cuda):
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda)
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda)
return self.mod

class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
Expand Down Expand Up @@ -485,11 +483,11 @@ class WeightOnlyInt4Linear(torch.nn.Module):

def __init__(
self, in_features: int, out_features: int,
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True, use_cuda=True,
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True,
) -> None:
super().__init__()
self.padding = padding
if padding:
self.padding = _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
if self.padding:
from model import find_multiple
self.origin_in_features = in_features
in_features = find_multiple(in_features, 1024)
Expand Down Expand Up @@ -597,7 +595,7 @@ def quantize(

dir_name = checkpoint_path.parent
base_name = checkpoint_path.name
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth")
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.{device}.pth")
else:
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")

Expand Down
17 changes: 17 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf

# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working
# echo "base"

python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5
python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5

# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile
# echo "quant good"

# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5

# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth

# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5

0 comments on commit 9875178

Please sign in to comment.