Skip to content

Commit

Permalink
[docs] Heavily extend sampler documentation (UKPLab#2921)
Browse files Browse the repository at this point in the history
* Heavily extend sampler documentation

* Add new Samplers tab in Package Reference
* Add "Recommended for losses such as:" in BatchSamplers
* Add Usage in BatchSamplers and MultiDatasetBatchSamplers
* Add "Recommendations" section to some losses, pointing to specific samplers

* Add explanation about custom batch sampler
  • Loading branch information
tomaarsen authored Sep 9, 2024
1 parent ac38c1c commit 2e13ee6
Show file tree
Hide file tree
Showing 28 changed files with 364 additions and 113 deletions.
1 change: 1 addition & 0 deletions docs/package_reference/sentence_transformer/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Sentence Transformer
SentenceTransformer
trainer
training_args
sampler
losses
evaluation
datasets
Expand Down
39 changes: 39 additions & 0 deletions docs/package_reference/sentence_transformer/sampler.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

# Samplers

## BatchSamplers
```eval_rst
.. autoclass:: sentence_transformers.training_args.BatchSamplers
:members:
```

```eval_rst
.. autoclass:: sentence_transformers.sampler.DefaultBatchSampler
:members:
```

```eval_rst
.. autoclass:: sentence_transformers.sampler.NoDuplicatesBatchSampler
:members:
```

```eval_rst
.. autoclass:: sentence_transformers.sampler.GroupByLabelBatchSampler
:members:
```

## MultiDatasetBatchSamplers
```eval_rst
.. autoclass:: sentence_transformers.training_args.MultiDatasetBatchSamplers
:members:
```

```eval_rst
.. autoclass:: sentence_transformers.sampler.RoundRobinBatchSampler
:members:
```

```eval_rst
.. autoclass:: sentence_transformers.sampler.ProportionalBatchSampler
:members:
```
12 changes: 0 additions & 12 deletions docs/package_reference/sentence_transformer/training_args.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,3 @@
:members:
:inherited-members:
```

## BatchSamplers
```eval_rst
.. autoclass:: sentence_transformers.training_args.BatchSamplers
:members:
```

## MultiDatasetBatchSamplers
```eval_rst
.. autoclass:: sentence_transformers.training_args.MultiDatasetBatchSamplers
:members:
```
10 changes: 5 additions & 5 deletions sentence_transformers/losses/AdaptiveLayerLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,17 @@ def __init__(
Requirements:
1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss` or :class:`CachedGISTEmbedLoss`.
Relations:
- :class:`Matryoshka2dLoss` uses this loss in combination with :class:`MatryoshkaLoss` which allows for
output dimensionality reduction for faster downstream tasks (e.g. retrieval).
Input:
Inputs:
+---------------------------------------+--------+
| Texts | Labels |
+=======================================+========+
| any | any |
+---------------------------------------+--------+
Relations:
- :class:`Matryoshka2dLoss` uses this loss in combination with :class:`MatryoshkaLoss` which allows for
output dimensionality reduction for faster downstream tasks (e.g. retrieval).
Example:
::
Expand Down
8 changes: 4 additions & 4 deletions sentence_transformers/losses/AnglELoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ def __init__(self, model: SentenceTransformer, scale: float = 20.0) -> None:
Requirements:
- Sentence pairs with corresponding similarity scores in range of the similarity function. Default is [-1,1].
Relations:
- :class:`CoSENTLoss` is AnglELoss with ``pairwise_cos_sim`` as the metric, rather than ``pairwise_angle_sim``.
- :class:`CosineSimilarityLoss` seems to produce a weaker training signal than ``CoSENTLoss`` or ``AnglELoss``.
Inputs:
+--------------------------------+------------------------+
| Texts | Labels |
+================================+========================+
| (sentence_A, sentence_B) pairs | float similarity score |
+--------------------------------+------------------------+
Relations:
- :class:`CoSENTLoss` is AnglELoss with ``pairwise_cos_sim`` as the metric, rather than ``pairwise_angle_sim``.
- :class:`CosineSimilarityLoss` seems to produce a weaker training signal than ``CoSENTLoss`` or ``AnglELoss``.
Example:
::
Expand Down
16 changes: 10 additions & 6 deletions sentence_transformers/losses/BatchAllTripletLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,23 @@ def __init__(
1. Each sentence must be labeled with a class.
2. Your dataset must contain at least 2 examples per labels class.
Relations:
* :class:`BatchHardTripletLoss` uses only the hardest positive and negative samples, rather than all possible, valid triplets.
* :class:`BatchHardSoftMarginTripletLoss` uses only the hardest positive and negative samples, rather than all possible, valid triplets.
Also, it does not require setting a margin.
* :class:`BatchSemiHardTripletLoss` uses only semi-hard triplets, valid triplets, rather than all possible, valid triplets.
Inputs:
+------------------+--------+
| Texts | Labels |
+==================+========+
| single sentences | class |
+------------------+--------+
Recommendations:
- Use ``BatchSamplers.GROUP_BY_LABEL`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
ensure that each batch contains 2+ examples per label class.
Relations:
* :class:`BatchHardTripletLoss` uses only the hardest positive and negative samples, rather than all possible, valid triplets.
* :class:`BatchHardSoftMarginTripletLoss` uses only the hardest positive and negative samples, rather than all possible, valid triplets.
Also, it does not require setting a margin.
* :class:`BatchSemiHardTripletLoss` uses only semi-hard triplets, valid triplets, rather than all possible, valid triplets.
Example:
::
Expand Down
10 changes: 7 additions & 3 deletions sentence_transformers/losses/BatchHardSoftMarginTripletLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,20 @@ def __init__(
2. Your dataset must contain at least 2 examples per labels class.
3. Your dataset should contain hard positives and negatives.
Relations:
* :class:`BatchHardTripletLoss` uses a user-specified margin, while this loss does not require setting a margin.
Inputs:
+------------------+--------+
| Texts | Labels |
+==================+========+
| single sentences | class |
+------------------+--------+
Recommendations:
- Use ``BatchSamplers.GROUP_BY_LABEL`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
ensure that each batch contains 2+ examples per label class.
Relations:
* :class:`BatchHardTripletLoss` uses a user-specified margin, while this loss does not require setting a margin.
Example:
::
Expand Down
4 changes: 4 additions & 0 deletions sentence_transformers/losses/BatchHardTripletLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def __init__(
| single sentences | class |
+------------------+--------+
Recommendations:
- Use ``BatchSamplers.GROUP_BY_LABEL`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
ensure that each batch contains 2+ examples per label class.
Relations:
* :class:`BatchAllTripletLoss` uses all possible, valid triplets, rather than only the hardest positive and negative samples.
* :class:`BatchSemiHardTripletLoss` uses only semi-hard triplets, valid triplets, rather than only the hardest positive and negative samples.
Expand Down
16 changes: 10 additions & 6 deletions sentence_transformers/losses/BatchSemiHardTripletLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,23 @@ def __init__(
2. Your dataset must contain at least 2 examples per labels class.
3. Your dataset should contain semi hard positives and negatives.
Relations:
* :class:`BatchHardTripletLoss` uses only the hardest positive and negative samples, rather than only semi hard positive and negatives.
* :class:`BatchAllTripletLoss` uses all possible, valid triplets, rather than only semi hard positive and negatives.
* :class:`BatchHardSoftMarginTripletLoss` uses only the hardest positive and negative samples, rather than only semi hard positive and negatives.
Also, it does not require setting a margin.
Inputs:
+------------------+--------+
| Texts | Labels |
+==================+========+
| single sentences | class |
+------------------+--------+
Recommendations:
- Use ``BatchSamplers.GROUP_BY_LABEL`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
ensure that each batch contains 2+ examples per label class.
Relations:
* :class:`BatchHardTripletLoss` uses only the hardest positive and negative samples, rather than only semi hard positive and negatives.
* :class:`BatchAllTripletLoss` uses all possible, valid triplets, rather than only semi hard positive and negatives.
* :class:`BatchHardSoftMarginTripletLoss` uses only the hardest positive and negative samples, rather than only semi hard positive and negatives.
Also, it does not require setting a margin.
Example:
::
Expand Down
10 changes: 7 additions & 3 deletions sentence_transformers/losses/CachedGISTEmbedLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,6 @@ def __init__(
1. (anchor, positive) pairs or (anchor, positive, negative pairs)
2. Should be used with large batch sizes for superior performance, but has slower training time than :class:`MultipleNegativesRankingLoss`
Relations:
- Equivalent to :class:`GISTEmbedLoss`, but with caching that allows for much higher batch sizes
Inputs:
+---------------------------------------+--------+
| Texts | Labels |
Expand All @@ -111,6 +108,13 @@ def __init__(
| (anchor, positive, negative) triplets | none |
+---------------------------------------+--------+
Recommendations:
- Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
ensure that no in-batch negatives are duplicates of the anchor or positive samples.
Relations:
- Equivalent to :class:`GISTEmbedLoss`, but with caching that allows for much higher batch sizes
Example:
::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,6 @@ def __init__(
1. (anchor, positive) pairs or (anchor, positive, negative pairs)
2. Should be used with large batch sizes for superior performance, but has slower training time than :class:`MultipleNegativesRankingLoss`
Relations:
- Equivalent to :class:`MultipleNegativesRankingLoss`, but with caching that allows for much higher batch sizes
(and thus better performance) without extra memory usage. This loss also trains roughly 2x to 2.4x slower than
:class:`MultipleNegativesRankingLoss`.
Inputs:
+---------------------------------------+--------+
| Texts | Labels |
Expand All @@ -118,6 +113,15 @@ def __init__(
| (anchor, positive, negative) triplets | none |
+---------------------------------------+--------+
Recommendations:
- Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
ensure that no in-batch negatives are duplicates of the anchor or positive samples.
Relations:
- Equivalent to :class:`MultipleNegativesRankingLoss`, but with caching that allows for much higher batch sizes
(and thus better performance) without extra memory usage. This loss also trains roughly 2x to 2.4x slower than
:class:`MultipleNegativesRankingLoss`.
Example:
::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,21 @@ def __init__(
1. (anchor, positive) pairs
2. Should be used with large batch sizes for superior performance, but has slower training time than non-cached versions
Relations:
- Like :class:`MultipleNegativesRankingLoss`, but with an additional symmetric loss term and caching mechanism.
- Inspired by :class:`CachedMultipleNegativesRankingLoss`, adapted for symmetric loss calculation.
Inputs:
+---------------------------------------+--------+
| Texts | Labels |
+=======================================+========+
| (anchor, positive) pairs | none |
+---------------------------------------+--------+
Recommendations:
- Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
ensure that no in-batch negatives are duplicates of the anchor or positive samples.
Relations:
- Like :class:`MultipleNegativesRankingLoss`, but with an additional symmetric loss term and caching mechanism.
- Inspired by :class:`CachedMultipleNegativesRankingLoss`, adapted for symmetric loss calculation.
Example:
::
Expand Down
8 changes: 4 additions & 4 deletions sentence_transformers/losses/CoSENTLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_f
Requirements:
- Sentence pairs with corresponding similarity scores in range of the similarity function. Default is [-1,1].
Relations:
- :class:`AnglELoss` is CoSENTLoss with ``pairwise_angle_sim`` as the metric, rather than ``pairwise_cos_sim``.
- :class:`CosineSimilarityLoss` seems to produce a weaker training signal than CoSENTLoss. In our experiments, CoSENTLoss is recommended.
Inputs:
+--------------------------------+------------------------+
| Texts | Labels |
+================================+========================+
| (sentence_A, sentence_B) pairs | float similarity score |
+--------------------------------+------------------------+
Relations:
- :class:`AnglELoss` is CoSENTLoss with ``pairwise_angle_sim`` as the metric, rather than ``pairwise_cos_sim``.
- :class:`CosineSimilarityLoss` seems to produce a weaker training signal than CoSENTLoss. In our experiments, CoSENTLoss is recommended.
Example:
::
Expand Down
8 changes: 4 additions & 4 deletions sentence_transformers/losses/ContrastiveLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ def __init__(
Requirements:
1. (anchor, positive/negative) pairs
Relations:
- :class:`OnlineContrastiveLoss` is similar, but uses hard positive and hard negative pairs.
It often yields better results.
Inputs:
+-----------------------------------------------+------------------------------+
| Texts | Labels |
+===============================================+==============================+
| (anchor, positive/negative) pairs | 1 if positive, 0 if negative |
+-----------------------------------------------+------------------------------+
Relations:
- :class:`OnlineContrastiveLoss` is similar, but uses hard positive and hard negative pairs.
It often yields better results.
Example:
::
Expand Down
6 changes: 3 additions & 3 deletions sentence_transformers/losses/ContrastiveTensionLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ class ContrastiveTensionLoss(nn.Module):
* Semantic Re-Tuning with Contrastive Tension: https://openreview.net/pdf?id=Ov_sMNau-PF
* `Unsupervised Learning > CT <../../examples/unsupervised_learning/CT/README.html>`_
Relations:
* :class:`ContrastiveTensionLossInBatchNegatives` uses in-batch negative sampling, which gives a stronger training signal than this loss.
Inputs:
+------------------+--------+
| Texts | Labels |
+==================+========+
| single sentences | none |
+------------------+--------+
Relations:
* :class:`ContrastiveTensionLossInBatchNegatives` uses in-batch negative sampling, which gives a stronger training signal than this loss.
Example:
::
Expand Down
8 changes: 4 additions & 4 deletions sentence_transformers/losses/CosineSimilarityLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ def __init__(
Requirements:
1. Sentence pairs with corresponding similarity scores in range `[0, 1]`
Relations:
- :class:`CoSENTLoss` seems to produce a stronger training signal than CosineSimilarityLoss. In our experiments, CoSENTLoss is recommended.
- :class:`AnglELoss` is :class:`CoSENTLoss` with ``pairwise_angle_sim`` as the metric, rather than ``pairwise_cos_sim``. It also produces a stronger training signal than CosineSimilarityLoss.
Inputs:
+--------------------------------+------------------------+
| Texts | Labels |
+================================+========================+
| (sentence_A, sentence_B) pairs | float similarity score |
+--------------------------------+------------------------+
Relations:
- :class:`CoSENTLoss` seems to produce a stronger training signal than CosineSimilarityLoss. In our experiments, CoSENTLoss is recommended.
- :class:`AnglELoss` is :class:`CoSENTLoss` with ``pairwise_angle_sim`` as the metric, rather than ``pairwise_cos_sim``. It also produces a stronger training signal than CosineSimilarityLoss.
Example:
::
Expand Down
14 changes: 9 additions & 5 deletions sentence_transformers/losses/GISTEmbedLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ def __init__(
1. (anchor, positive, negative) triplets
2. (anchor, positive) pairs
Relations:
- :class:`MultipleNegativesRankingLoss` is similar to this loss, but it does not use
a guide model to guide the in-batch negative sample selection. `GISTEmbedLoss` yields
a stronger training signal at the cost of some training overhead.
Inputs:
+---------------------------------------+--------+
| Texts | Labels |
Expand All @@ -51,6 +46,15 @@ def __init__(
| (anchor, positive) pairs | none |
+---------------------------------------+--------+
Recommendations:
- Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
ensure that no in-batch negatives are duplicates of the anchor or positive samples.
Relations:
- :class:`MultipleNegativesRankingLoss` is similar to this loss, but it does not use
a guide model to guide the in-batch negative sample selection. `GISTEmbedLoss` yields
a stronger training signal at the cost of some training overhead.
Example:
::
Expand Down
Loading

0 comments on commit 2e13ee6

Please sign in to comment.