From 724560be26893083dee3f6ab3fa44bb8a40730a4 Mon Sep 17 00:00:00 2001 From: shitao Date: Thu, 28 Sep 2023 00:49:16 +0800 Subject: [PATCH] normlized --- .../baai_general_embedding/finetune/arguments.py | 2 +- FlagEmbedding/baai_general_embedding/finetune/run.py | 2 +- .../baai_general_embedding/finetune/trainer.py | 12 +++++++++--- examples/pretrain/README.md | 3 ++- setup.py | 3 ++- 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/FlagEmbedding/baai_general_embedding/finetune/arguments.py b/FlagEmbedding/baai_general_embedding/finetune/arguments.py index 33ec3507..c56139c2 100644 --- a/FlagEmbedding/baai_general_embedding/finetune/arguments.py +++ b/FlagEmbedding/baai_general_embedding/finetune/arguments.py @@ -24,7 +24,6 @@ class ModelArguments: default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} ) - normlized: bool = field(default=True) @dataclass @@ -71,3 +70,4 @@ class RetrieverTrainingArguments(TrainingArguments): temperature: Optional[float] = field(default=0.02) fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"}) sentence_pooling_method: str = field(default='cls', metadata={"help": "the pooling method, should be cls or mean"}) + normlized: bool = field(default=True) diff --git a/FlagEmbedding/baai_general_embedding/finetune/run.py b/FlagEmbedding/baai_general_embedding/finetune/run.py index b8bd5e7a..e83e3519 100644 --- a/FlagEmbedding/baai_general_embedding/finetune/run.py +++ b/FlagEmbedding/baai_general_embedding/finetune/run.py @@ -69,7 +69,7 @@ def main(): logger.info('Config: %s', config) model = BiEncoderModel(model_name=model_args.model_name_or_path, - normlized=model_args.normlized, + normlized=training_args.normlized, sentence_pooling_method=training_args.sentence_pooling_method, negatives_cross_device=training_args.negatives_cross_device, temperature=training_args.temperature) diff --git a/FlagEmbedding/baai_general_embedding/finetune/trainer.py b/FlagEmbedding/baai_general_embedding/finetune/trainer.py index 91405bd1..0165359e 100644 --- a/FlagEmbedding/baai_general_embedding/finetune/trainer.py +++ b/FlagEmbedding/baai_general_embedding/finetune/trainer.py @@ -2,10 +2,14 @@ from transformers.trainer import * -def save_ckpt_for_sentence_transformers(ckpt_dir, pooling_mode: str = 'cls'): +def save_ckpt_for_sentence_transformers(ckpt_dir, pooling_mode: str = 'cls', normlized: bool=True): word_embedding_model = models.Transformer(ckpt_dir) pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=pooling_mode) - model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu') + if normlized: + normlize_layer = models.Normalize() + model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normlize_layer], device='cpu') + else: + model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu') model.save(ckpt_dir) @@ -29,7 +33,9 @@ def _save(self, output_dir: Optional[str] = None): # save the checkpoint for sentence-transformers library if self.is_world_process_zero(): - save_ckpt_for_sentence_transformers(output_dir, pooling_mode=self.args.sentence_pooling_method) + save_ckpt_for_sentence_transformers(output_dir, + pooling_mode=self.args.sentence_pooling_method, + normlized=self.args.normlized) def compute_loss(self, model, inputs, return_outputs=False): """ diff --git a/examples/pretrain/README.md b/examples/pretrain/README.md index a578250e..ffccfaa7 100644 --- a/examples/pretrain/README.md +++ b/examples/pretrain/README.md @@ -40,7 +40,8 @@ torchrun --nproc_per_node {number of gpus} \ --per_device_train_batch_size {batch size; set 1 for toy data} \ --dataloader_drop_last True \ --max_seq_length 512 \ ---logging_steps 10 +--logging_steps 10 \ +--dataloader_num_workers 12 ``` More training arguments please refer to [transformers.TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). diff --git a/setup.py b/setup.py index 315a7710..6f802ecd 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ 'torch>=1.6.0', 'transformers>=4.18.0', 'datasets', - 'accelerate>=0.20.1' + 'accelerate>=0.20.1', + 'sentence_transformers', ], )