Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetune patch #40

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions finetune/finetune.py
Original file line number Diff line number Diff line change
@@ -180,12 +180,16 @@ def train_tokenize_function(examples, tokenizer):
return data_dict

def build_model(model_args, training_args, checkpoint_dir):
if not model_args.use_lora: assert model_args.bits in [16, 32]
logger.info("Starting model building process...")
if not model_args.use_lora:
assert model_args.bits in [16, 32]
logger.info(f"Not using LoRA. Model bits: {model_args.bits}")
compute_dtype = (torch.bfloat16 if training_args.bf16 else torch.float16)
logger.info(f"Compute dtype: {compute_dtype}")

logger.info(f"Loading model from: {model_args.model_name_or_path}")
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
load_in_4bit=model_args.bits == 4,
load_in_8bit=model_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=model_args.bits == 4,
load_in_8bit=model_args.bits == 8,
@@ -197,43 +201,61 @@ def build_model(model_args, training_args, checkpoint_dir):
) if model_args.use_lora else None,
torch_dtype=compute_dtype,
trust_remote_code=True,
attn_implementation=model_args.attn_implementation,
)
logger.info("Model loaded successfully")

if compute_dtype == torch.float16 and model_args.bits == 4:
if torch.cuda.is_bf16_supported():
logger.info('='*80)
logger.info('Your GPU supports bfloat16, you can accelerate training with the argument --bf16')
logger.info('='*80)

logger.info("Setting model attributes...")
setattr(model, 'model_parallel', True)
setattr(model, 'is_parallelizable', True)
model.config.torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32
# Tokenizer
logger.info(f"Model torch dtype set to: {model.config.torch_dtype}")

if model_args.use_lora and model_args.bits < 16:
logger.info("Preparing model for k-bit training...")
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
logger.info("Model prepared for k-bit training")

if model_args.use_lora:
logger.info("LoRA is enabled. Proceeding with LoRA setup...")
if checkpoint_dir is not None:
logger.info(f"Loading adapters from {checkpoint_dir}.")
# os.path.join(checkpoint_dir, 'adapter_model')
model = PeftModel.from_pretrained(model, checkpoint_dir, is_trainable=True)
else:
logger.info(f'Init LoRA modules...')
target_modules = model_args.trainable.split(',')
logger.info(f"Target modules for LoRA: {target_modules}")

modules_to_save = model_args.modules_to_save
if modules_to_save is not None:
modules_to_save = modules_to_save.split(',')
logger.info(f"Modules to save: {modules_to_save}")
else:
logger.info("No modules to save specified")

lora_rank = model_args.lora_rank
lora_dropout = model_args.lora_dropout
lora_alpha = model_args.lora_alpha
logger.info(f"LoRA parameters: rank={lora_rank}, dropout={lora_dropout}, alpha={lora_alpha}")

peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
inference_mode=False,
r=lora_rank, lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
modules_to_save=modules_to_save)
logger.info(f"LoRA configuration: {peft_config}")

model = get_peft_model(model, peft_config)
logger.info("LoRA model preparation completed")

for name, module in model.named_modules():
if isinstance(module, LoraLayer):
@@ -291,7 +313,7 @@ def train():
train_tokenize_function,
batched=True,
batch_size=3000,
num_proc=32,
num_proc=os.cpu_count(),
remove_columns=raw_train_datasets.column_names,
load_from_cache_file=True, # not args.overwrite_cache
desc="Running Encoding",