Skip to content

Commit

Permalink
Update logit scaling method (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi authored Nov 4, 2024
1 parent c5d1244 commit dc264de
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 51 deletions.
10 changes: 5 additions & 5 deletions mmlearn/modules/layers/logit_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ class LearnableLogitScaling(torch.nn.Module):

def __init__(
self,
logit_scale_init: float = 1 / 0.07,
learnable: bool = True,
init_logit_scale: float = 1 / 0.07,
max_logit_scale: float = 100,
learnable: bool = True,
) -> None:
super().__init__()
self.max_logit_scale = max_logit_scale
self.logit_scale_init = logit_scale_init
self.init_logit_scale = init_logit_scale
self.learnable = learnable
log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
log_logit_scale = torch.ones([]) * np.log(self.init_logit_scale)
if learnable:
self.log_logit_scale = torch.nn.Parameter(log_logit_scale)
else:
Expand All @@ -49,6 +49,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def extra_repr(self) -> str:
"""Return the string representation of the layer."""
return (
f"logit_scale_init={self.logit_scale_init},learnable={self.learnable},"
f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
f" max_logit_scale={self.max_logit_scale}"
)
30 changes: 22 additions & 8 deletions mmlearn/modules/losses/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _get_logits(
self,
features_1: torch.Tensor,
features_2: torch.Tensor,
logit_scale: torch.Tensor,
rank: int,
world_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -93,7 +94,9 @@ def _get_logits(
features_1 : torch.Tensor
First feature tensor.
features_2 : torch.Tensor
Second feature tensor
Second feature tensor.
logit_scale : torch.Tensor
Logit scale.
rank : int
Rank of the current process.
world_size : int
Expand All @@ -114,19 +117,28 @@ def _get_logits(
)

if self.local_loss:
logits_per_feature_1 = _safe_matmul(features_1, all_features_2)
logits_per_feature_2 = _safe_matmul(features_2, all_features_1)
logits_per_feature_1 = logit_scale * _safe_matmul(
features_1, all_features_2
)
logits_per_feature_2 = logit_scale * _safe_matmul(
features_2, all_features_1
)
else:
logits_per_feature_1 = _safe_matmul(all_features_1, all_features_2)
logits_per_feature_1 = logit_scale * _safe_matmul(
all_features_1, all_features_2
)
logits_per_feature_2 = logits_per_feature_1.T
else:
logits_per_feature_1 = _safe_matmul(features_1, features_2)
logits_per_feature_2 = _safe_matmul(features_2, features_1)
logits_per_feature_1 = logit_scale * _safe_matmul(features_1, features_2)
logits_per_feature_2 = logit_scale * _safe_matmul(features_2, features_1)

return logits_per_feature_1, logits_per_feature_2

def forward(
self, features_1: torch.Tensor, features_2: torch.Tensor
self,
features_1: torch.Tensor,
features_2: torch.Tensor,
logit_scale: torch.Tensor,
) -> torch.Tensor:
"""Calculate the CLIP-style loss between two sets of features.
Expand All @@ -136,6 +148,8 @@ def forward(
First set of features.
features_2 : torch.Tensor
Second set of features.
logit_scale : torch.Tensor
Logit scale.
Returns
-------
Expand All @@ -150,7 +164,7 @@ def forward(
features_2 = F.normalize(features_2, p=2, dim=-1)

logits_per_feat1, logits_per_feat2 = self._get_logits(
features_1, features_2, rank=rank, world_size=world_size
features_1, features_2, logit_scale, rank=rank, world_size=world_size
)
labels = self._get_ground_truth(
features_1.device,
Expand Down
43 changes: 40 additions & 3 deletions mmlearn/tasks/contrastive_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import inspect
import itertools
import math
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union

import lightning as L # noqa: N812
import numpy as np
import torch
import torch.distributed
import torch.distributed.nn
Expand Down Expand Up @@ -151,6 +153,9 @@ def __init__( # noqa: PLR0912, PLR0915
partial[torch.optim.lr_scheduler.LRScheduler],
]
] = None,
init_logit_scale: float = 1 / 0.07,
max_logit_scale: float = 100,
learnable_logit_scale: bool = True,
loss: Optional[CLIPLoss] = None,
modality_loss_pairs: Optional[List[LossPairSpec]] = None,
auxiliary_tasks: Optional[Dict[str, AuxiliaryTaskSpec]] = None,
Expand Down Expand Up @@ -259,6 +264,19 @@ def __init__( # noqa: PLR0912, PLR0915
}
)

# set up logit scaling
log_logit_scale = torch.ones([]) * np.log(init_logit_scale)
self.max_logit_scale = max_logit_scale
self.learnable_logit_scale = learnable_logit_scale

if self.learnable_logit_scale:
self.log_logit_scale = torch.nn.Parameter(
log_logit_scale, requires_grad=True
)
else:
self.register_buffer("log_logit_scale", log_logit_scale)

# set up contrastive loss pairs
if modality_loss_pairs is None:
modality_loss_pairs = [
LossPairSpec(modalities=(m1.name, m2.name))
Expand All @@ -277,6 +295,7 @@ def __init__( # noqa: PLR0912, PLR0915
)
self.modality_loss_pairs = modality_loss_pairs

# set up auxiliary tasks
self.aux_task_specs = auxiliary_tasks or {}
self.auxiliary_tasks: Dict[str, L.LightningModule] = {}
for task_name, task_spec in self.aux_task_specs.items():
Expand Down Expand Up @@ -313,10 +332,11 @@ def __init__( # noqa: PLR0912, PLR0915
f"Expected {eval_task_spec.task} to be an instance of `EvaluationHooks` "
f"but got {type(eval_task_spec.task)}."
)

self.evaluation_tasks = evaluation_tasks

def encode(self, inputs: Dict[str, Any], modality: Modality) -> torch.Tensor:
def encode(
self, inputs: Dict[str, Any], modality: Modality, normalize: bool = False
) -> torch.Tensor:
"""Encode the input values for the given modality.
Parameters
Expand All @@ -325,6 +345,9 @@ def encode(self, inputs: Dict[str, Any], modality: Modality) -> torch.Tensor:
Input values.
modality : Modality
The modality to encode.
normalize : bool, optional, default=False
Whether to apply L2 normalization to the output (after the head and
postprocessor layers, if present).
Returns
-------
Expand All @@ -339,6 +362,9 @@ def encode(self, inputs: Dict[str, Any], modality: Modality) -> torch.Tensor:
if self.postprocessors and modality.name in self.postprocessors:
output = self.postprocessors[modality.name](output)

if normalize:
output = torch.nn.functional.normalize(output, p=2, dim=-1)

return output

def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
Expand All @@ -355,7 +381,7 @@ def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
The encodings for each modality.
"""
outputs = {
modality.embedding: self.encode(inputs, modality)
modality.embedding: self.encode(inputs, modality, normalize=True)
for modality in self._available_modalities
}

Expand All @@ -373,6 +399,16 @@ def _compute_loss(
if self.loss_fn is None:
return None

with torch.no_grad():
self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))
self.log(
"train/logit_scale",
self.log_logit_scale.exp(),
prog_bar=True,
on_step=True,
on_epoch=False,
)

contrastive_losses: list[torch.Tensor] = []
for loss_pair in self.modality_loss_pairs:
modality_a = Modalities.get_modality(loss_pair.modalities[0])
Expand All @@ -389,6 +425,7 @@ def _compute_loss(
self.loss_fn(
outputs[modality_a.embedding][indices_a],
outputs[modality_b.embedding][indices_b],
self.log_logit_scale.exp(),
)
* loss_pair.weight
)
Expand Down
15 changes: 0 additions & 15 deletions projects/bioscan_clip/configs/experiment/bioscan_1m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ defaults:
- /modules/[email protected]: timm-vit-lora
- /modules/[email protected]: barcode-bert-lora
- /modules/[email protected]: MLP # the other modalities have projection heads in their encoders
- /modules/[email protected]_and_logit_scale.norm: L2Norm
- /modules/[email protected]_and_logit_scale.logit_scale: LearnableLogitScaling
- /modules/[email protected]: CLIPLoss
- /modules/[email protected]: AdamW
- /modules/[email protected]_scheduler.scheduler: OneCycleLR
Expand Down Expand Up @@ -67,19 +65,6 @@ task:
text:
in_dim: 512
out_dim: ${task.encoders.rgb.projection_dim}
postprocessors:
norm_and_logit_scale:
norm:
dim: -1
logit_scale:
learnable: True
modality_module_mapping:
text:
postprocessor_key: norm_and_logit_scale
rgb:
postprocessor_key: norm_and_logit_scale
dna:
postprocessor_key: norm_and_logit_scale
optimizer:
lr: 1.0e-3
eps: 1.0e-6
Expand Down
13 changes: 0 additions & 13 deletions projects/med_benchmarking/configs/experiment/baseline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ defaults:
- /datasets/[email protected]_fn.batch_processors.text: HFCLIPTokenizer
- /modules/[email protected]: HFCLIPTextEncoderWithProjection
- /modules/[email protected]: HFCLIPVisionEncoderWithProjection
- /modules/[email protected]_and_logit_scale.norm: L2Norm
- /modules/[email protected]_and_logit_scale.logit_scale: LearnableLogitScaling
- /modules/[email protected]: CLIPLoss
- /modules/[email protected]: AdamW
- /modules/[email protected]_scheduler.scheduler: CosineAnnealingLR
Expand Down Expand Up @@ -47,17 +45,6 @@ dataloader:
num_workers: 4

task:
postprocessors:
norm_and_logit_scale:
norm:
dim: -1
logit_scale:
learnable: True
modality_module_mapping:
text:
postprocessor_key: norm_and_logit_scale
rgb:
postprocessor_key: norm_and_logit_scale
optimizer:
betas:
- 0.9
Expand Down
13 changes: 6 additions & 7 deletions projects/med_benchmarking/datasets/pad_ufes_20.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ def __init__(
self.split = split

# Load cached data if available
cache_path = f"cache/PadUfes20_{split}.pkl"
cache_path = f".cache/PadUfes20_{split}.pkl"
if os.path.exists(cache_path):
print(f"!!! Using cached dataset for {split}")
with open(cache_path, "rb") as f:
self.metadata = pickle.load(f)
else:
os.makedirs("cache/", exist_ok=True)
os.makedirs(".cache/", exist_ok=True)
self.metadata = self._load_and_process_metadata()
with open(cache_path, "wb") as f:
pickle.dump(self.metadata.to_dict("records"), f)
Expand All @@ -68,14 +68,13 @@ def _load_and_process_metadata(self) -> pd.DataFrame:
df["path"] = df["img_id"].apply(
lambda imgid: os.path.join(self.root_dir, "Dataset", imgid)
)
df.drop(columns=["img_id", "diagnostic"], inplace=True).reset_index(
drop=True, inplace=True
)
df.drop(columns=["img_id", "diagnostic"], inplace=True)
df.reset_index(drop=True, inplace=True)

# Split into train and test
dataset = {}
dataset["test"] = df.sample(frac=0.2)
dataset["train"] = df.drop(dataset["test"].index)
dataset["test"] = df.sample(frac=0.2, ignore_index=True)
dataset["train"] = df.drop(dataset["test"].index).reset_index(drop=True)
return dataset[self.split]

def _build_label(self, str_label: str) -> int:
Expand Down

0 comments on commit dc264de

Please sign in to comment.