Skip to content

Commit

Permalink
Merge pull request #403 from turboderp/dev
Browse files Browse the repository at this point in the history
Merge dev branch
  • Loading branch information
turboderp authored Apr 7, 2024
2 parents f6b7faa + 3e8e306 commit dafb508
Show file tree
Hide file tree
Showing 55 changed files with 1,606 additions and 576 deletions.
297 changes: 172 additions & 125 deletions conversion/adaptivegptq.py

Large diffs are not rendered by default.

57 changes: 32 additions & 25 deletions conversion/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,17 @@ def compile_model(job, save_fn, model):
if isinstance(module, ExLlamaV2ParallelDecoder):

has_gate = model.config.arch.mlp_gate
has_qk_norm = model.config.use_qk_norm
d = get_f_module(job, module.input_layernorm); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.attn.q_proj); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.attn.k_proj); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.attn.v_proj); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.attn.o_proj); out_dict.update(d); current_size += _dsize(d)
if has_gate: d = get_q_module(job, module.mlp.gate_proj); out_dict.update(d); current_size += _dsize(d)
if has_qk_norm:
d = get_f_module(job, module.attn.q_norm); out_dict.update(d); current_size += _dsize(d)
d = get_f_module(job, module.attn.k_norm); out_dict.update(d); current_size += _dsize(d)
if has_gate:
d = get_q_module(job, module.mlp.gate_proj); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.mlp.up_proj); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.mlp.down_proj); out_dict.update(d); current_size += _dsize(d)

Expand Down Expand Up @@ -206,27 +211,29 @@ def compile_model(job, save_fn, model):

# Add signature to config.json

ds = job["cal_dataset"]
if ds is not None: qcfg_ds = os.path.split(ds)[1]
else: qcfg_ds = "(default)"

qcfg = {
"quant_method": "exl2",
"version": __version__,
"bits": job["bits"],
"head_bits": job["head_bits"],
"calibration": {
"rows": job["dataset_rows"],
"length": job["length"],
"dataset": qcfg_ds
},
}

config_json = os.path.join(out_dir, "config.json")
with open(config_json, "r") as f:
config_dict = json.load(f)

config_dict["quantization_config"] = qcfg

with open(config_json, "w") as f:
f.write(json.dumps(config_dict, indent = 4))
if job["compile_full"] is not None:

ds = job["cal_dataset"]
if ds is not None: qcfg_ds = os.path.split(ds)[1]
else: qcfg_ds = "(default)"

qcfg = {
"quant_method": "exl2",
"version": __version__,
"bits": job["bits"],
"head_bits": job["head_bits"],
"calibration": {
"rows": job["dataset_rows"],
"length": job["length"],
"dataset": qcfg_ds
},
}

config_json = os.path.join(out_dir, "config.json")
with open(config_json, "r") as f:
config_dict = json.load(f)

config_dict["quantization_config"] = qcfg

with open(config_json, "w") as f:
f.write(json.dumps(config_dict, indent = 4))
34 changes: 30 additions & 4 deletions conversion/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_error(module, hidden_states, target_states, cache, attn_params):
return max(1e-6, 1 - (rfn_sum / rfn_count))


def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params):
def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params, keep_q = False):

qjobs, qmaps = get_qparams_reduced(qparams_attn)
results = []
Expand Down Expand Up @@ -181,6 +181,10 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_p
"o_proj": qjobs[3][o].get_dict() }
results.append(r)

for x in ["k_proj", "v_proj", "o_proj"] + (["q_proj"] if not keep_q else []):
if x in quantizers:
del quantizers[x]

return results


Expand Down Expand Up @@ -257,6 +261,9 @@ def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_pa
"down_proj": qjobs[2][d].get_dict() }
results.append(r)

for x in ["up_proj", "down_proj", "gate_proj"]:
if x in quantizers:
del quantizers[x]

return results

