Skip to content

Commit

Permalink
normlized
Browse files Browse the repository at this point in the history
  • Loading branch information
shitao committed Sep 27, 2023
1 parent 982f810 commit 724560b
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion FlagEmbedding/baai_general_embedding/finetune/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class ModelArguments:
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)

normlized: bool = field(default=True)


@dataclass
Expand Down Expand Up @@ -71,3 +70,4 @@ class RetrieverTrainingArguments(TrainingArguments):
temperature: Optional[float] = field(default=0.02)
fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"})
sentence_pooling_method: str = field(default='cls', metadata={"help": "the pooling method, should be cls or mean"})
normlized: bool = field(default=True)
2 changes: 1 addition & 1 deletion FlagEmbedding/baai_general_embedding/finetune/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def main():
logger.info('Config: %s', config)

model = BiEncoderModel(model_name=model_args.model_name_or_path,
normlized=model_args.normlized,
normlized=training_args.normlized,
sentence_pooling_method=training_args.sentence_pooling_method,
negatives_cross_device=training_args.negatives_cross_device,
temperature=training_args.temperature)
Expand Down
12 changes: 9 additions & 3 deletions FlagEmbedding/baai_general_embedding/finetune/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
from transformers.trainer import *


def save_ckpt_for_sentence_transformers(ckpt_dir, pooling_mode: str = 'cls'):
def save_ckpt_for_sentence_transformers(ckpt_dir, pooling_mode: str = 'cls', normlized: bool=True):
word_embedding_model = models.Transformer(ckpt_dir)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=pooling_mode)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu')
if normlized:
normlize_layer = models.Normalize()
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normlize_layer], device='cpu')
else:
model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu')
model.save(ckpt_dir)


Expand All @@ -29,7 +33,9 @@ def _save(self, output_dir: Optional[str] = None):

# save the checkpoint for sentence-transformers library
if self.is_world_process_zero():
save_ckpt_for_sentence_transformers(output_dir, pooling_mode=self.args.sentence_pooling_method)
save_ckpt_for_sentence_transformers(output_dir,
pooling_mode=self.args.sentence_pooling_method,
normlized=self.args.normlized)

def compute_loss(self, model, inputs, return_outputs=False):
"""
Expand Down
3 changes: 2 additions & 1 deletion examples/pretrain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ torchrun --nproc_per_node {number of gpus} \
--per_device_train_batch_size {batch size; set 1 for toy data} \
--dataloader_drop_last True \
--max_seq_length 512 \
--logging_steps 10
--logging_steps 10 \
--dataloader_num_workers 12
```

More training arguments please refer to [transformers.TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments).
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
'torch>=1.6.0',
'transformers>=4.18.0',
'datasets',
'accelerate>=0.20.1'
'accelerate>=0.20.1',
'sentence_transformers',
],
)

0 comments on commit 724560b

Please sign in to comment.