diff --git a/mmlearn/tasks/ijepa_pretraining.py b/mmlearn/tasks/ijepa_pretraining.py index 6e4c5d8..6b3b716 100644 --- a/mmlearn/tasks/ijepa_pretraining.py +++ b/mmlearn/tasks/ijepa_pretraining.py @@ -36,6 +36,8 @@ class IJEPA(L.LightningModule): Initial momentum for EMA of target encoder, by default 0.996. ema_decay_end : float, optional Final momentum for EMA of target encoder, by default 1.0. + ema_anneal_end_step : int, optional + Number of steps to anneal EMA momentum to `ema_decay_end`, by default 1000. loss_fn : Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional Loss function to use, by default None. compute_validation_loss : bool, optional @@ -55,6 +57,7 @@ def __init__( lr_scheduler: Optional[Any] = None, ema_decay: float = 0.996, ema_decay_end: float = 1.0, + ema_anneal_end_step: int = 1000, loss_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, compute_validation_loss: bool = True, compute_test_loss: bool = True, @@ -76,30 +79,35 @@ def __init__( self.encoder = encoder self.predictor = predictor - self.ema = ExponentialMovingAverage(encoder, ema_decay, ema_decay_end, 1000) + self.ema = ExponentialMovingAverage( + encoder, + ema_decay, + ema_decay_end, + ema_anneal_end_step, + device_id=self.device, + ) def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: """Perform a single training step.""" - return self._shared_step(batch, batch_idx, step_type="train", is_training=True) + return self._shared_step(batch, batch_idx, step_type="train") def validation_step( self, batch: Dict[str, Any], batch_idx: int ) -> Optional[torch.Tensor]: """Run a single validation step.""" - return self._shared_step(batch, batch_idx, step_type="val", is_training=False) + return self._shared_step(batch, batch_idx, step_type="val") def test_step( self, batch: Dict[str, Any], batch_idx: int ) -> Optional[torch.Tensor]: """Run a single test step.""" - return self._shared_step(batch, batch_idx, step_type="test", is_training=False) + return self._shared_step(batch, batch_idx, step_type="test") def _shared_step( self, batch: Dict[str, Any], batch_idx: int, step_type: str, - is_training: bool = False, ) -> Optional[torch.Tensor]: images = batch[Modalities.RGB.name] @@ -135,7 +143,7 @@ def _shared_step( sync_dist=True, ) - if is_training: + if step_type == "train": # EMA update of target encoder self.ema.step(self.encoder) @@ -240,3 +248,35 @@ def on_test_epoch_start(self) -> None: def on_test_epoch_end(self) -> None: """Actions at the end of the test epoch.""" self._on_eval_epoch_end("test") + + def _on_eval_epoch_start(self, step_type: str) -> None: + """Initialize states or configurations at the start of an evaluation epoch. + + Parameters + ---------- + step_type : str + Type of the evaluation phase ("val" or "test"). + """ + if ( + step_type == "val" + and self.compute_validation_loss + or step_type == "test" + and self.compute_test_loss + ): + self.log(f"{step_type}/start", 1, prog_bar=True, sync_dist=True) + + def _on_eval_epoch_end(self, step_type: str) -> None: + """Finalize states or logging at the end of an evaluation epoch. + + Parameters + ---------- + step_type : str + Type of the evaluation phase ("val" or "test"). + """ + if ( + step_type == "val" + and self.compute_validation_loss + or step_type == "test" + and self.compute_test_loss + ): + self.log(f"{step_type}/end", 1, prog_bar=True, sync_dist=True)