Skip to content

Commit

Permalink
feat: 🎸 add NBeats
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Dec 15, 2023
1 parent ef73a61 commit 24be5c0
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 1 deletion.
120 changes: 120 additions & 0 deletions baselines/NBeats/M4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import os
import sys

# TODO: remove it when basicts can be installed by pip
sys.path.append(os.path.abspath(__file__ + "/../../.."))
from easydict import EasyDict
from basicts.losses import masked_mae
from basicts.data import M4ForecastingDataset
from basicts.runners import M4ForecastingRunner

from .arch import NBeats

def get_cfg(seasonal_pattern):
assert seasonal_pattern in ["Yearly", "Quarterly", "Monthly", "Weekly", "Daily", "Hourly"]
prediction_len = {"Yearly": 6, "Quarterly": 8, "Monthly": 18, "Weekly": 13, "Daily": 14, "Hourly": 48}[seasonal_pattern]
num_nodes = {"Yearly": 23000, "Quarterly": 24000, "Monthly": 48000, "Weekly": 359, "Daily": 4227, "Hourly": 414}[seasonal_pattern]
history_size = 2
history_len = history_size * prediction_len

CFG = EasyDict()

# ================= general ================= #
CFG.DESCRIPTION = "Multi-layer perceptron model configuration "
CFG.RUNNER = M4ForecastingRunner
CFG.DATASET_CLS = M4ForecastingDataset
CFG.DATASET_NAME = "M4_" + seasonal_pattern
CFG.DATASET_INPUT_LEN = history_len
CFG.DATASET_OUTPUT_LEN = prediction_len
CFG.GPU_NUM = 1

# ================= environment ================= #
CFG.ENV = EasyDict()
CFG.ENV.SEED = 1
CFG.ENV.CUDNN = EasyDict()
CFG.ENV.CUDNN.ENABLED = True

# ================= model ================= #
CFG.MODEL = EasyDict()
CFG.MODEL.NAME = "NBeats"
CFG.MODEL.ARCH = NBeats
CFG.MODEL.PARAM = {
"type": "generic",
"input_size": CFG.DATASET_INPUT_LEN,
"output_size": CFG.DATASET_OUTPUT_LEN,
"layer_size": 512,
"layers": 4,
"stacks": 30
}
# CFG.MODEL.PARAM = {
# "type": "interpretable",
# "input_size": CFG.DATASET_INPUT_LEN,
# "output_size": CFG.DATASET_OUTPUT_LEN,
# "seasonality_layer_size": 2048,
# "seasonality_blocks": 3,
# "seasonality_layers": 4,
# "trend_layer_size": 256,
# "degree_of_polynomial": 2,
# "trend_blocks": 3,
# "trend_layers": 4,
# "num_of_harmonics": 1
# }
CFG.MODEL.FORWARD_FEATURES = [0]
CFG.MODEL.TARGET_FEATURES = [0]

# ================= optim ================= #
CFG.TRAIN = EasyDict()
CFG.TRAIN.LOSS = masked_mae
CFG.TRAIN.OPTIM = EasyDict()
CFG.TRAIN.OPTIM.TYPE = "Adam"
CFG.TRAIN.OPTIM.PARAM = {
"lr": 0.002,
"weight_decay": 0.0001,
}
CFG.TRAIN.LR_SCHEDULER = EasyDict()
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR"
CFG.TRAIN.LR_SCHEDULER.PARAM = {
"milestones": [1, 50, 80],
"gamma": 0.5
}

# ================= train ================= #
CFG.TRAIN.CLIP_GRAD_PARAM = {
'max_norm': 5.0
}
CFG.TRAIN.NUM_EPOCHS = 100
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
"checkpoints",
"_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)])
)
# train data
CFG.TRAIN.DATA = EasyDict()
# read data
CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME
# dataloader args, optional
CFG.TRAIN.DATA.BATCH_SIZE = 64
CFG.TRAIN.DATA.PREFETCH = False
CFG.TRAIN.DATA.SHUFFLE = True
CFG.TRAIN.DATA.NUM_WORKERS = 2
CFG.TRAIN.DATA.PIN_MEMORY = False

# ================= test ================= #
CFG.TEST = EasyDict()
CFG.TEST.INTERVAL = CFG.TRAIN.NUM_EPOCHS
# test data
CFG.TEST.DATA = EasyDict()
# read data
CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME
# dataloader args, optional
CFG.TEST.DATA.BATCH_SIZE = 64
CFG.TEST.DATA.PREFETCH = False
CFG.TEST.DATA.SHUFFLE = False
CFG.TEST.DATA.NUM_WORKERS = 2
CFG.TEST.DATA.PIN_MEMORY = False

# ================= evaluate ================= #
CFG.EVAL = EasyDict()
CFG.EVAL.HORIZONS = []
CFG.EVAL.SAVE_PATH = os.path.abspath(__file__ + "/..")

return CFG
1 change: 1 addition & 0 deletions baselines/NBeats/arch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .nbeats import NBeats
197 changes: 197 additions & 0 deletions baselines/NBeats/arch/nbeats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# This source code is provided for the purposes of scientific reproducibility
# under the following limited license from Element AI Inc. The code is an
# implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis
# expansion analysis for interpretable time series forecasting,
# https://arxiv.org/abs/1905.10437). The copyright to the source code is
# licensed under the Creative Commons - Attribution-NonCommercial 4.0
# International license (CC BY-NC 4.0):
# https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether
# for the benefit of third parties or internally in production) requires an
# explicit license. The subject-matter of the N-BEATS model and associated
# materials are the property of Element AI Inc. and may be subject to patent
# protection. No license to patents is granted hereunder (whether express or
# implied). Copyright © 2020 Element AI Inc. All rights reserved.

# Modified from:

"""
N-BEATS Model.
"""
from typing import Tuple

import numpy as np
import torch as t


class NBeatsBlock(t.nn.Module):
"""
N-BEATS block which takes a basis function as an argument.
"""
def __init__(self,
input_size,
theta_size: int,
basis_function: t.nn.Module,
layers: int,
layer_size: int):
"""
N-BEATS block.
:param input_size: Insample size.
:param theta_size: Number of parameters for the basis function.
:param basis_function: Basis function which takes the parameters and produces backcast and forecast.
:param layers: Number of layers.
:param layer_size: Layer size.
"""
super().__init__()
self.layers = t.nn.ModuleList([t.nn.Linear(in_features=input_size, out_features=layer_size)] +
[t.nn.Linear(in_features=layer_size, out_features=layer_size)
for _ in range(layers - 1)])
self.basis_parameters = t.nn.Linear(in_features=layer_size, out_features=theta_size)
self.basis_function = basis_function

def forward(self, x: t.Tensor) -> Tuple[t.Tensor, t.Tensor]:
block_input = x
for layer in self.layers:
block_input = t.relu(layer(block_input))
basis_parameters = self.basis_parameters(block_input)
return self.basis_function(basis_parameters)


class NBeats(t.nn.Module):
"""
Paper: N-BEATS: Neural basis expansion analysis for interpretable time series forecasting
Link: https://arxiv.org/abs/1905.10437
Official Code:
https://github.com/ServiceNow/N-BEATS
https://github.com/philipperemy/n-beats
"""

def __init__(self, input_size: int, type: str, output_size: int, **kwargs):
super().__init__()
assert type in ["generic", "interpretable"], "Unknown type of N-Beats model"
if type == "generic":
# input_size: int, output_size: int,
# stacks: int, layers: int, layer_size: int
stacks = kwargs["stacks"]
layers = kwargs["layers"]
layer_size = kwargs["layer_size"]
self.blocks = t.nn.ModuleList([NBeatsBlock(input_size=input_size,
theta_size=input_size + output_size,
basis_function=GenericBasis(backcast_size=input_size,
forecast_size=output_size),
layers=layers,
layer_size=layer_size)
for _ in range(stacks)])
pass
else:
trend_blocks = kwargs["trend_blocks"]
trend_layers = kwargs["trend_layers"]
trend_layer_size = kwargs["trend_layer_size"]
degree_of_polynomial = kwargs["degree_of_polynomial"]
seasonality_blocks = kwargs["seasonality_blocks"]
seasonality_layers = kwargs["seasonality_layers"]
seasonality_layer_size = kwargs["seasonality_layer_size"]
num_of_harmonics = kwargs["num_of_harmonics"]

trend_block = NBeatsBlock(input_size=input_size,
theta_size=2 * (degree_of_polynomial + 1),
basis_function=TrendBasis(degree_of_polynomial=degree_of_polynomial,
backcast_size=input_size,
forecast_size=output_size),
layers=trend_layers,
layer_size=trend_layer_size)
seasonality_block = NBeatsBlock(input_size=input_size,
theta_size=4 * int(
np.ceil(num_of_harmonics / 2 * output_size) - (num_of_harmonics - 1)),
basis_function=SeasonalityBasis(harmonics=num_of_harmonics,
backcast_size=input_size,
forecast_size=output_size),
layers=seasonality_layers,
layer_size=seasonality_layer_size)
self.blocks = t.nn.ModuleList(
[trend_block for _ in range(trend_blocks)] + [seasonality_block for _ in range(seasonality_blocks)])

def forward(self, history_data: t.Tensor, history_mask: t.Tensor, **kwargs) -> t.Tensor:
x = history_data.squeeze()
input_mask = history_mask.squeeze()

residuals = x.flip(dims=(1,))
input_mask = input_mask.flip(dims=(1,))
forecast = x[:, -1:]
for i, block in enumerate(self.blocks):
backcast, block_forecast = block(residuals)
residuals = (residuals - backcast) * input_mask
forecast = forecast + block_forecast
forecast = forecast.unsqueeze(-1).unsqueeze(-1)
return forecast


class GenericBasis(t.nn.Module):
"""
Generic basis function.
"""
def __init__(self, backcast_size: int, forecast_size: int):
super().__init__()
self.backcast_size = backcast_size
self.forecast_size = forecast_size

def forward(self, theta: t.Tensor):
return theta[:, :self.backcast_size], theta[:, -self.forecast_size:]


class TrendBasis(t.nn.Module):
"""
Polynomial function to model trend.
"""
def __init__(self, degree_of_polynomial: int, backcast_size: int, forecast_size: int):
super().__init__()
self.polynomial_size = degree_of_polynomial + 1 # degree of polynomial with constant term
self.backcast_time = t.nn.Parameter(
t.tensor(np.concatenate([np.power(np.arange(backcast_size, dtype=np.float) / backcast_size, i)[None, :]
for i in range(self.polynomial_size)]), dtype=t.float32),
requires_grad=False)
self.forecast_time = t.nn.Parameter(
t.tensor(np.concatenate([np.power(np.arange(forecast_size, dtype=np.float) / forecast_size, i)[None, :]
for i in range(self.polynomial_size)]), dtype=t.float32), requires_grad=False)

def forward(self, theta: t.Tensor):
backcast = t.einsum('bp,pt->bt', theta[:, self.polynomial_size:], self.backcast_time)
forecast = t.einsum('bp,pt->bt', theta[:, :self.polynomial_size], self.forecast_time)
return backcast, forecast


class SeasonalityBasis(t.nn.Module):
"""
Harmonic functions to model seasonality.
"""
def __init__(self, harmonics: int, backcast_size: int, forecast_size: int):
super().__init__()
self.frequency = np.append(np.zeros(1, dtype=np.float32),
np.arange(harmonics, harmonics / 2 * forecast_size,
dtype=np.float32) / harmonics)[None, :]
backcast_grid = -2 * np.pi * (
np.arange(backcast_size, dtype=np.float32)[:, None] / forecast_size) * self.frequency
forecast_grid = 2 * np.pi * (
np.arange(forecast_size, dtype=np.float32)[:, None] / forecast_size) * self.frequency
self.backcast_cos_template = t.nn.Parameter(t.tensor(np.transpose(np.cos(backcast_grid)), dtype=t.float32),
requires_grad=False)
self.backcast_sin_template = t.nn.Parameter(t.tensor(np.transpose(np.sin(backcast_grid)), dtype=t.float32),
requires_grad=False)
self.forecast_cos_template = t.nn.Parameter(t.tensor(np.transpose(np.cos(forecast_grid)), dtype=t.float32),
requires_grad=False)
self.forecast_sin_template = t.nn.Parameter(t.tensor(np.transpose(np.sin(forecast_grid)), dtype=t.float32),
requires_grad=False)

def forward(self, theta: t.Tensor):
params_per_harmonic = theta.shape[1] // 4
backcast_harmonics_cos = t.einsum('bp,pt->bt', theta[:, 2 * params_per_harmonic:3 * params_per_harmonic],
self.backcast_cos_template)
backcast_harmonics_sin = t.einsum('bp,pt->bt', theta[:, 3 * params_per_harmonic:], self.backcast_sin_template)
backcast = backcast_harmonics_sin + backcast_harmonics_cos
forecast_harmonics_cos = t.einsum('bp,pt->bt',
theta[:, :params_per_harmonic], self.forecast_cos_template)
forecast_harmonics_sin = t.einsum('bp,pt->bt', theta[:, params_per_harmonic:2 * params_per_harmonic],
self.forecast_sin_template)
forecast = forecast_harmonics_sin + forecast_harmonics_cos

return backcast, forecast
2 changes: 1 addition & 1 deletion baselines/STID_M4/M4.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_cfg(seasonal_pattern):
"day_of_week_size": 7
}
CFG.MODEL.FORWARD_FEATURES = [0, 1] # values, node id
CFG.MODEL.TARGET_FEATURES = [0] # traffic speed
CFG.MODEL.TARGET_FEATURES = [0]

# ================= optim ================= #
CFG.TRAIN = EasyDict()
Expand Down

0 comments on commit 24be5c0

Please sign in to comment.