diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 6464b1a7..b56d5b85 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -445,6 +445,7 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): model, tokenizer, global_step * args.samples_per_gpu * world_size, + is_lora=bool(args.lora_r), ) if ( diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 4ea8ac77..b68988f9 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -1,6 +1,7 @@ # Standard +from collections import OrderedDict from contextlib import contextmanager -from copy import copy +from copy import copy, deepcopy from functools import partial from pathlib import Path from tempfile import TemporaryDirectory @@ -611,7 +612,24 @@ def log_rank_0(msg, include_caller=False, rank=None, to_print=False): # print(msg) -def save_hf_format_ds(args, model, tokenizer, samples_seen, convert_granite=True): +def _copy_no_lora_dict(state_dict): + cleaned_state_dict = OrderedDict() + for param_tensor in state_dict: + if not "lora" in param_tensor: + cleaned_state_dict[ + param_tensor.replace(".base_layer", "").replace("base_model.model.", "") + ] = deepcopy(state_dict[param_tensor]).cpu() + return cleaned_state_dict + + +def save_hf_format_ds( + args, + model, + tokenizer, + samples_seen, + convert_granite=True, + is_lora=False, +): model_to_save = model.module log_rank_0( f"\033[93mSaving model in huggingface format at samples_seen: {samples_seen}\033[0m", @@ -628,7 +646,12 @@ def save_hf_format_ds(args, model, tokenizer, samples_seen, convert_granite=True WEIGHTS_NAME = "pytorch_model.bin" output_dir = Path(args.output_dir) / "hf_format" / f"samples_{samples_seen}" if torch.distributed.get_rank() == 0: + if is_lora: + model_to_save.merge_adapter() + model_state = model_to_save.state_dict() + if is_lora: + model_state = _copy_no_lora_dict(model_state) output_dir.mkdir(parents=True, exist_ok=True) output_model_file = output_dir / WEIGHTS_NAME output_config_file = output_dir / CONFIG_NAME @@ -658,6 +681,9 @@ def save_hf_format_ds(args, model, tokenizer, samples_seen, convert_granite=True tmp_conf.to_json_file(str(output_config_file)) tokenizer.save_pretrained(str(output_dir)) + if is_lora: + model_to_save.unmerge_adapter() + dist.barrier() log_rank_0(f"\033[93mModel saved in {output_dir}\033[0m", to_print=True) log_rank_0(f"saving took {time.time() - start} seconds")