From c6b07e0f43b837730b6686ed92cc92282db4921c Mon Sep 17 00:00:00 2001 From: Vahid <34673511+vahid0001@users.noreply.github.com> Date: Thu, 5 Dec 2024 11:59:47 -0500 Subject: [PATCH] Add I-JEPA task (#25) Co-authored-by: fcogidi <41602287+fcogidi@users.noreply.github.com> --- mmlearn/datasets/processors/masking.py | 80 +++--- mmlearn/modules/ema.py | 92 +++---- mmlearn/modules/encoders/vision.py | 71 ++++- mmlearn/tasks/__init__.py | 2 + mmlearn/tasks/base.py | 149 ++++++++++ mmlearn/tasks/contrastive_pretraining.py | 118 ++------ mmlearn/tasks/ijepa.py | 259 ++++++++++++++++++ .../configs/experiment/bioscan_1m.yaml | 4 +- projects/ijepa/configs/__init__.py | 102 +++++++ .../experiment/reproduce_imagenet.yaml | 75 +++++ .../configs/experiment/baseline.yaml | 2 +- 11 files changed, 761 insertions(+), 193 deletions(-) create mode 100644 mmlearn/tasks/base.py create mode 100644 mmlearn/tasks/ijepa.py create mode 100644 projects/ijepa/configs/__init__.py create mode 100644 projects/ijepa/configs/experiment/reproduce_imagenet.yaml diff --git a/mmlearn/datasets/processors/masking.py b/mmlearn/datasets/processors/masking.py index 6a0f986..f56bb87 100644 --- a/mmlearn/datasets/processors/masking.py +++ b/mmlearn/datasets/processors/masking.py @@ -237,31 +237,38 @@ def apply_masks( Parameters ---------- x : torch.Tensor - Input tensor of shape (B, N, D), where B is the batch size, N is the number - of patches, and D is the feature dimension. + Input tensor of shape (B, N, D). masks : Union[torch.Tensor, List[torch.Tensor]] - A list of tensors containing the indices of patches to keep for each sample. - Each mask tensor has shape (B, N), where B is the batch size and N is the number - of patches. + A list of mask tensors of shape (N,), (1, N), or (B, N). Returns ------- torch.Tensor The masked tensor where only the patches indicated by the masks are kept. - The output tensor has shape (B', N', D), where B' is the new batch size - (which may be different due to concatenation) and N' is the - reduced number of patches. - - Notes - ----- - - The masks should indicate which patches to keep (1 for keep, 0 for discard). - - The function uses `torch.gather` to select the patches specified by the masks. + The output tensor has shape (B * num_masks, N', D), + where N' is the number of patches kept. """ all_x = [] - for m in masks: - # Expand the mask to match the feature dimension and gather the relevant patches - mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) - all_x.append(torch.gather(x, dim=1, index=mask_keep)) + batch_size = x.size(0) + for m_ in masks: + m = m_.to(x.device) + + # Ensure mask is at least 2D + if m.dim() == 1: + m = m.unsqueeze(0) # Shape: (1, N) + + # Expand mask to match the batch size if needed + if m.size(0) == 1 and batch_size > 1: + m = m.expand(batch_size, -1) # Shape: (B, N) + + # Expand mask to match x's dimensions + m_expanded = ( + m.unsqueeze(-1).expand(-1, -1, x.size(-1)).bool() + ) # Shape: (B, N, D) + + # Use boolean indexing + selected_patches = x[m_expanded].view(batch_size, -1, x.size(-1)) + all_x.append(selected_patches) # Concatenate along the batch dimension return torch.cat(all_x, dim=0) @@ -271,8 +278,7 @@ def apply_masks( class IJEPAMaskGenerator: """Generates encoder and predictor masks for preprocessing. - This class generates masks dynamically for individual examples and can be passed to - a data loader as a preprocessing step. + This class generates masks dynamically for batches of examples. Parameters ---------- @@ -280,31 +286,31 @@ class IJEPAMaskGenerator: Input image size. patch_size : int, default=16 Size of each patch. - min_keep : int, default=4 + min_keep : int, default=10 Minimum number of patches to keep. allow_overlap : bool, default=False Whether to allow overlap between encoder and predictor masks. - enc_mask_scale : tuple[float, float], default=(0.2, 0.8) + enc_mask_scale : tuple[float, float], default=(0.85, 1.0) Scale range for encoder mask. - pred_mask_scale : tuple[float, float], default=(0.2, 0.8) + pred_mask_scale : tuple[float, float], default=(0.15, 0.2) Scale range for predictor mask. - aspect_ratio : tuple[float, float], default=(0.3, 3.0) + aspect_ratio : tuple[float, float], default=(0.75, 1.0) Aspect ratio range for mask blocks. nenc : int, default=1 Number of encoder masks to generate. - npred : int, default=2 + npred : int, default=4 Number of predictor masks to generate. """ input_size: Tuple[int, int] = (224, 224) patch_size: int = 16 - min_keep: int = 4 + min_keep: int = 10 allow_overlap: bool = False - enc_mask_scale: Tuple[float, float] = (0.2, 0.8) - pred_mask_scale: Tuple[float, float] = (0.2, 0.8) - aspect_ratio: Tuple[float, float] = (0.3, 3.0) + enc_mask_scale: Tuple[float, float] = (0.85, 1.0) + pred_mask_scale: Tuple[float, float] = (0.15, 0.2) + aspect_ratio: Tuple[float, float] = (0.75, 1.0) nenc: int = 1 - npred: int = 2 + npred: int = 4 def __post_init__(self) -> None: """Initialize the mask generator.""" @@ -353,8 +359,14 @@ def _sample_block_mask( def __call__( self, + batch_size: int = 1, ) -> Dict[str, Any]: - """Generate encoder and predictor masks for a single example. + """Generate encoder and predictor masks for a batch of examples. + + Parameters + ---------- + batch_size : int, default=1 + The batch size for which to generate masks. Returns ------- @@ -378,14 +390,18 @@ def __call__( masks_pred, masks_enc = [], [] for _ in range(self.npred): mask_p, _ = self._sample_block_mask(p_size) + # Expand mask to match batch size + mask_p = mask_p.unsqueeze(0).expand(batch_size, -1) masks_pred.append(mask_p) # Generate encoder masks for _ in range(self.nenc): mask_e, _ = self._sample_block_mask(e_size) + # Expand mask to match batch size + mask_e = mask_e.unsqueeze(0).expand(batch_size, -1) masks_enc.append(mask_e) return { - "encoder_masks": torch.stack(masks_enc), - "predictor_masks": torch.stack(masks_pred), + "encoder_masks": masks_enc, # List of tensors of shape (batch_size, N) + "predictor_masks": masks_pred, # List of tensors of shape (batch_size, N) } diff --git a/mmlearn/modules/ema.py b/mmlearn/modules/ema.py index 570b8ca..b3e7f59 100644 --- a/mmlearn/modules/ema.py +++ b/mmlearn/modules/ema.py @@ -52,6 +52,52 @@ def __init__( self.ema_end_decay = ema_end_decay self.ema_anneal_end_step = ema_anneal_end_step + @staticmethod + def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module: + """Deep copy the model.""" + try: + return copy.deepcopy(model) + except RuntimeError as e: + raise RuntimeError("Unable to copy the model ", e) from e + + @staticmethod + def get_annealed_rate( + start: float, + end: float, + curr_step: int, + total_steps: int, + ) -> float: + """Calculate EMA annealing rate.""" + r = end - start + pct_remaining = 1 - curr_step / total_steps + return end - r * pct_remaining + + def step(self, new_model: torch.nn.Module) -> None: + """Perform single EMA update step.""" + self._update_weights(new_model) + self._update_ema_decay() + + def restore(self, model: torch.nn.Module) -> torch.nn.Module: + """Reassign weights from another model. + + Parameters + ---------- + model : nn.Module + Model to load weights from. + + Returns + ------- + nn.Module + model with new weights + """ + d = self.model.state_dict() + model.load_state_dict(d, strict=False) + return model + + def state_dict(self) -> dict[str, Any]: + """Return the state dict of the model.""" + return self.model.state_dict() # type: ignore[no-any-return] + @torch.no_grad() # type: ignore[misc] def _update_weights(self, new_model: torch.nn.Module) -> None: if self.decay < 1: @@ -98,49 +144,3 @@ def _update_ema_decay(self) -> None: self.ema_anneal_end_step, ) self.decay = decay - - def step(self, new_model: torch.nn.Module) -> None: - """Perform single EMA update step.""" - self._update_weights(new_model) - self._update_ema_decay() - - @staticmethod - def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module: - """Deep copy the model.""" - try: - return copy.deepcopy(model) - except RuntimeError as e: - raise RuntimeError("Unable to copy the model ", e) from e - - def restore(self, model: torch.nn.Module) -> torch.nn.Module: - """Reassign weights from another model. - - Parameters - ---------- - model : nn.Module - Model to load weights from. - - Returns - ------- - nn.Module - model with new weights - """ - d = self.model.state_dict() - model.load_state_dict(d, strict=False) - return model - - def state_dict(self) -> dict[str, Any]: - """Return the state dict of the model.""" - return self.model.state_dict() # type: ignore[no-any-return] - - @staticmethod - def get_annealed_rate( - start: float, - end: float, - curr_step: int, - total_steps: int, - ) -> float: - """Calculate EMA annealing rate.""" - r = end - start - pct_remaining = 1 - curr_step / total_steps - return end - r * pct_remaining diff --git a/mmlearn/modules/encoders/vision.py b/mmlearn/modules/encoders/vision.py index 1f2ee7c..a427fdd 100644 --- a/mmlearn/modules/encoders/vision.py +++ b/mmlearn/modules/encoders/vision.py @@ -2,7 +2,7 @@ import math from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast import timm import torch @@ -284,7 +284,6 @@ def __init__( # Weight Initialization self.init_std = init_std self.apply(self._init_weights) - self.fix_init_weight() def fix_init_weight(self) -> None: """Fix initialization of weights by rescaling them according to layer depth.""" @@ -294,7 +293,7 @@ def rescale(param: torch.Tensor, layer_id: int) -> None: for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) - rescale(layer.mlp.fc2.weight.data, layer_id + 1) + rescale(layer.mlp[-1].weight.data, layer_id + 1) def _init_weights(self, m: nn.Module) -> None: """Initialize weights for the layers.""" @@ -428,7 +427,7 @@ class VisionTransformerPredictor(nn.Module): def __init__( self, - num_patches: int, + num_patches: int = 196, embed_dim: int = 768, predictor_embed_dim: int = 384, depth: int = 6, @@ -445,7 +444,11 @@ def __init__( ) -> None: """Initialize the Vision Transformer Predictor module.""" super().__init__() - self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) + self.num_patches = num_patches + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.predictor_embed = nn.Linear(self.embed_dim, predictor_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) @@ -453,10 +456,12 @@ def __init__( # Positional Embedding self.predictor_pos_embed = nn.Parameter( - torch.zeros(1, num_patches, predictor_embed_dim), requires_grad=False + torch.zeros(1, self.num_patches, predictor_embed_dim), requires_grad=False ) predictor_pos_embed = get_2d_sincos_pos_embed( - self.predictor_pos_embed.shape[-1], int(num_patches**0.5), cls_token=False + self.predictor_pos_embed.shape[-1], + int(self.num_patches**0.5), + cls_token=False, ) self.predictor_pos_embed.data.copy_( torch.from_numpy(predictor_pos_embed).float().unsqueeze(0) @@ -467,7 +472,7 @@ def __init__( [ Block( dim=predictor_embed_dim, - num_heads=num_heads, + num_heads=self.num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, @@ -487,7 +492,6 @@ def __init__( self.init_std = init_std trunc_normal_(self.mask_token, std=self.init_std) self.apply(self._init_weights) - self.fix_init_weight() def fix_init_weight(self) -> None: """Fix initialization of weights by rescaling them according to layer depth.""" @@ -561,6 +565,13 @@ def forward( return self.predictor_proj(x) +@cast( + VisionTransformerPredictor, + store( + group="modules/encoders", + provider="mmlearn", + ), +) def vit_predictor(**kwargs: Any) -> VisionTransformerPredictor: """ Create a VisionTransformerPredictor model. @@ -575,6 +586,13 @@ def vit_predictor(**kwargs: Any) -> VisionTransformerPredictor: ) +@cast( + VisionTransformer, + store( + group="modules/encoders", + provider="mmlearn", + ), +) def vit_tiny(patch_size: int = 16, **kwargs: Any) -> VisionTransformer: """ Create a VisionTransformer model with tiny configuration. @@ -596,6 +614,13 @@ def vit_tiny(patch_size: int = 16, **kwargs: Any) -> VisionTransformer: ) +@cast( + VisionTransformer, + store( + group="modules/encoders", + provider="mmlearn", + ), +) def vit_small(patch_size: int = 16, **kwargs: Any) -> VisionTransformer: """ Create a VisionTransformer model with small configuration. @@ -617,6 +642,13 @@ def vit_small(patch_size: int = 16, **kwargs: Any) -> VisionTransformer: ) +@cast( + VisionTransformer, + store( + group="modules/encoders", + provider="mmlearn", + ), +) def vit_base(patch_size: int = 16, **kwargs: Any) -> VisionTransformer: """ Create a VisionTransformer model with base configuration. @@ -638,6 +670,13 @@ def vit_base(patch_size: int = 16, **kwargs: Any) -> VisionTransformer: ) +@cast( + VisionTransformer, + store( + group="modules/encoders", + provider="mmlearn", + ), +) def vit_large(patch_size: int = 16, **kwargs: Any) -> VisionTransformer: """ Create a VisionTransformer model with large configuration. @@ -659,6 +698,13 @@ def vit_large(patch_size: int = 16, **kwargs: Any) -> VisionTransformer: ) +@cast( + VisionTransformer, + store( + group="modules/encoders", + provider="mmlearn", + ), +) def vit_huge(patch_size: int = 16, **kwargs: Any) -> VisionTransformer: """ Create a VisionTransformer model with huge configuration. @@ -680,6 +726,13 @@ def vit_huge(patch_size: int = 16, **kwargs: Any) -> VisionTransformer: ) +@cast( + VisionTransformer, + store( + group="modules/encoders", + provider="mmlearn", + ), +) def vit_giant(patch_size: int = 16, **kwargs: Any) -> VisionTransformer: """ Create a VisionTransformer model with giant configuration. diff --git a/mmlearn/tasks/__init__.py b/mmlearn/tasks/__init__.py index 6396428..552cae0 100644 --- a/mmlearn/tasks/__init__.py +++ b/mmlearn/tasks/__init__.py @@ -1,12 +1,14 @@ """Modules for pretraining, downstream and evaluation tasks.""" from mmlearn.tasks.contrastive_pretraining import ContrastivePretraining +from mmlearn.tasks.ijepa import IJEPA from mmlearn.tasks.zero_shot_classification import ZeroShotClassification from mmlearn.tasks.zero_shot_retrieval import ZeroShotCrossModalRetrieval __all__ = [ "ContrastivePretraining", + "IJEPA", "ZeroShotCrossModalRetrieval", "ZeroShotClassification", ] diff --git a/mmlearn/tasks/base.py b/mmlearn/tasks/base.py new file mode 100644 index 0000000..4127382 --- /dev/null +++ b/mmlearn/tasks/base.py @@ -0,0 +1,149 @@ +"""Base class for all tasks in mmlearn that require training.""" + +import inspect +from functools import partial +from typing import Any, Dict, Optional, Union + +import lightning as L # noqa: N812 +import torch +import torch.distributed +import torch.distributed.nn +from lightning.pytorch.utilities.types import OptimizerLRScheduler +from lightning_utilities.core.rank_zero import rank_zero_warn + + +class TrainingTask(L.LightningModule): + """Base class for all tasks in mmlearn that require training. + + Parameters + ---------- + optimizer : partial[torch.optim.Optimizer], optional, default=None + The optimizer to use for training. This is expected to be a partial function, + created using `functools.partial`, that takes the model parameters as the + only required argument. If not provided, training will continue without an + optimizer. + lr_scheduler : Union[Dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]], optional, default=None + The learning rate scheduler to use for training. This can be a partial function + that takes the optimizer as the only required argument or a dictionary with + a `scheduler` key that specifies the scheduler and an optional `extras` key + that specifies additional arguments to pass to the scheduler. If not provided, + the learning rate will not be adjusted during training. + loss_fn : Optional[torch.nn.Module], optional, default=None + Loss function to use for training. + compute_validation_loss : bool, optional, default=True + Whether to compute the validation loss if a validation dataloader is provided. + The loss function must be provided to compute the validation loss. + compute_test_loss : bool, optional, default=True + Whether to compute the test loss if a test dataloader is provided. The loss + function must be provided to compute the test loss. + """ # noqa: W505 + + def __init__( + self, + optimizer: Optional[partial[torch.optim.Optimizer]] = None, + lr_scheduler: Optional[ + Union[ + Dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], + partial[torch.optim.lr_scheduler.LRScheduler], + ] + ] = None, + loss_fn: Optional[torch.nn.Module] = None, + compute_validation_loss: bool = True, + compute_test_loss: bool = True, + ): + super().__init__() + if loss_fn is None and (compute_validation_loss or compute_test_loss): + raise ValueError( + "Loss function must be provided to compute validation or test loss." + ) + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.loss_fn = loss_fn + self.compute_validation_loss = compute_validation_loss + self.compute_test_loss = compute_test_loss + + def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912 + """Configure the optimizer and learning rate scheduler.""" + if self.optimizer is None: + rank_zero_warn( + "Optimizer not provided. Training will continue without an optimizer. " + "LR scheduler will not be used.", + ) + return None + + weight_decay: Optional[float] = self.optimizer.keywords.get( + "weight_decay", None + ) + if weight_decay is None: # try getting default value + kw_param = inspect.signature(self.optimizer.func).parameters.get( + "weight_decay" + ) + if kw_param is not None and kw_param.default != inspect.Parameter.empty: + weight_decay = kw_param.default + + parameters = [param for param in self.parameters() if param.requires_grad] + + if weight_decay is not None: + decay_params = [] + no_decay_params = [] + + for param in self.parameters(): + if not param.requires_grad: + continue + + if param.ndim < 2: # includes all bias and normalization parameters + no_decay_params.append(param) + else: + decay_params.append(param) + + parameters = [ + { + "params": decay_params, + "weight_decay": weight_decay, + "name": "weight_decay_params", + }, + { + "params": no_decay_params, + "weight_decay": 0.0, + "name": "no_weight_decay_params", + }, + ] + + optimizer = self.optimizer(parameters) + if not isinstance(optimizer, torch.optim.Optimizer): + raise TypeError( + "Expected optimizer to be an instance of `torch.optim.Optimizer`, " + f"but got {type(optimizer)}.", + ) + + if self.lr_scheduler is not None: + if isinstance(self.lr_scheduler, dict): + if "scheduler" not in self.lr_scheduler: + raise ValueError( + "Expected 'scheduler' key in the learning rate scheduler dictionary.", + ) + + lr_scheduler = self.lr_scheduler["scheduler"](optimizer) + if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler): + raise TypeError( + "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, " + f"but got {type(lr_scheduler)}.", + ) + lr_scheduler_dict: Dict[ + str, Union[torch.optim.lr_scheduler.LRScheduler, Any] + ] = {"scheduler": lr_scheduler} + + if self.lr_scheduler.get("extras"): + lr_scheduler_dict.update(self.lr_scheduler["extras"]) + return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict} + + lr_scheduler = self.lr_scheduler(optimizer) + if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler): + raise TypeError( + "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, " + f"but got {type(lr_scheduler)}.", + ) + return [optimizer], [lr_scheduler] + + return optimizer diff --git a/mmlearn/tasks/contrastive_pretraining.py b/mmlearn/tasks/contrastive_pretraining.py index c48a724..dc1b07c 100644 --- a/mmlearn/tasks/contrastive_pretraining.py +++ b/mmlearn/tasks/contrastive_pretraining.py @@ -1,6 +1,5 @@ """Contrastive pretraining task.""" -import inspect import itertools import math from dataclasses import dataclass @@ -13,12 +12,11 @@ import torch.distributed import torch.distributed.nn from hydra_zen import store -from lightning.pytorch.utilities.types import OptimizerLRScheduler -from lightning_utilities.core.rank_zero import rank_zero_warn from torch import nn from mmlearn.datasets.core import Modalities from mmlearn.datasets.core.modalities import Modality +from mmlearn.tasks.base import TrainingTask from mmlearn.tasks.hooks import EvaluationHooks @@ -65,7 +63,7 @@ class EvaluationSpec: @store(group="task", provider="mmlearn") -class ContrastivePretraining(L.LightningModule): +class ContrastivePretraining(TrainingTask): """Contrastive pretraining task. This class supports contrastive pretraining with `N` modalities of data. It @@ -120,7 +118,7 @@ class ContrastivePretraining(L.LightningModule): learnable_logit_scale : bool, optional, default=True Whether the logit scale parameter is learnable. If set to False, the logit scale parameter is treated as a constant. - loss : CLIPLoss, optional, default=None + loss : nn.Module, optional, default=None The loss function to use. modality_loss_pairs : List[LossPairSpec], optional, default=None A list of pairs of modalities to compute the contrastive loss between and @@ -173,7 +171,14 @@ def __init__( # noqa: PLR0912, PLR0915 evaluation_tasks: Optional[Dict[str, EvaluationSpec]] = None, ) -> None: """Initialize the module.""" - super().__init__() + super().__init__( + optimizer=optimizer, + lr_scheduler=lr_scheduler, + loss_fn=loss, + compute_validation_loss=compute_validation_loss, + compute_test_loss=compute_test_loss, + ) + self.save_hyperparameters( ignore=[ "encoders", @@ -323,17 +328,7 @@ def __init__( # noqa: PLR0912, PLR0915 self.encoders[Modalities.get_modality(task_spec.modality).name] ) - if loss is None and (compute_validation_loss or compute_test_loss): - raise ValueError( - "Loss function must be provided to compute validation or test loss." - ) - - self.loss_fn = loss - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler self.log_auxiliary_tasks_loss = log_auxiliary_tasks_loss - self.compute_validation_loss = compute_validation_loss - self.compute_test_loss = compute_test_loss if evaluation_tasks is not None: for eval_task_spec in evaluation_tasks.values(): @@ -473,7 +468,6 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale)) loss = self._compute_loss(batch, batch_idx, outputs) - print("loss: ", loss) if loss is None: raise ValueError("The loss function must be provided for training.") @@ -543,93 +537,13 @@ def on_test_epoch_end(self) -> None: """Compute and log epoch-level metrics at the end of the test epoch.""" self._on_eval_epoch_end("test") - def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912 - """Configure the optimizer and learning rate scheduler.""" - if self.optimizer is None: - rank_zero_warn( - "Optimizer not provided. Training will continue without an optimizer. " - "LR scheduler will not be used.", - ) - return None - - weight_decay: Optional[float] = self.optimizer.keywords.get( - "weight_decay", None - ) - if weight_decay is None: # try getting default value - kw_param = inspect.signature(self.optimizer.func).parameters.get( - "weight_decay" - ) - if kw_param is not None and kw_param.default != inspect.Parameter.empty: - weight_decay = kw_param.default - - parameters = [param for param in self.parameters() if param.requires_grad] - - if weight_decay is not None: - decay_params = [] - no_decay_params = [] - - for param in self.parameters(): - if not param.requires_grad: - continue - - if param.ndim < 2: # includes all bias and normalization parameters - no_decay_params.append(param) - else: - decay_params.append(param) - - parameters = [ - { - "params": decay_params, - "weight_decay": weight_decay, - "name": "weight_decay_params", - }, - { - "params": no_decay_params, - "weight_decay": 0.0, - "name": "no_weight_decay_params", - }, - ] - - optimizer = self.optimizer(parameters) - if not isinstance(optimizer, torch.optim.Optimizer): - raise TypeError( - "Expected optimizer to be an instance of `torch.optim.Optimizer`, " - f"but got {type(optimizer)}.", - ) - - if self.lr_scheduler is not None: - if isinstance(self.lr_scheduler, dict): - if "scheduler" not in self.lr_scheduler: - raise ValueError( - "Expected 'scheduler' key in the learning rate scheduler dictionary.", - ) - - lr_scheduler = self.lr_scheduler["scheduler"](optimizer) - if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler): - raise TypeError( - "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, " - f"but got {type(lr_scheduler)}.", - ) - lr_scheduler_dict: Dict[ - str, Union[torch.optim.lr_scheduler.LRScheduler, Any] - ] = {"scheduler": lr_scheduler} - - if self.lr_scheduler.get("extras"): - lr_scheduler_dict.update(self.lr_scheduler["extras"]) - return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict} - - lr_scheduler = self.lr_scheduler(optimizer) - if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler): - raise TypeError( - "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, " - f"but got {type(lr_scheduler)}.", - ) - return [optimizer], [lr_scheduler] - - return optimizer - def _on_eval_epoch_start(self, eval_type: Literal["val", "test"]) -> None: """Prepare for the evaluation epoch.""" + self.encoders.eval() + if self.heads: + self.heads.eval() + if self.postprocessors: + self.postprocessors.eval() if self.evaluation_tasks: for task_spec in self.evaluation_tasks.values(): if (eval_type == "val" and task_spec.run_on_validation) or ( diff --git a/mmlearn/tasks/ijepa.py b/mmlearn/tasks/ijepa.py new file mode 100644 index 0000000..a581bc4 --- /dev/null +++ b/mmlearn/tasks/ijepa.py @@ -0,0 +1,259 @@ +"""IJEPA (Image Joint-Embedding Predictive Architecture) pretraining task.""" + +from functools import partial +from typing import Any, Callable, Dict, Optional, Union + +import torch +import torch.nn.functional as F # noqa: N812 +from hydra_zen import store + +from mmlearn.datasets.core import Modalities +from mmlearn.datasets.processors.masking import IJEPAMaskGenerator, apply_masks +from mmlearn.datasets.processors.transforms import repeat_interleave_batch +from mmlearn.modules.ema import ExponentialMovingAverage +from mmlearn.modules.encoders.vision import VisionTransformer +from mmlearn.tasks.base import TrainingTask + + +@store(group="task", provider="mmlearn") +class IJEPA(TrainingTask): + """Pretraining module for IJEPA. + + This class implements the IJEPA (Image Joint-Embedding Predictive Architecture) + pretraining task using PyTorch Lightning. It trains an encoder and a predictor to + reconstruct masked regions of an image based on its unmasked context. + + Parameters + ---------- + encoder : VisionTransformer + Vision transformer encoder. + predictor : VisionTransformer + Vision transformer predictor. + optimizer : partial[torch.optim.Optimizer], optional, default=None + The optimizer to use for training. This is expected to be a partial function, + created using `functools.partial`, that takes the model parameters as the + only required argument. If not provided, training will continue without an + optimizer. + lr_scheduler : Union[Dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]], optional, default=None + The learning rate scheduler to use for training. This can be a partial function + that takes the optimizer as the only required argument or a dictionary with + a `scheduler` key that specifies the scheduler and an optional `extras` key + that specifies additional arguments to pass to the scheduler. If not provided, + the learning rate will not be adjusted during training. + ema_decay : float, optional + Initial momentum for EMA of target encoder, by default 0.996. + ema_decay_end : float, optional + Final momentum for EMA of target encoder, by default 1.0. + ema_anneal_end_step : int, optional + Number of steps to anneal EMA momentum to `ema_decay_end`, by default 1000. + loss_fn : Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional + Loss function to use, by default None. + compute_validation_loss : bool, optional + Whether to compute validation loss, by default True. + compute_test_loss : bool, optional + Whether to compute test loss, by default True. + """ # noqa: W505 + + def __init__( + self, + encoder: VisionTransformer, + predictor: VisionTransformer, + optimizer: Optional[partial[torch.optim.Optimizer]] = None, + lr_scheduler: Optional[ + Union[ + Dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], + partial[torch.optim.lr_scheduler.LRScheduler], + ] + ] = None, + ema_decay: float = 0.996, + ema_decay_end: float = 1.0, + ema_anneal_end_step: int = 1000, + loss_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + compute_validation_loss: bool = True, + compute_test_loss: bool = True, + ): + super().__init__( + optimizer=optimizer, + lr_scheduler=lr_scheduler, + loss_fn=loss_fn if loss_fn is not None else F.smooth_l1_loss, + compute_validation_loss=compute_validation_loss, + compute_test_loss=compute_test_loss, + ) + + self.mask_generator = IJEPAMaskGenerator() + + self.current_step = 0 + self.total_steps = None + + self.encoder = encoder + self.predictor = predictor + + self.predictor.num_patches = encoder.patch_embed.num_patches + self.predictor.embed_dim = encoder.embed_dim + self.predictor.num_heads = encoder.num_heads + + self.ema = ExponentialMovingAverage( + self.encoder, + ema_decay, + ema_decay_end, + ema_anneal_end_step, + device_id=self.device, + ) + + def configure_model(self) -> None: + """Configure the model.""" + self.ema.model.to(device=self.device, dtype=self.dtype) + + def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: + """Perform exponential moving average update of target encoder. + + This is done right after the optimizer step, which comes just before `zero_grad` + to account for gradient accumulation. + """ + if self.ema is not None: + self.ema.step(self.encoder) + + def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: + """Perform a single training step.""" + return self._shared_step(batch, batch_idx, step_type="train") + + def validation_step( + self, batch: Dict[str, Any], batch_idx: int + ) -> Optional[torch.Tensor]: + """Run a single validation step.""" + return self._shared_step(batch, batch_idx, step_type="val") + + def test_step( + self, batch: Dict[str, Any], batch_idx: int + ) -> Optional[torch.Tensor]: + """Run a single test step.""" + return self._shared_step(batch, batch_idx, step_type="test") + + def on_validation_epoch_start(self) -> None: + """Prepare for the validation epoch.""" + self._on_eval_epoch_start("val") + + def on_validation_epoch_end(self) -> None: + """Actions at the end of the validation epoch.""" + self._on_eval_epoch_end("val") + + def on_test_epoch_start(self) -> None: + """Prepare for the test epoch.""" + self._on_eval_epoch_start("test") + + def on_test_epoch_end(self) -> None: + """Actions at the end of the test epoch.""" + self._on_eval_epoch_end("test") + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """Add relevant EMA state to the checkpoint. + + Parameters + ---------- + checkpoint : Dict[str, Any] + The state dictionary to save the EMA state to. + """ + if self.ema is not None: + checkpoint["ema_params"] = { + "decay": self.ema.decay, + "num_updates": self.ema.num_updates, + } + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """Restore EMA state from the checkpoint. + + Parameters + ---------- + checkpoint : Dict[str, Any] + The state dictionary to restore the EMA state from. + """ + if "ema_params" in checkpoint and self.ema is not None: + ema_params = checkpoint.pop("ema_params") + self.ema.decay = ema_params["decay"] + self.ema.num_updates = ema_params["num_updates"] + + self.ema.restore(self.encoder) + + def _shared_step( + self, batch: Dict[str, Any], batch_idx: int, step_type: str + ) -> Optional[torch.Tensor]: + images = batch[Modalities.RGB.name] + + # Generate masks + batch_size = images.size(0) + mask_info = self.mask_generator(batch_size=batch_size) + + # Extract masks and move to device + device = images.device + encoder_masks = [mask.to(device) for mask in mask_info["encoder_masks"]] + predictor_masks = [mask.to(device) for mask in mask_info["predictor_masks"]] + + # Forward pass through target encoder to get h + with torch.no_grad(): + h = self.ema.model(images) + h = F.layer_norm(h, h.size()[-1:]) + h_masked = apply_masks(h, predictor_masks) + h_masked = repeat_interleave_batch( + h_masked, images.size(0), repeat=len(encoder_masks) + ) + + # Forward pass through encoder with encoder_masks + z = self.encoder(images, masks=encoder_masks) + + # Pass z through predictor with encoder_masks and predictor_masks + z_pred = self.predictor(z, encoder_masks, predictor_masks) + + if step_type == "train": + self.log("train/ema_decay", self.ema.decay, prog_bar=True) + + if self.loss_fn is not None and ( + step_type == "train" + or (step_type == "val" and self.compute_validation_loss) + or (step_type == "test" and self.compute_test_loss) + ): + # Compute loss between z_pred and h_masked + loss = self.loss_fn(z_pred, h_masked) + + # Log loss + self.log( + f"{step_type}/loss", + loss, + prog_bar=True, + sync_dist=True, + ) + + return loss + + return None + + def _on_eval_epoch_start(self, step_type: str) -> None: + """Initialize states or configurations at the start of an evaluation epoch. + + Parameters + ---------- + step_type : str + Type of the evaluation phase ("val" or "test"). + """ + if ( + step_type == "val" + and self.compute_validation_loss + or step_type == "test" + and self.compute_test_loss + ): + self.log(f"{step_type}/start", 1, prog_bar=True, sync_dist=True) + + def _on_eval_epoch_end(self, step_type: str) -> None: + """Finalize states or logging at the end of an evaluation epoch. + + Parameters + ---------- + step_type : str + Type of the evaluation phase ("val" or "test"). + """ + if ( + step_type == "val" + and self.compute_validation_loss + or step_type == "test" + and self.compute_test_loss + ): + self.log(f"{step_type}/end", 1, prog_bar=True, sync_dist=True) diff --git a/projects/bioscan_clip/configs/experiment/bioscan_1m.yaml b/projects/bioscan_clip/configs/experiment/bioscan_1m.yaml index be807c3..b678c8d 100644 --- a/projects/bioscan_clip/configs/experiment/bioscan_1m.yaml +++ b/projects/bioscan_clip/configs/experiment/bioscan_1m.yaml @@ -12,7 +12,7 @@ defaults: - /modules/encoders@task.encoders.rgb: timm-vit-lora - /modules/encoders@task.encoders.dna: barcode-bert-lora - /modules/layers@task.heads.text: MLP # the other modalities have projection heads in their encoders - - /modules/losses@task.loss: CLIPLoss + - /modules/losses@task.loss: ContrastiveLoss - /modules/optimizers@task.optimizer: AdamW - /modules/lr_schedulers@task.lr_scheduler.scheduler: OneCycleLR - /eval_task@task.evaluation_tasks.tax_cls.task: TaxonomicClassification @@ -101,8 +101,6 @@ trainer: model_summary: max_depth: 2 -strict_loading: False - tags: - ${experiment_name} - contrastive pretraining diff --git a/projects/ijepa/configs/__init__.py b/projects/ijepa/configs/__init__.py new file mode 100644 index 0000000..aefa1e4 --- /dev/null +++ b/projects/ijepa/configs/__init__.py @@ -0,0 +1,102 @@ +import os +from typing import Literal +from logging import getLogger + +from PIL import ImageFilter + +import torch +from torchvision import transforms +from mmlearn.conf import external_store + +logger = getLogger() + + +@external_store(group="datasets/transforms") +def ijepa_transforms( + crop_size: int = 224, + crop_scale: tuple = (0.3, 1.0), + color_jitter: float = 0.0, + horizontal_flip: bool = False, + color_distortion: bool = False, + gaussian_blur: bool = False, + normalization: tuple = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + job_type: Literal["train", "eval"] = "train", +) -> transforms.Compose: + """ + Create transforms for training and evaluation. + + Parameters + ---------- + crop_size : int, default=224 + Size of the image crop. + crop_scale : tuple, default=(0.3, 1.0) + Range for the random resized crop scaling. + color_jitter : float, default=0.0 + Strength of color jitter. + horizontal_flip : bool, default=False + Whether to apply random horizontal flip. + color_distortion : bool, default=False + Whether to apply color distortion. + gaussian_blur : bool, default=False + Whether to apply Gaussian blur. + normalization : tuple, default=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + Mean and std for normalization. + job_type : {"train", "eval"}, default="train" + Type of the job (training or evaluation) for which the transforms are needed. + + Returns + ------- + transforms.Compose + Composed transforms for training/evaluation with images. + """ + logger.info("Creating data transforms") + + def get_color_distortion(s: float = 1.0): + """Apply color jitter and random grayscale.""" + color_jitter_transform = transforms.ColorJitter( + 0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s + ) + return transforms.Compose( + [ + transforms.RandomApply([color_jitter_transform], p=0.8), + transforms.RandomGrayscale(p=0.2), + ] + ) + + class GaussianBlur: + """Gaussian blur transform.""" + + def __init__( + self, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0 + ): + self.prob = p + self.radius_min = radius_min + self.radius_max = radius_max + + def __call__(self, img): + if torch.bernoulli(torch.tensor(self.prob)) == 0: + return img + radius = self.radius_min + torch.rand(1).item() * ( + self.radius_max - self.radius_min + ) + return img.filter(ImageFilter.GaussianBlur(radius)) + + transforms_list = [] + if job_type == "train": + transforms_list.append( + transforms.RandomResizedCrop(crop_size, scale=crop_scale) + ) + if horizontal_flip: + transforms_list.append(transforms.RandomHorizontalFlip()) + if color_distortion: + transforms_list.append(get_color_distortion(s=color_jitter)) + if gaussian_blur: + transforms_list.append(GaussianBlur(p=0.5)) + else: + transforms_list.append(transforms.Resize(crop_size)) + transforms_list.append(transforms.CenterCrop(crop_size)) + + transforms_list.append(transforms.ToTensor()) + transforms_list.append(transforms.Normalize(normalization[0], normalization[1])) + + return transforms.Compose(transforms_list) diff --git a/projects/ijepa/configs/experiment/reproduce_imagenet.yaml b/projects/ijepa/configs/experiment/reproduce_imagenet.yaml new file mode 100644 index 0000000..58b01fe --- /dev/null +++ b/projects/ijepa/configs/experiment/reproduce_imagenet.yaml @@ -0,0 +1,75 @@ +# @package _global_ + +defaults: + - /datasets@datasets.train: ImageNet + - /datasets/transforms@datasets.train.transform: ijepa_transforms + - /datasets@datasets.val: ImageNet + - /datasets/transforms@datasets.val.transform: ijepa_transforms + - /modules/encoders@task.encoder: vit_base + - /modules/encoders@task.predictor: vit_predictor + - /modules/optimizers@task.optimizer: AdamW + - /modules/lr_schedulers@task.lr_scheduler.scheduler: CosineAnnealingLR + - /trainer/callbacks@trainer.callbacks.lr_monitor: LearningRateMonitor + - /trainer/callbacks@trainer.callbacks.model_checkpoint: ModelCheckpoint + - /trainer/callbacks@trainer.callbacks.early_stopping: EarlyStopping + - /trainer/callbacks@trainer.callbacks.model_summary: ModelSummary + - /trainer/logger@trainer.logger.wandb: WandbLogger + - override /task: IJEPA + - _self_ + +seed: 0 + +datasets: + val: + split: val + transform: + job_type: eval + +dataloader: + train: + batch_size: 256 + num_workers: 10 + val: + batch_size: 256 + num_workers: 10 + +task: + optimizer: + betas: + - 0.9 + - 0.999 + lr: 1.0e-3 + weight_decay: 0.05 + eps: 1.0e-8 + lr_scheduler: + scheduler: + T_max: ${trainer.max_epochs} + extras: + interval: epoch + +trainer: + max_epochs: 300 + precision: 16-mixed + deterministic: False + benchmark: True + sync_batchnorm: False # Set to True if using DDP with batchnorm + log_every_n_steps: 100 + accumulate_grad_batches: 4 + check_val_every_n_epoch: 1 + callbacks: + model_checkpoint: + monitor: val/loss + save_top_k: 1 + save_last: True + every_n_epochs: 1 + dirpath: /checkpoint/${oc.env:USER}/${oc.env:SLURM_JOB_ID} # only works on Vector SLURM environment + early_stopping: + monitor: val/loss + patience: 5 + mode: min + model_summary: + max_depth: 2 + +tags: + - ${experiment_name} + - ijepa pretraining diff --git a/projects/med_benchmarking/configs/experiment/baseline.yaml b/projects/med_benchmarking/configs/experiment/baseline.yaml index e3c5494..d7bdfc8 100644 --- a/projects/med_benchmarking/configs/experiment/baseline.yaml +++ b/projects/med_benchmarking/configs/experiment/baseline.yaml @@ -15,7 +15,7 @@ defaults: - /datasets/tokenizers@dataloader.val.collate_fn.batch_processors.text: HFCLIPTokenizer - /modules/encoders@task.encoders.text: HFCLIPTextEncoderWithProjection - /modules/encoders@task.encoders.rgb: HFCLIPVisionEncoderWithProjection - - /modules/losses@task.loss: CLIPLoss + - /modules/losses@task.loss: ContrastiveLoss - /modules/optimizers@task.optimizer: AdamW - /modules/lr_schedulers@task.lr_scheduler.scheduler: CosineAnnealingLR - /eval_task@task.evaluation_tasks.retrieval.task: ZeroShotCrossModalRetrieval