Skip to content

Commit

Permalink
change/add files for ijepa training
Browse files Browse the repository at this point in the history
  • Loading branch information
vahid0001 committed Oct 30, 2024
1 parent 8d41a76 commit b671d09
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
8 changes: 8 additions & 0 deletions mmlearn/modules/encoders/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,10 @@ def forward(
return self.predictor_proj(x)


@store(
group="modules/encoders",
provider="mmlearn",
)
def vit_predictor(**kwargs: Any) -> VisionTransformerPredictor:
"""
Create a VisionTransformerPredictor model.
Expand All @@ -575,6 +579,10 @@ def vit_predictor(**kwargs: Any) -> VisionTransformerPredictor:
)


@store(
group="modules/encoders",
provider="mmlearn",
)
def vit_tiny(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
"""
Create a VisionTransformer model with tiny configuration.
Expand Down
2 changes: 2 additions & 0 deletions mmlearn/tasks/ijepa_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import lightning as L # noqa: N812
import torch
import torch.nn.functional as F # noqa: N812
from hydra_zen import store
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from lightning_utilities.core.rank_zero import rank_zero_warn

Expand All @@ -15,6 +16,7 @@
from mmlearn.modules.encoders.vision import VisionTransformer


@store(group="task", provider="mmlearn")
class IJEPA(L.LightningModule):
"""Pretraining module for IJEPA.
Expand Down
74 changes: 74 additions & 0 deletions projects/ijepa/configs/experiment/run_ijepa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# @package _global_

defaults:
- /[email protected]: ImageNet
- /datasets/[email protected]: med_clip_vision_transform
- /[email protected]: ImageNet
- /datasets/[email protected]: med_clip_vision_transform
- /modules/[email protected]: vit_tiny
- /modules/[email protected]: vit_predictor
- /modules/[email protected]: AdamW
- /modules/[email protected]_scheduler.scheduler: CosineAnnealingLR
- /trainer/[email protected]_monitor: LearningRateMonitor
- /trainer/[email protected]_checkpoint: ModelCheckpoint
- /trainer/[email protected]_stopping: EarlyStopping
- /trainer/[email protected]_summary: ModelSummary
- /trainer/[email protected]: WandbLogger
- override /task: IJEPA
- _self_

seed: 0

datasets:
val:
split: valid
transform:
job_type: eval

dataloader:
train:
batch_size: 32
num_workers: 4
val:
batch_size: 32
num_workers: 4

task:
optimizer:
betas:
- 0.9
- 0.98
lr: 5.0e-5
weight_decay: 0.1
eps: 1.0e-6
lr_scheduler:
scheduler:
T_max: 107_559 # make sure to change this if max_epochs or accumulate_grad_batches is changed
extras:
interval: step

trainer:
max_epochs: 20
precision: 16-mixed
deterministic: False
benchmark: True
sync_batchnorm: False # set to True if using DDP with batchnorm
log_every_n_steps: 100
accumulate_grad_batches: 4
check_val_every_n_epoch: 1
callbacks:
model_checkpoint:
monitor: val/loss
save_top_k: 1
save_last: True
every_n_epochs: 1
early_stopping:
monitor: val/loss
patience: 5
mode: min
model_summary:
max_depth: 2

tags:
- ${experiment_name}
- ijepa pretraining

0 comments on commit b671d09

Please sign in to comment.