Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling checkpoint-breaking changes #48

Open
joeloskarsson opened this issue Jun 2, 2024 · 3 comments
Open

Handling checkpoint-breaking changes #48

joeloskarsson opened this issue Jun 2, 2024 · 3 comments

Comments

@joeloskarsson
Copy link
Collaborator

Background

As we make more changes to the code there will be points where checkpoints from saved models can not be directly loaded in a newer version of neural-lam. This happens in particular if we start making changes to variable names of nn.Module attributes and the overall structure of the model classes. It would be good to have a policy of how we want to handle such breaking changes. This issue is for discussing this.

Proposals

I see three main options:

  1. Ignore this issue, and only guarantee that checkpoints trained in a specific version of neural-lam works with that version. If you upgrade you have to re-train models or do some "surgery" on your checkpoints files yourself.
  2. Make sure that we can load checkpoints from all previous versions. This is doable as long as the same neural network parameters are in there, just with different names. We have an example of this already, in the current ARModel:
    def on_load_checkpoint(self, checkpoint):
    """
    Perform any changes to state dict before loading checkpoint
    """
    loaded_state_dict = checkpoint["state_dict"]
    # Fix for loading older models after IneractionNet refactoring, where
    # the grid MLP was moved outside the encoder InteractionNet class
    if "g2m_gnn.grid_mlp.0.weight" in loaded_state_dict:
    replace_keys = list(
    filter(
    lambda key: key.startswith("g2m_gnn.grid_mlp"),
    loaded_state_dict.keys(),
    )
    )
    for old_key in replace_keys:
    new_key = old_key.replace(
    "g2m_gnn.grid_mlp", "encoding_grid_mlp"
    )
    loaded_state_dict[new_key] = loaded_state_dict[old_key]
    del loaded_state_dict[old_key]
  3. Create a separate script for converting checkpoint files from one version to another. The required logic for this is the same as in point 2, but here moved to a separate script that takes a checkpoint file as input and saves a new checkpoint file, now compatible with the new neural-lam version.

