Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int4 gptq shape fix #142

Merged
merged 3 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def __init__(
}

# trace model for one input
one_input = [multi.values[0] for multi in inputs]
one_input = [multi.values[0].cpu() for multi in inputs]
exported_model = torch._dynamo.export(
model, aten_graph=True, pre_dispatch=True, tracing_mode="fake"
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
)(*one_input)
super().__init__(exported_model.graph_module)
self.new_state_dict = model.state_dict()
Expand Down
32 changes: 15 additions & 17 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_til
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
return weight_int4pack, scales_and_zeros

def _calc_padded_size(k, groupsize=1, innner_k_tiles=1):
from model import find_multiple
return find_multiple(k, 1024)

def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
Expand All @@ -378,29 +381,24 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0

def replace_linear_int4(module, groupsize, inner_k_tiles, padding, use_cuda):
def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
setattr(module, name, WeightOnlyInt4Linear(
child.in_features, child.out_features, bias=False,
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False, use_cuda=use_cuda
))
elif padding:
setattr(module, name, WeightOnlyInt4Linear(
child.in_features, child.out_features, bias=False,
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True, use_cuda=use_cuda
groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda
))
else:
replace_linear_int4(child, groupsize, inner_k_tiles, padding, use_cuda)
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda)


class WeightOnlyInt4QuantHandler:
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True):
self.mod = mod
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding = padding
self.padding_allowed = padding_allowed
assert groupsize in [32, 64, 128, 256]
assert inner_k_tiles in [2, 4, 8]

Expand All @@ -417,7 +415,7 @@ def create_quantized_state_dict(self):

weight = mod.weight.data
if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
if self.padding:
if self.padding_allowed:
from model import find_multiple
import torch.nn.functional as F
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
Expand All @@ -436,7 +434,7 @@ def create_quantized_state_dict(self):
return cur_state_dict

def convert_for_runtime(self, use_cuda):
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda)
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda)
return self.mod

class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
Expand Down Expand Up @@ -485,11 +483,11 @@ class WeightOnlyInt4Linear(torch.nn.Module):

def __init__(
self, in_features: int, out_features: int,
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True, use_cuda=True,
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True,
) -> None:
super().__init__()
self.padding = padding
if padding:
self.padding = _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
if self.padding:
from model import find_multiple
self.origin_in_features = in_features
in_features = find_multiple(in_features, 1024)
Expand Down Expand Up @@ -597,7 +595,7 @@ def quantize(

dir_name = checkpoint_path.parent
base_name = checkpoint_path.name
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth")
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.{device}.pth")
else:
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")

Expand Down