diff --git a/docs/package_reference/sentence_transformer/index.rst b/docs/package_reference/sentence_transformer/index.rst index 0e724e78b..476c1b88a 100644 --- a/docs/package_reference/sentence_transformer/index.rst +++ b/docs/package_reference/sentence_transformer/index.rst @@ -7,6 +7,7 @@ Sentence Transformer SentenceTransformer trainer training_args + sampler losses evaluation datasets diff --git a/docs/package_reference/sentence_transformer/sampler.md b/docs/package_reference/sentence_transformer/sampler.md new file mode 100644 index 000000000..183612ee6 --- /dev/null +++ b/docs/package_reference/sentence_transformer/sampler.md @@ -0,0 +1,39 @@ + +# Samplers + +## BatchSamplers +```eval_rst +.. autoclass:: sentence_transformers.training_args.BatchSamplers + :members: +``` + +```eval_rst +.. autoclass:: sentence_transformers.sampler.DefaultBatchSampler + :members: +``` + +```eval_rst +.. autoclass:: sentence_transformers.sampler.NoDuplicatesBatchSampler + :members: +``` + +```eval_rst +.. autoclass:: sentence_transformers.sampler.GroupByLabelBatchSampler + :members: +``` + +## MultiDatasetBatchSamplers +```eval_rst +.. autoclass:: sentence_transformers.training_args.MultiDatasetBatchSamplers + :members: +``` + +```eval_rst +.. autoclass:: sentence_transformers.sampler.RoundRobinBatchSampler + :members: +``` + +```eval_rst +.. autoclass:: sentence_transformers.sampler.ProportionalBatchSampler + :members: +``` \ No newline at end of file diff --git a/docs/package_reference/sentence_transformer/training_args.md b/docs/package_reference/sentence_transformer/training_args.md index 0c68fe97c..2ca8c14b5 100644 --- a/docs/package_reference/sentence_transformer/training_args.md +++ b/docs/package_reference/sentence_transformer/training_args.md @@ -7,15 +7,3 @@ :members: :inherited-members: ``` - -## BatchSamplers -```eval_rst -.. autoclass:: sentence_transformers.training_args.BatchSamplers - :members: -``` - -## MultiDatasetBatchSamplers -```eval_rst -.. autoclass:: sentence_transformers.training_args.MultiDatasetBatchSamplers - :members: -``` \ No newline at end of file diff --git a/sentence_transformers/losses/AdaptiveLayerLoss.py b/sentence_transformers/losses/AdaptiveLayerLoss.py index 43fabbe91..a65416e2f 100644 --- a/sentence_transformers/losses/AdaptiveLayerLoss.py +++ b/sentence_transformers/losses/AdaptiveLayerLoss.py @@ -150,17 +150,17 @@ def __init__( Requirements: 1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss` or :class:`CachedGISTEmbedLoss`. - Relations: - - :class:`Matryoshka2dLoss` uses this loss in combination with :class:`MatryoshkaLoss` which allows for - output dimensionality reduction for faster downstream tasks (e.g. retrieval). - - Input: + Inputs: +---------------------------------------+--------+ | Texts | Labels | +=======================================+========+ | any | any | +---------------------------------------+--------+ + Relations: + - :class:`Matryoshka2dLoss` uses this loss in combination with :class:`MatryoshkaLoss` which allows for + output dimensionality reduction for faster downstream tasks (e.g. retrieval). + Example: :: diff --git a/sentence_transformers/losses/AnglELoss.py b/sentence_transformers/losses/AnglELoss.py index 323224e29..d7ff5f852 100644 --- a/sentence_transformers/losses/AnglELoss.py +++ b/sentence_transformers/losses/AnglELoss.py @@ -33,10 +33,6 @@ def __init__(self, model: SentenceTransformer, scale: float = 20.0) -> None: Requirements: - Sentence pairs with corresponding similarity scores in range of the similarity function. Default is [-1,1]. - Relations: - - :class:`CoSENTLoss` is AnglELoss with ``pairwise_cos_sim`` as the metric, rather than ``pairwise_angle_sim``. - - :class:`CosineSimilarityLoss` seems to produce a weaker training signal than ``CoSENTLoss`` or ``AnglELoss``. - Inputs: +--------------------------------+------------------------+ | Texts | Labels | @@ -44,6 +40,10 @@ def __init__(self, model: SentenceTransformer, scale: float = 20.0) -> None: | (sentence_A, sentence_B) pairs | float similarity score | +--------------------------------+------------------------+ + Relations: + - :class:`CoSENTLoss` is AnglELoss with ``pairwise_cos_sim`` as the metric, rather than ``pairwise_angle_sim``. + - :class:`CosineSimilarityLoss` seems to produce a weaker training signal than ``CoSENTLoss`` or ``AnglELoss``. + Example: :: diff --git a/sentence_transformers/losses/BatchAllTripletLoss.py b/sentence_transformers/losses/BatchAllTripletLoss.py index f396a3cbf..cd337c9a2 100644 --- a/sentence_transformers/losses/BatchAllTripletLoss.py +++ b/sentence_transformers/losses/BatchAllTripletLoss.py @@ -39,12 +39,6 @@ def __init__( 1. Each sentence must be labeled with a class. 2. Your dataset must contain at least 2 examples per labels class. - Relations: - * :class:`BatchHardTripletLoss` uses only the hardest positive and negative samples, rather than all possible, valid triplets. - * :class:`BatchHardSoftMarginTripletLoss` uses only the hardest positive and negative samples, rather than all possible, valid triplets. - Also, it does not require setting a margin. - * :class:`BatchSemiHardTripletLoss` uses only semi-hard triplets, valid triplets, rather than all possible, valid triplets. - Inputs: +------------------+--------+ | Texts | Labels | @@ -52,6 +46,16 @@ def __init__( | single sentences | class | +------------------+--------+ + Recommendations: + - Use ``BatchSamplers.GROUP_BY_LABEL`` (:class:`docs `) to + ensure that each batch contains 2+ examples per label class. + + Relations: + * :class:`BatchHardTripletLoss` uses only the hardest positive and negative samples, rather than all possible, valid triplets. + * :class:`BatchHardSoftMarginTripletLoss` uses only the hardest positive and negative samples, rather than all possible, valid triplets. + Also, it does not require setting a margin. + * :class:`BatchSemiHardTripletLoss` uses only semi-hard triplets, valid triplets, rather than all possible, valid triplets. + Example: :: diff --git a/sentence_transformers/losses/BatchHardSoftMarginTripletLoss.py b/sentence_transformers/losses/BatchHardSoftMarginTripletLoss.py index 45a769598..db381622c 100644 --- a/sentence_transformers/losses/BatchHardSoftMarginTripletLoss.py +++ b/sentence_transformers/losses/BatchHardSoftMarginTripletLoss.py @@ -44,9 +44,6 @@ def __init__( 2. Your dataset must contain at least 2 examples per labels class. 3. Your dataset should contain hard positives and negatives. - Relations: - * :class:`BatchHardTripletLoss` uses a user-specified margin, while this loss does not require setting a margin. - Inputs: +------------------+--------+ | Texts | Labels | @@ -54,6 +51,13 @@ def __init__( | single sentences | class | +------------------+--------+ + Recommendations: + - Use ``BatchSamplers.GROUP_BY_LABEL`` (:class:`docs `) to + ensure that each batch contains 2+ examples per label class. + + Relations: + * :class:`BatchHardTripletLoss` uses a user-specified margin, while this loss does not require setting a margin. + Example: :: diff --git a/sentence_transformers/losses/BatchHardTripletLoss.py b/sentence_transformers/losses/BatchHardTripletLoss.py index 957f0a428..c407694eb 100644 --- a/sentence_transformers/losses/BatchHardTripletLoss.py +++ b/sentence_transformers/losses/BatchHardTripletLoss.py @@ -102,6 +102,10 @@ def __init__( | single sentences | class | +------------------+--------+ + Recommendations: + - Use ``BatchSamplers.GROUP_BY_LABEL`` (:class:`docs `) to + ensure that each batch contains 2+ examples per label class. + Relations: * :class:`BatchAllTripletLoss` uses all possible, valid triplets, rather than only the hardest positive and negative samples. * :class:`BatchSemiHardTripletLoss` uses only semi-hard triplets, valid triplets, rather than only the hardest positive and negative samples. diff --git a/sentence_transformers/losses/BatchSemiHardTripletLoss.py b/sentence_transformers/losses/BatchSemiHardTripletLoss.py index e6d904f51..e29a65b95 100644 --- a/sentence_transformers/losses/BatchSemiHardTripletLoss.py +++ b/sentence_transformers/losses/BatchSemiHardTripletLoss.py @@ -50,12 +50,6 @@ def __init__( 2. Your dataset must contain at least 2 examples per labels class. 3. Your dataset should contain semi hard positives and negatives. - Relations: - * :class:`BatchHardTripletLoss` uses only the hardest positive and negative samples, rather than only semi hard positive and negatives. - * :class:`BatchAllTripletLoss` uses all possible, valid triplets, rather than only semi hard positive and negatives. - * :class:`BatchHardSoftMarginTripletLoss` uses only the hardest positive and negative samples, rather than only semi hard positive and negatives. - Also, it does not require setting a margin. - Inputs: +------------------+--------+ | Texts | Labels | @@ -63,6 +57,16 @@ def __init__( | single sentences | class | +------------------+--------+ + Recommendations: + - Use ``BatchSamplers.GROUP_BY_LABEL`` (:class:`docs `) to + ensure that each batch contains 2+ examples per label class. + + Relations: + * :class:`BatchHardTripletLoss` uses only the hardest positive and negative samples, rather than only semi hard positive and negatives. + * :class:`BatchAllTripletLoss` uses all possible, valid triplets, rather than only semi hard positive and negatives. + * :class:`BatchHardSoftMarginTripletLoss` uses only the hardest positive and negative samples, rather than only semi hard positive and negatives. + Also, it does not require setting a margin. + Example: :: diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index 243022d64..83e755d1c 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -99,9 +99,6 @@ def __init__( 1. (anchor, positive) pairs or (anchor, positive, negative pairs) 2. Should be used with large batch sizes for superior performance, but has slower training time than :class:`MultipleNegativesRankingLoss` - Relations: - - Equivalent to :class:`GISTEmbedLoss`, but with caching that allows for much higher batch sizes - Inputs: +---------------------------------------+--------+ | Texts | Labels | @@ -111,6 +108,13 @@ def __init__( | (anchor, positive, negative) triplets | none | +---------------------------------------+--------+ + Recommendations: + - Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs `) to + ensure that no in-batch negatives are duplicates of the anchor or positive samples. + + Relations: + - Equivalent to :class:`GISTEmbedLoss`, but with caching that allows for much higher batch sizes + Example: :: diff --git a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py index e2f1c5d65..e0665eedc 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py @@ -104,11 +104,6 @@ def __init__( 1. (anchor, positive) pairs or (anchor, positive, negative pairs) 2. Should be used with large batch sizes for superior performance, but has slower training time than :class:`MultipleNegativesRankingLoss` - Relations: - - Equivalent to :class:`MultipleNegativesRankingLoss`, but with caching that allows for much higher batch sizes - (and thus better performance) without extra memory usage. This loss also trains roughly 2x to 2.4x slower than - :class:`MultipleNegativesRankingLoss`. - Inputs: +---------------------------------------+--------+ | Texts | Labels | @@ -118,6 +113,15 @@ def __init__( | (anchor, positive, negative) triplets | none | +---------------------------------------+--------+ + Recommendations: + - Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs `) to + ensure that no in-batch negatives are duplicates of the anchor or positive samples. + + Relations: + - Equivalent to :class:`MultipleNegativesRankingLoss`, but with caching that allows for much higher batch sizes + (and thus better performance) without extra memory usage. This loss also trains roughly 2x to 2.4x slower than + :class:`MultipleNegativesRankingLoss`. + Example: :: diff --git a/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py index 98911ff9e..83fe1e06f 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py @@ -74,10 +74,6 @@ def __init__( 1. (anchor, positive) pairs 2. Should be used with large batch sizes for superior performance, but has slower training time than non-cached versions - Relations: - - Like :class:`MultipleNegativesRankingLoss`, but with an additional symmetric loss term and caching mechanism. - - Inspired by :class:`CachedMultipleNegativesRankingLoss`, adapted for symmetric loss calculation. - Inputs: +---------------------------------------+--------+ | Texts | Labels | @@ -85,6 +81,14 @@ def __init__( | (anchor, positive) pairs | none | +---------------------------------------+--------+ + Recommendations: + - Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs `) to + ensure that no in-batch negatives are duplicates of the anchor or positive samples. + + Relations: + - Like :class:`MultipleNegativesRankingLoss`, but with an additional symmetric loss term and caching mechanism. + - Inspired by :class:`CachedMultipleNegativesRankingLoss`, adapted for symmetric loss calculation. + Example: :: diff --git a/sentence_transformers/losses/CoSENTLoss.py b/sentence_transformers/losses/CoSENTLoss.py index 378ed0bb4..00b304695 100644 --- a/sentence_transformers/losses/CoSENTLoss.py +++ b/sentence_transformers/losses/CoSENTLoss.py @@ -40,10 +40,6 @@ def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_f Requirements: - Sentence pairs with corresponding similarity scores in range of the similarity function. Default is [-1,1]. - Relations: - - :class:`AnglELoss` is CoSENTLoss with ``pairwise_angle_sim`` as the metric, rather than ``pairwise_cos_sim``. - - :class:`CosineSimilarityLoss` seems to produce a weaker training signal than CoSENTLoss. In our experiments, CoSENTLoss is recommended. - Inputs: +--------------------------------+------------------------+ | Texts | Labels | @@ -51,6 +47,10 @@ def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_f | (sentence_A, sentence_B) pairs | float similarity score | +--------------------------------+------------------------+ + Relations: + - :class:`AnglELoss` is CoSENTLoss with ``pairwise_angle_sim`` as the metric, rather than ``pairwise_cos_sim``. + - :class:`CosineSimilarityLoss` seems to produce a weaker training signal than CoSENTLoss. In our experiments, CoSENTLoss is recommended. + Example: :: diff --git a/sentence_transformers/losses/ContrastiveLoss.py b/sentence_transformers/losses/ContrastiveLoss.py index c797a7419..6fdb02180 100644 --- a/sentence_transformers/losses/ContrastiveLoss.py +++ b/sentence_transformers/losses/ContrastiveLoss.py @@ -45,10 +45,6 @@ def __init__( Requirements: 1. (anchor, positive/negative) pairs - Relations: - - :class:`OnlineContrastiveLoss` is similar, but uses hard positive and hard negative pairs. - It often yields better results. - Inputs: +-----------------------------------------------+------------------------------+ | Texts | Labels | @@ -56,6 +52,10 @@ def __init__( | (anchor, positive/negative) pairs | 1 if positive, 0 if negative | +-----------------------------------------------+------------------------------+ + Relations: + - :class:`OnlineContrastiveLoss` is similar, but uses hard positive and hard negative pairs. + It often yields better results. + Example: :: diff --git a/sentence_transformers/losses/ContrastiveTensionLoss.py b/sentence_transformers/losses/ContrastiveTensionLoss.py index bc5eb2c49..7af82c253 100644 --- a/sentence_transformers/losses/ContrastiveTensionLoss.py +++ b/sentence_transformers/losses/ContrastiveTensionLoss.py @@ -33,9 +33,6 @@ class ContrastiveTensionLoss(nn.Module): * Semantic Re-Tuning with Contrastive Tension: https://openreview.net/pdf?id=Ov_sMNau-PF * `Unsupervised Learning > CT <../../examples/unsupervised_learning/CT/README.html>`_ - Relations: - * :class:`ContrastiveTensionLossInBatchNegatives` uses in-batch negative sampling, which gives a stronger training signal than this loss. - Inputs: +------------------+--------+ | Texts | Labels | @@ -43,6 +40,9 @@ class ContrastiveTensionLoss(nn.Module): | single sentences | none | +------------------+--------+ + Relations: + * :class:`ContrastiveTensionLossInBatchNegatives` uses in-batch negative sampling, which gives a stronger training signal than this loss. + Example: :: diff --git a/sentence_transformers/losses/CosineSimilarityLoss.py b/sentence_transformers/losses/CosineSimilarityLoss.py index 4dcf913a6..5a769a588 100644 --- a/sentence_transformers/losses/CosineSimilarityLoss.py +++ b/sentence_transformers/losses/CosineSimilarityLoss.py @@ -37,10 +37,6 @@ def __init__( Requirements: 1. Sentence pairs with corresponding similarity scores in range `[0, 1]` - Relations: - - :class:`CoSENTLoss` seems to produce a stronger training signal than CosineSimilarityLoss. In our experiments, CoSENTLoss is recommended. - - :class:`AnglELoss` is :class:`CoSENTLoss` with ``pairwise_angle_sim`` as the metric, rather than ``pairwise_cos_sim``. It also produces a stronger training signal than CosineSimilarityLoss. - Inputs: +--------------------------------+------------------------+ | Texts | Labels | @@ -48,6 +44,10 @@ def __init__( | (sentence_A, sentence_B) pairs | float similarity score | +--------------------------------+------------------------+ + Relations: + - :class:`CoSENTLoss` seems to produce a stronger training signal than CosineSimilarityLoss. In our experiments, CoSENTLoss is recommended. + - :class:`AnglELoss` is :class:`CoSENTLoss` with ``pairwise_angle_sim`` as the metric, rather than ``pairwise_cos_sim``. It also produces a stronger training signal than CosineSimilarityLoss. + Example: :: diff --git a/sentence_transformers/losses/GISTEmbedLoss.py b/sentence_transformers/losses/GISTEmbedLoss.py index c1df1e0ae..f1bb833bd 100644 --- a/sentence_transformers/losses/GISTEmbedLoss.py +++ b/sentence_transformers/losses/GISTEmbedLoss.py @@ -37,11 +37,6 @@ def __init__( 1. (anchor, positive, negative) triplets 2. (anchor, positive) pairs - Relations: - - :class:`MultipleNegativesRankingLoss` is similar to this loss, but it does not use - a guide model to guide the in-batch negative sample selection. `GISTEmbedLoss` yields - a stronger training signal at the cost of some training overhead. - Inputs: +---------------------------------------+--------+ | Texts | Labels | @@ -51,6 +46,15 @@ def __init__( | (anchor, positive) pairs | none | +---------------------------------------+--------+ + Recommendations: + - Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs `) to + ensure that no in-batch negatives are duplicates of the anchor or positive samples. + + Relations: + - :class:`MultipleNegativesRankingLoss` is similar to this loss, but it does not use + a guide model to guide the in-batch negative sample selection. `GISTEmbedLoss` yields + a stronger training signal at the cost of some training overhead. + Example: :: diff --git a/sentence_transformers/losses/MSELoss.py b/sentence_transformers/losses/MSELoss.py index e8b53ee33..02ec3a568 100644 --- a/sentence_transformers/losses/MSELoss.py +++ b/sentence_transformers/losses/MSELoss.py @@ -28,10 +28,7 @@ def __init__(self, model: SentenceTransformer) -> None: Requirements: 1. Usually uses a finetuned teacher M in a knowledge distillation setup - Relations: - - :class:`MarginMSELoss` is equivalent to this loss, but with a margin through a negative pair. - - Input: + Inputs: +-----------------------------------------+-----------------------------+ | Texts | Labels | +=========================================+=============================+ @@ -40,6 +37,9 @@ def __init__(self, model: SentenceTransformer) -> None: | sentence_1, sentence_2, ..., sentence_N | model sentence embeddings | +-----------------------------------------+-----------------------------+ + Relations: + - :class:`MarginMSELoss` is equivalent to this loss, but with a margin through a negative pair. + Example: :: diff --git a/sentence_transformers/losses/MarginMSELoss.py b/sentence_transformers/losses/MarginMSELoss.py index b34444d2c..d11f9bd19 100644 --- a/sentence_transformers/losses/MarginMSELoss.py +++ b/sentence_transformers/losses/MarginMSELoss.py @@ -32,9 +32,6 @@ def __init__(self, model: SentenceTransformer, similarity_fct=util.pairwise_dot_ 1. (query, passage_one, passage_two) triplets 2. Usually used with a finetuned teacher M in a knowledge distillation setup - Relations: - - :class:`MSELoss` is equivalent to this loss, but without a margin through the negative pair. - Inputs: +-----------------------------------------------+-----------------------------------------------+ | Texts | Labels | @@ -42,6 +39,9 @@ def __init__(self, model: SentenceTransformer, similarity_fct=util.pairwise_dot_ | (query, passage_one, passage_two) triplets | M(query, passage_one) - M(query, passage_two) | +-----------------------------------------------+-----------------------------------------------+ + Relations: + - :class:`MSELoss` is equivalent to this loss, but without a margin through the negative pair. + Example: With gold labels, e.g. if you have hard scores for sentences. Imagine you want a model to embed sentences diff --git a/sentence_transformers/losses/Matryoshka2dLoss.py b/sentence_transformers/losses/Matryoshka2dLoss.py index b64631363..4b77b9c74 100644 --- a/sentence_transformers/losses/Matryoshka2dLoss.py +++ b/sentence_transformers/losses/Matryoshka2dLoss.py @@ -81,17 +81,17 @@ def __init__( Requirements: 1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss`. - Relations: - - :class:`MatryoshkaLoss` is used in this loss, and it is responsible for the dimensionality reduction. - - :class:`AdaptiveLayerLoss` is used in this loss, and it is responsible for the layer reduction. - - Input: + Inputs: +---------------------------------------+--------+ | Texts | Labels | +=======================================+========+ | any | any | +---------------------------------------+--------+ + Relations: + - :class:`MatryoshkaLoss` is used in this loss, and it is responsible for the dimensionality reduction. + - :class:`AdaptiveLayerLoss` is used in this loss, and it is responsible for the layer reduction. + Example: :: diff --git a/sentence_transformers/losses/MatryoshkaLoss.py b/sentence_transformers/losses/MatryoshkaLoss.py index 9964c2425..e4a6dd851 100644 --- a/sentence_transformers/losses/MatryoshkaLoss.py +++ b/sentence_transformers/losses/MatryoshkaLoss.py @@ -86,17 +86,17 @@ def __init__( Requirements: 1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss` or :class:`CachedGISTEmbedLoss`. - Relations: - - :class:`Matryoshka2dLoss` uses this loss in combination with :class:`AdaptiveLayerLoss` which allows for - layer reduction for faster inference. - - Input: + Inputs: +---------------------------------------+--------+ | Texts | Labels | +=======================================+========+ | any | any | +---------------------------------------+--------+ + Relations: + - :class:`Matryoshka2dLoss` uses this loss in combination with :class:`AdaptiveLayerLoss` which allows for + layer reduction for faster inference. + Example: :: diff --git a/sentence_transformers/losses/MegaBatchMarginLoss.py b/sentence_transformers/losses/MegaBatchMarginLoss.py index ae4659697..a964eb726 100644 --- a/sentence_transformers/losses/MegaBatchMarginLoss.py +++ b/sentence_transformers/losses/MegaBatchMarginLoss.py @@ -45,13 +45,17 @@ def __init__( 1. (anchor, positive) pairs 2. Large batches (500 or more examples) - Input: + Inputs: +---------------------------------------+--------+ | Texts | Labels | +=======================================+========+ | (anchor, positive) pairs | none | +---------------------------------------+--------+ + Recommendations: + - Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs `) to + ensure that no in-batch negatives are duplicates of the anchor or positive samples. + Example: :: diff --git a/sentence_transformers/losses/MultipleNegativesRankingLoss.py b/sentence_transformers/losses/MultipleNegativesRankingLoss.py index 5f43d9a64..a45d93191 100644 --- a/sentence_transformers/losses/MultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/MultipleNegativesRankingLoss.py @@ -48,14 +48,6 @@ def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_f Requirements: 1. (anchor, positive) pairs or (anchor, positive, negative) triplets - Relations: - - :class:`CachedMultipleNegativesRankingLoss` is equivalent to this loss, but it uses caching that allows for - much higher batch sizes (and thus better performance) without extra memory usage. However, it is slightly - slower. - - :class:`MultipleNegativesSymmetricRankingLoss` is equivalent to this loss, but with an additional loss term. - - :class:`GISTEmbedLoss` is equivalent to this loss, but uses a guide model to guide the in-batch negative - sample selection. `GISTEmbedLoss` yields a stronger training signal at the cost of some training overhead. - Inputs: +---------------------------------------+--------+ | Texts | Labels | @@ -65,6 +57,18 @@ def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_f | (anchor, positive, negative) triplets | none | +---------------------------------------+--------+ + Recommendations: + - Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs `) to + ensure that no in-batch negatives are duplicates of the anchor or positive samples. + + Relations: + - :class:`CachedMultipleNegativesRankingLoss` is equivalent to this loss, but it uses caching that allows for + much higher batch sizes (and thus better performance) without extra memory usage. However, it is slightly + slower. + - :class:`MultipleNegativesSymmetricRankingLoss` is equivalent to this loss, but with an additional loss term. + - :class:`GISTEmbedLoss` is equivalent to this loss, but uses a guide model to guide the in-batch negative + sample selection. `GISTEmbedLoss` yields a stronger training signal at the cost of some training overhead. + Example: :: diff --git a/sentence_transformers/losses/MultipleNegativesSymmetricRankingLoss.py b/sentence_transformers/losses/MultipleNegativesSymmetricRankingLoss.py index a09dd1824..a512f5404 100644 --- a/sentence_transformers/losses/MultipleNegativesSymmetricRankingLoss.py +++ b/sentence_transformers/losses/MultipleNegativesSymmetricRankingLoss.py @@ -35,12 +35,6 @@ def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_f Requirements: 1. (anchor, positive) pairs - Relations: - - Like :class:`MultipleNegativesRankingLoss`, but with an additional loss term. - - :class:`CachedMultipleNegativesSymmetricRankingLoss` is equivalent to this loss, but it uses caching that - allows for much higher batch sizes (and thus better performance) without extra memory usage. However, it - is slightly slower. - Inputs: +---------------------------------------+--------+ | Texts | Labels | @@ -48,6 +42,16 @@ def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_f | (anchor, positive) pairs | none | +---------------------------------------+--------+ + Recommendations: + - Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs `) to + ensure that no in-batch negatives are duplicates of the anchor or positive samples. + + Relations: + - Like :class:`MultipleNegativesRankingLoss`, but with an additional loss term. + - :class:`CachedMultipleNegativesSymmetricRankingLoss` is equivalent to this loss, but it uses caching that + allows for much higher batch sizes (and thus better performance) without extra memory usage. However, it + is slightly slower. + Example: :: diff --git a/sentence_transformers/losses/OnlineContrastiveLoss.py b/sentence_transformers/losses/OnlineContrastiveLoss.py index 647281c0f..e9fbb28bc 100644 --- a/sentence_transformers/losses/OnlineContrastiveLoss.py +++ b/sentence_transformers/losses/OnlineContrastiveLoss.py @@ -34,10 +34,6 @@ def __init__( 1. (anchor, positive/negative) pairs 2. Data should include hard positives and hard negatives - Relations: - - :class:`ContrastiveLoss` is similar, but does not use hard positive and hard negative pairs. - :class:`OnlineContrastiveLoss` often yields better results. - Inputs: +-----------------------------------------------+------------------------------+ | Texts | Labels | @@ -45,6 +41,10 @@ def __init__( | (anchor, positive/negative) pairs | 1 if positive, 0 if negative | +-----------------------------------------------+------------------------------+ + Relations: + - :class:`ContrastiveLoss` is similar, but does not use hard positive and hard negative pairs. + :class:`OnlineContrastiveLoss` often yields better results. + Example: :: diff --git a/sentence_transformers/sampler.py b/sentence_transformers/sampler.py index 350f80045..615c72f99 100644 --- a/sentence_transformers/sampler.py +++ b/sentence_transformers/sampler.py @@ -32,23 +32,39 @@ def set_epoch(self, epoch: int) -> None: class DefaultBatchSampler(SetEpochMixin, BatchSampler): - pass + """ + This sampler is the default batch sampler used in the SentenceTransformer library. + It is equivalent to the PyTorch BatchSampler. + + Args: + sampler (Sampler or Iterable): The sampler used for sampling elements from the dataset, + such as SubsetRandomSampler. + batch_size (int): Number of samples per batch. + drop_last (bool): If True, drop the last incomplete batch if the dataset size + is not divisible by the batch size. + """ class GroupByLabelBatchSampler(SetEpochMixin, BatchSampler): """ This sampler groups samples by their labels and aims to create batches such that each batch contains samples where the labels are as homogeneous as possible. - This sampler is meant to be used alongside the `Batch...TripletLoss` classes, which + This sampler is meant to be used alongside the ``Batch...TripletLoss`` classes, which require that each batch contains at least 2 examples per label class. + Recommended for: + - :class:`~sentence_transformers.losses.BatchAllTripletLoss` + - :class:`~sentence_transformers.losses.BatchHardSoftMarginTripletLoss` + - :class:`~sentence_transformers.losses.BatchHardTripletLoss` + - :class:`~sentence_transformers.losses.BatchSemiHardTripletLoss` + Args: dataset (Dataset): The dataset to sample from. batch_size (int): Number of samples per batch. Must be divisible by 2. - drop_last (bool): If True, drop the last incomplete batch, if the dataset size + drop_last (bool): If True, drop the last incomplete batch if the dataset size is not divisible by the batch size. valid_label_columns (List[str]): List of column names to check for labels. - The first column name found in the dataset will + The first column name from ``valid_label_columns`` found in the dataset will be used as the label column. generator (torch.Generator, optional): Optional random number generator for shuffling the indices. @@ -123,6 +139,32 @@ def __init__( generator: torch.Generator = None, seed: int = 0, ) -> None: + """ + This sampler creates batches such that each batch contains samples where the values are unique, + even across columns. This is useful when losses consider other samples in a batch to be in-batch + negatives, and you want to ensure that the negatives are not duplicates of the anchor/positive sample. + + Recommended for: + - :class:`~sentence_transformers.losses.MultipleNegativesRankingLoss` + - :class:`~sentence_transformers.losses.CachedMultipleNegativesRankingLoss` + - :class:`~sentence_transformers.losses.MultipleNegativesSymmetricRankingLoss` + - :class:`~sentence_transformers.losses.CachedMultipleNegativesSymmetricRankingLoss` + - :class:`~sentence_transformers.losses.MegaBatchMarginLoss` + - :class:`~sentence_transformers.losses.GISTEmbedLoss` + - :class:`~sentence_transformers.losses.CachedGISTEmbedLoss` + + Args: + dataset (Dataset): The dataset to sample from. + batch_size (int): Number of samples per batch. + drop_last (bool): If True, drop the last incomplete batch if the dataset size + is not divisible by the batch size. + valid_label_columns (List[str]): List of column names to check for labels. + The first column name from ``valid_label_columns`` found in the dataset will + be used as the label column. + generator (torch.Generator, optional): Optional random number generator for shuffling + the indices. + seed (int, optional): Seed for the random number generator to ensure reproducibility. + """ super().__init__(dataset, batch_size, drop_last) if label_columns := set(dataset.column_names) & (set(valid_label_columns) | {"dataset_name"}): dataset = dataset.remove_columns(label_columns) @@ -227,6 +269,16 @@ def __init__( generator: torch.Generator, seed: int, ) -> None: + """ + Batch sampler that samples from each dataset in proportion to its size, until all are exhausted simultaneously. + With this sampler, all samples from each dataset are used and larger datasets are sampled from more frequently. + + Args: + dataset (ConcatDataset): A concatenation of multiple datasets. + batch_samplers (List[BatchSampler]): A list of batch samplers, one for each dataset in the ConcatDataset. + generator (torch.Generator, optional): A generator for reproducible sampling. Defaults to None. + seed (int, optional): A seed for the generator. Defaults to None. + """ super().__init__(dataset, batch_samplers[0].batch_size, batch_samplers[0].drop_last) self.dataset = dataset self.batch_samplers = batch_samplers diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 8cb3da3b8..c54b2e8ed 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -464,6 +464,25 @@ def get_batch_sampler( valid_label_columns: list[str] | None = None, generator: torch.Generator | None = None, ) -> BatchSampler: + """ + Returns the appropriate batch sampler based on the ``batch_sampler`` argument in ``self.args``. + This batch sampler class supports ``__len__`` and ``__iter__`` methods, and is used as the ``batch_sampler`` + to create the :class:`torch.utils.data.DataLoader`. + + .. note:: + Override this method to provide a custom batch sampler. + + Args: + dataset (Dataset): The dataset to sample from. + batch_size (int): Number of samples per batch. + drop_last (bool): If True, drop the last incomplete batch if the dataset size + is not divisible by the batch size. + valid_label_columns (List[str]): List of column names to check for labels. + The first column name from ``valid_label_columns`` found in the dataset will + be used as the label column. + generator (torch.Generator, optional): Optional random number generator for shuffling + the indices. + """ if self.args.batch_sampler == BatchSamplers.NO_DUPLICATES: return NoDuplicatesBatchSampler( dataset=dataset, @@ -495,6 +514,20 @@ def get_multi_dataset_batch_sampler( generator: torch.Generator | None = None, seed: int | None = 0, ) -> BatchSampler: + """ + Returns the appropriate multi-dataset batch sampler based on the ``multi_dataset_batch_sampler`` argument + in ``self.args``. This batch sampler class supports ``__len__`` and ``__iter__`` methods, and is used as the + ``batch_sampler`` to create the :class:`torch.utils.data.DataLoader`. + + .. note:: + Override this method to provide a custom multi-dataset batch sampler. + + Args: + dataset (ConcatDataset): The concatenation of all datasets. + batch_samplers (List[BatchSampler]): List of batch samplers for each dataset in the concatenated dataset. + generator (torch.Generator, optional): Optional random number generator for shuffling the indices. + seed (int, optional): Optional seed for the random number generator + """ if self.args.multi_dataset_batch_sampler == MultiDatasetBatchSamplers.ROUND_ROBIN: return RoundRobinBatchSampler( dataset=dataset, diff --git a/sentence_transformers/training_args.py b/sentence_transformers/training_args.py index 39ba4ac04..ca1e15da8 100644 --- a/sentence_transformers/training_args.py +++ b/sentence_transformers/training_args.py @@ -17,9 +17,58 @@ class BatchSamplers(ExplicitEnum): The batch sampler is responsible for determining how samples are grouped into batches during training. Valid options are: - - ``BatchSamplers.BATCH_SAMPLER``: The default PyTorch batch sampler. - - ``BatchSamplers.NO_DUPLICATES``: Ensures no duplicate samples in a batch. - - ``BatchSamplers.GROUP_BY_LABEL``: Ensures each batch has 2+ samples from the same label. + - ``BatchSamplers.BATCH_SAMPLER``: **[default]** Uses :class:`~sentence_transformers.sampler.DefaultBatchSampler`, the default + PyTorch batch sampler. + - ``BatchSamplers.NO_DUPLICATES``: Uses :class:`~sentence_transformers.sampler.NoDuplicatesBatchSampler`, + ensuring no duplicate samples in a batch. Recommended for losses that use in-batch negatives, such as: + + - :class:`~sentence_transformers.losses.MultipleNegativesRankingLoss` + - :class:`~sentence_transformers.losses.CachedMultipleNegativesRankingLoss` + - :class:`~sentence_transformers.losses.MultipleNegativesSymmetricRankingLoss` + - :class:`~sentence_transformers.losses.CachedMultipleNegativesSymmetricRankingLoss` + - :class:`~sentence_transformers.losses.MegaBatchMarginLoss` + - :class:`~sentence_transformers.losses.GISTEmbedLoss` + - :class:`~sentence_transformers.losses.CachedGISTEmbedLoss` + - ``BatchSamplers.GROUP_BY_LABEL``: Uses :class:`~sentence_transformers.sampler.GroupByLabelBatchSampler`, + ensuring that each batch has 2+ samples from the same label. Recommended for losses that require multiple + samples from the same label, such as: + + - :class:`~sentence_transformers.losses.BatchAllTripletLoss` + - :class:`~sentence_transformers.losses.BatchHardSoftMarginTripletLoss` + - :class:`~sentence_transformers.losses.BatchHardTripletLoss` + - :class:`~sentence_transformers.losses.BatchSemiHardTripletLoss` + + If you want to use a custom batch sampler, you can create a new Trainer class that inherits from + :class:`~sentence_transformers.trainer.SentenceTransformerTrainer` and overrides the + :meth:`~sentence_transformers.trainer.SentenceTransformerTrainer.get_batch_sampler` method. The + method must return a class instance that supports ``__iter__`` and ``__len__`` methods. The former + should yield a list of indices for each batch, and the latter should return the number of batches. + + Usage: + :: + + from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments + from sentence_transformers.training_args import BatchSamplers + from sentence_transformers.losses import MultipleNegativesRankingLoss + from datasets import Dataset + + model = SentenceTransformer("microsoft/mpnet-base") + train_dataset = Dataset.from_dict({ + "anchor": ["It's nice weather outside today.", "He drove to work."], + "positive": ["It's so sunny.", "He took the car to the office."], + }) + loss = MultipleNegativesRankingLoss(model) + args = SentenceTransformerTrainingArguments( + output_dir="checkpoints", + batch_sampler=BatchSamplers.NO_DUPLICATES, + ) + trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + loss=loss, + ) + trainer.train() """ BATCH_SAMPLER = "batch_sampler" @@ -34,11 +83,56 @@ class MultiDatasetBatchSamplers(ExplicitEnum): The multi-dataset batch sampler is responsible for determining in what order batches are sampled from multiple datasets during training. Valid options are: - - ``MultiDatasetBatchSamplers.ROUND_ROBIN``: Round-robin sampling from each dataset until one is exhausted. + - ``MultiDatasetBatchSamplers.ROUND_ROBIN``: Uses :class:`~sentence_transformers.sampler.RoundRobinBatchSampler`, + which uses round-robin sampling from each dataset until one is exhausted. With this strategy, it's likely that not all samples from each dataset are used, but each dataset is sampled from equally. - - ``MultiDatasetBatchSamplers.PROPORTIONAL``: Sample from each dataset in proportion to its size [default]. + - ``MultiDatasetBatchSamplers.PROPORTIONAL``: **[default]** Uses :class:`~sentence_transformers.sampler.ProportionalBatchSampler`, + which samples from each dataset in proportion to its size. With this strategy, all samples from each dataset are used and larger datasets are sampled from more frequently. + + Usage: + :: + + from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments + from sentence_transformers.training_args import MultiDatasetBatchSamplers + from sentence_transformers.losses import CoSENTLoss + from datasets import Dataset, DatasetDict + + model = SentenceTransformer("microsoft/mpnet-base") + train_general = Dataset.from_dict({ + "sentence_A": ["It's nice weather outside today.", "He drove to work."], + "sentence_B": ["It's so sunny.", "He took the car to the bank."], + "score": [0.9, 0.4], + }) + train_medical = Dataset.from_dict({ + "sentence_A": ["The patient has a fever.", "The doctor prescribed medication.", "The patient is sweating."], + "sentence_B": ["The patient feels hot.", "The medication was given to the patient.", "The patient is perspiring."], + "score": [0.8, 0.6, 0.7], + }) + train_legal = Dataset.from_dict({ + "sentence_A": ["This contract is legally binding.", "The parties agree to the terms and conditions."], + "sentence_B": ["Both parties acknowledge their obligations.", "By signing this agreement, the parties enter into a legal relationship."], + "score": [0.7, 0.8], + }) + train_dataset = DatasetDict({ + "general": train_general, + "medical": train_medical, + "legal": train_legal, + }) + + loss = CoSENTLoss(model) + args = SentenceTransformerTrainingArguments( + output_dir="checkpoints", + multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, + ) + trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + loss=loss, + ) + trainer.train() """ ROUND_ROBIN = "round_robin" # Round-robin sampling from each dataset