diff --git a/mmlearn/tasks/contrastive_pretraining.py b/mmlearn/tasks/contrastive_pretraining.py index 966fb60..b91296f 100644 --- a/mmlearn/tasks/contrastive_pretraining.py +++ b/mmlearn/tasks/contrastive_pretraining.py @@ -112,6 +112,15 @@ class ContrastivePretraining(L.LightningModule): a `scheduler` key that specifies the scheduler and an optional `extras` key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training. + init_logit_scale : float, optional, default=1 / 0.07 + The initial value of the logit scale parameter. This is the log of the scale + factor applied to the logits before computing the contrastive loss. + max_logit_scale : float, optional, default=100 + The maximum value of the logit scale parameter. The logit scale parameter + is clamped to the range [0, log(max_logit_scale)]. + learnable_logit_scale : bool, optional, default=True + Whether the logit scale parameter is learnable. If set to False, the logit + scale parameter is treated as a constant. loss : CLIPLoss, optional, default=None The loss function to use. modality_loss_pairs : List[LossPairSpec], optional, default=None