Skip to content

Commit

Permalink
Add I-JEPA task (#25)
Browse files Browse the repository at this point in the history
Co-authored-by: fcogidi <[email protected]>
  • Loading branch information
vahid0001 and fcogidi authored Dec 5, 2024
1 parent e474698 commit c6b07e0
Show file tree
Hide file tree
Showing 11 changed files with 761 additions and 193 deletions.
80 changes: 48 additions & 32 deletions mmlearn/datasets/processors/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,31 +237,38 @@ def apply_masks(
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, N, D), where B is the batch size, N is the number
of patches, and D is the feature dimension.
Input tensor of shape (B, N, D).
masks : Union[torch.Tensor, List[torch.Tensor]]
A list of tensors containing the indices of patches to keep for each sample.
Each mask tensor has shape (B, N), where B is the batch size and N is the number
of patches.
A list of mask tensors of shape (N,), (1, N), or (B, N).
Returns
-------
torch.Tensor
The masked tensor where only the patches indicated by the masks are kept.
The output tensor has shape (B', N', D), where B' is the new batch size
(which may be different due to concatenation) and N' is the
reduced number of patches.
Notes
-----
- The masks should indicate which patches to keep (1 for keep, 0 for discard).
- The function uses `torch.gather` to select the patches specified by the masks.
The output tensor has shape (B * num_masks, N', D),
where N' is the number of patches kept.
"""
all_x = []
for m in masks:
# Expand the mask to match the feature dimension and gather the relevant patches
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
all_x.append(torch.gather(x, dim=1, index=mask_keep))
batch_size = x.size(0)
for m_ in masks:
m = m_.to(x.device)

# Ensure mask is at least 2D
if m.dim() == 1:
m = m.unsqueeze(0) # Shape: (1, N)

# Expand mask to match the batch size if needed
if m.size(0) == 1 and batch_size > 1:
m = m.expand(batch_size, -1) # Shape: (B, N)

# Expand mask to match x's dimensions
m_expanded = (
m.unsqueeze(-1).expand(-1, -1, x.size(-1)).bool()
) # Shape: (B, N, D)

# Use boolean indexing
selected_patches = x[m_expanded].view(batch_size, -1, x.size(-1))
all_x.append(selected_patches)

# Concatenate along the batch dimension
return torch.cat(all_x, dim=0)
Expand All @@ -271,40 +278,39 @@ def apply_masks(
class IJEPAMaskGenerator:
"""Generates encoder and predictor masks for preprocessing.
This class generates masks dynamically for individual examples and can be passed to
a data loader as a preprocessing step.
This class generates masks dynamically for batches of examples.
Parameters
----------
input_size : tuple[int, int], default=(224, 224)
Input image size.
patch_size : int, default=16
Size of each patch.
min_keep : int, default=4
min_keep : int, default=10
Minimum number of patches to keep.
allow_overlap : bool, default=False
Whether to allow overlap between encoder and predictor masks.
enc_mask_scale : tuple[float, float], default=(0.2, 0.8)
enc_mask_scale : tuple[float, float], default=(0.85, 1.0)
Scale range for encoder mask.
pred_mask_scale : tuple[float, float], default=(0.2, 0.8)
pred_mask_scale : tuple[float, float], default=(0.15, 0.2)
Scale range for predictor mask.
aspect_ratio : tuple[float, float], default=(0.3, 3.0)
aspect_ratio : tuple[float, float], default=(0.75, 1.0)
Aspect ratio range for mask blocks.
nenc : int, default=1
Number of encoder masks to generate.
npred : int, default=2
npred : int, default=4
Number of predictor masks to generate.
"""

input_size: Tuple[int, int] = (224, 224)
patch_size: int = 16
min_keep: int = 4
min_keep: int = 10
allow_overlap: bool = False
enc_mask_scale: Tuple[float, float] = (0.2, 0.8)
pred_mask_scale: Tuple[float, float] = (0.2, 0.8)
aspect_ratio: Tuple[float, float] = (0.3, 3.0)
enc_mask_scale: Tuple[float, float] = (0.85, 1.0)
pred_mask_scale: Tuple[float, float] = (0.15, 0.2)
aspect_ratio: Tuple[float, float] = (0.75, 1.0)
nenc: int = 1
npred: int = 2
npred: int = 4

def __post_init__(self) -> None:
"""Initialize the mask generator."""
Expand Down Expand Up @@ -353,8 +359,14 @@ def _sample_block_mask(

def __call__(
self,
batch_size: int = 1,
) -> Dict[str, Any]:
"""Generate encoder and predictor masks for a single example.
"""Generate encoder and predictor masks for a batch of examples.
Parameters
----------
batch_size : int, default=1
The batch size for which to generate masks.
Returns
-------
Expand All @@ -378,14 +390,18 @@ def __call__(
masks_pred, masks_enc = [], []
for _ in range(self.npred):
mask_p, _ = self._sample_block_mask(p_size)
# Expand mask to match batch size
mask_p = mask_p.unsqueeze(0).expand(batch_size, -1)
masks_pred.append(mask_p)

# Generate encoder masks
for _ in range(self.nenc):
mask_e, _ = self._sample_block_mask(e_size)
# Expand mask to match batch size
mask_e = mask_e.unsqueeze(0).expand(batch_size, -1)
masks_enc.append(mask_e)

return {
"encoder_masks": torch.stack(masks_enc),
"predictor_masks": torch.stack(masks_pred),
"encoder_masks": masks_enc, # List of tensors of shape (batch_size, N)
"predictor_masks": masks_pred, # List of tensors of shape (batch_size, N)
}
92 changes: 46 additions & 46 deletions mmlearn/modules/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,52 @@ def __init__(
self.ema_end_decay = ema_end_decay
self.ema_anneal_end_step = ema_anneal_end_step

@staticmethod
def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
"""Deep copy the model."""
try:
return copy.deepcopy(model)
except RuntimeError as e:
raise RuntimeError("Unable to copy the model ", e) from e

@staticmethod
def get_annealed_rate(
start: float,
end: float,
curr_step: int,
total_steps: int,
) -> float:
"""Calculate EMA annealing rate."""
r = end - start
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining

def step(self, new_model: torch.nn.Module) -> None:
"""Perform single EMA update step."""
self._update_weights(new_model)
self._update_ema_decay()

def restore(self, model: torch.nn.Module) -> torch.nn.Module:
"""Reassign weights from another model.
Parameters
----------
model : nn.Module
Model to load weights from.
Returns
-------
nn.Module
model with new weights
"""
d = self.model.state_dict()
model.load_state_dict(d, strict=False)
return model

def state_dict(self) -> dict[str, Any]:
"""Return the state dict of the model."""
return self.model.state_dict() # type: ignore[no-any-return]

@torch.no_grad() # type: ignore[misc]
def _update_weights(self, new_model: torch.nn.Module) -> None:
if self.decay < 1:
Expand Down Expand Up @@ -98,49 +144,3 @@ def _update_ema_decay(self) -> None:
self.ema_anneal_end_step,
)
self.decay = decay

def step(self, new_model: torch.nn.Module) -> None:
"""Perform single EMA update step."""
self._update_weights(new_model)
self._update_ema_decay()

@staticmethod
def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
"""Deep copy the model."""
try:
return copy.deepcopy(model)
except RuntimeError as e:
raise RuntimeError("Unable to copy the model ", e) from e

def restore(self, model: torch.nn.Module) -> torch.nn.Module:
"""Reassign weights from another model.
Parameters
----------
model : nn.Module
Model to load weights from.
Returns
-------
nn.Module
model with new weights
"""
d = self.model.state_dict()
model.load_state_dict(d, strict=False)
return model

def state_dict(self) -> dict[str, Any]:
"""Return the state dict of the model."""
return self.model.state_dict() # type: ignore[no-any-return]

@staticmethod
def get_annealed_rate(
start: float,
end: float,
curr_step: int,
total_steps: int,
) -> float:
"""Calculate EMA annealing rate."""
r = end - start
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining
Loading

0 comments on commit c6b07e0

Please sign in to comment.