Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IJEPA task #25

Merged
merged 35 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
e108aef
add v1 of ijepa task
vahid0001 Oct 9, 2024
763a7fa
reorganize imports
vahid0001 Oct 9, 2024
cc98d33
fix precommits
vahid0001 Oct 9, 2024
f7fdcea
change docstring
vahid0001 Oct 9, 2024
be3ac26
revise optimizer configure method
vahid0001 Oct 10, 2024
a1dcd84
add load checkpoint to vit class
vahid0001 Oct 16, 2024
41a21f6
revise ijepa class
vahid0001 Oct 16, 2024
f9ee90a
remove load checkpoint
vahid0001 Oct 30, 2024
8d41a76
revise ijepa trainer class
vahid0001 Oct 30, 2024
b671d09
change/add files for ijepa training
vahid0001 Oct 30, 2024
0f44546
add IJEPA task
vahid0001 Nov 19, 2024
b91b9ea
revise masking function
vahid0001 Nov 19, 2024
d4be471
revise vit classes
vahid0001 Nov 19, 2024
9046cbb
add IJEPA task
vahid0001 Nov 19, 2024
4cd057c
revise masking function
vahid0001 Nov 19, 2024
48eea90
revise vit classes
vahid0001 Nov 19, 2024
6db4155
remove state_dict method
vahid0001 Nov 19, 2024
dbc4154
revise ijepa pretraining task
vahid0001 Nov 19, 2024
583c020
revise ijepa config file
vahid0001 Nov 19, 2024
871e6a2
add init to ijepa project config file
vahid0001 Nov 19, 2024
f508293
add proper transformations
vahid0001 Nov 21, 2024
336c040
revise ijepa config file
vahid0001 Nov 21, 2024
c06390c
change some hp, fix precommit errors
vahid0001 Nov 21, 2024
facac4b
fix precommit errors
vahid0001 Nov 21, 2024
6e91766
fix type annotation
vahid0001 Nov 21, 2024
2fe2a7c
Remove IJEPA import from __init__.py and add it to tasks/__init__.py
fcogidi Nov 27, 2024
7a0bd33
Refactor ExponentialMovingAverage class in ema.py
fcogidi Nov 27, 2024
70bd0e6
Add base class for tasks that require training
fcogidi Nov 27, 2024
b0be82b
Update configuration file for reproducing ImageNet experiment
fcogidi Nov 27, 2024
e37d476
Merge branch 'main' into ijepa_training
fcogidi Nov 27, 2024
099ae33
Fix modalities key in ijepa_pretraining.py
fcogidi Nov 27, 2024
39c2e19
Fix type hints and rename ijepa task module
fcogidi Nov 27, 2024
5e19f87
Add dirpath for saving checkpoints in Vector SLURM environment
fcogidi Nov 27, 2024
5937565
Fix import statement for IJEPA class
fcogidi Nov 27, 2024
ceed49d
Merge branch 'main' into ijepa_training
fcogidi Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 48 additions & 32 deletions mmlearn/datasets/processors/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,31 +237,38 @@ def apply_masks(
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, N, D), where B is the batch size, N is the number
of patches, and D is the feature dimension.
Input tensor of shape (B, N, D).
masks : Union[torch.Tensor, List[torch.Tensor]]
A list of tensors containing the indices of patches to keep for each sample.
Each mask tensor has shape (B, N), where B is the batch size and N is the number
of patches.
A list of mask tensors of shape (N,), (1, N), or (B, N).

Returns
-------
torch.Tensor
The masked tensor where only the patches indicated by the masks are kept.
The output tensor has shape (B', N', D), where B' is the new batch size
(which may be different due to concatenation) and N' is the
reduced number of patches.

