From bde749560b429e39a87171052662a1b46518815f Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 12 Dec 2024 09:10:51 +0000 Subject: [PATCH] fix qwen2vl mrope --- src/llamafactory/data/collator.py | 9 +++++++++ src/llamafactory/train/dpo/workflow.py | 1 + src/llamafactory/train/kto/workflow.py | 1 + src/llamafactory/train/ppo/workflow.py | 2 +- src/llamafactory/train/pt/trainer.py | 5 ++++- src/llamafactory/train/rm/trainer.py | 5 ++++- src/llamafactory/train/rm/workflow.py | 4 +++- src/llamafactory/train/sft/trainer.py | 5 ++++- src/llamafactory/train/sft/workflow.py | 1 + 9 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 84f8006143..ebeb8e8d4e 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -121,6 +121,15 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso feature["token_type_ids"] = token_type_ids[i] features: Dict[str, "torch.Tensor"] = super().__call__(features) + + if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope + features["position_ids"], _ = self.model.get_rope_index( + input_ids=features["input_ids"], + image_grid_thw=mm_inputs.get("image_grid_thw", None), + video_grid_thw=mm_inputs.get("video_grid_thw", None), + attention_mask=features["attention_mask"], + ) + if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled cross_attention_mask = mm_inputs.pop("cross_attention_mask") seq_len = features["input_ids"].size(1) diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index e3d6e66022..e513f5c464 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -48,6 +48,7 @@ def run_dpo( data_collator = PairwiseDataCollatorWithPadding( template=template, + model=model, pad_to_multiple_of=8, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, **tokenizer_module, diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index f0dd2acaa7..e98510d5b7 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -47,6 +47,7 @@ def run_kto( data_collator = KTODataCollatorWithPadding( template=template, + model=model, pad_to_multiple_of=8, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, **tokenizer_module, diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py index 92262f99e2..50210583c0 100644 --- a/src/llamafactory/train/ppo/workflow.py +++ b/src/llamafactory/train/ppo/workflow.py @@ -46,7 +46,7 @@ def run_ppo( model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training - data_collator = MultiModalDataCollatorForSeq2Seq(template=template, **tokenizer_module) + data_collator = MultiModalDataCollatorForSeq2Seq(template=template, model=model, **tokenizer_module) # Create reference model and reward model ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 3f902dcd6f..37dcadfd96 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -18,7 +18,7 @@ from transformers import Trainer from typing_extensions import override -from ...extras.packages import is_transformers_version_equal_to_4_46 +from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -38,6 +38,9 @@ class CustomTrainer(Trainer): def __init__( self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs ) -> None: + if is_transformers_version_greater_than("4.46"): + kwargs["processing_class"] = kwargs.pop("tokenizer") + super().__init__(**kwargs) self.finetuning_args = finetuning_args diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 6469550c79..bccfdef5df 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -25,7 +25,7 @@ from typing_extensions import override from ...extras import logging -from ...extras.packages import is_transformers_version_equal_to_4_46 +from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -48,6 +48,9 @@ class PairwiseTrainer(Trainer): def __init__( self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs ) -> None: + if is_transformers_version_greater_than("4.46"): + kwargs["processing_class"] = kwargs.pop("tokenizer") + super().__init__(**kwargs) self.finetuning_args = finetuning_args self.can_return_loss = True # override property to return eval_loss diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py index e3f1b762c7..bced101859 100644 --- a/src/llamafactory/train/rm/workflow.py +++ b/src/llamafactory/train/rm/workflow.py @@ -44,7 +44,9 @@ def run_rm( template = get_template_and_fix_tokenizer(tokenizer, data_args) dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) - data_collator = PairwiseDataCollatorWithPadding(template=template, pad_to_multiple_of=8, **tokenizer_module) + data_collator = PairwiseDataCollatorWithPadding( + template=template, model=model, pad_to_multiple_of=8, **tokenizer_module + ) # Update arguments training_args.remove_unused_columns = False # important for multimodal and pairwise dataset diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 85ce6e8af0..1d3df9151a 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -27,7 +27,7 @@ from ...extras import logging from ...extras.constants import IGNORE_INDEX -from ...extras.packages import is_transformers_version_equal_to_4_46 +from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -51,6 +51,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): def __init__( self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs ) -> None: + if is_transformers_version_greater_than("4.46"): + kwargs["processing_class"] = kwargs.pop("tokenizer") + super().__init__(**kwargs) self.finetuning_args = finetuning_args diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index bc7ccb50e5..f93920050d 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -56,6 +56,7 @@ def run_sft( data_collator = SFTDataCollatorWith4DAttentionMask( template=template, + model=model, pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, block_diag_attn=model_args.block_diag_attn,