Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement options for selecting layers to apply LoRA to during training #163

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
from os.path import exists, join, isdir
from dataclasses import dataclass, field
import re
import sys
from typing import Optional, Dict, Sequence
import numpy as np
Expand Down Expand Up @@ -135,6 +136,10 @@ class TrainingArguments(transformers.Seq2SeqTrainingArguments):
default=False,
metadata={"help": "Use 8-bit adam."}
)
lora_modules: str = field(
default=None,
metadata={"help": "Select layers to add LoRA to. Default or 'all' selects all linear layers except the head, 'attention' selects the attention layers, 'ffn' the feed-forward ones, other values are treated as regex patterns for exact matching layer names."}
)
double_quant: bool = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
Expand Down Expand Up @@ -219,12 +224,17 @@ class GenerationArguments:
no_repeat_ngram_size: Optional[int] = field(default=0)

def find_all_linear_names(args, model):
"""Find the model's linear layer names for applying LoRA, and return them in a set.
If args.lora_modules is None or 'all', it selects all linear layers except the head, 'attention' selects the attention layers, 'ffn' the feed-forward ones."""
cls = bnb.nn.Linear4bit if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if (args.lora_modules == 'all' or not args.lora_modules
or (args.lora_modules == 'attention' and re.search('attn|attention|query|key|value', name.lower()))
or (args.lora_modules == 'ffn' and re.search('mlp|ffn', name.lower()))):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])


if 'lm_head' in lora_module_names: # needed for 16-bit
Expand Down Expand Up @@ -320,7 +330,10 @@ def get_accelerate_model(args, checkpoint_dir):
model = PeftModel.from_pretrained(model, join(checkpoint_dir, 'adapter_model'), is_trainable=True)
else:
print(f'adding LoRA modules...')
modules = find_all_linear_names(args, model)
if not args.lora_modules or args.lora_modules in ['all', 'attention', 'ffn']:
modules = find_all_linear_names(args, model)
else:
modules = args.lora_modules
config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
Expand Down