Notes
-----
- The masks should indicate which patches to keep (1 for keep, 0 for discard).
- The function uses `torch.gather` to select the patches specified by the masks.
The output tensor has shape (B * num_masks, N', D),
where N' is the number of patches kept.
"""
all_x = []
for m in masks:
# Expand the mask to match the feature dimension and gather the relevant patches
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
all_x.append(torch.gather(x, dim=1, index=mask_keep))
batch_size = x.size(0)
for m_ in masks:
m = m_.to(x.device)

# Ensure mask is at least 2D
if m.dim() == 1:
m = m.unsqueeze(0) # Shape: (1, N)

# Expand mask to match the batch size if needed
if m.size(0) == 1 and batch_size > 1:
m = m.expand(batch_size, -1) # Shape: (B, N)

# Expand mask to match x's dimensions
m_expanded = (
m.unsqueeze(-1).expand(-1, -1, x.size(-1)).bool()
) # Shape: (B, N, D)

# Use boolean indexing
selected_patches = x[m_expanded].view(batch_size, -1, x.size(-1))
all_x.append(selected_patches)

# Concatenate along the batch dimension
return torch.cat(all_x, dim=0)
Expand All @@ -271,40 +278,39 @@ def apply_masks(
class IJEPAMaskGenerator:
"""Generates encoder and predictor masks for preprocessing.

This class generates masks dynamically for individual examples and can be passed to
a data loader as a preprocessing step.
This class generates masks dynamically for batches of examples.

Parameters
----------
input_size : tuple[int, int], default=(224, 224)
Input image size.
patch_size : int, default=16
Size of each patch.
min_keep : int, default=4
min_keep : int, default=10
Minimum number of patches to keep.
allow_overlap : bool, default=False
Whether to allow overlap between encoder and predictor masks.
enc_mask_scale : tuple[float, float], default=(0.2, 0.8)
enc_mask_scale : tuple[float, float], default=(0.85, 1.0)
Scale range for encoder mask.
pred_mask_scale : tuple[float, float], default=(0.2, 0.8)
pred_mask_scale : tuple[float, float], default=(0.15, 0.2)
Scale range for predictor mask.
aspect_ratio : tuple[float, float], default=(0.3, 3.0)
aspect_ratio : tuple[float, float], default=(0.75, 1.0)
Aspect ratio range for mask blocks.
nenc : int, default=1
Number of encoder masks to generate.
npred : int, default=2
npred : int, default=4
Number of predictor masks to generate.
"""

input_size: Tuple[int, int] = (224, 224)
patch_size: int = 16
min_keep: int = 4
min_keep: int = 10
allow_overlap: bool = False
enc_mask_scale: Tuple[float, float] = (0.2, 0.8)
pred_mask_scale: Tuple[float, float] = (0.2, 0.8)
aspect_ratio: Tuple[float, float] = (0.3, 3.0)
enc_mask_scale: Tuple[float, float] = (0.85, 1.0)
pred_mask_scale: Tuple[float, float] = (0.15, 0.2)
aspect_ratio: Tuple[float, float] = (0.75, 1.0)
nenc: int = 1
npred: int = 2
npred: int = 4

def __post_init__(self) -> None:
"""Initialize the mask generator."""
Expand Down Expand Up @@ -353,8 +359,14 @@ def _sample_block_mask(

def __call__(
self,
batch_size: int = 1,
) -> Dict[str, Any]:
"""Generate encoder and predictor masks for a single example.
"""Generate encoder and predictor masks for a batch of examples.

Parameters
----------
batch_size : int, default=1
The batch size for which to generate masks.

Returns
-------
Expand All @@ -378,14 +390,18 @@ def __call__(
masks_pred, masks_enc = [], []
for _ in range(self.npred):
mask_p, _ = self._sample_block_mask(p_size)
# Expand mask to match batch size
mask_p = mask_p.unsqueeze(0).expand(batch_size, -1)
masks_pred.append(mask_p)

# Generate encoder masks
for _ in range(self.nenc):
mask_e, _ = self._sample_block_mask(e_size)
# Expand mask to match batch size
mask_e = mask_e.unsqueeze(0).expand(batch_size, -1)
masks_enc.append(mask_e)

return {
"encoder_masks": torch.stack(masks_enc),
"predictor_masks": torch.stack(masks_pred),
"encoder_masks": masks_enc, # List of tensors of shape (batch_size, N)
"predictor_masks": masks_pred, # List of tensors of shape (batch_size, N)
}
92 changes: 46 additions & 46 deletions mmlearn/modules/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,52 @@ def __init__(
self.ema_end_decay = ema_end_decay
self.ema_anneal_end_step = ema_anneal_end_step

@staticmethod
def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
"""Deep copy the model."""
try:
return copy.deepcopy(model)
except RuntimeError as e:
raise RuntimeError("Unable to copy the model ", e) from e

@staticmethod
def get_annealed_rate(
start: float,
end: float,
curr_step: int,
total_steps: int,
) -> float:
"""Calculate EMA annealing rate."""
r = end - start
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining

def step(self, new_model: torch.nn.Module) -> None:
"""Perform single EMA update step."""
self._update_weights(new_model)
self._update_ema_decay()

def restore(self, model: torch.nn.Module) -> torch.nn.Module:
"""Reassign weights from another model.

Parameters
----------
model : nn.Module
Model to load weights from.

Returns
-------
nn.Module
model with new weights
"""
d = self.model.state_dict()
model.load_state_dict(d, strict=False)
return model

def state_dict(self) -> dict[str, Any]:
"""Return the state dict of the model."""
return self.model.state_dict() # type: ignore[no-any-return]

@torch.no_grad() # type: ignore[misc]
def _update_weights(self, new_model: torch.nn.Module) -> None:
if self.decay < 1:
Expand Down Expand Up @@ -98,49 +144,3 @@ def _update_ema_decay(self) -> None:
self.ema_anneal_end_step,
)
self.decay = decay

def step(self, new_model: torch.nn.Module) -> None:
"""Perform single EMA update step."""
self._update_weights(new_model)
self._update_ema_decay()

@staticmethod
def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
"""Deep copy the model."""
try:
return copy.deepcopy(model)
except RuntimeError as e:
raise RuntimeError("Unable to copy the model ", e) from e

def restore(self, model: torch.nn.Module) -> torch.nn.Module:
"""Reassign weights from another model.

Parameters
----------
model : nn.Module
Model to load weights from.

Returns
-------
nn.Module
model with new weights
"""
d = self.model.state_dict()
model.load_state_dict(d, strict=False)
return model

def state_dict(self) -> dict[str, Any]:
"""Return the state dict of the model."""
return self.model.state_dict() # type: ignore[no-any-return]

@staticmethod
def get_annealed_rate(
start: float,
end: float,
curr_step: int,
total_steps: int,
) -> float:
"""Calculate EMA annealing rate."""
r = end - start
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining
Loading
Loading