Skip to content

Commit

Permalink
revise ijepa trainer class
Browse files Browse the repository at this point in the history
  • Loading branch information
vahid0001 committed Oct 30, 2024
1 parent f9ee90a commit 8d41a76
Showing 1 changed file with 46 additions and 6 deletions.
52 changes: 46 additions & 6 deletions mmlearn/tasks/ijepa_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class IJEPA(L.LightningModule):
Initial momentum for EMA of target encoder, by default 0.996.
ema_decay_end : float, optional
Final momentum for EMA of target encoder, by default 1.0.
ema_anneal_end_step : int, optional
Number of steps to anneal EMA momentum to `ema_decay_end`, by default 1000.
loss_fn : Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional
Loss function to use, by default None.
compute_validation_loss : bool, optional
Expand All @@ -55,6 +57,7 @@ def __init__(
lr_scheduler: Optional[Any] = None,
ema_decay: float = 0.996,
ema_decay_end: float = 1.0,
ema_anneal_end_step: int = 1000,
loss_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
compute_validation_loss: bool = True,
compute_test_loss: bool = True,
Expand All @@ -76,30 +79,35 @@ def __init__(
self.encoder = encoder
self.predictor = predictor

self.ema = ExponentialMovingAverage(encoder, ema_decay, ema_decay_end, 1000)
self.ema = ExponentialMovingAverage(
encoder,
ema_decay,
ema_decay_end,
ema_anneal_end_step,
device_id=self.device,
)

def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor:
"""Perform a single training step."""
return self._shared_step(batch, batch_idx, step_type="train", is_training=True)
return self._shared_step(batch, batch_idx, step_type="train")

def validation_step(
self, batch: Dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
"""Run a single validation step."""
return self._shared_step(batch, batch_idx, step_type="val", is_training=False)
return self._shared_step(batch, batch_idx, step_type="val")

def test_step(
self, batch: Dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
"""Run a single test step."""
return self._shared_step(batch, batch_idx, step_type="test", is_training=False)
return self._shared_step(batch, batch_idx, step_type="test")

def _shared_step(
self,
batch: Dict[str, Any],
batch_idx: int,
step_type: str,
is_training: bool = False,
) -> Optional[torch.Tensor]:
images = batch[Modalities.RGB.name]

Expand Down Expand Up @@ -135,7 +143,7 @@ def _shared_step(
sync_dist=True,
)

if is_training:
if step_type == "train":
# EMA update of target encoder
self.ema.step(self.encoder)

Expand Down Expand Up @@ -240,3 +248,35 @@ def on_test_epoch_start(self) -> None:
def on_test_epoch_end(self) -> None:
"""Actions at the end of the test epoch."""
self._on_eval_epoch_end("test")

def _on_eval_epoch_start(self, step_type: str) -> None:
"""Initialize states or configurations at the start of an evaluation epoch.
Parameters
----------
step_type : str
Type of the evaluation phase ("val" or "test").
"""
if (
step_type == "val"
and self.compute_validation_loss
or step_type == "test"
and self.compute_test_loss
):
self.log(f"{step_type}/start", 1, prog_bar=True, sync_dist=True)

def _on_eval_epoch_end(self, step_type: str) -> None:
"""Finalize states or logging at the end of an evaluation epoch.
Parameters
----------
step_type : str
Type of the evaluation phase ("val" or "test").
"""
if (
step_type == "val"
and self.compute_validation_loss
or step_type == "test"
and self.compute_test_loss
):
self.log(f"{step_type}/end", 1, prog_bar=True, sync_dist=True)

0 comments on commit 8d41a76

Please sign in to comment.