diff --git a/README.md b/README.md index 9489608..e921489 100644 --- a/README.md +++ b/README.md @@ -114,16 +114,16 @@ df = pd.read_csv( # context must be either a 1D tensor, a list of 1D tensors, # or a left-padded 2D tensor with batch as the first dimension -# The original Chronos models generate forecast samples, so forecast has shape -# [num_series, num_samples, prediction_length]. -# Chronos-Bolt models generate quantile forecasts, so forecast has shape -# [num_series, num_quantiles, prediction_length]. -forecast = pipeline.predict( - context=torch.tensor(df["#Passengers"]), prediction_length=12 +# quantiles is an fp32 tensor with shape [batch_size, prediction_length, num_quantile_levels] +# mean is an fp32 tensor with shape [batch_size, prediction_length] +quantiles, mean = pipeline.predict_quantiles( + context=torch.tensor(df["#Passengers"]), + prediction_length=12, + quantile_levels=[0.1, 0.5, 0.9], ) ``` -More options for `pipeline.predict` can be found with: +For the original Chronos models, `pipeline.predict` can be used to draw forecast samples. More options for `predict_kwargs` in `pipeline.predict_quantiles` can be found with: ```python from chronos import ChronosPipeline, ChronosBoltPipeline @@ -136,10 +136,9 @@ We can now visualize the forecast: ```python import matplotlib.pyplot as plt # requires: pip install matplotlib -import numpy as np forecast_index = range(len(df), len(df) + 12) -low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) +low, median, high = quantiles[0, :, 0], quantiles[0, :, 1], quantiles[0, :, 2] plt.figure(figsize=(8, 4)) plt.plot(df["#Passengers"], color="royalblue", label="historical data")