From 3fe98656ab9b80f10322fe290b7eb4c5efaac211 Mon Sep 17 00:00:00 2001 From: viktorhargitai <105610130+viktorhargitai@users.noreply.github.com> Date: Mon, 12 Jun 2023 00:44:04 +0200 Subject: [PATCH] implement LoRA layer choice --- qlora.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/qlora.py b/qlora.py index 59e2a701..040d1b6a 100644 --- a/qlora.py +++ b/qlora.py @@ -7,6 +7,7 @@ import os from os.path import exists, join, isdir from dataclasses import dataclass, field +import re import sys from typing import Optional, Dict, Sequence import numpy as np @@ -135,6 +136,10 @@ class TrainingArguments(transformers.Seq2SeqTrainingArguments): default=False, metadata={"help": "Use 8-bit adam."} ) + lora_modules: str = field( + default=None, + metadata={"help": "Select layers to add LoRA to. Default or 'all' selects all linear layers except the head, 'attention' selects the attention layers, 'ffn' the feed-forward ones, other values are treated as regex patterns for exact matching layer names."} + ) double_quant: bool = field( default=True, metadata={"help": "Compress the quantization statistics through double quantization."} @@ -219,12 +224,17 @@ class GenerationArguments: no_repeat_ngram_size: Optional[int] = field(default=0) def find_all_linear_names(args, model): + """Find the model's linear layer names for applying LoRA, and return them in a set. + If args.lora_modules is None or 'all', it selects all linear layers except the head, 'attention' selects the attention layers, 'ffn' the feed-forward ones.""" cls = bnb.nn.Linear4bit if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear) lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls): - names = name.split('.') - lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + if (args.lora_modules == 'all' or not args.lora_modules + or (args.lora_modules == 'attention' and re.search('attn|attention|query|key|value', name.lower())) + or (args.lora_modules == 'ffn' and re.search('mlp|ffn', name.lower()))): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if 'lm_head' in lora_module_names: # needed for 16-bit @@ -320,7 +330,10 @@ def get_accelerate_model(args, checkpoint_dir): model = PeftModel.from_pretrained(model, join(checkpoint_dir, 'adapter_model'), is_trainable=True) else: print(f'adding LoRA modules...') - modules = find_all_linear_names(args, model) + if not args.lora_modules or args.lora_modules in ['all', 'attention', 'ffn']: + modules = find_all_linear_names(args, model) + else: + modules = args.lora_modules config = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha,