Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 2 additions & 37 deletions pipelinerl/finetune/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from transformers.models.auto.modeling_auto import _BaseAutoModelClass

from .context import get_accelerator, logger
from .lora import has_lora_checkpoint, lora_load, lora_save, prepare_lora_model
from .types import ModelClass, TrainingMetrics


Expand Down Expand Up @@ -101,23 +100,6 @@ def load_model(args, model_class, current_dir):

if args.load_as_bf16:
loading_args["torch_dtype"] = torch.bfloat16
if args.lora.enabled:
if is_ds_zero_3:
raise Exception("LoRA is not compatible with Deepspeed zero stage 3")
if args.lora.base_model_8bit:
loading_args["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=False,
load_in_8bit=True,
llm_int8_has_fp16_weight=args.load_as_bf16,
)
elif args.lora.base_model_4bit:
loading_args["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
load_in_8bit=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=False,
bnb_4bit_compute_dtype=torch.bfloat16,
)
if args.auto_device_map:
loading_args["device_map"] = "auto"
model_cls = get_auto_model_class(model_class)
Expand All @@ -130,24 +112,14 @@ def load_model(args, model_class, current_dir):
# Size mismatch errors here may be due to improper used of Deepspeed+save_pretrained()
# instead, always call save_model_only() in all processes

# when LoRA enabled, always preload the original model, the lora weights will be loaded later
model_to_load = args.config_name if args.lora.enabled else str(current_dir)
model_to_load = args.config_name
logger.info(f"Loading model {model_cls} weights from {current_dir}")
else: # from scratch
logger.info(f"Initializing model {model_cls} from {args.config_name}")

logger.info(f"Loading args: {loading_args}")
model = model_cls.from_pretrained(model_to_load, **loading_args)

if args.lora.enabled:
model = prepare_lora_model(args.lora, model, args.gradient_checkpointing)
if has_lora_checkpoint(current_dir):
lora_load(current_dir, model)
elif args.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": args.reentrant_checkpointing}
)

get_accelerator().wait_for_everyone()
return model

Expand Down Expand Up @@ -300,7 +272,6 @@ def save_model_and_tokenizer(
output_dir: Path,
model: transformers.PreTrainedModel,
tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast,
lora: bool = False,
safe_serialization: bool = False,
):
logger.info("Saving model and tokenizer")
Expand All @@ -309,7 +280,6 @@ def save_model_and_tokenizer(
temp_dir,
model,
unwrap=True,
lora=lora,
safe_serialization=safe_serialization,
)
save_tokenizer_only(temp_dir, tokenizer)
Expand All @@ -319,7 +289,6 @@ def save_model_only(
output_dir: Path,
model,
unwrap: bool = True,
lora: bool = False,
safe_serialization: bool = False,
):
"""
Expand All @@ -344,12 +313,8 @@ def save_model_only(
logger.info(f"Save model to {output_dir}")

unwrapped_model = get_accelerator().unwrap_model(model) if unwrap else model
if lora:
lora_save(output_dir, unwrapped_model)
return

# for non-deepspeed models
elif isinstance(unwrapped_model, transformers.PreTrainedModel):
if isinstance(unwrapped_model, transformers.PreTrainedModel):
logger.info("Saving model using transformers save_pretrained")
unwrapped_model.save_pretrained( # type: ignore
output_dir,
Expand Down
165 changes: 0 additions & 165 deletions pipelinerl/finetune/lora.py

This file was deleted.

3 changes: 1 addition & 2 deletions pipelinerl/finetune/optim.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import torch
from peft.peft_model import PeftModel
from torch.optim.adamw import AdamW
from torch.optim.optimizer import Optimizer
from transformers import Adafactor, PreTrainedModel


def get_grouped_params(
model: PreTrainedModel | PeftModel,
model: PreTrainedModel,
weight_decay: float,
no_decay: list[str] = ["bias", "LayerNorm.weight"],
):
Expand Down
58 changes: 18 additions & 40 deletions pipelinerl/finetune_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,45 +712,9 @@ def batch_generator_fn():
dist.all_gather(all_samples, local_samples)
total_samples = sum(int(tensor.item()) for tensor in all_samples)
do_optimizer_step = total_samples == target_samples
using_deepspeed = isinstance(model, deepspeed.DeepSpeedEngine)

def backward(loss, is_final_micro_batch=False):
"""Perform backward pass with appropriate gradient accumulation boundary"""
if using_deepspeed:
# Tell DeepSpeed whether this is a boundary for gradient accumulation
model.set_gradient_accumulation_boundary(is_final_micro_batch)
# DeepSpeed's backward
model.backward(loss)
else:
# accelerator's backward
get_accelerator().backward(loss)

def optimizer_step_and_zero_grad():
"""Perform optimizer step and zero gradients"""
if using_deepspeed:
# Final boundary before optimizer step
model.set_gradient_accumulation_boundary(True)
model.step()
grad_norm = model.get_global_grad_norm() if hasattr(model, "get_global_grad_norm") else None
if isinstance(training_metrics.grad_norm, torch.Tensor):
grad_norm = grad_norm.item()
training_metrics.grad_norm = grad_norm if grad_norm is not None else -1.0
else:
max_grad_norm = args.get("gradient_clipping_threshold", None)
training_metrics.grad_norm = get_accelerator().clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
optimizer.zero_grad()

@contextlib.contextmanager
def toggle_sync(sync: bool):
"""Wrap accelerate.no_sync() if sync is False."""
if sync:
yield # do not enforce no_sync mode
else:
with get_accelerator().no_sync(model):
yield

with toggle_sync(do_optimizer_step):
# Perform backward pass with appropriate gradient accumulation boundary
with get_accelerator().accumulate(model):
# Choose RL step function based on seq_packing config
loss, this_step_rl_metrics = rl_step(
model, batch, training_metrics.completed_steps, final_train_steps, rl_config
Expand All @@ -765,7 +729,22 @@ def toggle_sync(sync: bool):

training_metrics.lr = optimizer.param_groups[0]["lr"]

backward(loss, is_final_micro_batch=do_optimizer_step)
# Use accelerator's unified backward method
get_accelerator().backward(loss)

# Only perform optimizer step when sync_gradients is True
if get_accelerator().sync_gradients:
# Clip gradients
max_grad_norm = args.get("gradient_clipping_threshold", None)
if max_grad_norm is not None:
grad_norm = get_accelerator().clip_grad_norm_(model.parameters(), max_grad_norm)
training_metrics.grad_norm = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
else:
training_metrics.grad_norm = -1.0

optimizer.step()
optimizer.zero_grad()
torch.cuda.empty_cache()

if not is_sentinel_batch:
passes_took.append(time.time() - time_before_pass)
Expand Down Expand Up @@ -795,7 +774,6 @@ def toggle_sync(sync: bool):
except Exception as e:
logger.warning(f"Synchronization error: {e}. Continuing anyway...")

optimizer_step_and_zero_grad()
lr_scheduler.step()

metrics_dict = {}
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@ authors = [
dependencies = [
"torch>=2.6",
"vllm==0.8.3",
"accelerate==1.7.0",
"Tapeagents[finetune]==0.1.15",
"Tapeagents==0.1.15",
"deepspeed~=0.15.4",
"accelerate==1.8.0",
"transformers==4.51.0",
"flash-attn==2.7.4.post1",
"math-verify[antlr4_9_3]==0.7.0",
"orjson==3.10.16",
"redis==5.2.1",
"hydra-core>=1.3.2",
"wandb~=0.19",
]

[tool.setuptools.packages.find]
Expand Down