Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Uncomplete metrics on callback since darts 0.30 (everything still works with darts 0.29) #2637

Open
MarcBresson opened this issue Jan 6, 2025 · 3 comments
Labels
bug Something isn't working

Comments

@MarcBresson
Copy link
Contributor

MarcBresson commented Jan 6, 2025

Describe the bug
On the first callback on_validation_epoch_end or on_train_epoch_end, no metric given in torch_metrics is available.

To Reproduce

import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, Callback
from torchmetrics import MeanAbsolutePercentageError, MetricCollection, MeanSquaredError

from darts.dataprocessing.transformers import Scaler
from darts.datasets import AirPassengersDataset
from darts.models import NBEATSModel


class LiveMetricsCallback(Callback):
    def on_train_epoch_end(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
    ) -> None:
        print()
        print("train", trainer.current_epoch, trainer.callback_metrics)

    def on_validation_epoch_end(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
    ) -> None:
        print()
        print("val", trainer.current_epoch, trainer.callback_metrics)

# read data
series = AirPassengersDataset().load()

# create training and validation sets:
train, val = series.split_after(pd.Timestamp(year=1957, month=12, day=1))

# normalize the time series
transformer = Scaler()
train = transformer.fit_transform(train)
val = transformer.transform(val)

# any TorchMetric or val_loss can be used as the monitor
torch_metrics = MetricCollection([MeanAbsolutePercentageError(), MeanSquaredError()])

# early stop callback
my_stopper = LiveMetricsCallback()
pl_trainer_kwargs = {"callbacks": [my_stopper], "accelerator": "cpu", "log_every_n_steps": 1}

# create the model
model = NBEATSModel(
    input_chunk_length=24,
    output_chunk_length=12,
    n_epochs=3,
    torch_metrics=torch_metrics,
    pl_trainer_kwargs=pl_trainer_kwargs
)

# use validation set for early stopping
model.fit(
    series=train,
    val_series=val,
)

Expected behavior
In darts 0.29, it was giving all the metrics right away:

Sanity Checking DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  9.71it/s]
val 0 {'val_loss': tensor(2.1967, dtype=torch.float64), 'val_MeanAbsolutePercentageError': tensor(1.2826, dtype=torch.float64), 'val_MeanSquaredError': tensor(2.1967, dtype=torch.float64)}
                                                                           
Epoch 0: 100%|██████████| 3/3 [00:00<00:00, 15.81it/s, train_loss=0.141]
val 0 {'train_loss': tensor(0.1408, dtype=torch.float64), 'val_loss': tensor(0.4933, dtype=torch.float64), 'val_MeanAbsolutePercentageError': tensor(0.6230, dtype=torch.float64), 'val_MeanSquaredError': tensor(0.4933, dtype=torch.float64)}
Epoch 0: 100%|██████████| 3/3 [00:00<00:00, 14.64it/s, train_loss=0.141, val_loss=0.493, val_MeanAbsolutePercentageError=0.623, val_MeanSquaredError=0.493]
train 0 {'train_loss': tensor(0.1408, dtype=torch.float64), 'val_loss': tensor(0.4933, dtype=torch.float64), 'val_MeanAbsolutePercentageError': tensor(0.6230, dtype=torch.float64), 'val_MeanSquaredError': tensor(0.4933, dtype=torch.float64), 'train_MeanAbsolutePercentageError': tensor(1.8711, dtype=torch.float64), 'train_MeanSquaredError': tensor(0.7567, dtype=torch.float64)}
Epoch 1: 100%|██████████| 3/3 [00:00<00:00, 18.56it/s, train_loss=0.0535, val_loss=0.493, val_MeanAbsolutePercentageError=0.623, val_MeanSquaredError=0.493, train_MeanAbsolutePercentageError=1.870, train_MeanSquaredError=0.757]
val 1 {'train_loss': tensor(0.0535, dtype=torch.float64), 'val_loss': tensor(0.3712, dtype=torch.float64), 'val_MeanAbsolutePercentageError': tensor(0.5145, dtype=torch.float64), 'val_MeanSquaredError': tensor(0.3712, dtype=torch.float64), 'train_MeanAbsolutePercentageError': tensor(1.8711, dtype=torch.float64), 'train_MeanSquaredError': tensor(0.7567, dtype=torch.float64)}
Epoch 1: 100%|██████████| 3/3 [00:00<00:00, 16.84it/s, train_loss=0.0535, val_loss=0.371, val_MeanAbsolutePercentageError=0.515, val_MeanSquaredError=0.371, train_MeanAbsolutePercentageError=1.870, train_MeanSquaredError=0.757]
train 1 {'train_loss': tensor(0.0535, dtype=torch.float64), 'val_loss': tensor(0.3712, dtype=torch.float64), 'val_MeanAbsolutePercentageError': tensor(0.5145, dtype=torch.float64), 'val_MeanSquaredError': tensor(0.3712, dtype=torch.float64), 'train_MeanAbsolutePercentageError': tensor(0.9090, dtype=torch.float64), 'train_MeanSquaredError': tensor(0.1487, dtype=torch.float64)}
Epoch 2: 100%|██████████| 3/3 [00:00<00:00, 19.50it/s, train_loss=0.0609, val_loss=0.371, val_MeanAbsolutePercentageError=0.515, val_MeanSquaredError=0.371, train_MeanAbsolutePercentageError=0.909, train_MeanSquaredError=0.149]
val 2 {'train_loss': tensor(0.0609, dtype=torch.float64), 'val_loss': tensor(0.2675, dtype=torch.float64), 'val_MeanAbsolutePercentageError': tensor(0.4588, dtype=torch.float64), 'val_MeanSquaredError': tensor(0.2675, dtype=torch.float64), 'train_MeanAbsolutePercentageError': tensor(0.9090, dtype=torch.float64), 'train_MeanSquaredError': tensor(0.1487, dtype=torch.float64)}
Epoch 2: 100%|██████████| 3/3 [00:00<00:00, 17.72it/s, train_loss=0.0609, val_loss=0.267, val_MeanAbsolutePercentageError=0.459, val_MeanSquaredError=0.267, train_MeanAbsolutePercentageError=0.909, train_MeanSquaredError=0.149]
train 2 {'train_loss': tensor(0.0609, dtype=torch.float64), 'val_loss': tensor(0.2675, dtype=torch.float64), 'val_MeanAbsolutePercentageError': tensor(0.4588, dtype=torch.float64), 'val_MeanSquaredError': tensor(0.2675, dtype=torch.float64), 'train_MeanAbsolutePercentageError': tensor(0.5780, dtype=torch.float64), 'train_MeanSquaredError': tensor(0.0712, dtype=torch.float64)}
Epoch 2: 100%|██████████| 3/3 [00:00<00:00, 17.51it/s, train_loss=0.0609, val_loss=0.267, val_MeanAbsolutePercentageError=0.459, val_MeanSquaredError=0.267, train_MeanAbsolutePercentageError=0.578, train_MeanSquaredError=0.0712]
`Trainer.fit` stopped: `max_epochs=3` reached.
Epoch 2: 100%|██████████| 3/3 [00:00<00:00, 17.44it/s, train_loss=0.0609, val_loss=0.267, val_MeanAbsolutePercentageError=0.459, val_MeanSquaredError=0.267, train_MeanAbsolutePercentageError=0.578, train_MeanSquaredError=0.0712]

