Skip to content

Commit

Permalink
remove load checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
vahid0001 committed Oct 30, 2024
1 parent 41a21f6 commit f9ee90a
Showing 1 changed file with 0 additions and 35 deletions.
35 changes: 0 additions & 35 deletions mmlearn/modules/encoders/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,41 +310,6 @@ def _init_weights(self, m: nn.Module) -> None:
if m.bias is not None:
nn.init.constant_(m.bias, 0)

def load_checkpoint(
self,
device: str,
checkpoint_path: str,
encoder: nn.Module,
predictor: nn.Module,
target_encoder: nn.Module,
) -> Tuple[nn.Module, nn.Module, nn.Module]:
"""Load a pre-trained model from a checkpoint."""
try:
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
epoch = checkpoint["epoch"]

# loading encoder
pretrained_dict = checkpoint["encoder"]
msg = encoder.load_state_dict(pretrained_dict)
print(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}")

# loading predictor
pretrained_dict = checkpoint["predictor"]
msg = predictor.load_state_dict(pretrained_dict)
print(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}")

# loading target_encoder
if target_encoder is not None:
print(list(checkpoint.keys()))
pretrained_dict = checkpoint["target_encoder"]
msg = target_encoder.load_state_dict(pretrained_dict)
print(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}")

except Exception as e:
print(f"Encountered exception when loading checkpoint {e}")

return encoder, predictor, target_encoder

def forward(
self,
x: torch.Tensor,
Expand Down

0 comments on commit f9ee90a

Please sign in to comment.