diff --git a/GPTQ.py b/GPTQ.py index 806ffad..e1279bd 100644 --- a/GPTQ.py +++ b/GPTQ.py @@ -150,9 +150,9 @@ def __init__( } # trace model for one input - one_input = [multi.values[0] for multi in inputs] + one_input = [multi.values[0].cpu() for multi in inputs] exported_model = torch._dynamo.export( - model, aten_graph=True, pre_dispatch=True, tracing_mode="fake" + model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake" )(*one_input) super().__init__(exported_model.graph_module) self.new_state_dict = model.state_dict() diff --git a/model.py b/model.py index dbf24e5..e70e87a 100644 --- a/model.py +++ b/model.py @@ -78,10 +78,8 @@ def update(self, input_pos, k_val, v_val): # input_pos: [S], k_val: [B, H, S, D] assert input_pos.shape[0] == k_val.shape[2] - k_out = self.k_cache - v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val + k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val) + v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val) return k_out, v_out diff --git a/quantize.py b/quantize.py index db47775..a9b3f79 100644 --- a/quantize.py +++ b/quantize.py @@ -365,6 +365,9 @@ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_til weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) return weight_int4pack, scales_and_zeros +def _calc_padded_size(k, groupsize=1, innner_k_tiles=1): + from model import find_multiple + return find_multiple(k, 1024) def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): origin_x_size = x.size() @@ -378,29 +381,24 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1): return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 -def replace_linear_int4(module, groupsize, inner_k_tiles, padding, use_cuda): +def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda): for name, child in module.named_children(): if isinstance(child, nn.Linear): - if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): + if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: setattr(module, name, WeightOnlyInt4Linear( child.in_features, child.out_features, bias=False, - groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False, use_cuda=use_cuda - )) - elif padding: - setattr(module, name, WeightOnlyInt4Linear( - child.in_features, child.out_features, bias=False, - groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True, use_cuda=use_cuda + groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda )) else: - replace_linear_int4(child, groupsize, inner_k_tiles, padding, use_cuda) + replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda) class WeightOnlyInt4QuantHandler: - def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True): self.mod = mod self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles - self.padding = padding + self.padding_allowed = padding_allowed assert groupsize in [32, 64, 128, 256] assert inner_k_tiles in [2, 4, 8] @@ -417,7 +415,7 @@ def create_quantized_state_dict(self): weight = mod.weight.data if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles): - if self.padding: + if self.padding_allowed: from model import find_multiple import torch.nn.functional as F print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") @@ -436,7 +434,7 @@ def create_quantized_state_dict(self): return cur_state_dict def convert_for_runtime(self, use_cuda): - replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda) + replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda) return self.mod class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): @@ -485,11 +483,11 @@ class WeightOnlyInt4Linear(torch.nn.Module): def __init__( self, in_features: int, out_features: int, - bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True, use_cuda=True, + bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True, ) -> None: super().__init__() - self.padding = padding - if padding: + self.padding = _check_linear_int4_k(in_features, groupsize, inner_k_tiles) + if self.padding: from model import find_multiple self.origin_in_features = in_features in_features = find_multiple(in_features, 1024) @@ -597,7 +595,7 @@ def quantize( dir_name = checkpoint_path.parent base_name = checkpoint_path.name - new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth") + new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.{device}.pth") else: raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]") diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..b772fa3 --- /dev/null +++ b/run.sh @@ -0,0 +1,17 @@ +export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf + +# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working +# echo "base" + +python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5 +python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5 + +# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile +# echo "quant good" + +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 +# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5 + +# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth + +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5