Expand Down Expand Up @@ -329,10 +336,22 @@ def measure_moe_mlp(module, hidden_states, target_states, quantizers, cache, att

def measure_parallel_decoder(module, hidden_states, target_states_attn, target_states_mlp, quantizers, cache, attn_params):

for i in range(len(hidden_states)):
hidden_states[i] = hidden_states[i].cpu()

print(f" -- Sublayer: {module.key}.self_attn")
results_attn = measure_attn(module.attn, hidden_states, target_states_attn, quantizers, cache, attn_params)
results_attn = measure_attn(module.attn, hidden_states, target_states_attn, quantizers, cache, attn_params, keep_q = True)

module.attn.unload()
gc.collect()
torch.cuda.empty_cache()

print(f" -- Sublayer: {module.key}.mlp")
results_mlp = measure_mlp(module.mlp, hidden_states, target_states_mlp, quantizers, cache, attn_params, "q_proj")

for i in range(len(hidden_states)):
hidden_states[i] = hidden_states[i].to("cuda:0")

r = { "attn": results_attn,
"mlp": results_mlp }
return r
Expand Down Expand Up @@ -367,7 +386,9 @@ def measure_quant(job, save_fn, model):
accuracy_count = 0
overall_rolling_accuracy = 0

snapshot_interval = 10
last_snapshot_time = time.time()
snapshot_interval_s = 90

temp_filename = os.path.join(job["out_dir"], "hidden_states_temp.safetensors")
states_filename = os.path.join(job["out_dir"], "hidden_states.safetensors")
measurement = job.get("measurement", {})
Expand Down Expand Up @@ -602,7 +623,10 @@ def measure_quant(job, save_fn, model):

# Checkpoint

if index % snapshot_interval == 0 or index == len(model.modules) - 1:
time_since_snapshot = time.time() - last_snapshot_time
if time_since_snapshot > snapshot_interval_s or index == len(model.modules) - 1:

print(" -- Saving checkpoint...")

save_dict = {f"row.{idx:05}": h for idx, h in enumerate(hidden_states)}
save_file(save_dict, temp_filename)
Expand All @@ -621,6 +645,8 @@ def measure_quant(job, save_fn, model):
del job["invalid"]
save_fn()

last_snapshot_time = time.time()

# Export measurement

exp_measurement = { "measurement": job["measurement"],
Expand Down
21 changes: 17 additions & 4 deletions conversion/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ def optimize(job, save_fn, model):

error_norm = 2.4
max_step_size = 2
first_layer_bias = 10
bias_layers = 2
bias_iter = 10

key = "model.layers.0"
key_q = key + ".self_attn.q_proj"
Expand Down Expand Up @@ -60,8 +63,11 @@ def optimize(job, save_fn, model):

measurement = job["measurement"]

def fn(x):
return 1 - ((1 - x) ** error_norm)
def fn(x, idx):
if idx < bias_layers:
return 1 - ((1 - x) ** error_norm) * first_layer_bias
else:
return 1 - ((1 - x) ** error_norm)

weights = []
values = []
Expand All @@ -74,7 +80,7 @@ def fn(x):
m1 = measurement["model.layers." + str(i) + ".self_attn"]
m2 = measurement["model.layers." + str(i) + "." + mlp_mode]
for m in [m1, m2]:
v = [fn(e["accuracy"]) for e in m]
v = [fn(e["accuracy"], i) for e in m]
w = [e["total_bits"] for e in m]
weights.append(w)
values.append(v)
Expand Down Expand Up @@ -111,10 +117,13 @@ def fn(x):
value = 1
for i in range(num_layers * 2): value *= values[i][0]

iteration = 0

while True:
min_idx = -1
min_value = float("inf")
for i in range(num_layers * 2):
iteration += 1
for i in range(bias_layers if iteration < bias_iter else num_layers * 2):
s = f_solution[i]
if values[i][s] < min_value:
if s < len(weights[i]) - 1:
Expand Down Expand Up @@ -211,6 +220,7 @@ def improve(solution, s_weight, hold = None):

print(" -- Quantization strategy:")

errp = 1
job["strategy"] = {}
for layer_ in range(num_layers):

Expand All @@ -224,5 +234,8 @@ def improve(solution, s_weight, hold = None):
bpw = p["total_bits"] / n
err = 1 - p["accuracy"]
print(f" -- {k:50} {bpw:1.4f} bpw - exp. error: {err:1.8f}")
errp *= (1 - err)

print(f" -- Total exp. error: {1 - errp:1.12f}")

xx = 0
41 changes: 28 additions & 13 deletions conversion/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,24 +83,30 @@ def quant_linear(job: dict,
recons_dict = {}
recons_keys = ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups"]
if source.has_bias: recons_keys += ["bias"]
r_device = packed_dict[source.key + ".q_weight"].device
recons_linear.set_device_idx(r_device.index)
for k in recons_keys:
recons_dict[k] = packed_dict[source.key + "." + k]
recons_dict[k] = packed_dict[source.key + "." + k].to(r_device)
recons_dict["q_perm"] = torch.argsort(recons_dict["q_invperm"]).to(torch.int)
recons_linear.load(recons_dict)
recons_linear.load(recons_dict, device_tensors = False)

# Sanity test to ensure reconstructed matrix matches unpacked matrix

quant_w = source.linear.weight.T
recons_w = recons_linear.get_weight_tensor_dq()

if quant_w.numel() <= 1e9:
ident = torch.eye(recons_linear.in_features, dtype = torch.half).cuda()
recons_w2 = recons_linear.forward(ident, force_cuda = True)
recons_w2.sub_(quant_w)
if recons_linear.has_bias: recons_w2.sub_(recons_dict["bias"])
recons_w2.abs_()
diff2 = torch.max(recons_w2)
else:
try:
if quant_w.numel() <= 1e9:
ident = torch.eye(recons_linear.in_features, dtype = torch.half, device = r_device)
recons_w2 = recons_linear.forward(ident, force_cuda = True)
recons_w2.sub_(quant_w)
if recons_linear.has_bias: recons_w2.sub_(recons_dict["bias"])
recons_w2.abs_()
diff2 = torch.max(recons_w2)
else:
diff2 = 0
except torch.cuda.OutOfMemoryError as e:
print(f" !! Warning, not enough VRAM for second sanity check of {source.key}")
diff2 = 0

quant_w.sub_(recons_w)
Expand All @@ -120,7 +126,7 @@ def quant_linear(job: dict,

# Apply reconstructed matrix to source layer

source.linear.weight.data = recons_w.T
source.linear.weight.data = recons_w.T.to("cuda:0")


def quant_attn(job, module, hidden_states, target_states, quantizers, attn_params, strat):
Expand Down Expand Up @@ -246,7 +252,9 @@ def quant_parallel_decoder(job, module, hidden_states, target_states, quantizers
@torch.inference_mode()
def quant(job, save_fn, model):

snapshot_interval = 10
last_snapshot_time = time.time()
snapshot_interval_s = 90

temp_filename = os.path.join(job["out_dir"], "hidden_states_temp.safetensors")
states_filename = os.path.join(job["out_dir"], "hidden_states.safetensors")
strategy = job["strategy"]
Expand Down Expand Up @@ -412,6 +420,7 @@ def quant(job, save_fn, model):
strat_mlp = strategy[module.key + ".mlp"]
quant_parallel_decoder(job, module, hidden_states, target_states, quantizers, attn_params, strat_attn, strat_mlp)

torch.cuda.synchronize()
quantizers.clear()
gc.collect()
torch.cuda.empty_cache()
Expand All @@ -421,6 +430,7 @@ def quant(job, save_fn, model):
if mode == "linear":
with safe_open(job["cal_filename"], framework = "pt", device = "cpu") as f:
cal_ids = f.get_tensor("input_ids")
module.linear.weight.data = module.linear.weight.data.to("cuda:0")

rfn_sum = 0
rfn_count = 0
Expand Down Expand Up @@ -494,7 +504,10 @@ def quant(job, save_fn, model):

# Checkpoint

if index % snapshot_interval == 0 or index == len(model.modules) - 1:
time_since_snapshot = time.time() - last_snapshot_time
if time_since_snapshot > snapshot_interval_s or index == len(model.modules) - 1:

print(" -- Saving checkpoint...")

if mode != "linear":
save_dict = {f"row.{idx:05}": h for idx, h in enumerate(hidden_states)}
Expand All @@ -512,3 +525,5 @@ def quant(job, save_fn, model):

del job["invalid"]
save_fn()

time_since_snapshot = time.time()
12 changes: 12 additions & 0 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ def save_job():

if progress == "measure_quant":
print(f" -- Measuring quantization impact...")

model.unload()
config.max_output_len = 16
model = ExLlamaV2(config)
model.load(lazy = True)

status = measure_quant(job, save_job, model) # capturing the graceful exits
if status == "interrupted":
print("Process interrupted. Exiting gracefully.")
Expand All @@ -227,6 +233,12 @@ def save_job():
job["progress"] = "finished"
save_job()

model.unload()
config.max_output_len = None
model = ExLlamaV2(config)
model.load(lazy = True)


if progress == "optimize":

print(f" -- Optimizing...")
Expand Down
Loading

0 comments on commit dafb508

Please sign in to comment.