Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] TCN model cannot be saved when used with callbacks #2638

Open
MarcBresson opened this issue Jan 7, 2025 · 5 comments
Open

[BUG] TCN model cannot be saved when used with callbacks #2638

MarcBresson opened this issue Jan 7, 2025 · 5 comments
Assignees
Labels
bug Something isn't working improvement New feature or improvement

Comments

@MarcBresson
Copy link
Contributor

Describe the bug
For some reason, defining callbacks in pl_trainer_kwargs will make the torch.save function of TCN.save method save the lightning module. Because TCN contains Parametrized layers, torch will raise an error RuntimeError: Serialization of parametrized modules is only supported through state_dict().

Not setting any custom callbacks will not trigger the bug.

To Reproduce

from darts.utils import timeseries_generation as tg
from darts.models.forecasting.tcn_model import TCNModel

def test_save(self):
    large_ts = tg.constant_timeseries(length=100, value=1000)
    model = TCNModel(
        input_chunk_length=6,
        output_chunk_length=2,
        n_epochs=10,
        num_layers=2,
        kernel_size=3,
        dilation_base=3,
        weight_norm=True,
        dropout=0.1,
        **{
            "pl_trainer_kwargs": {
                "accelerator": "cpu",
                "enable_progress_bar": False,
                "enable_model_summary": False,
                "callbacks": [LiveMetricsCallback()],
            }
        },
    )
    model.fit(large_ts[:98])

        model.save("model.pt")


import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback


class LiveMetricsCallback(Callback):
    def __init__(self):
        self.is_sanity_checking = True

    def on_train_epoch_end(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
    ) -> None:
        print()
        print("train", trainer.current_epoch, self.get_metrics(trainer, pl_module))

    def on_validation_epoch_end(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
    ) -> None:
        if self.is_sanity_checking and trainer.num_sanity_val_steps != 0:
            self.is_sanity_checking = False
            return
        print()
        print("val", trainer.current_epoch, self.get_metrics(trainer, pl_module))

    @staticmethod
    def get_metrics(trainer, pl_module):
        """Computes and returns metrics and losses at the current state."""
        losses = {
            "train_loss": trainer.callback_metrics.get("train_loss"),
            "val_loss": trainer.callback_metrics.get("val_loss"),
        }
        return dict(
            losses,
            **pl_module.train_metrics.compute(),
            **pl_module.val_metrics.compute(),
        )

will output

darts/models/forecasting/torch_forecasting_model.py:1679: in save
    torch.save(self, f_out)
../o2_ml_2/.venv/lib/python3.10/site-packages/torch/serialization.py:629: in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
../o2_ml_2/.venv/lib/python3.10/site-packages/torch/serialization.py:841: in _save
    pickler.dump(obj)
RuntimeError: Serialization of parametrized modules is only supported through state_dict(). See: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training

System (please complete the following information):

  • Python version: 3.10
  • darts version 0.32.0 (the bug does not exist in 0.31)

Additional context
It likely comes from #2593

@MarcBresson MarcBresson added bug Something isn't working triage Issue waiting for triaging labels Jan 7, 2025
@dennisbader
Copy link
Collaborator

Thanks for raising this issue @MarcBresson. It indeed comes from the new parametrized weight norm for which PyTorch doesn't support (pickle) serialization but only in combination with Callbacks (I'll explain further below).

TLDR; You can store the model if you remove the callback before calling save().

model.fit(large_ts[:98], val_series=large_ts)

model.trainer_params["callbacks"] = []
model.save("model.pt")

model_loaded = TCNModel.load("model.pt")
preds = model_loaded.predict(n=2)

Explanation: To avoid these issues, we already prevent pickling through torch.save() of the Lightning Module (the neural network) here:

return {k: v for k, v in self.__dict__.items() if k not in TFM_ATTRS_NO_PICKLE}

Now (unfortunately), after training, the callbacks themselves also have a reference to the LightningModule (_TCNModule in this cases with the parametrized weight norm). Then torch tries to pickle the _TCNModule which raises this error.

You can fix this issue by removing the callbacks before storing the model.

It has been in our backlog to add the option to remove the trainer parameters, training series, and other nonessential objects before saving. I'll move it higher up the priority.

@dennisbader dennisbader added improvement New feature or improvement and removed triage Issue waiting for triaging labels Jan 8, 2025
@dennisbader dennisbader added this to darts Jan 8, 2025
@github-project-automation github-project-automation bot moved this to To do in darts Jan 8, 2025
@MarcBresson
Copy link
Contributor Author

MarcBresson commented Jan 8, 2025

Ok thank you. What is the role of the trainer_params attribute? Isn't it redundant with the trainer attribute?

--EDIT-- to clarify, it seems like most of the info available in trainer_params are also available under the trainer attribute. Is it because the trainer is saved as a separate file that you still want to include the trainer_params mapping in the main file?

@dennisbader
Copy link
Collaborator

dennisbader commented Jan 9, 2025

There are three things:

  • pl_trainer_kwargs will not be used directly, but is stored under ForecastingModel._model_params (accessible through ForecastingModel.model_params. This allows us to re-create the model from scratch using your input arguments at the state they were when you created the model. We need this for example in historical_forecasts with retrain=True, where we need to have a fresh model instance for every iteration.
  • trainer_params is initially a deepcopy of your pl_trainer_kwargs. We use trainer_params to create the Lightning Trainer for training and prediction. We store this attribute when saving the model to be able to recreate the trainer upon loading. However there are some trainer parameters which cause issues for loading (e.g. callbacks, some objects, ...). This is what we want to improve, since callbacks are mostly only required for training, or can simply be passed to predict() with a new trainer object.
  • trainer is the PyTorch Lightning trainer used for training and prediction. If you do not pass a trainer to fit()/predict(), we will use the trainer_params from your pl_trainer_kwargs set at model creation.
    We do not save the trainer, it's only used to handle the underlying model (training, prediciton, checkpointing, saving / loading...)

Hope this clears things up.

@MarcBresson
Copy link
Contributor Author

Thank you very much. It is crystal clear.

It seems like pytorch lightning checkpoints contain a lot of info, but it's probably not enough to recreate the trainer. When I'll have time, I will make a deeper dive into pytorch lightning, it seems to be full of good stuff!

@dennisbader
Copy link
Collaborator

No worries :) Indeed, it's a great tool!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working improvement New feature or improvement
Projects
Status: In progress
Development

No branches or pull requests

3 participants