Skip to content

Commit

Permalink
Merge pull request #142 from instructlab/lora-ckpt
Browse files Browse the repository at this point in the history
Adding weight merging for LoRA/QLoRA ckpts
  • Loading branch information
aldopareja authored Jul 18, 2024
2 parents 41b0857 + 6647fd4 commit 9fdeb87
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
model,
tokenizer,
global_step * args.samples_per_gpu * world_size,
is_lora=bool(args.lora_r),
)

if (
Expand Down
30 changes: 28 additions & 2 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Standard
from collections import OrderedDict
from contextlib import contextmanager
from copy import copy
from copy import copy, deepcopy
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -611,7 +612,24 @@ def log_rank_0(msg, include_caller=False, rank=None, to_print=False):
# print(msg)


def save_hf_format_ds(args, model, tokenizer, samples_seen, convert_granite=True):
def _copy_no_lora_dict(state_dict):
cleaned_state_dict = OrderedDict()
for param_tensor in state_dict:
if not "lora" in param_tensor:
cleaned_state_dict[
param_tensor.replace(".base_layer", "").replace("base_model.model.", "")
] = deepcopy(state_dict[param_tensor]).cpu()
return cleaned_state_dict


def save_hf_format_ds(
args,
model,
tokenizer,
samples_seen,
convert_granite=True,
is_lora=False,
):
model_to_save = model.module
log_rank_0(
f"\033[93mSaving model in huggingface format at samples_seen: {samples_seen}\033[0m",
Expand All @@ -628,7 +646,12 @@ def save_hf_format_ds(args, model, tokenizer, samples_seen, convert_granite=True
WEIGHTS_NAME = "pytorch_model.bin"
output_dir = Path(args.output_dir) / "hf_format" / f"samples_{samples_seen}"
if torch.distributed.get_rank() == 0:
if is_lora:
model_to_save.merge_adapter()

model_state = model_to_save.state_dict()
if is_lora:
model_state = _copy_no_lora_dict(model_state)
output_dir.mkdir(parents=True, exist_ok=True)
output_model_file = output_dir / WEIGHTS_NAME
output_config_file = output_dir / CONFIG_NAME
Expand Down Expand Up @@ -658,6 +681,9 @@ def save_hf_format_ds(args, model, tokenizer, samples_seen, convert_granite=True
tmp_conf.to_json_file(str(output_config_file))
tokenizer.save_pretrained(str(output_dir))

if is_lora:
model_to_save.unmerge_adapter()

dist.barrier()
log_rank_0(f"\033[93mModel saved in {output_dir}\033[0m", to_print=True)
log_rank_0(f"saving took {time.time() - start} seconds")
Expand Down

0 comments on commit 9fdeb87

Please sign in to comment.