Skip to content

Commit

Permalink
implement ability to save lora models as pretrained models
Browse files Browse the repository at this point in the history
Signed-off-by: Oleg S <[email protected]>
  • Loading branch information
RobotSail committed Oct 23, 2024
1 parent fb4a46c commit 78e2122
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 99 deletions.
71 changes: 26 additions & 45 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Standard
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
import argparse
import math
Expand All @@ -17,7 +18,6 @@

# pylint: disable=no-name-in-module
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
from peft import LoraModel, LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, get_scheduler
import torch
Expand All @@ -44,7 +44,7 @@
apply_gradient_checkpointing,
convert_loss_to_reduce_sum,
ensure_loadable_granite_checkpoint,
get_projection_layer_names,
create_lora_config,
load_latest_full_state,
prepare_peft_model,
prepare_universal_checkpoint_from_latest,
Expand Down Expand Up @@ -115,13 +115,16 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
args.model_name_or_path, args.output_dir
) as path:
base_model_args["pretrained_model_name_or_path"] = path
base_model_args["use_padding_free_transformer"] = True
model = GPTDolomiteForCausalLM.from_pretrained(
**base_model_args,
use_padding_free_transformer=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(**base_model_args)

# store the base model args so we can recall them later if saving a LoRA model
args.base_model_args = base_model_args

if len(tokenizer) > model.config.vocab_size:
print(
f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size"
Expand Down Expand Up @@ -175,49 +178,16 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
# it is handled differently for lora and full
# - with the exception of granite, which handles it
# in the later stanza
distributed_backend = args.distributed_training_framework

Check warning on line 181 in src/instructlab/training/main_ds.py

View workflow job for this annotation

GitHub Actions / pylint

W0612: Unused variable 'distributed_backend' (unused-variable)
if args.lora_r > 0:
# if lora
# Third Party
from peft import LoraConfig

# ensure we select only the modules that exist in the model
proj_layers = get_projection_layer_names(model)
if not args.lora_target_modules:
print(
f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules"
)
if not proj_layers:
raise RuntimeError("could not find any projection layers in the model")
args.__dict__["lora_target_modules"] = proj_layers
else:
# when the user specifies the module, we should verify that they align with what's in the model
lora_target_modules_set = set(args.lora_target_modules)
diff = lora_target_modules_set - set(proj_layers)
layers_to_target = lora_target_modules_set - diff
if len(diff) == len(args.lora_target_modules):
raise ValueError(
f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically."
)
if diff:
print(
f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m"
)
args.__dict__["lora_target_modules"] = list(layers_to_target)

peft_config = LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=args.lora_target_modules,
lora_config = create_lora_config(model, args)
model = prepare_peft_model(
model,
lora_config,
args.distributed_training_framework,
gradient_checkpointing=not args.is_granite,
)
if args.distributed_training_framework == "fsdp":
model = LoraModel(model, peft_config, "default")
else:
model = prepare_peft_model(
model, peft_config, gradient_checkpointing=not args.is_granite
)
args.lora_config = lora_config
elif not args.is_granite:
model.gradient_checkpointing_enable()

Expand Down Expand Up @@ -532,7 +502,11 @@ def main(args):
#### distributed init #####
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
args.local_rank = int(os.environ["LOCAL_RANK"])
torch.distributed.init_process_group("nccl")
nccl_timeout: timedelta | None = None
if args.debug:
# surely we won't need any more than this... right?
nccl_timeout = timedelta(days=1)
torch.distributed.init_process_group("nccl", timeout=nccl_timeout)
args.global_rank = torch.distributed.get_rank()
tensor = torch.ByteTensor([False]).cuda()
torch.distributed.all_reduce(tensor)
Expand Down Expand Up @@ -930,6 +904,13 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py"
),
)
# hidden argument for our own sake
parser.add_argument(
"--debug",
help="Enables settings for debugging. For example, the NCCL timeout increases so more time can be spent in breakpoints.",
action="store_true",
default=False,
)
parser.add_argument("--disable_flash_attn", action="store_true")
args = parser.parse_args()
set_random_seed(args.seed)
Expand Down
195 changes: 141 additions & 54 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from argparse import Namespace
from collections import OrderedDict
from contextlib import contextmanager
from copy import deepcopy
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, List, Optional
from typing import Any, List, Optional, Tuple
import importlib
import inspect
import logging
Expand All @@ -21,7 +22,7 @@

# Third Party
# pylint: disable=no-name-in-module
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from instructlab.dolomite.hf_models import (
GPTDolomiteConfig,
export_to_huggingface,
Expand All @@ -35,8 +36,14 @@
apply_activation_checkpointing,
checkpoint_wrapper,
)
from transformers import PreTrainedModel
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
FullStateDictConfig,
StateDictType,
)
from transformers import PreTrainedModel, AutoModelForCausalLM, PreTrainedTokenizer
import numpy as np
from torch import nn
import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -306,10 +313,128 @@ def patch_target_module(
setattr(source, obj_name_to_patch, replace_with)


def wraps(module: nn.Module, wrapped_classes: Tuple[Any]) -> bool:
"""Checks if a module or its children are an instance of one of the provided classes.
Args:
module (nn.Module): A PyTorch module.
wrapped_classes(Tuple): A tuple of potential classes the module could be.
Returns:
bool: True if the module or any of its children are instances of `transformers.PreTrainedModel`, False otherwise.
"""
if isinstance(module, wrapped_classes):
return True

for m in module.children():
if wraps(m, wrapped_classes):
return True

return False


def create_lora_config(model: PreTrainedModel, args: Namespace) -> "peft.LoraConfig":
# if lora
# Third Party
from peft import LoraConfig

# ensure we select only the modules that exist in the model
proj_layers = get_projection_layer_names(model)
if not args.lora_target_modules:
print(
f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules"
)
if not proj_layers:
raise RuntimeError("could not find any projection layers in the model")
args.__dict__["lora_target_modules"] = proj_layers
else:
# when the user specifies the module, we should verify that they align with what's in the model
lora_target_modules_set = set(args.lora_target_modules)
diff = lora_target_modules_set - set(proj_layers)
layers_to_target = lora_target_modules_set - diff
if len(diff) == len(args.lora_target_modules):
raise ValueError(
f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically."
)
if diff:
print(
f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m"
)
args.__dict__["lora_target_modules"] = list(layers_to_target)

return LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=args.lora_target_modules,
)


def save_fsdp_lora_model(
args: Namespace,
model: FSDP,
tokenizer: PreTrainedTokenizer,
accelerator: Accelerator,
output_dir: Path,
):
"""Given a LoRA model wrapped by FSDP and Accelerate, save a full copy of the original
model with the trained LoRA adapters merged into the copy.
This function creates a full copy of the model being trained and stores it in CPU memory.
If encountering OOM errors on CPU, this is likely a culprit.
Args:
args (Namespace): Args received by the ArgumentParser.
model (FSDP): FSDP model as prepared by `accelerate.Accelerator`
accelerator (Accelerator): The given accelerator object.
"""
from peft import LoraModel, LoraConfig

if accelerator.distributed_type != DistributedType.FSDP:
raise RuntimeError(
"`save_fsdp_lora_model` was called when FSDP was not being used."
)
if not wraps(model, FSDP):
raise RuntimeError(
"`save_fsdp_lora_model` was called but provided model is not an FSDP model."
)
if not wraps(model, LoraModel):
raise RuntimeError(
"`save_fsdp_lora_model` was called but provided model is not a LoRA model."
)

# okay now that validation is out of the way, we are free to implement saving
lora_conf: LoraConfig = args.lora_config
sd_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, sd_config):
state = model.state_dict()

if accelerator.is_main_process:
# remove device_map from args list so we can load the model on CPU
old_device_map = args.base_model_args.pop("device_map", None)
model_copy = AutoModelForCausalLM.from_pretrained(
**args.base_model_args, device_map="cpu"
)
model_copy = LoraModel(model_copy, lora_conf, "default")
model_copy.load_state_dict(state)
model_copy.merge_and_unload(progressbar=True)
model_copy.save_pretrained(output_dir, safe_serialization=True)
model.config.to_json_file(f"{output_dir}/config.json")
tokenizer.save_pretrained(output_dir)
del model_copy
if old_device_map:
# return the previous device_map so it can be used later on if needed
args.base_model_args["device_map"] = old_device_map

dist.barrier()


def prepare_peft_model(
model: PreTrainedModel,
peft_config,
distributed_backend: DistributedBackend,
distributed_backend: str,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": True},
mixed_precision="bf16",
Expand Down Expand Up @@ -359,7 +484,7 @@ def make_inputs_require_grad(module, input, output):
make_inputs_require_grad
)

if distributed_backend == DistributedBackend.FSDP:
if distributed_backend == DistributedBackend.FSDP.value:
# FSDP doesn't like `get_peft_model` as it leads to dtype mismatches
model = LoraModel(model, peft_config, "default")
else:
Expand Down Expand Up @@ -689,63 +814,25 @@ def save_hf_format_accelerate(
CONFIG_NAME = "config.json"
output_config_file = output_dir / CONFIG_NAME

if is_lora and accelerator.distributed_type == DistributedType.FSDP:
save_fsdp_lora_model(
args=args,
model=model,
tokenizer=tokenizer,
accelerator=accelerator,
output_dir=output_dir,
)
dist.barrier()
return

get_state_dict_unpatched = accelerator.get_state_dict

def _get_state_dict_patched(model, unwrap=False):
return get_state_dict_unpatched(model, unwrap=unwrap)

accelerator.get_state_dict = _get_state_dict_patched
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
from transformers import AutoModelForCausalLM
from peft import LoraModel, LoraConfig

cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)

local_rank = int(os.environ["LOCAL_RANK"])
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
model: FSDP = model
state = model.state_dict()
torch.save(state, "./oleg-sanity-check-3-pre-state-dict.pt")
# state = {k: v.to(torch.bfloat16) for k, v in state.items()}
if accelerator.is_main_process:
# currently a copy of the exact settings.
# TODO: make this pull from the settings that the user passed in
config = LoraConfig(
r=4,
target_modules=["q_proj", "v_proj", "o_proj", "k_proj"],
lora_alpha=32,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
)
_model = AutoModelForCausalLM.from_pretrained(
"instructlab/granite-7b-lab",
device_map="cpu",
torch_dtype=torch.bfloat16,
)
# tokenizer2 = AutoModelForCausalLM.from_pretrained('instructlab/granite-7b-lab')
_model = LoraModel(_model, config, "default")
_model.load_state_dict(state)
_model.merge_and_unload()
_save_out_dir = "./oleg-sanity-check-6"
_model.save_pretrained(_save_out_dir, safe_serialization=True)
_model.config.to_json_file(f"{_save_out_dir}/config.json")
tokenizer.save_pretrained(_save_out_dir)
print(f"[rank {local_rank}]")
dist.barrier()
return

if accelerator.is_main_process:
from IPython import embed

embed()
dist.barrier()

if accelerator.is_main_process:
from IPython import embed

embed()
if is_lora:
model.module.merge_adapter()
model_state = model.module.state_dict()
Expand Down

0 comments on commit 78e2122

Please sign in to comment.