Skip to content

Commit

Permalink
change docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
vahid0001 committed Oct 9, 2024
1 parent cc98d33 commit f7fdcea
Showing 1 changed file with 44 additions and 26 deletions.
70 changes: 44 additions & 26 deletions mmlearn/tasks/ijepa_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class IJEPAPretraining(L.LightningModule):
, by default 0.996.
ema_momentum_end : float, optional
Final momentum for EMA of target encoder, by default 1.0.
loss_fn : Optional[Callable], optional
loss_fn : Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional
Loss function to use, by default None.
compute_validation_loss : bool, optional
Whether to compute validation loss, by default True.
Expand Down Expand Up @@ -87,29 +87,30 @@ def __init__(
self.total_steps = None

self.encoder = VisionTransformer.__dict__[model_name](
img_size=[crop_size],
patch_size=patch_size
img_size=[crop_size], patch_size=patch_size
)

self.predictor = VisionTransformer.__dict__["vit_predictor"](
num_patches=self.encoder.patch_embed.num_patches,
embed_dim=self.encoder.embed_dim,
predictor_embed_dim=pred_emb_dim,
depth=pred_depth,
num_heads=self.encoder.num_heads
num_heads=self.encoder.num_heads,
)

self.target_encoder = copy.deepcopy(self.encoder)

if checkpoint_path != "":
self.encoder, self.predictor, self.target_encoder, _, _, _ = self.load_checkpoint(
device=self.device,
checkpoint_path=checkpoint_path,
encoder=self.encoder,
predictor=self.predictor,
target_encoder=self.target_encoder,
opt=None,
scaler=None,
self.encoder, self.predictor, self.target_encoder, _, _, _ = (
self.load_checkpoint(
device=self.device,
checkpoint_path=checkpoint_path,
encoder=self.encoder,
predictor=self.predictor,
target_encoder=self.target_encoder,
opt=None,
scaler=None,
)
)

# Freeze parameters of target encoder
Expand All @@ -121,14 +122,14 @@ def __init__(
self.target_encoder.to(self.device)

def load_checkpoint(
self,
device : str,
checkpoint_path : str,
encoder : nn.Module,
predictor : nn.Module,
target_encoder : nn.Module,
opt : Any,
scaler : Any,
self,
device: str,
checkpoint_path: str,
encoder: nn.Module,
predictor: nn.Module,
target_encoder: nn.Module,
opt: Any,
scaler: Any,
) -> Tuple[nn.Module, nn.Module, nn.Module, Any, Any, int]:
"""Load a pre-trained model from a checkpoint."""
try:
Expand Down Expand Up @@ -166,7 +167,9 @@ def load_checkpoint(
return encoder, predictor, target_encoder, opt, scaler, epoch

def forward(
self, x: torch.Tensor, masks: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None
self,
x: torch.Tensor,
masks: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
) -> torch.Tensor:
"""Forward pass through the encoder."""
return self.encode(x, masks)
Expand Down Expand Up @@ -205,7 +208,14 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor:
loss = self.loss_fn(z_pred, h_masked)

# Log loss
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
self.log(
"train/loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True,
)

# EMA update of target encoder
self._update_target_encoder()
Expand All @@ -217,7 +227,9 @@ def _update_target_encoder(self) -> None:
if self.total_steps is None:
self.total_steps = self.trainer.estimated_stepping_batches
current_step = self.trainer.global_step
m = self.ema_momentum + (self.ema_momentum_end - self.ema_momentum) * (current_step / self.total_steps)
m = self.ema_momentum + (self.ema_momentum_end - self.ema_momentum) * (
current_step / self.total_steps
)
for param_q, param_k in zip(
self.encoder.parameters(), self.target_encoder.parameters()
):
Expand Down Expand Up @@ -253,15 +265,21 @@ def configure_optimizers(self) -> torch.optim.Optimizer:

return optimizer

def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> Optional[torch.Tensor]:
def validation_step(
self, batch: Dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
"""Run a single validation step."""
return self._shared_eval_step(batch, batch_idx, "val")

def test_step(self, batch: Dict[str, Any], batch_idx: int) -> Optional[torch.Tensor]:
def test_step(
self, batch: Dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
"""Run a single test step."""
return self._shared_eval_step(batch, batch_idx, "test")

def _shared_eval_step(self, batch: Dict[str, Any], batch_idx: int, eval_type: str) -> Optional[torch.Tensor]:
def _shared_eval_step(
self, batch: Dict[str, Any], batch_idx: int, eval_type: str
) -> Optional[torch.Tensor]:
images = batch[Modalities.RGB]

# Generate masks
Expand Down

0 comments on commit f7fdcea

Please sign in to comment.