Skip to content

Commit

Permalink
Fix README example to use predict_quantiles (#220)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:* `predict` returns different things based on
model type. This fixes the example to use `predict_quantiles` which will
give correct quantiles.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

Co-authored-by: Abdul Fatir Ansari <[email protected]>
  • Loading branch information
abdulfatir and Abdul Fatir Ansari authored Nov 29, 2024
1 parent 4c43cfb commit e3bbda7
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,16 @@ df = pd.read_csv(

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
# The original Chronos models generate forecast samples, so forecast has shape
# [num_series, num_samples, prediction_length].
# Chronos-Bolt models generate quantile forecasts, so forecast has shape
# [num_series, num_quantiles, prediction_length].
forecast = pipeline.predict(
context=torch.tensor(df["#Passengers"]), prediction_length=12
# quantiles is an fp32 tensor with shape [batch_size, prediction_length, num_quantile_levels]
# mean is an fp32 tensor with shape [batch_size, prediction_length]
quantiles, mean = pipeline.predict_quantiles(
context=torch.tensor(df["#Passengers"]),
prediction_length=12,
quantile_levels=[0.1, 0.5, 0.9],
)
```

More options for `pipeline.predict` can be found with:
For the original Chronos models, `pipeline.predict` can be used to draw forecast samples. More options for `predict_kwargs` in `pipeline.predict_quantiles` can be found with:

```python
from chronos import ChronosPipeline, ChronosBoltPipeline
Expand All @@ -136,10 +136,9 @@ We can now visualize the forecast:

```python
import matplotlib.pyplot as plt # requires: pip install matplotlib
import numpy as np

forecast_index = range(len(df), len(df) + 12)
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
low, median, high = quantiles[0, :, 0], quantiles[0, :, 1], quantiles[0, :, 2]

plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
Expand Down

0 comments on commit e3bbda7

Please sign in to comment.