Reproduction of Chronos-small Zero-Shot Forecasting Results #120
-
Hi, I am attempting to reproduce the Chronos zero-shot forecasting results on the 27 benchmark II datasets from the paper. I have carefully followed the methods outlined in the paper, including the MASE formula, seasonality parameter, forecast horizon H, context length, etc. Despite this, my zero-shot forecast results using the pre-trained amazon/chronos-t5-small model significantly deviate from the reported MASE scores in Appendix E Table 9. To ensure accuracy and consistency, I kindly request access to the code you used to evaluate Chronos-small on these 27 open-source datasets. Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 6 replies
-
@kushalkodn-db Please check Issue #75 for an example on how to use GluonTS for computing metrics as we did in our paper. As we note in that issue, the datasets in GluonTS with the same name may differ from the setting in the paper. We are working towards release of our evaluation datasets soon. Please follow #102 for updates. |
Beta Was this translation helpful? Give feedback.
-
Hey @abdulfatir, I appreciate all the detailed responses you've provided on all discussion threads. I ran this code to generate the MASE/WQL scores for benchmark II (tables 8-9 in Appendix E) using chronos-t5-small, but my scores don't match up exactly. Is this expected behavior, or am I missing something? The screenshot includes both the computed scores and paper's reported scores. import numpy as np
import pandas as pd
import torch
from gluonts.dataset.repository import get_dataset
from gluonts.dataset.split import split
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.itertools import batcher
from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model.forecast import SampleForecast
from tqdm.auto import tqdm
from src.chronos import ChronosPipeline
filenames = ['australian_electricity_demand', 'car_parts_without_missing', 'cif_2016', 'covid_deaths', 'dominick', 'ercot', 'ett_small_15min', 'ett_small_1h', 'exchange_rate',
'fred_md', 'hospital', 'm1_monthly', 'm1_quarterly', 'm1_yearly', 'm3_monthly', 'm3_quarterly', 'm3_yearly', 'm4_quarterly', 'm4_yearly', 'm5',
'nn5_daily_without_missing', 'nn5_weekly', 'tourism_monthly', 'tourism_quarterly', 'tourism_yearly', 'traffic', 'weather']
file_horizon_mapping = {'australian_electricity_demand': 48, 'car_parts_without_missing': 12, 'cif_2016': 12, 'covid_deaths': 30, 'dominick': 8, 'ercot': 24, 'ett_small_15min': 24, 'ett_small_1h': 24,
'exchange_rate': 30, 'fred_md': 12, 'hospital': 12, 'm1_monthly': 18, 'm1_quarterly': 8, 'm1_yearly': 6, 'm3_monthly': 18, 'm3_quarterly': 8, 'm3_yearly': 6, 'm4_quarterly': 8, 'm4_yearly': 6, 'm5': 28,
'nn5_daily_without_missing': 56, 'nn5_weekly': 8, 'tourism_monthly': 24, 'tourism_quarterly': 8, 'tourism_yearly': 4, 'traffic': 24, 'weather': 30}
all_metrics = pd.DataFrame(columns=['filename', 'MASE[0.5]', 'mean_weighted_sum_quantile_loss'])
batch_size = 32
num_samples = 20
for filename in filenames:
# Load dataset
prediction_length = file_horizon_mapping[filename]
dataset = get_dataset(dataset_name=filename)
# Load Chronos
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-small",
device_map="cuda:0",
torch_dtype=torch.bfloat16,
)
# Split dataset for evaluation
_, test_template = split(dataset.test, offset=-prediction_length)
test_data = test_template.generate_instances(prediction_length)
# Generate forecast samples
forecast_samples = []
for batch in tqdm(batcher(test_data.input, batch_size=32)):
context = [torch.tensor(entry["target"]) for entry in batch]
forecast_samples.append(
pipeline.predict(
context,
prediction_length=prediction_length,
num_samples=num_samples,
).numpy()
)
forecast_samples = np.concatenate(forecast_samples)
# Convert forecast samples into gluonts SampleForecast objects
sample_forecasts = []
for item, ts in zip(forecast_samples, test_data.input):
forecast_start_date = ts["start"] + len(ts["target"])
sample_forecasts.append(
SampleForecast(samples=item, start_date=forecast_start_date)
)
# Evaluate
metrics_df = evaluate_forecasts(
sample_forecasts,
test_data=test_data,
metrics=[
MASE(),
MeanWeightedSumQuantileLoss(np.arange(0.1, 1.0, 0.1)),
],
)
new_df = pd.DataFrame({
'filename': [filename],
'MASE[0.5]': [metrics_df['MASE[0.5]']],
'mean_weighted_sum_quantile_loss': [metrics_df['mean_weighted_sum_quantile_loss']],
})
all_metrics = pd.concat([all_metrics, new_df], ignore_index=True)
print(all_metrics) |
Beta Was this translation helpful? Give feedback.
-
Update: We have just open-sourced the datasets used in the paper (thanks @shchur!). Please check the updated README. We have also released an evaluation script and backtest configs to compute the WQL and MASE numbers as reported in the paper. Please follow the instructions in this README to evaluate on the in-domain and zero-shot benchmarks. |
Beta Was this translation helpful? Give feedback.
Update: We have just open-sourced the datasets used in the paper (thanks @shchur!). Please check the updated README. We have also released an evaluation script and backtest configs to compute the WQL and MASE numbers as reported in the paper. Please follow the instructions in this README to evaluate on the in-domain and zero-shot benchmarks.