From 89410118e751b1bb0bd1b3f25ca1bb42066b7b42 Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Wed, 17 Jul 2024 11:36:59 -0400 Subject: [PATCH 1/6] Adding weight merging for LoRA/QLoRA Signed-off-by: Mustafa Eyceoz --- src/instructlab/training/utils.py | 57 ++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 4ea8ac77..f6a2d58f 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -23,6 +23,7 @@ export_to_huggingface, import_from_huggingface, ) +from peft.utils import _get_submodules from rich.logging import RichHandler from safetensors.torch import save_file from torch import distributed as dist @@ -611,7 +612,56 @@ def log_rank_0(msg, include_caller=False, rank=None, to_print=False): # print(msg) -def save_hf_format_ds(args, model, tokenizer, samples_seen, convert_granite=True): +def _dequantize_model(model, dtype=torch.bfloat16, device="cpu"): + """ + 'model': the peftmodel you loaded with qlora. + 'tokenizer': the model's corresponding hf's tokenizer. + 'to': directory to save the dequantized model + 'dtype': dtype that the model was trained using + 'device': device to load the model to + FROM: https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 + """ + # Third Party + from bitsandbytes.functional import dequantize_4bit + import bitsandbytes as bnb + + cls = bnb.nn.Linear4bit + + with torch.no_grad(): + for name, module in model.named_modules(): + if isinstance(module, cls): + print(f"Dequantizing `{name}`...") + quant_state = copy.deepcopy(module.weight.quant_state) + + quant_state[2] = dtype + + weights = dequantize_4bit( + module.weight.data, quant_state=quant_state, quant_type="nf4" + ).to(dtype) + + new_module = torch.nn.Linear( + module.in_features, module.out_features, bias=None, dtype=dtype + ) + new_module.weight = torch.nn.Parameter(weights) + new_module.to(device=device, dtype=dtype) + + parent, _, target_name = _get_submodules(model, name) + setattr(parent, target_name, new_module) + + model.is_loaded_in_4bit = False + + return model + + +def save_hf_format_ds( + args, + model, + tokenizer, + samples_seen, + convert_granite=True, + is_lora=False, + is_quant=False, +): model_to_save = model.module log_rank_0( f"\033[93mSaving model in huggingface format at samples_seen: {samples_seen}\033[0m", @@ -628,6 +678,11 @@ def save_hf_format_ds(args, model, tokenizer, samples_seen, convert_granite=True WEIGHTS_NAME = "pytorch_model.bin" output_dir = Path(args.output_dir) / "hf_format" / f"samples_{samples_seen}" if torch.distributed.get_rank() == 0: + if is_lora: + if is_quant: + model = _dequantize_model(model) + model_to_save = model.merge_and_unload() + model_state = model_to_save.state_dict() output_dir.mkdir(parents=True, exist_ok=True) output_model_file = output_dir / WEIGHTS_NAME From 32c37425e9a1b72aa2b9f873992c673984024423 Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Wed, 17 Jul 2024 11:49:40 -0400 Subject: [PATCH 2/6] Connect new merging function to main save call Signed-off-by: Mustafa Eyceoz --- src/instructlab/training/main_ds.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 6464b1a7..d50c5f17 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -445,6 +445,8 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): model, tokenizer, global_step * args.samples_per_gpu * world_size, + bool(args.lora_r), + bool(args.lora_r) and bool(args.lora_quant_bits), ) if ( From b6e7c9071a21c4c9adb3320aa0f2ff5b84d36f1b Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Wed, 17 Jul 2024 12:44:47 -0400 Subject: [PATCH 3/6] Bug fixes Signed-off-by: Mustafa Eyceoz --- src/instructlab/training/main_ds.py | 4 ++-- src/instructlab/training/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index d50c5f17..2c31b510 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -445,8 +445,8 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): model, tokenizer, global_step * args.samples_per_gpu * world_size, - bool(args.lora_r), - bool(args.lora_r) and bool(args.lora_quant_bits), + is_lora=bool(args.lora_r), + is_quant=bool(args.lora_r) and bool(args.lora_quant_bits), ) if ( diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index f6a2d58f..c2d5fa45 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -680,8 +680,8 @@ def save_hf_format_ds( if torch.distributed.get_rank() == 0: if is_lora: if is_quant: - model = _dequantize_model(model) - model_to_save = model.merge_and_unload() + model_to_save = _dequantize_model(model_to_save) + model_to_save = model_to_save.merge_and_unload() model_state = model_to_save.state_dict() output_dir.mkdir(parents=True, exist_ok=True) From 440992760afb4784e4df615ede900c8fa7afb1b0 Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Wed, 17 Jul 2024 15:40:23 -0400 Subject: [PATCH 4/6] Create cleaned state dict for lora-less save Signed-off-by: Mustafa Eyceoz --- src/instructlab/training/utils.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index c2d5fa45..2c8f75c3 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -1,6 +1,7 @@ # Standard +from collections import OrderedDict from contextlib import contextmanager -from copy import copy +from copy import copy, deepcopy from functools import partial from pathlib import Path from tempfile import TemporaryDirectory @@ -615,8 +616,6 @@ def log_rank_0(msg, include_caller=False, rank=None, to_print=False): def _dequantize_model(model, dtype=torch.bfloat16, device="cpu"): """ 'model': the peftmodel you loaded with qlora. - 'tokenizer': the model's corresponding hf's tokenizer. - 'to': directory to save the dequantized model 'dtype': dtype that the model was trained using 'device': device to load the model to FROM: https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 @@ -631,7 +630,7 @@ def _dequantize_model(model, dtype=torch.bfloat16, device="cpu"): for name, module in model.named_modules(): if isinstance(module, cls): print(f"Dequantizing `{name}`...") - quant_state = copy.deepcopy(module.weight.quant_state) + quant_state = deepcopy(module.weight.quant_state) quant_state[2] = dtype @@ -653,6 +652,16 @@ def _dequantize_model(model, dtype=torch.bfloat16, device="cpu"): return model +def _copy_no_lora_dict(state_dict): + cleaned_state_dict = OrderedDict() + for param_tensor in state_dict: + if not "lora" in param_tensor: + cleaned_state_dict[ + param_tensor.replace(".base_layer", "").replace("base_model.model.", "") + ] = deepcopy(state_dict[param_tensor]).cpu() + return cleaned_state_dict + + def save_hf_format_ds( args, model, @@ -681,9 +690,11 @@ def save_hf_format_ds( if is_lora: if is_quant: model_to_save = _dequantize_model(model_to_save) - model_to_save = model_to_save.merge_and_unload() + model_to_save.merge_adapter() model_state = model_to_save.state_dict() + if is_lora: + model_state = _copy_no_lora_dict(model_state) output_dir.mkdir(parents=True, exist_ok=True) output_model_file = output_dir / WEIGHTS_NAME output_config_file = output_dir / CONFIG_NAME @@ -713,6 +724,9 @@ def save_hf_format_ds( tmp_conf.to_json_file(str(output_config_file)) tokenizer.save_pretrained(str(output_dir)) + if is_lora: + model_to_save.unmerge_adapter() + dist.barrier() log_rank_0(f"\033[93mModel saved in {output_dir}\033[0m", to_print=True) log_rank_0(f"saving took {time.time() - start} seconds") From ac772cb3748ccf7f8b6cc5499f5c0df917d4ca3f Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Wed, 17 Jul 2024 16:38:56 -0400 Subject: [PATCH 5/6] Remove special quant case (auto-handled) Signed-off-by: Mustafa Eyceoz --- src/instructlab/training/main_ds.py | 1 - src/instructlab/training/utils.py | 42 ----------------------------- 2 files changed, 43 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 2c31b510..b56d5b85 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -446,7 +446,6 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): tokenizer, global_step * args.samples_per_gpu * world_size, is_lora=bool(args.lora_r), - is_quant=bool(args.lora_r) and bool(args.lora_quant_bits), ) if ( diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 2c8f75c3..cf106aa8 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -613,45 +613,6 @@ def log_rank_0(msg, include_caller=False, rank=None, to_print=False): # print(msg) -def _dequantize_model(model, dtype=torch.bfloat16, device="cpu"): - """ - 'model': the peftmodel you loaded with qlora. - 'dtype': dtype that the model was trained using - 'device': device to load the model to - FROM: https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 - """ - # Third Party - from bitsandbytes.functional import dequantize_4bit - import bitsandbytes as bnb - - cls = bnb.nn.Linear4bit - - with torch.no_grad(): - for name, module in model.named_modules(): - if isinstance(module, cls): - print(f"Dequantizing `{name}`...") - quant_state = deepcopy(module.weight.quant_state) - - quant_state[2] = dtype - - weights = dequantize_4bit( - module.weight.data, quant_state=quant_state, quant_type="nf4" - ).to(dtype) - - new_module = torch.nn.Linear( - module.in_features, module.out_features, bias=None, dtype=dtype - ) - new_module.weight = torch.nn.Parameter(weights) - new_module.to(device=device, dtype=dtype) - - parent, _, target_name = _get_submodules(model, name) - setattr(parent, target_name, new_module) - - model.is_loaded_in_4bit = False - - return model - - def _copy_no_lora_dict(state_dict): cleaned_state_dict = OrderedDict() for param_tensor in state_dict: @@ -669,7 +630,6 @@ def save_hf_format_ds( samples_seen, convert_granite=True, is_lora=False, - is_quant=False, ): model_to_save = model.module log_rank_0( @@ -688,8 +648,6 @@ def save_hf_format_ds( output_dir = Path(args.output_dir) / "hf_format" / f"samples_{samples_seen}" if torch.distributed.get_rank() == 0: if is_lora: - if is_quant: - model_to_save = _dequantize_model(model_to_save) model_to_save.merge_adapter() model_state = model_to_save.state_dict() From 6647fd446ce69e8dfdef3ddb4d34181de76b67cc Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Wed, 17 Jul 2024 16:49:39 -0400 Subject: [PATCH 6/6] Remove unused import Signed-off-by: Mustafa Eyceoz --- src/instructlab/training/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index cf106aa8..b68988f9 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -24,7 +24,6 @@ export_to_huggingface, import_from_huggingface, ) -from peft.utils import _get_submodules from rich.logging import RichHandler from safetensors.torch import save_file from torch import distributed as dist