but on darts 0.30, metrics are missing

Sanity Checking DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  9.62it/s]
val 0 {'val_loss': tensor(5.1715, dtype=torch.float64)}
                                                                           
Epoch 0: 100%|██████████| 3/3 [00:00<00:00, 16.22it/s, train_loss=0.595]
val 0 {'train_loss': tensor(0.5954, dtype=torch.float64), 'val_loss': tensor(0.5011, dtype=torch.float64)}
Epoch 0: 100%|██████████| 3/3 [00:00<00:00, 14.97it/s, train_loss=0.595, val_loss=0.501, val_MeanAbsolutePercentageError=0.556, val_MeanSquaredError=0.501]
train 0 {'train_loss': tensor(0.5954, dtype=torch.float64), 'val_loss': tensor(0.5011, dtype=torch.float64), 'val_MeanAbsolutePercentageError': tensor(0.5562), 'val_MeanSquaredError': tensor(0.5011)}
Epoch 1: 100%|██████████| 3/3 [00:00<00:00, 19.18it/s, train_loss=0.287, val_loss=0.501, val_MeanAbsolutePercentageError=0.556, val_MeanSquaredError=0.501, train_MeanAbsolutePercentageError=3.500, train_MeanSquaredError=2.280]
val 1 {'train_loss': tensor(0.2875, dtype=torch.float64), 'val_loss': tensor(0.1900, dtype=torch.float64), 'val_MeanAbsolutePercentageError': tensor(0.5562), 'val_MeanSquaredError': tensor(0.5011), 'train_MeanAbsolutePercentageError': tensor(3.5033), 'train_MeanSquaredError': tensor(2.2803)}
Epoch 1: 100%|██████████| 3/3 [00:00<00:00, 14.68it/s, train_loss=0.287, val_loss=0.190, val_MeanAbsolutePercentageError=0.366, val_MeanSquaredError=0.190, train_MeanAbsolutePercentageError=3.500, train_MeanSquaredError=2.280]
train 1 {'train_loss': tensor(0.2875, dtype=torch.float64), 'val_loss': tensor(0.1900, dtype=torch.float64), 'val_MeanAbsolutePercentageError': tensor(0.3659), 'val_MeanSquaredError': tensor(0.1900), 'train_MeanAbsolutePercentageError': tensor(3.5033), 'train_MeanSquaredError': tensor(2.2803)}
Epoch 2: 100%|██████████| 3/3 [00:00<00:00, 18.16it/s, train_loss=0.158, val_loss=0.190, val_MeanAbsolutePercentageError=0.366, val_MeanSquaredError=0.190, train_MeanAbsolutePercentageError=1.340, train_MeanSquaredError=0.332]
val 2 {'train_loss': tensor(0.1579, dtype=torch.float64), 'val_loss': tensor(0.5300, dtype=torch.float64), 'val_MeanAbsolutePercentageError': tensor(0.3659), 'val_MeanSquaredError': tensor(0.1900), 'train_MeanAbsolutePercentageError': tensor(1.3388), 'train_MeanSquaredError': tensor(0.3322)}
Epoch 2: 100%|██████████| 3/3 [00:00<00:00, 16.51it/s, train_loss=0.158, val_loss=0.530, val_MeanAbsolutePercentageError=0.529, val_MeanSquaredError=0.530, train_MeanAbsolutePercentageError=1.340, train_MeanSquaredError=0.332]
train 2 {'train_loss': tensor(0.1579, dtype=torch.float64), 'val_loss': tensor(0.5300, dtype=torch.float64), 'val_MeanAbsolutePercentageError': tensor(0.5293), 'val_MeanSquaredError': tensor(0.5300), 'train_MeanAbsolutePercentageError': tensor(1.3388), 'train_MeanSquaredError': tensor(0.3322)}
Epoch 2: 100%|██████████| 3/3 [00:00<00:00, 16.31it/s, train_loss=0.158, val_loss=0.530, val_MeanAbsolutePercentageError=0.529, val_MeanSquaredError=0.530, train_MeanAbsolutePercentageError=1.340, train_MeanSquaredError=0.332]
`Trainer.fit` stopped: `max_epochs=3` reached.
Epoch 2: 100%|██████████| 3/3 [00:00<00:00, 16.19it/s, train_loss=0.158, val_loss=0.530, val_MeanAbsolutePercentageError=0.529, val_MeanSquaredError=0.530, train_MeanAbsolutePercentageError=1.340, train_MeanSquaredError=0.332]

