You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Recently I have been trying to train an SMP model using pytorch lightning Checkpoint saving and Earlystopping features in my training loop to save progress. it uses a torch.save() method to save the models at each user-specified condition as a .ckpt/.pth file. I used the torch.load() method to load the file:
checkpoint_path = "DeepLabV3Plus_Training\checkpointsaves\last.pth"
model = SegmentationModel.load_from_checkpoint(checkpoint_path, strict = True)
and I got an error saying that there are many keys that do not match in the state dict. I tried it with strict=False and ofc the model acted like there were no weights, showing just imagenet weight results.
Is there a way to make saving and loading the model using torch.save and torch.load work? I have no problem using the provided methods myself, I just want to integrate that into the pytorch lightning workflow. I made a custom callback that saves the model every x epochs:
class SaveModelEveryNepochs(pl.Callback):
def __init__(self, save_dir: str, save_interval: int = 10):
self.save_dir = save_dir
self.save_interval = save_interval
def on_train_epoch_end(self, trainer, pl_module):
"""
Called at the end of each training epoch.
Args:
trainer: The trainer instance.
pl_module: The Lightning module instance.
"""
current_epoch = trainer.current_epoch
if current_epoch % self.save_interval == 0:
# Save the model using SMP's method
model = pl_module.model # Assuming your model is stored in pl_module
model.save_pretrained(f"{self.save_dir}/model_epoch_{current_epoch}")
However, I'd still like it to save the model in the same way in EarlyStopping. Additionally, the default lightning checkpoint callback gives the option to save: the top k models according to user criteria, saving last k models, saving every x epochs, etc.
Is there a way to integrate the existing save_pretrained() and the loading counterpart into the existing EarlyStopping and ModelCheckpoint objects? This would be immensely useful for long experiments, which mine is.
The text was updated successfully, but these errors were encountered:
Recently I have been trying to train an SMP model using pytorch lightning Checkpoint saving and Earlystopping features in my training loop to save progress. it uses a torch.save() method to save the models at each user-specified condition as a .ckpt/.pth file. I used the torch.load() method to load the file:
and I got an error saying that there are many keys that do not match in the state dict. I tried it with strict=False and ofc the model acted like there were no weights, showing just imagenet weight results.
Is there a way to make saving and loading the model using torch.save and torch.load work? I have no problem using the provided methods myself, I just want to integrate that into the pytorch lightning workflow. I made a custom callback that saves the model every x epochs:
However, I'd still like it to save the model in the same way in EarlyStopping. Additionally, the default lightning checkpoint callback gives the option to save: the top k models according to user criteria, saving last k models, saving every x epochs, etc.
Is there a way to integrate the existing save_pretrained() and the loading counterpart into the existing EarlyStopping and ModelCheckpoint objects? This would be immensely useful for long experiments, which mine is.
The text was updated successfully, but these errors were encountered: