Skip to content

Commit

Permalink
Add prediction via callback
Browse files Browse the repository at this point in the history
  • Loading branch information
jdb78 committed Apr 6, 2023
1 parent d1fa3a7 commit 038b0c3
Show file tree
Hide file tree
Showing 14 changed files with 4,647 additions and 4,557 deletions.
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
PYTORCH_ENABLE_MPS_FALLBACK=1
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ celerybeat.pid
*.sage.py

# Environments
.env
# .env
.venv
env/
venv/
Expand Down
13 changes: 11 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
# Release Notes

## v0.10.4 UNRELEASED (xx/xx/xxxx)
## v1.0.0 Update to pytorch 2.0 (xx/xx/xxxx)


### Breaking Changes

- Upgraded to pytorch 2.0 and lightning 2.0. This brings a couple of changes, such as configuration of trainers. See the [lightning upgrade guide](https://lightning.ai/docs/pytorch/latest/upgrade/migration_guide.html). For PyTorch Forecasting, this particularly means if you are developing own models, the class method `epoch_end` has been renamed to `on_epoch_end` and replacing `model.summarize()` with `ModelSummary(model, max_depth=-1)` and `Tuner(trainer)` is its own class, so `trainer.tuner` needs replacing. (#1280)

### Changes

- The predict method is now using the lightning predict functionality and allows writing results to disk (#1280).

### Fixed

Expand Down Expand Up @@ -402,7 +411,7 @@ This release has only one purpose: Allow usage of PyTorch Lightning 1.0 - all te
- Using `LearningRateMonitor` instead of `LearningRateLogger`
- Use `EarlyStopping` callback in trainer `callbacks` instead of `early_stopping` argument
- Update metric system `update()` and `compute()` methods
- Use `trainer.tuner.lr_find()` instead of `trainer.lr_find()` in tutorials and examples
- Use `Tuner(trainer).lr_find()` instead of `trainer.lr_find()` in tutorials and examples
- Update poetry to 1.1.0

---
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
# import dataset, network to train and metric to optimize
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, QuantileLoss
from lightning.pytorch.tuner import Tuner

# load data: this is pandas dataframe with at least a column for
# * the target (what you want to predict)
Expand Down Expand Up @@ -159,7 +160,7 @@ tft = TemporalFusionTransformer.from_dataset(
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# find the optimal learning rate
res = trainer.lr_find(
res = Tuner(trainer).lr_find(
tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, early_stop_threshold=1000.0, max_lr=0.3,
)
# and plot the result - always visually confirm that the suggested learning rate makes sense
Expand Down
4 changes: 2 additions & 2 deletions docs/source/getting-started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Example
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.tuner import Tuner
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
# load data
Expand Down Expand Up @@ -127,7 +127,7 @@ Example
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")
# find optimal learning rate (set limit_train_batches to 1.0 and log_interval = -1)
res = trainer.tuner.lr_find(
res = Tuner(trainer).lr_find(
tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, early_stop_threshold=1000.0, max_lr=0.3,
)
Expand Down
372 changes: 104 additions & 268 deletions docs/source/tutorials/ar.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 038b0c3

Please sign in to comment.