Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Lightning Checkpoint saving and Earlystopping #979

Open
omarequalmars opened this issue Nov 15, 2024 · 0 comments
Open

Pytorch Lightning Checkpoint saving and Earlystopping #979

omarequalmars opened this issue Nov 15, 2024 · 0 comments

Comments

@omarequalmars
Copy link

Recently I have been trying to train an SMP model using pytorch lightning Checkpoint saving and Earlystopping features in my training loop to save progress. it uses a torch.save() method to save the models at each user-specified condition as a .ckpt/.pth file. I used the torch.load() method to load the file:

checkpoint_path = "DeepLabV3Plus_Training\checkpointsaves\last.pth"

model = SegmentationModel.load_from_checkpoint(checkpoint_path, strict = True)

and I got an error saying that there are many keys that do not match in the state dict. I tried it with strict=False and ofc the model acted like there were no weights, showing just imagenet weight results.

Is there a way to make saving and loading the model using torch.save and torch.load work? I have no problem using the provided methods myself, I just want to integrate that into the pytorch lightning workflow. I made a custom callback that saves the model every x epochs:

    class SaveModelEveryNepochs(pl.Callback):
    def __init__(self, save_dir: str, save_interval: int = 10):

    self.save_dir = save_dir
    self.save_interval = save_interval

def on_train_epoch_end(self, trainer, pl_module):
    """
    Called at the end of each training epoch.

    Args:
        trainer: The trainer instance.
        pl_module: The Lightning module instance.
    """
    current_epoch = trainer.current_epoch
    if current_epoch % self.save_interval == 0:
        # Save the model using SMP's method
        model = pl_module.model  # Assuming your model is stored in pl_module
        model.save_pretrained(f"{self.save_dir}/model_epoch_{current_epoch}")

However, I'd still like it to save the model in the same way in EarlyStopping. Additionally, the default lightning checkpoint callback gives the option to save: the top k models according to user criteria, saving last k models, saving every x epochs, etc.

Is there a way to integrate the existing save_pretrained() and the loading counterpart into the existing EarlyStopping and ModelCheckpoint objects? This would be immensely useful for long experiments, which mine is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant