diff --git a/mmlearn/modules/layers/logit_scaling.py b/mmlearn/modules/layers/logit_scaling.py index ac7f7e5..84f90b3 100644 --- a/mmlearn/modules/layers/logit_scaling.py +++ b/mmlearn/modules/layers/logit_scaling.py @@ -22,15 +22,15 @@ class LearnableLogitScaling(torch.nn.Module): def __init__( self, - logit_scale_init: float = 1 / 0.07, - learnable: bool = True, + init_logit_scale: float = 1 / 0.07, max_logit_scale: float = 100, + learnable: bool = True, ) -> None: super().__init__() self.max_logit_scale = max_logit_scale - self.logit_scale_init = logit_scale_init + self.init_logit_scale = init_logit_scale self.learnable = learnable - log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) + log_logit_scale = torch.ones([]) * np.log(self.init_logit_scale) if learnable: self.log_logit_scale = torch.nn.Parameter(log_logit_scale) else: @@ -49,6 +49,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def extra_repr(self) -> str: """Return the string representation of the layer.""" return ( - f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}," + f"logit_scale_init={self.init_logit_scale},learnable={self.learnable}," f" max_logit_scale={self.max_logit_scale}" ) diff --git a/mmlearn/modules/losses/contrastive.py b/mmlearn/modules/losses/contrastive.py index 9ad2a75..a2adafc 100644 --- a/mmlearn/modules/losses/contrastive.py +++ b/mmlearn/modules/losses/contrastive.py @@ -83,6 +83,7 @@ def _get_logits( self, features_1: torch.Tensor, features_2: torch.Tensor, + logit_scale: torch.Tensor, rank: int, world_size: int, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -93,7 +94,9 @@ def _get_logits( features_1 : torch.Tensor First feature tensor. features_2 : torch.Tensor - Second feature tensor + Second feature tensor. + logit_scale : torch.Tensor + Logit scale. rank : int Rank of the current process. world_size : int @@ -114,19 +117,28 @@ def _get_logits( ) if self.local_loss: - logits_per_feature_1 = _safe_matmul(features_1, all_features_2) - logits_per_feature_2 = _safe_matmul(features_2, all_features_1) + logits_per_feature_1 = logit_scale * _safe_matmul( + features_1, all_features_2 + ) + logits_per_feature_2 = logit_scale * _safe_matmul( + features_2, all_features_1 + ) else: - logits_per_feature_1 = _safe_matmul(all_features_1, all_features_2) + logits_per_feature_1 = logit_scale * _safe_matmul( + all_features_1, all_features_2 + ) logits_per_feature_2 = logits_per_feature_1.T else: - logits_per_feature_1 = _safe_matmul(features_1, features_2) - logits_per_feature_2 = _safe_matmul(features_2, features_1) + logits_per_feature_1 = logit_scale * _safe_matmul(features_1, features_2) + logits_per_feature_2 = logit_scale * _safe_matmul(features_2, features_1) return logits_per_feature_1, logits_per_feature_2 def forward( - self, features_1: torch.Tensor, features_2: torch.Tensor + self, + features_1: torch.Tensor, + features_2: torch.Tensor, + logit_scale: torch.Tensor, ) -> torch.Tensor: """Calculate the CLIP-style loss between two sets of features. @@ -136,6 +148,8 @@ def forward( First set of features. features_2 : torch.Tensor Second set of features. + logit_scale : torch.Tensor + Logit scale. Returns ------- @@ -150,7 +164,7 @@ def forward( features_2 = F.normalize(features_2, p=2, dim=-1) logits_per_feat1, logits_per_feat2 = self._get_logits( - features_1, features_2, rank=rank, world_size=world_size + features_1, features_2, logit_scale, rank=rank, world_size=world_size ) labels = self._get_ground_truth( features_1.device, diff --git a/mmlearn/tasks/contrastive_pretraining.py b/mmlearn/tasks/contrastive_pretraining.py index 747cb46..966fb60 100644 --- a/mmlearn/tasks/contrastive_pretraining.py +++ b/mmlearn/tasks/contrastive_pretraining.py @@ -2,11 +2,13 @@ import inspect import itertools +import math from dataclasses import dataclass from functools import partial from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union import lightning as L # noqa: N812 +import numpy as np import torch import torch.distributed import torch.distributed.nn @@ -151,6 +153,9 @@ def __init__( # noqa: PLR0912, PLR0915 partial[torch.optim.lr_scheduler.LRScheduler], ] ] = None, + init_logit_scale: float = 1 / 0.07, + max_logit_scale: float = 100, + learnable_logit_scale: bool = True, loss: Optional[CLIPLoss] = None, modality_loss_pairs: Optional[List[LossPairSpec]] = None, auxiliary_tasks: Optional[Dict[str, AuxiliaryTaskSpec]] = None, @@ -259,6 +264,19 @@ def __init__( # noqa: PLR0912, PLR0915 } ) + # set up logit scaling + log_logit_scale = torch.ones([]) * np.log(init_logit_scale) + self.max_logit_scale = max_logit_scale + self.learnable_logit_scale = learnable_logit_scale + + if self.learnable_logit_scale: + self.log_logit_scale = torch.nn.Parameter( + log_logit_scale, requires_grad=True + ) + else: + self.register_buffer("log_logit_scale", log_logit_scale) + + # set up contrastive loss pairs if modality_loss_pairs is None: modality_loss_pairs = [ LossPairSpec(modalities=(m1.name, m2.name)) @@ -277,6 +295,7 @@ def __init__( # noqa: PLR0912, PLR0915 ) self.modality_loss_pairs = modality_loss_pairs + # set up auxiliary tasks self.aux_task_specs = auxiliary_tasks or {} self.auxiliary_tasks: Dict[str, L.LightningModule] = {} for task_name, task_spec in self.aux_task_specs.items(): @@ -313,10 +332,11 @@ def __init__( # noqa: PLR0912, PLR0915 f"Expected {eval_task_spec.task} to be an instance of `EvaluationHooks` " f"but got {type(eval_task_spec.task)}." ) - self.evaluation_tasks = evaluation_tasks - def encode(self, inputs: Dict[str, Any], modality: Modality) -> torch.Tensor: + def encode( + self, inputs: Dict[str, Any], modality: Modality, normalize: bool = False + ) -> torch.Tensor: """Encode the input values for the given modality. Parameters @@ -325,6 +345,9 @@ def encode(self, inputs: Dict[str, Any], modality: Modality) -> torch.Tensor: Input values. modality : Modality The modality to encode. + normalize : bool, optional, default=False + Whether to apply L2 normalization to the output (after the head and + postprocessor layers, if present). Returns ------- @@ -339,6 +362,9 @@ def encode(self, inputs: Dict[str, Any], modality: Modality) -> torch.Tensor: if self.postprocessors and modality.name in self.postprocessors: output = self.postprocessors[modality.name](output) + if normalize: + output = torch.nn.functional.normalize(output, p=2, dim=-1) + return output def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: @@ -355,7 +381,7 @@ def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: The encodings for each modality. """ outputs = { - modality.embedding: self.encode(inputs, modality) + modality.embedding: self.encode(inputs, modality, normalize=True) for modality in self._available_modalities } @@ -373,6 +399,16 @@ def _compute_loss( if self.loss_fn is None: return None + with torch.no_grad(): + self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale)) + self.log( + "train/logit_scale", + self.log_logit_scale.exp(), + prog_bar=True, + on_step=True, + on_epoch=False, + ) + contrastive_losses: list[torch.Tensor] = [] for loss_pair in self.modality_loss_pairs: modality_a = Modalities.get_modality(loss_pair.modalities[0]) @@ -389,6 +425,7 @@ def _compute_loss( self.loss_fn( outputs[modality_a.embedding][indices_a], outputs[modality_b.embedding][indices_b], + self.log_logit_scale.exp(), ) * loss_pair.weight ) diff --git a/projects/bioscan_clip/configs/experiment/bioscan_1m.yaml b/projects/bioscan_clip/configs/experiment/bioscan_1m.yaml index 16b407d..be807c3 100644 --- a/projects/bioscan_clip/configs/experiment/bioscan_1m.yaml +++ b/projects/bioscan_clip/configs/experiment/bioscan_1m.yaml @@ -12,8 +12,6 @@ defaults: - /modules/encoders@task.encoders.rgb: timm-vit-lora - /modules/encoders@task.encoders.dna: barcode-bert-lora - /modules/layers@task.heads.text: MLP # the other modalities have projection heads in their encoders - - /modules/layers@task.postprocessors.norm_and_logit_scale.norm: L2Norm - - /modules/layers@task.postprocessors.norm_and_logit_scale.logit_scale: LearnableLogitScaling - /modules/losses@task.loss: CLIPLoss - /modules/optimizers@task.optimizer: AdamW - /modules/lr_schedulers@task.lr_scheduler.scheduler: OneCycleLR @@ -67,19 +65,6 @@ task: text: in_dim: 512 out_dim: ${task.encoders.rgb.projection_dim} - postprocessors: - norm_and_logit_scale: - norm: - dim: -1 - logit_scale: - learnable: True - modality_module_mapping: - text: - postprocessor_key: norm_and_logit_scale - rgb: - postprocessor_key: norm_and_logit_scale - dna: - postprocessor_key: norm_and_logit_scale optimizer: lr: 1.0e-3 eps: 1.0e-6 diff --git a/projects/med_benchmarking/configs/experiment/baseline.yaml b/projects/med_benchmarking/configs/experiment/baseline.yaml index 2776d33..e3c5494 100644 --- a/projects/med_benchmarking/configs/experiment/baseline.yaml +++ b/projects/med_benchmarking/configs/experiment/baseline.yaml @@ -15,8 +15,6 @@ defaults: - /datasets/tokenizers@dataloader.val.collate_fn.batch_processors.text: HFCLIPTokenizer - /modules/encoders@task.encoders.text: HFCLIPTextEncoderWithProjection - /modules/encoders@task.encoders.rgb: HFCLIPVisionEncoderWithProjection - - /modules/layers@task.postprocessors.norm_and_logit_scale.norm: L2Norm - - /modules/layers@task.postprocessors.norm_and_logit_scale.logit_scale: LearnableLogitScaling - /modules/losses@task.loss: CLIPLoss - /modules/optimizers@task.optimizer: AdamW - /modules/lr_schedulers@task.lr_scheduler.scheduler: CosineAnnealingLR @@ -47,17 +45,6 @@ dataloader: num_workers: 4 task: - postprocessors: - norm_and_logit_scale: - norm: - dim: -1 - logit_scale: - learnable: True - modality_module_mapping: - text: - postprocessor_key: norm_and_logit_scale - rgb: - postprocessor_key: norm_and_logit_scale optimizer: betas: - 0.9 diff --git a/projects/med_benchmarking/datasets/pad_ufes_20.py b/projects/med_benchmarking/datasets/pad_ufes_20.py index d0c4c04..35a9700 100644 --- a/projects/med_benchmarking/datasets/pad_ufes_20.py +++ b/projects/med_benchmarking/datasets/pad_ufes_20.py @@ -43,13 +43,13 @@ def __init__( self.split = split # Load cached data if available - cache_path = f"cache/PadUfes20_{split}.pkl" + cache_path = f".cache/PadUfes20_{split}.pkl" if os.path.exists(cache_path): print(f"!!! Using cached dataset for {split}") with open(cache_path, "rb") as f: self.metadata = pickle.load(f) else: - os.makedirs("cache/", exist_ok=True) + os.makedirs(".cache/", exist_ok=True) self.metadata = self._load_and_process_metadata() with open(cache_path, "wb") as f: pickle.dump(self.metadata.to_dict("records"), f) @@ -68,14 +68,13 @@ def _load_and_process_metadata(self) -> pd.DataFrame: df["path"] = df["img_id"].apply( lambda imgid: os.path.join(self.root_dir, "Dataset", imgid) ) - df.drop(columns=["img_id", "diagnostic"], inplace=True).reset_index( - drop=True, inplace=True - ) + df.drop(columns=["img_id", "diagnostic"], inplace=True) + df.reset_index(drop=True, inplace=True) # Split into train and test dataset = {} - dataset["test"] = df.sample(frac=0.2) - dataset["train"] = df.drop(dataset["test"].index) + dataset["test"] = df.sample(frac=0.2, ignore_index=True) + dataset["train"] = df.drop(dataset["test"].index).reset_index(drop=True) return dataset[self.split] def _build_label(self, str_label: str) -> int: