Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IJEPA task #25

Merged
merged 35 commits into from
Dec 5, 2024
Merged

Add IJEPA task #25

merged 35 commits into from
Dec 5, 2024

Conversation

vahid0001
Copy link
Collaborator

@vahid0001 vahid0001 commented Oct 9, 2024

PR Type

[Feature]

Short Description

Add first version of IJEPA pretraining task

Tests Added

N/A


This change is Reviewable

Copy link
Collaborator

@fcogidi fcogidi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 files reviewed, 14 unresolved discussions (waiting on @vahid0001)


mmlearn/tasks/ijepa_pretraining.py line 19 at r1 (raw file):

class IJEPAPretraining(L.LightningModule):

Suggestion:

class IJEPA(L.LightningModule):

mmlearn/tasks/ijepa_pretraining.py line 38 at r1 (raw file):

    pred_depth : int
        Depth of the predictor.
    optimizer : Optional[Any], optional

Suggestion:

optimizer : Optional[torch.optim.Optimizer], optional

mmlearn/tasks/ijepa_pretraining.py line 39 at r1 (raw file):

        Depth of the predictor.
    optimizer : Optional[Any], optional
        Optimizer configuration, by default None.

What you will get here is not a config, but an initialized optimizer.


mmlearn/tasks/ijepa_pretraining.py line 41 at r1 (raw file):

        Optimizer configuration, by default None.
    lr_scheduler : Optional[Any], optional
        Learning rate scheduler configuration, by default None.

Same as the optimizer. This will be an instantiated learning rate scheduler, if one is provided.


mmlearn/tasks/ijepa_pretraining.py line 91 at r1 (raw file):

        self.total_steps = None

        self.encoder = VisionTransformer.__dict__[model_name](

Like we discussed. I think this should be passed in already initialized. Same with the predictor.


mmlearn/tasks/ijepa_pretraining.py line 106 at r1 (raw file):

        if checkpoint_path != "":
            self.encoder, self.predictor, self.target_encoder, _, _, _ = (

Also, as we discussed, I think each module should handle loading the pretrained checkpoints individually. An example use case for why is being able to use iJEPA-pretrained encoder (the original one) for contrastive pretraining.


mmlearn/tasks/ijepa_pretraining.py line 171 at r1 (raw file):

        return encoder, predictor, target_encoder, opt, scaler, epoch

    def forward(

This forward pass is currently not being used by this module. Everything is in training_step right now. Remember that the LightningModule is also an nn.Module.


mmlearn/tasks/ijepa_pretraining.py line 177 at r1 (raw file):

    ) -> torch.Tensor:
        """Forward pass through the encoder."""
        return self.encode(x, masks)

The encode method doesn't exist. Please make sure to try running the code to ensure that things are working properly.
Also, please take a look at the other encoders in the library - the current convention is to pass the entire batch dictionary to the encoder.


mmlearn/tasks/ijepa_pretraining.py line 181 at r1 (raw file):

    def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor:
        """Perform a single training step."""
        images = batch[Modalities.RGB]

The format for this has changed. See this PR.

Suggestion:

images = batch[Modalities.RGB.name]

mmlearn/tasks/ijepa_pretraining.py line 191 at r1 (raw file):

        # Move images and masks to device
        images = images.to(self.device)

You don't need to do this for anything inside the batch dictionary. Lightning will handle that.


mmlearn/tasks/ijepa_pretraining.py line 216 at r1 (raw file):

            "train/loss",
            loss,
            on_step=True,

Why log both on_step and on_epoch? During training, the default for lightning is to log on_step but not on_epoch (for validation and testing, the default is to log on_epoch but not on_step)


mmlearn/tasks/ijepa_pretraining.py line 227 at r1 (raw file):

        return loss

    def _update_target_encoder(self) -> None:

Please look into reusing the existing EMA module.


mmlearn/tasks/ijepa_pretraining.py line 247 at r1 (raw file):

                "params": (p for n, p in self.encoder.named_parameters()
                           if ("bias" not in n) and (len(p.shape) != 1)),
                "weight_decay": self.optimizer_cfg.get("weight_decay", 0.0)

Like I mentioned earlier, you will be getting an instantiated Optimizer object, not the config. Please take a look at the contrastive pretraining task for how to get the weight decay value from the instantiated Optimizer object.


mmlearn/tasks/ijepa_pretraining.py line 309 at r1 (raw file):

        return self._shared_eval_step(batch, batch_idx, "test")

    def _shared_eval_step(

Notice that most of the code in this method is repeated in training_step. You can define it once and call it multiple times for training and eval.

Copy link
Collaborator

@fcogidi fcogidi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 2 files reviewed, 4 unresolved discussions (waiting on @vahid0001)


mmlearn/modules/encoders/vision.py line 313 at r2 (raw file):

                nn.init.constant_(m.bias, 0)

    def load_checkpoint(

How's this intended to be used?

It looks like both the encoder and predictor needs to be instantiated first and then passed to this method. When I had first made the suggestion to move the checkpoint loading functionality to the encoder, I imagined the VisionTransformer having its own checkpoint loading logic (just for the encoder) and the predictor having its own (loading the target encoder may not be necessary, I think, if one knows the exact ema value where it stopped at).


mmlearn/tasks/ijepa_pretraining.py line 79 at r2 (raw file):

        self.predictor = predictor

        self.ema = ExponentialMovingAverage(encoder, ema_decay, ema_decay_end, 1000)

Should the ema_anneal_end_step, currently fixed at 1000 be a user-defined value?

Also, device_id might be important to set, especially in distributed setting. It can be set to self.device.


mmlearn/tasks/ijepa_pretraining.py line 138 at r2 (raw file):

        )

        if is_training:

Is the is_training flag necessary? Why not, if step_type == "train"?


mmlearn/tasks/ijepa_pretraining.py line 230 at r2 (raw file):

    def on_validation_epoch_start(self) -> None:
        """Prepare for the validation epoch."""
        self._on_eval_epoch_start("val")

The self._on_eval_epoch_{start/end} methods are not defined.

Copy link
Collaborator

@fcogidi fcogidi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 2 files reviewed, 4 unresolved discussions (waiting on @vahid0001)


mmlearn/modules/encoders/vision.py line 313 at r2 (raw file):

Previously, fcogidi (Franklin) wrote…

How's this intended to be used?

It looks like both the encoder and predictor needs to be instantiated first and then passed to this method. When I had first made the suggestion to move the checkpoint loading functionality to the encoder, I imagined the VisionTransformer having its own checkpoint loading logic (just for the encoder) and the predictor having its own (loading the target encoder may not be necessary, I think, if one knows the exact ema value where it stopped at).

Speaking of, please check this out.

@fcogidi fcogidi changed the title add IJEPA task Add IJEPA task Dec 5, 2024
@fcogidi fcogidi merged commit c6b07e0 into main Dec 5, 2024
4 checks passed
@fcogidi fcogidi deleted the ijepa_training branch December 5, 2024 16:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants