Skip to content

Commit

Permalink
Merge pull request #6078 from wtmlon/support-efficient-tokens-calcula…
Browse files Browse the repository at this point in the history
…tion

support effective tokens calculation on sft/dpo
  • Loading branch information
hiyouga authored Nov 20, 2024
2 parents fdcc78b + 40627c6 commit bd639a1
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 1 deletion.
9 changes: 9 additions & 0 deletions src/llamafactory/extras/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import TYPE_CHECKING, Tuple, Union

import torch
import torch.distributed as dist
import transformers.dynamic_module_utils
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers.dynamic_module_utils import get_relative_imports
Expand Down Expand Up @@ -263,3 +264,11 @@ def use_modelscope() -> bool:

def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]


def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int:
r"""
calculate effective tokens.
"""
result = effective_token_num * epoch / train_runtime
return result / dist.get_world_size() if dist.is_initialized() else result
4 changes: 4 additions & 0 deletions src/llamafactory/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
)
include_effective_tokens_per_second: bool = field(
default=False,
metadata={"help": "Whether or not to compute effective tokens per second."},
)

def __post_init__(self):
def split_arg(arg):
Expand Down
13 changes: 13 additions & 0 deletions src/llamafactory/train/dpo/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import cal_effective_tokens
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
Expand Down Expand Up @@ -64,6 +65,12 @@ def run_dpo(
# Update arguments
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset

effective_token_num = 0.0
if finetuning_args.include_effective_tokens_per_second:
for data in dataset_module["train_dataset"]:
effective_token_num += len(data["chosen_input_ids"])
effective_token_num += len(data["rejected_input_ids"])

# Initialize our Trainer
trainer = CustomDPOTrainer(
model=model,
Expand All @@ -79,6 +86,12 @@ def run_dpo(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)

if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
)

trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
Expand Down
12 changes: 11 additions & 1 deletion src/llamafactory/train/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import get_logits_processor
from ...extras.misc import cal_effective_tokens, get_logits_processor
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
Expand Down Expand Up @@ -65,6 +65,11 @@ def run_sft(
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False # important for multimodal dataset

effective_token_num = 0.0
if finetuning_args.include_effective_tokens_per_second:
for data in dataset_module["train_dataset"]:
effective_token_num += len(data["input_ids"])

# Metric utils
metric_module = {}
if training_args.predict_with_generate:
Expand Down Expand Up @@ -94,6 +99,11 @@ def run_sft(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
)

trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
Expand Down

0 comments on commit bd639a1

Please sign in to comment.