Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into ijepa_training
Browse files Browse the repository at this point in the history
fcogidi committed Dec 5, 2024
2 parents 5937565 + e474698 commit ceed49d
Showing 14 changed files with 515 additions and 270 deletions.
2 changes: 1 addition & 1 deletion mmlearn/cli/run.py
Original file line number Diff line number Diff line change
@@ -45,7 +45,7 @@ def main(cfg: MMLearnConf) -> None: # noqa: PLR0912

if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = True
if "16-mixed" in cfg.trainer.precision:
if "16-mixed" in str(cfg.trainer.precision):
cfg.trainer.precision = "bf16-mixed"

# setup trainer first so that we can get some variables for distributed training
2 changes: 1 addition & 1 deletion mmlearn/conf/__init__.py
Original file line number Diff line number Diff line change
@@ -168,7 +168,7 @@ class MMLearnConf:
job=JobConf(
name=II("experiment_name"),
env_set={
"TORCH_NCCL_ASYNC_ERROR_HANDLING": "3",
"TORCH_NCCL_ASYNC_ERROR_HANDLING": "1",
"HYDRA_FULL_ERROR": "1",
},
),
2 changes: 1 addition & 1 deletion mmlearn/hf_utils.py
Original file line number Diff line number Diff line change
@@ -67,7 +67,7 @@ def load_huggingface_model(
return_unused_kwargs=True,
**model_config_kwargs,
)
model = model_type._from_config(config, **kwargs)
model = model_type.from_config(config, **kwargs)

if get_model_attr is not None and hasattr(model, get_model_attr):
model = getattr(model, get_model_attr)
4 changes: 2 additions & 2 deletions mmlearn/modules/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -5,9 +5,9 @@
HFCLIPTextEncoderWithProjection,
HFCLIPVisionEncoder,
HFCLIPVisionEncoderWithProjection,
PubMedBERTForCLIPTextEncoding,
)
from mmlearn.modules.encoders.text import HFTextEncoder
from mmlearn.modules.encoders.vision import TimmViT


__all__ = [
@@ -16,5 +16,5 @@
"HFCLIPTextEncoderWithProjection",
"HFCLIPVisionEncoder",
"HFCLIPVisionEncoderWithProjection",
"PubMedBERTForCLIPTextEncoding",
"TimmViT",
]
117 changes: 0 additions & 117 deletions mmlearn/modules/encoders/clip.py
Original file line number Diff line number Diff line change
@@ -474,123 +474,6 @@ def forward(self, inputs: Dict[str, Any]) -> Tuple[torch.Tensor]:
return (self.model.visual_projection(pooled_output),)


@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
class PubMedBERTForCLIPTextEncoding(nn.Module):
"""BiomedNLP's PubMedBERT model for CLIP text encoding.
This module is wrapper around the PubMedBERT model from huggingface.
Parameters
----------
pretrained : bool, default=False
Whether to load the pretrained weights or not.
pooling_layer : nn.Module, optional, default=None
Pooling layer to apply to the last hidden state of the model.
freeze_layers : int | float | List[int] | bool, default=False
Whether to freeze layers of the model and which layers to freeze. If `True`,
all model layers are frozen. If it is an integer, the first `N` layers of
the model are frozen. If it is a float, the first `N` percent of the layers
are frozen. If it is a list of integers, the layers at the indices in the
list are frozen.
freeze_layer_norm : bool, default=True
Whether to freeze the layer normalization layers of the model.
peft_config : PeftConfig, optional, default=None
The configuration from the `peft` library to use to wrap the model
for parameter-efficient finetuning.
model_config_kwargs : Dict[str, Any], optional, default=None
Additional keyword arguments to pass to the model configuration.
Warns
-----
UserWarning
If both `peft_config` and `freeze_layers` are set. The `peft_config` will
override the `freeze_layers` setting.
"""

def __init__(
self,
pretrained: bool = True,
pooling_layer: Optional[nn.Module] = None,
freeze_layers: Union[int, float, List[int], bool] = False,
freeze_layer_norm: bool = True,
peft_config: Optional["PeftConfig"] = None,
model_config_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the model."""
super().__init__()
_warn_freeze_with_peft(peft_config, freeze_layers)

model = hf_utils.load_huggingface_model(
transformers.AutoModelForMaskedLM,
"microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
load_pretrained_weights=pretrained,
get_model_attr="bert",
model_config_kwargs=model_config_kwargs,
)

if isinstance(freeze_layers, bool) and freeze_layers:
for name, param in model.named_parameters():
param.requires_grad = (
(not freeze_layer_norm) if "LayerNorm" in name else False
)

layers = [model.embeddings, *model.encoder.layer]
if isinstance(freeze_layers, float):
freeze_layers = int(freeze_layers * len(layers))
if isinstance(freeze_layers, int):
freeze_layers = list(range(freeze_layers))

if isinstance(freeze_layers, list):
for idx, layer in enumerate(layers):
if idx in freeze_layers:
for name, param in layer.named_parameters():
param.requires_grad = (
(not freeze_layer_norm) if "LayerNorm" in name else False
)

if peft_config is not None:
model = hf_utils._wrap_peft_model(model, peft_config)

self.model = model
self.pooling_layer = pooling_layer

def forward(self, inputs: Dict[str, Any]) -> BaseModelOutput:
"""Run the forward pass.
Parameters
----------
inputs : Dict[str, Any]
The input data. The `input_ids` will be expected under the `Modalities.TEXT`
key.
Returns
-------
BaseModelOutput
The output of the model, including the last hidden state, all hidden states,
and the attention weights, if `output_attentions` is set to `True`.
"""
output = self.model(
input_ids=inputs[Modalities.TEXT.name],
attention_mask=inputs.get(
"attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
),
inputs_embeds=inputs.get("inputs_embeds"),
output_attentions=inputs.get("output_attentions"),
output_hidden_states=True,
return_dict=True,
)
last_hidden_state = output.last_hidden_state
if self.pooling_layer is not None:
last_hidden_state = self.pooling_layer(last_hidden_state)

return BaseModelOutput(
last_hidden_state=last_hidden_state,
hidden_states=output.hidden_states,
attentions=output.attentions,
)


#### Utility methods ####


4 changes: 2 additions & 2 deletions mmlearn/modules/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Loss functions."""

from mmlearn.modules.losses.contrastive import CLIPLoss
from mmlearn.modules.losses.contrastive import ContrastiveLoss
from mmlearn.modules.losses.data2vec import Data2VecLoss


__all__ = ["CLIPLoss", "Data2VecLoss"]
__all__ = ["ContrastiveLoss", "Data2VecLoss"]
527 changes: 440 additions & 87 deletions mmlearn/modules/losses/contrastive.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mmlearn/modules/lr_schedulers/linear_warmup_cosine_lr.py
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ def linear_warmup_cosine_annealing_lr(
)
cosine_lr = CosineAnnealingLR(
optimizer,
T_max=max_steps - warmup_steps - 1,
T_max=max_steps - warmup_steps,
eta_min=eta_min,
last_epoch=last_epoch,
)
10 changes: 4 additions & 6 deletions mmlearn/modules/metrics/retrieval_recall.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
import torch
import torch.distributed
from hydra_zen import store
from torch.nn import functional as F # noqa: N812
from torchmetrics import Metric
from torchmetrics.retrieval.base import _retrieval_aggregate
from torchmetrics.utilities.checks import _check_same_shape
@@ -52,7 +53,7 @@ class RetrievalRecallAtK(Metric):
def __init__(
self,
top_k: int,
reduction: Literal["mean", "sum", "none", None] = "sum",
reduction: Literal["mean", "sum", "none", None] = None,
aggregation: Union[
Literal["mean", "median", "min", "max"],
Callable[[torch.Tensor, int], torch.Tensor],
@@ -166,12 +167,9 @@ def compute(self) -> torch.Tensor:
torch.Tensor
The computed metric.
"""
x = dim_zero_cat(self.x)
y = dim_zero_cat(self.y)

# compute the cosine similarity
x_norm = x / x.norm(dim=-1, p=2, keepdim=True)
y_norm = y / y.norm(dim=-1, p=2, keepdim=True)
x_norm = F.normalize(dim_zero_cat(self.x), p=2, dim=-1)
y_norm = F.normalize(dim_zero_cat(self.y), p=2, dim=-1)
similarity = _safe_matmul(x_norm, y_norm)
reduction_mapping: Dict[
Optional[str], Callable[[torch.Tensor], torch.Tensor]
87 changes: 47 additions & 40 deletions mmlearn/tasks/contrastive_pretraining.py
Original file line number Diff line number Diff line change
@@ -14,9 +14,8 @@
from hydra_zen import store
from torch import nn

from mmlearn.datasets.core import Modalities, find_matching_indices
from mmlearn.datasets.core import Modalities
from mmlearn.datasets.core.modalities import Modality
from mmlearn.modules.losses import CLIPLoss
from mmlearn.tasks.base import TrainingTask
from mmlearn.tasks.hooks import EvaluationHooks

@@ -119,7 +118,7 @@ class ContrastivePretraining(TrainingTask):
learnable_logit_scale : bool, optional, default=True
Whether the logit scale parameter is learnable. If set to False, the logit
scale parameter is treated as a constant.
loss : CLIPLoss, optional, default=None
loss : nn.Module, optional, default=None
The loss function to use.
modality_loss_pairs : List[LossPairSpec], optional, default=None
A list of pairs of modalities to compute the contrastive loss between and
@@ -163,7 +162,7 @@ def __init__( # noqa: PLR0912, PLR0915
init_logit_scale: float = 1 / 0.07,
max_logit_scale: float = 100,
learnable_logit_scale: bool = True,
loss: Optional[CLIPLoss] = None,
loss: Optional[nn.Module] = None,
modality_loss_pairs: Optional[List[LossPairSpec]] = None,
auxiliary_tasks: Optional[Dict[str, AuxiliaryTaskSpec]] = None,
log_auxiliary_tasks_loss: bool = False,
@@ -189,6 +188,7 @@ def __init__( # noqa: PLR0912, PLR0915
"loss",
"auxiliary_tasks",
"evaluation_tasks",
"modality_loss_pairs",
]
)

@@ -260,7 +260,7 @@ def __init__( # noqa: PLR0912, PLR0915
if isinstance(heads[head_key], nn.Module)
else nn.Sequential(*heads[head_key].values())
for modality_key, head_key in modality_head_mapping.items()
if head_key is not None
if head_key is not None and head_key in heads
}
)

@@ -275,6 +275,7 @@ def __init__( # noqa: PLR0912, PLR0915
else nn.Sequential(*postprocessors[postprocessor_key].values())
for modality_key, postprocessor_key in modality_postprocessor_mapping.items()
if postprocessor_key is not None
and postprocessor_key in postprocessors
}
)

@@ -363,12 +364,12 @@ def encode(
if self.heads and modality.name in self.heads:
output = self.heads[modality.name](output)

if self.postprocessors and modality.name in self.postprocessors:
output = self.postprocessors[modality.name](output)

if normalize:
output = torch.nn.functional.normalize(output, p=2, dim=-1)

if self.postprocessors and modality.name in self.postprocessors:
output = self.postprocessors[modality.name](output)

return output

def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
@@ -387,6 +388,7 @@ def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
outputs = {
modality.embedding: self.encode(inputs, modality, normalize=True)
for modality in self._available_modalities
if modality.name in inputs
}

if not all(
@@ -403,37 +405,13 @@ def _compute_loss(
if self.loss_fn is None:
return None

with torch.no_grad():
self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))
self.log(
"train/logit_scale",
contrastive_loss = self.loss_fn(
outputs,
batch["example_ids"],
self.log_logit_scale.exp(),
prog_bar=True,
on_step=True,
on_epoch=False,
self.modality_loss_pairs,
)

contrastive_losses: list[torch.Tensor] = []
for loss_pair in self.modality_loss_pairs:
modality_a = Modalities.get_modality(loss_pair.modalities[0])
modality_b = Modalities.get_modality(loss_pair.modalities[1])

indices_a, indices_b = find_matching_indices(
batch["example_ids"][modality_a.name],
batch["example_ids"][modality_b.name],
)
if indices_a.numel() == 0 or indices_b.numel() == 0:
continue

contrastive_losses.append(
self.loss_fn(
outputs[modality_a.embedding][indices_a],
outputs[modality_b.embedding][indices_b],
self.log_logit_scale.exp(),
)
* loss_pair.weight
)

auxiliary_losses: list[torch.Tensor] = []
if self.auxiliary_tasks:
for task_name, task_spec in self.aux_task_specs.items():
@@ -452,9 +430,22 @@ def _compute_loss(

auxiliary_losses.append(task_spec.loss_weight * auxiliary_task_loss)
if self.log_auxiliary_tasks_loss:
self.log(f"train/{task_name}_loss", auxiliary_task_loss)
self.log(
f"train/{task_name}_loss", auxiliary_task_loss, sync_dist=True
)

if not auxiliary_losses:
return contrastive_loss

return torch.stack(contrastive_losses + auxiliary_losses).sum()
return torch.stack(auxiliary_losses).sum() + contrastive_loss

def on_train_epoch_start(self) -> None:
"""Prepare for the training epoch."""
self.encoders.train()
if self.heads:
self.heads.train()
if self.postprocessors:
self.postprocessors.train()

def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor:
"""Compute the loss for the batch.
@@ -472,12 +463,23 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor:
The loss for the batch.
"""
outputs = self(batch)

with torch.no_grad():
self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))

loss = self._compute_loss(batch, batch_idx, outputs)

if loss is None:
raise ValueError("The loss function must be provided for training.")

self.log("train/loss", loss, prog_bar=True)
self.log("train/loss", loss, prog_bar=True, sync_dist=True)
self.log(
"train/logit_scale",
self.log_logit_scale.exp(),
prog_bar=True,
on_step=True,
on_epoch=False,
)

return loss

@@ -537,6 +539,11 @@ def on_test_epoch_end(self) -> None:

def _on_eval_epoch_start(self, eval_type: Literal["val", "test"]) -> None:
"""Prepare for the evaluation epoch."""
self.encoders.eval()
if self.heads:
self.heads.eval()
if self.postprocessors:
self.postprocessors.eval()
if self.evaluation_tasks:
for task_spec in self.evaluation_tasks.values():
if (eval_type == "val" and task_spec.run_on_validation) or (
@@ -571,7 +578,7 @@ def _shared_eval_step(
outputs = self(batch)
loss = self._compute_loss(batch, batch_idx, outputs)
if loss is not None and not self.trainer.sanity_checking:
self.log(f"{eval_type}/loss", loss, prog_bar=True)
self.log(f"{eval_type}/loss", loss, prog_bar=True, sync_dist=True)

if self.evaluation_tasks:
for task_spec in self.evaluation_tasks.values():
20 changes: 13 additions & 7 deletions mmlearn/tasks/zero_shot_retrieval.py
Original file line number Diff line number Diff line change
@@ -48,6 +48,7 @@ def __init__(self, task_specs: List[RetrievalTaskSpec]):

self.task_specs = task_specs
self.metrics: Dict[Tuple[str, str], MetricCollection] = {}
self._available_modalities = set()

for spec in self.task_specs:
query_modality = spec.query_modality
@@ -63,6 +64,8 @@ def __init__(self, task_specs: List[RetrievalTaskSpec]):
for k in spec.top_k
}
)
self._available_modalities.add(query_modality)
self._available_modalities.add(target_modality)

def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
"""Move the metrics to the device of the Lightning module."""
@@ -90,14 +93,17 @@ def evaluation_step(
if pl_module.trainer.sanity_checking:
return

outputs: Dict[str, Any] = pl_module(batch)
outputs: Dict[str, Any] = {}
for modality_name in self._available_modalities:
if modality_name in batch:
outputs[modality_name] = pl_module.encode(
batch, Modalities.get_modality(modality_name), normalize=False
)
for (query_modality, target_modality), metric in self.metrics.items():
query_embeddings: torch.Tensor = outputs[
Modalities.get_modality(query_modality).embedding
]
target_embeddings: torch.Tensor = outputs[
Modalities.get_modality(target_modality).embedding
]
if query_modality not in outputs or target_modality not in outputs:
continue
query_embeddings: torch.Tensor = outputs[query_modality]
target_embeddings: torch.Tensor = outputs[target_modality]
indexes = torch.arange(query_embeddings.size(0), device=pl_module.device)

metric.update(query_embeddings, target_embeddings, indexes)
4 changes: 1 addition & 3 deletions projects/bioscan_clip/configs/experiment/bioscan_1m.yaml
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@ defaults:
- /modules/encoders@task.encoders.rgb: timm-vit-lora
- /modules/encoders@task.encoders.dna: barcode-bert-lora
- /modules/layers@task.heads.text: MLP # the other modalities have projection heads in their encoders
- /modules/losses@task.loss: CLIPLoss
- /modules/losses@task.loss: ContrastiveLoss
- /modules/optimizers@task.optimizer: AdamW
- /modules/lr_schedulers@task.lr_scheduler.scheduler: OneCycleLR
- /eval_task@task.evaluation_tasks.tax_cls.task: TaxonomicClassification
@@ -101,8 +101,6 @@ trainer:
model_summary:
max_depth: 2

strict_loading: False

tags:
- ${experiment_name}
- contrastive pretraining
2 changes: 1 addition & 1 deletion projects/med_benchmarking/configs/experiment/baseline.yaml
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ defaults:
- /datasets/tokenizers@dataloader.val.collate_fn.batch_processors.text: HFCLIPTokenizer
- /modules/encoders@task.encoders.text: HFCLIPTextEncoderWithProjection
- /modules/encoders@task.encoders.rgb: HFCLIPVisionEncoderWithProjection
- /modules/losses@task.loss: CLIPLoss
- /modules/losses@task.loss: ContrastiveLoss
- /modules/optimizers@task.optimizer: AdamW
- /modules/lr_schedulers@task.lr_scheduler.scheduler: CosineAnnealingLR
- /eval_task@task.evaluation_tasks.retrieval.task: ZeroShotCrossModalRetrieval
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mmlearn"
version = "0.1.0a0.dev8" # https://www.python.org/dev/peps/pep-0440/#version-schemes
version = "0.1.0a0" # https://www.python.org/dev/peps/pep-0440/#version-schemes
description = "A modular framework for research on multimodal representation learning."
readme = "README.md"
authors = ["Vector AI Engineering <ai_engineering@vectorinstitute.ai>"]

0 comments on commit ceed49d

Please sign in to comment.