-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add IJEPA task #25
Add IJEPA task #25
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 files reviewed, 14 unresolved discussions (waiting on @vahid0001)
mmlearn/tasks/ijepa_pretraining.py
line 19 at r1 (raw file):
class IJEPAPretraining(L.LightningModule):
Suggestion:
class IJEPA(L.LightningModule):
mmlearn/tasks/ijepa_pretraining.py
line 38 at r1 (raw file):
pred_depth : int Depth of the predictor. optimizer : Optional[Any], optional
Suggestion:
optimizer : Optional[torch.optim.Optimizer], optional
mmlearn/tasks/ijepa_pretraining.py
line 39 at r1 (raw file):
Depth of the predictor. optimizer : Optional[Any], optional Optimizer configuration, by default None.
What you will get here is not a config, but an initialized optimizer.
mmlearn/tasks/ijepa_pretraining.py
line 41 at r1 (raw file):
Optimizer configuration, by default None. lr_scheduler : Optional[Any], optional Learning rate scheduler configuration, by default None.
Same as the optimizer. This will be an instantiated learning rate scheduler, if one is provided.
mmlearn/tasks/ijepa_pretraining.py
line 91 at r1 (raw file):
self.total_steps = None self.encoder = VisionTransformer.__dict__[model_name](
Like we discussed. I think this should be passed in already initialized. Same with the predictor.
mmlearn/tasks/ijepa_pretraining.py
line 106 at r1 (raw file):
if checkpoint_path != "": self.encoder, self.predictor, self.target_encoder, _, _, _ = (
Also, as we discussed, I think each module should handle loading the pretrained checkpoints individually. An example use case for why is being able to use iJEPA-pretrained encoder (the original one) for contrastive pretraining.
mmlearn/tasks/ijepa_pretraining.py
line 171 at r1 (raw file):
return encoder, predictor, target_encoder, opt, scaler, epoch def forward(
This forward pass is currently not being used by this module. Everything is in training_step
right now. Remember that the LightningModule is also an nn.Module.
mmlearn/tasks/ijepa_pretraining.py
line 177 at r1 (raw file):
) -> torch.Tensor: """Forward pass through the encoder.""" return self.encode(x, masks)
The encode
method doesn't exist. Please make sure to try running the code to ensure that things are working properly.
Also, please take a look at the other encoders in the library - the current convention is to pass the entire batch dictionary to the encoder.
mmlearn/tasks/ijepa_pretraining.py
line 181 at r1 (raw file):
def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: """Perform a single training step.""" images = batch[Modalities.RGB]
The format for this has changed. See this PR.
Suggestion:
images = batch[Modalities.RGB.name]
mmlearn/tasks/ijepa_pretraining.py
line 191 at r1 (raw file):
# Move images and masks to device images = images.to(self.device)
You don't need to do this for anything inside the batch dictionary. Lightning will handle that.
mmlearn/tasks/ijepa_pretraining.py
line 216 at r1 (raw file):
"train/loss", loss, on_step=True,
Why log both on_step and on_epoch? During training, the default for lightning is to log on_step but not on_epoch (for validation and testing, the default is to log on_epoch but not on_step)
mmlearn/tasks/ijepa_pretraining.py
line 227 at r1 (raw file):
return loss def _update_target_encoder(self) -> None:
Please look into reusing the existing EMA module.
mmlearn/tasks/ijepa_pretraining.py
line 247 at r1 (raw file):
"params": (p for n, p in self.encoder.named_parameters() if ("bias" not in n) and (len(p.shape) != 1)), "weight_decay": self.optimizer_cfg.get("weight_decay", 0.0)
Like I mentioned earlier, you will be getting an instantiated Optimizer object, not the config. Please take a look at the contrastive pretraining task for how to get the weight decay value from the instantiated Optimizer object.
mmlearn/tasks/ijepa_pretraining.py
line 309 at r1 (raw file):
return self._shared_eval_step(batch, batch_idx, "test") def _shared_eval_step(
Notice that most of the code in this method is repeated in training_step
. You can define it once and call it multiple times for training and eval.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 2 files reviewed, 4 unresolved discussions (waiting on @vahid0001)
mmlearn/modules/encoders/vision.py
line 313 at r2 (raw file):
nn.init.constant_(m.bias, 0) def load_checkpoint(
How's this intended to be used?
It looks like both the encoder and predictor needs to be instantiated first and then passed to this method. When I had first made the suggestion to move the checkpoint loading functionality to the encoder, I imagined the VisionTransformer
having its own checkpoint loading logic (just for the encoder) and the predictor having its own (loading the target encoder may not be necessary, I think, if one knows the exact ema value where it stopped at).
mmlearn/tasks/ijepa_pretraining.py
line 79 at r2 (raw file):
self.predictor = predictor self.ema = ExponentialMovingAverage(encoder, ema_decay, ema_decay_end, 1000)
Should the ema_anneal_end_step
, currently fixed at 1000
be a user-defined value?
Also, device_id
might be important to set, especially in distributed setting. It can be set to self.device
.
mmlearn/tasks/ijepa_pretraining.py
line 138 at r2 (raw file):
) if is_training:
Is the is_training
flag necessary? Why not, if step_type == "train"
?
mmlearn/tasks/ijepa_pretraining.py
line 230 at r2 (raw file):
def on_validation_epoch_start(self) -> None: """Prepare for the validation epoch.""" self._on_eval_epoch_start("val")
The self._on_eval_epoch_{start/end}
methods are not defined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 2 files reviewed, 4 unresolved discussions (waiting on @vahid0001)
mmlearn/modules/encoders/vision.py
line 313 at r2 (raw file):
Previously, fcogidi (Franklin) wrote…
How's this intended to be used?
It looks like both the encoder and predictor needs to be instantiated first and then passed to this method. When I had first made the suggestion to move the checkpoint loading functionality to the encoder, I imagined the
VisionTransformer
having its own checkpoint loading logic (just for the encoder) and the predictor having its own (loading the target encoder may not be necessary, I think, if one knows the exact ema value where it stopped at).
Speaking of, please check this out.
PR Type
[Feature]
Short Description
Add first version of IJEPA pretraining task
Tests Added
N/A
This change is