Considerations for point 2 and 3

  • This require that as soon as we make such a checkpoint-breaking change we also write the code for handling checkpoints before that change.
  • It would likely be useful to keep track of which version a certain checkpoint was created with, so we know if it needs to be converted. A simple way to do this could be to handle it similarly as Lightning, that stores the version of the package in the checkpoint file (e.g `ckpt["pytorch-lightning_version"] = "2.2.1").

My view

  • I see these kinds of changes to not happen that often, and maybe mostly will happen right now as we are refactoring some things. That could be a reason to just go for alternative 1, but also means that alternative 2/3 is less work. I think I am leaning towards alternative 3, as I would like to be able to use my existing trained checkpoints.
  • I prefer 3 over 2 as I think on_load_checkpoint would get unnecessarily complicated and I'd rather just do the conversion once and have a set of new checkpoint files. It is also easy to do both 2 and 3: if you try to load an old checkpoint you just convert it before loading.
  • While I think it makes sense to keep some track of which version a checkpoint was created with, I would like to avoid building any complex system around this. At the end of the day it is up to the user to make sure that they are using their checkpoint in a training/eval configuration that makes sense. With good tracking in e.g. W&B this is entirely doable. But it is nice to provide some tools to easily upgrade checkpoints if possible.

Tagging @leifdenby and @sadamov to get your input.

@sadamov
Copy link
Collaborator

sadamov commented Jun 3, 2024

These are some very important considerations. I myself have angered some colleagues by making old checkpoints unusable. Now I am also looking at #49 which would introduce much more flexibility to the user wrt model choices. Mostly for that reason and because I don't think we have the human-power to assure backwards compatibility I am leaning towards option 1. Maybe in the future with a more stable repo + more staff we can implement 3?
What I would do now is very solid logging with wandb:

  • git commit + version-tag
  • pinned environment.txt file
  • all yaml-files and user inputs
  • submission scripts (e.g. SLURM)

With such information every checkpoint should be usable for a long time. Maybe I am very much overestimating how much time 3 would require. If that is the case I gladly change my opinion.

@joeloskarsson
Copy link
Collaborator Author

I am a bit unsure myself about how much work it would really be. As long as we only rename members or change the hierarchy of nn.Modules then it just boils down to renaming keys in the state dict. This we already have an implementation for here:

if "g2m_gnn.grid_mlp.0.weight" in loaded_state_dict:
replace_keys = list(
filter(
lambda key: key.startswith("g2m_gnn.grid_mlp"),
loaded_state_dict.keys(),
)
)
for old_key in replace_keys:
new_key = old_key.replace(
"g2m_gnn.grid_mlp", "encoding_grid_mlp"
)
loaded_state_dict[new_key] = loaded_state_dict[old_key]
del loaded_state_dict[old_key]

It just has to be generalized to more than g2m_gnn.grid_mlp.0.weight.

When things can get tricky is if we reorder input features or change dimensionalities of something. But thinking about this a bit more now I realize:

  • If we reorder some input features it is still not a lot of work to just pick apart a weight matrix and reorder its rows. I also doubt there will be many changes that do this.
  • If we change the dimensionality somewhere (let's say that we for some reason now think that all MLPs should be latent_dim -> 2 x latent_dim -> latent_dim instead of latent_dim -> latent_dim -> latent_dim), then we can anyhow not re-use any old checkpoint because there are new entries in the weight matrices. If we would reduce some dimensionality we would also not be able to convert an old checkpoint, as we would have a different output from that part of the model. But these changes can to a very large extent be made optional, and then you just have to make sure to set these options correctly for old checkpoints to be compatible. But this points towards leaving 1 as an option in such cases, and not saying "we will never introduce a change that breaks checkpoint in a way that they can not be converted to the new version".

@joeloskarsson
Copy link
Collaborator Author

I had to do some "surgery" to one of my old checkpoint files, after I had changed the ordering of input features in the implementation. This corresponds to the first bullet point in my comment above. I'll put the script here as an example of what a checkpoint-conversion script could look like:

# Standard library
import os
from argparse import ArgumentParser
from collections import OrderedDict

# Third-party
import torch

# Parameters to reorder dimensions in
# NOTE: If multiple reoders per parameter they are applied sequentially
REORDER_INPUT_DIMS = {
    "grid_prev_embedder.0.weight": OrderedDict({49: 34}),
    "grid_current_embedder.0.weight": OrderedDict({66: 51}),
}


def main():
    """
    Upgrade a checkpoint file to reflect changes to architecture.
    Here specifically reordering of input features.
    """
    parser = ArgumentParser(description="Upgrade checkpoint file")
    parser.add_argument(
        "--load",
        type=str,
        help="Path to checkpoint file to upgrade",
    )
    args = parser.parse_args()

    assert args.load, "Must specify path to checkpoint file to load"

    # Load checkpoint file
    checkpoint_dict = torch.load(args.load, map_location="cpu")
    state_dict = checkpoint_dict["state_dict"]

    # Reorder dimensions
    for param_name, reorder_dict in REORDER_INPUT_DIMS.items():
        param_tensor = state_dict[
            param_name
        ]  # Reorder dimensions in this param
        for from_dim, to_dim in reorder_dict.items():
            # Extract vector at from_dim
            # indexing along dim 1 for input features
            moved_vec = param_tensor[:, from_dim : (from_dim + 1)]

            # Remove from_dim from param
            param_tensor = torch.cat(
                (param_tensor[:, :from_dim], param_tensor[:, (from_dim + 1) :]),
                dim=1,
            )

            # Insert vector as dimension to_dim
            param_tensor = torch.cat(
                (param_tensor[:, :to_dim], moved_vec, param_tensor[:, to_dim:]),
                dim=1,
            )

        # Re-write parameter in state dict
        state_dict[param_name] = param_tensor

    # Save updated state dict
    path_dirname, path_basename = os.path.split(args.load)
    upgraded_ckpt_path = os.path.join(path_dirname, f"upgraded_{path_basename}")

    torch.save(checkpoint_dict, upgraded_ckpt_path)


if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants