diff --git a/mmlearn/tasks/ijepa_pretraining.py b/mmlearn/tasks/ijepa_pretraining.py index 3b9fad5..cbf067b 100644 --- a/mmlearn/tasks/ijepa_pretraining.py +++ b/mmlearn/tasks/ijepa_pretraining.py @@ -42,7 +42,7 @@ class IJEPAPretraining(L.LightningModule): , by default 0.996. ema_momentum_end : float, optional Final momentum for EMA of target encoder, by default 1.0. - loss_fn : Optional[Callable], optional + loss_fn : Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional Loss function to use, by default None. compute_validation_loss : bool, optional Whether to compute validation loss, by default True. @@ -87,8 +87,7 @@ def __init__( self.total_steps = None self.encoder = VisionTransformer.__dict__[model_name]( - img_size=[crop_size], - patch_size=patch_size + img_size=[crop_size], patch_size=patch_size ) self.predictor = VisionTransformer.__dict__["vit_predictor"]( @@ -96,20 +95,22 @@ def __init__( embed_dim=self.encoder.embed_dim, predictor_embed_dim=pred_emb_dim, depth=pred_depth, - num_heads=self.encoder.num_heads + num_heads=self.encoder.num_heads, ) self.target_encoder = copy.deepcopy(self.encoder) if checkpoint_path != "": - self.encoder, self.predictor, self.target_encoder, _, _, _ = self.load_checkpoint( - device=self.device, - checkpoint_path=checkpoint_path, - encoder=self.encoder, - predictor=self.predictor, - target_encoder=self.target_encoder, - opt=None, - scaler=None, + self.encoder, self.predictor, self.target_encoder, _, _, _ = ( + self.load_checkpoint( + device=self.device, + checkpoint_path=checkpoint_path, + encoder=self.encoder, + predictor=self.predictor, + target_encoder=self.target_encoder, + opt=None, + scaler=None, + ) ) # Freeze parameters of target encoder @@ -121,14 +122,14 @@ def __init__( self.target_encoder.to(self.device) def load_checkpoint( - self, - device : str, - checkpoint_path : str, - encoder : nn.Module, - predictor : nn.Module, - target_encoder : nn.Module, - opt : Any, - scaler : Any, + self, + device: str, + checkpoint_path: str, + encoder: nn.Module, + predictor: nn.Module, + target_encoder: nn.Module, + opt: Any, + scaler: Any, ) -> Tuple[nn.Module, nn.Module, nn.Module, Any, Any, int]: """Load a pre-trained model from a checkpoint.""" try: @@ -166,7 +167,9 @@ def load_checkpoint( return encoder, predictor, target_encoder, opt, scaler, epoch def forward( - self, x: torch.Tensor, masks: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None + self, + x: torch.Tensor, + masks: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, ) -> torch.Tensor: """Forward pass through the encoder.""" return self.encode(x, masks) @@ -205,7 +208,14 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: loss = self.loss_fn(z_pred, h_masked) # Log loss - self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log( + "train/loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) # EMA update of target encoder self._update_target_encoder() @@ -217,7 +227,9 @@ def _update_target_encoder(self) -> None: if self.total_steps is None: self.total_steps = self.trainer.estimated_stepping_batches current_step = self.trainer.global_step - m = self.ema_momentum + (self.ema_momentum_end - self.ema_momentum) * (current_step / self.total_steps) + m = self.ema_momentum + (self.ema_momentum_end - self.ema_momentum) * ( + current_step / self.total_steps + ) for param_q, param_k in zip( self.encoder.parameters(), self.target_encoder.parameters() ): @@ -253,15 +265,21 @@ def configure_optimizers(self) -> torch.optim.Optimizer: return optimizer - def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> Optional[torch.Tensor]: + def validation_step( + self, batch: Dict[str, Any], batch_idx: int + ) -> Optional[torch.Tensor]: """Run a single validation step.""" return self._shared_eval_step(batch, batch_idx, "val") - def test_step(self, batch: Dict[str, Any], batch_idx: int) -> Optional[torch.Tensor]: + def test_step( + self, batch: Dict[str, Any], batch_idx: int + ) -> Optional[torch.Tensor]: """Run a single test step.""" return self._shared_eval_step(batch, batch_idx, "test") - def _shared_eval_step(self, batch: Dict[str, Any], batch_idx: int, eval_type: str) -> Optional[torch.Tensor]: + def _shared_eval_step( + self, batch: Dict[str, Any], batch_idx: int, eval_type: str + ) -> Optional[torch.Tensor]: images = batch[Modalities.RGB] # Generate masks