as you can see, on val 0 I don't have any metrics, but on train 0 I have all the validation metrics (but no training metrics)

System (please complete the following information):

  • Python version: 3.10
  • darts version 0.30
@MarcBresson MarcBresson added bug Something isn't working triage Issue waiting for triaging labels Jan 6, 2025
@dennisbader
Copy link
Collaborator

Hi @MarcBresson and thanks for raising this issue.

Indeed, this was introduced in #2391, where we added support for stateful metrics (e.g. cache metric results per batch and then aggregate only at the end of the epoch).

The problem is that the Callback hook on_*_epoch_end() is called before the module's (model) internal on_*_epoch_end() where the metrics are computed.

So ultimately, in the callback, you currently only have access to metrics of the last epoch.

I can investigate if there is a way to compute the metrics before the callback hook. But the current implementation is actually based on what torchmetrics recommends.

Anyways, you can fix this in your Callback by computing the metrics on the go.

class LiveMetricsCallback(Callback):
    def on_train_epoch_end(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
    ) -> None:
        print()
        print("train", trainer.current_epoch, self.get_metrics(trainer, pl_module))

    def on_validation_epoch_end(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
    ) -> None:
        print()
        print("val", trainer.current_epoch, self.get_metrics(trainer, pl_module))

    @staticmethod
    def get_metrics(trainer, pl_module):
        """Computes and returns metrics and losses at the current state."""
        losses = {
            "train_loss": trainer.callback_metrics.get("train_loss"),
            "val_loss": trainer.callback_metrics.get("val_loss"),
        }
        return dict(
            losses,
            **pl_module.train_metrics.compute(),
            **pl_module.val_metrics.compute(),
        )

It will raise a warning when performing sanity checks (first validation run) where the train metrics have not been updated. This is expected and not a problem (can also be fixed with some more lines of code).

@MarcBresson
Copy link
Contributor Author

Thank you for your quick answer :) I'll do what you suggested.

It has been almost two years since I contributed to torchmetrics, so I may be a bit rusty on that matter. After going through their source code (both of pytorch-lightning and torchmetrics), I have one more question:

EarlyStopping rely on the exact same behaviour as my LiveMetricsCallback i.e. it uses the trainer callback_metrics. Does that mean that using ES on a metric (I don't know why you would do that but...) will be delayed by one epoch since the computation of metrics happen after the callback? One workaround would be to have a callback before the ES that actually computes the metrics.

@dennisbader
Copy link
Collaborator

That is true, and a bit worrying. Might be worth contacting torchmetrics at this point to clarify.

@dennisbader dennisbader removed the triage Issue waiting for triaging label Jan 8, 2025
@dennisbader dennisbader added this to darts Jan 8, 2025
@github-project-automation github-project-automation bot moved this to To do in darts Jan 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: To do
Development

No branches or pull requests

2 participants