From 7e6412b4f6bd1414959dc2883d12976b9e212842 Mon Sep 17 00:00:00 2001 From: superarthurlx <40826115+superarthurlx@users.noreply.github.com> Date: Thu, 5 Dec 2024 15:36:51 +0800 Subject: [PATCH] complete BigST with preprocess (#206) --- baselines/BigST/PEMS08.py | 182 ++++++++++++++++++ .../BigST/{PEMS04.py => PreprocessPEMS08.py} | 43 ++--- baselines/BigST/arch/__init__.py | 4 +- baselines/BigST/arch/bigst_arch.py | 163 +++++----------- baselines/BigST/arch/model.py | 122 ++++++++++++ .../{preprocess/model.py => preprocess.py} | 23 ++- baselines/BigST/arch/preprocess/metrics.py | 53 ----- baselines/BigST/arch/preprocess/pipeline.py | 38 ---- baselines/BigST/arch/preprocess/preprocess.py | 127 ------------ baselines/BigST/arch/preprocess/util.py | 147 -------------- baselines/BigST/runner/__init__.py | 1 + .../BigST/runner/bigstpreprocess_runner.py | 48 +++++ 12 files changed, 442 insertions(+), 509 deletions(-) create mode 100644 baselines/BigST/PEMS08.py rename baselines/BigST/{PEMS04.py => PreprocessPEMS08.py} (84%) create mode 100644 baselines/BigST/arch/model.py rename baselines/BigST/arch/{preprocess/model.py => preprocess.py} (91%) delete mode 100644 baselines/BigST/arch/preprocess/metrics.py delete mode 100644 baselines/BigST/arch/preprocess/pipeline.py delete mode 100644 baselines/BigST/arch/preprocess/preprocess.py delete mode 100644 baselines/BigST/arch/preprocess/util.py create mode 100644 baselines/BigST/runner/__init__.py create mode 100644 baselines/BigST/runner/bigstpreprocess_runner.py diff --git a/baselines/BigST/PEMS08.py b/baselines/BigST/PEMS08.py new file mode 100644 index 00000000..9c9a4b2e --- /dev/null +++ b/baselines/BigST/PEMS08.py @@ -0,0 +1,182 @@ +import os +import sys +import torch +from easydict import EasyDict +sys.path.append(os.path.abspath(__file__ + '/../../..')) + +from basicts.metrics import masked_mae, masked_mape, masked_rmse +from basicts.data import TimeSeriesForecastingDataset +from basicts.runners import SimpleTimeSeriesForecastingRunner +from basicts.scaler import ZScoreScaler +from basicts.utils import get_regular_settings, load_adj + +from .arch import BigST +# from .runner import BigSTPreprocessRunner +from .loss import bigst_loss + +import pdb + +############################## Hot Parameters ############################## +# Dataset & Metrics configuration +DATA_NAME = 'PEMS08' # Dataset name +regular_settings = get_regular_settings(DATA_NAME) +INPUT_LEN = 2016 # regular_settings['INPUT_LEN'] # Length of input sequence +OUTPUT_LEN = 12 # regular_settings['OUTPUT_LEN'] # Length of output sequence +TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios +NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data +RESCALE = regular_settings['RESCALE'] # Whether to rescale the data +NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data +# Model architecture and parameters +PREPROCESSED_FILE = "checkpoints\\BigSTPreprocess\\PEMS08_100_2016_12\\db8308a2c87de35e5f3db6177c5714ff\\BigSTPreprocess_best_val_MAE.pt" +MODEL_ARCH = BigST + +adj_mx, _ = load_adj("datasets/" + DATA_NAME + + "/adj_mx.pkl", "doubletransition") +MODEL_PARAM = { + "bigst_args":{ + "num_nodes": 170, + "seq_num": 12, + "in_dim": 3, + "out_dim": OUTPUT_LEN, # 源代码固定成12了 + "hid_dim": 32, + "tau" : 0.25, + "random_feature_dim": 64, + "node_emb_dim": 32, + "time_emb_dim": 32, + "use_residual": True, + "use_bn": True, + "use_long": True, + "use_spatial": True, + "dropout": 0.3, + "supports": [torch.tensor(i) for i in adj_mx], + "time_of_day_size": 288, + "day_of_week_size": 7 + }, + "preprocess_path": PREPROCESSED_FILE, + "preprocess_args":{ + "num_nodes": 170, + "in_dim": 3, + "dropout": 0.3, + "input_length": 2016, + "output_length": 12, + "nhid": 32, + "tiny_batch_size": 64, + } + + +} + +NUM_EPOCHS = 100 + +############################## General Configuration ############################## +CFG = EasyDict() +# General settings +CFG.DESCRIPTION = 'An Example Config' +CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode) +# Runner +CFG.RUNNER = SimpleTimeSeriesForecastingRunner + +############################## Environment Configuration ############################## + +CFG.ENV = EasyDict() # Environment settings. Default: None +CFG.ENV.SEED = 0 # Random seed. Default: None + +############################## Dataset Configuration ############################## +CFG.DATASET = EasyDict() +# Dataset settings +CFG.DATASET.NAME = DATA_NAME +CFG.DATASET.TYPE = TimeSeriesForecastingDataset +CFG.DATASET.PARAM = EasyDict({ + 'dataset_name': DATA_NAME, + 'train_val_test_ratio': TRAIN_VAL_TEST_RATIO, + 'input_len': INPUT_LEN, + 'output_len': OUTPUT_LEN, + # 'mode' is automatically set by the runner +}) + +############################## Scaler Configuration ############################## +CFG.SCALER = EasyDict() +# Scaler settings +CFG.SCALER.TYPE = ZScoreScaler # Scaler class +CFG.SCALER.PARAM = EasyDict({ + 'dataset_name': DATA_NAME, + 'train_ratio': TRAIN_VAL_TEST_RATIO[0], + 'norm_each_channel': NORM_EACH_CHANNEL, + 'rescale': RESCALE, +}) + +############################## Model Configuration ############################## +CFG.MODEL = EasyDict() +# Model settings +CFG.MODEL.NAME = MODEL_ARCH.__name__ +CFG.MODEL.ARCH = MODEL_ARCH +CFG.MODEL.PARAM = MODEL_PARAM +CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] +CFG.MODEL.TARGET_FEATURES = [0] + +############################## Metrics Configuration ############################## + +CFG.METRICS = EasyDict() +# Metrics settings +CFG.METRICS.FUNCS = EasyDict({ + 'MAE': masked_mae, + 'MAPE': masked_mape, + 'RMSE': masked_rmse, + }) +CFG.METRICS.TARGET = 'MAE' +CFG.METRICS.NULL_VAL = NULL_VAL + +############################## Training Configuration ############################## +CFG.TRAIN = EasyDict() +CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS +CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( + 'checkpoints', + MODEL_ARCH.__name__, + '_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)]) +) + + +CFG.TRAIN.LOSS = bigst_loss if MODEL_PARAM['bigst_args']['use_spatial'] else masked_mae +# Optimizer settings +CFG.TRAIN.OPTIM = EasyDict() +CFG.TRAIN.OPTIM.TYPE = "AdamW" +CFG.TRAIN.OPTIM.PARAM = { + "lr": 0.002, + "weight_decay": 0.0001, +} +# Learning rate scheduler settings +CFG.TRAIN.LR_SCHEDULER = EasyDict() +CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" +CFG.TRAIN.LR_SCHEDULER.PARAM = { + "milestones": [1, 50], + "gamma": 0.5 +} +# Train data loader settings +CFG.TRAIN.DATA = EasyDict() +CFG.TRAIN.DATA.BATCH_SIZE = 64 +CFG.TRAIN.DATA.SHUFFLE = True +# Gradient clipping settings +CFG.TRAIN.CLIP_GRAD_PARAM = { + "max_norm": 5.0 +} + +############################## Validation Configuration ############################## +CFG.VAL = EasyDict() +CFG.VAL.INTERVAL = 1 +CFG.VAL.DATA = EasyDict() +CFG.VAL.DATA.BATCH_SIZE = 64 + +############################## Test Configuration ############################## +CFG.TEST = EasyDict() +CFG.TEST.INTERVAL = 1 +CFG.TEST.DATA = EasyDict() +CFG.TEST.DATA.BATCH_SIZE = 64 + +############################## Evaluation Configuration ############################## +CFG.EVAL = EasyDict() + +# Evaluation parameters +CFG.EVAL.HORIZONS = [3, 6, 12] # Prediction horizons for evaluation. Default: [] +CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True + + diff --git a/baselines/BigST/PEMS04.py b/baselines/BigST/PreprocessPEMS08.py similarity index 84% rename from baselines/BigST/PEMS04.py rename to baselines/BigST/PreprocessPEMS08.py index 6dcc698a..39d7f4b9 100644 --- a/baselines/BigST/PEMS04.py +++ b/baselines/BigST/PreprocessPEMS08.py @@ -10,41 +10,32 @@ from basicts.scaler import ZScoreScaler from basicts.utils import get_regular_settings, load_adj -from .arch import BigST -from .loss import bigst_loss +from .arch import BigSTPreprocess +from .runner import BigSTPreprocessRunner ############################## Hot Parameters ############################## # Dataset & Metrics configuration -DATA_NAME = 'PEMS04' # Dataset name +DATA_NAME = 'PEMS08' # Dataset name regular_settings = get_regular_settings(DATA_NAME) -INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence -OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence +INPUT_LEN = 2016 +OUTPUT_LEN = 12 TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data RESCALE = regular_settings['RESCALE'] # Whether to rescale the data NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data # Model architecture and parameters -MODEL_ARCH = BigST +MODEL_ARCH = BigSTPreprocess adj_mx, _ = load_adj("datasets/" + DATA_NAME + "/adj_mx.pkl", "doubletransition") MODEL_PARAM = { - "num_nodes": 307, - "seq_num": INPUT_LEN, + "num_nodes": 170, "in_dim": 3, - "out_dim": OUTPUT_LEN, - "hid_dim": 32, - "tau" : 0.25, - "random_feature_dim": 64, - "node_emb_dim": 32, - "time_emb_dim": 32, - "use_residual": True, - "use_bn": True, - "use_spatial": True, - "use_long": False, "dropout": 0.3, - "supports": [torch.tensor(i) for i in adj_mx], - "time_of_day_size": 288, - "day_of_week_size": 7, + "input_length": INPUT_LEN, + "output_length": OUTPUT_LEN, + "nhid": 32, + "tiny_batch_size": 64, + } NUM_EPOCHS = 100 @@ -55,7 +46,7 @@ CFG.DESCRIPTION = 'An Example Config' CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode) # Runner -CFG.RUNNER = SimpleTimeSeriesForecastingRunner +CFG.RUNNER = BigSTPreprocessRunner ############################## Environment Configuration ############################## @@ -115,7 +106,7 @@ MODEL_ARCH.__name__, '_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)]) ) -CFG.TRAIN.LOSS = bigst_loss +CFG.TRAIN.LOSS = masked_mae # Optimizer settings CFG.TRAIN.OPTIM = EasyDict() CFG.TRAIN.OPTIM.TYPE = "AdamW" @@ -132,7 +123,7 @@ } # Train data loader settings CFG.TRAIN.DATA = EasyDict() -CFG.TRAIN.DATA.BATCH_SIZE = 64 +CFG.TRAIN.DATA.BATCH_SIZE = 1 CFG.TRAIN.DATA.SHUFFLE = True # Gradient clipping settings CFG.TRAIN.CLIP_GRAD_PARAM = { @@ -143,13 +134,13 @@ CFG.VAL = EasyDict() CFG.VAL.INTERVAL = 1 CFG.VAL.DATA = EasyDict() -CFG.VAL.DATA.BATCH_SIZE = 64 +CFG.VAL.DATA.BATCH_SIZE = 1 ############################## Test Configuration ############################## CFG.TEST = EasyDict() CFG.TEST.INTERVAL = 1 CFG.TEST.DATA = EasyDict() -CFG.TEST.DATA.BATCH_SIZE = 64 +CFG.TEST.DATA.BATCH_SIZE = 1 ############################## Evaluation Configuration ############################## diff --git a/baselines/BigST/arch/__init__.py b/baselines/BigST/arch/__init__.py index 7cb17069..e2d419fd 100644 --- a/baselines/BigST/arch/__init__.py +++ b/baselines/BigST/arch/__init__.py @@ -1,3 +1,5 @@ from .bigst_arch import BigST +from .preprocess import BigSTPreprocess -__all__ = ["BigST"] + +__all__ = ["BigST", "BigSTPreprocess"] diff --git a/baselines/BigST/arch/bigst_arch.py b/baselines/BigST/arch/bigst_arch.py index 5e8c6034..dd3d0342 100644 --- a/baselines/BigST/arch/bigst_arch.py +++ b/baselines/BigST/arch/bigst_arch.py @@ -1,3 +1,4 @@ +import os import math import torch import torch.nn as nn @@ -5,6 +6,19 @@ from .linear_conv import * from torch.autograd import Variable import pdb +from .preprocess import BigSTPreprocess +from .model import Model + +def sample_period(x, time_num): + # trainx (B, N, T, F) + history_length = x.shape[-2] + idx_list = [i for i in range(history_length)] + period_list = [idx_list[i:i+12] for i in range(0, history_length, time_num)] + period_feat = [x[:,:,sublist,0] for sublist in period_list] + period_feat = torch.stack(period_feat) + period_feat = torch.mean(period_feat, dim=0) + + return period_feat class BigST(nn.Module): """ @@ -14,126 +28,51 @@ class BigST(nn.Module): Venue: VLDB 2024 Task: Spatial-Temporal Forecasting """ - def __init__(self, seq_num, in_dim, out_dim, hid_dim, num_nodes, tau, random_feature_dim, node_emb_dim, time_emb_dim, \ - use_residual, use_bn, use_spatial, use_long, dropout, time_of_day_size, day_of_week_size, supports=None, edge_indices=None): + + def __init__(self, bigst_args, preprocess_path, preprocess_args): super(BigST, self).__init__() - self.tau = tau - self.layer_num = 3 - self.in_dim = in_dim - self.random_feature_dim = random_feature_dim - - self.use_residual = use_residual - self.use_bn = use_bn - self.use_spatial = use_spatial - self.use_long = use_long - - self.dropout = dropout - self.activation = nn.ReLU() - self.supports = supports - - self.time_num = time_of_day_size - self.week_num = day_of_week_size - - # node embedding layer - self.node_emb_layer = nn.Parameter(torch.empty(num_nodes, node_emb_dim)) - nn.init.xavier_uniform_(self.node_emb_layer) - - # time embedding layer - self.time_emb_layer = nn.Parameter(torch.empty(self.time_num, time_emb_dim)) - nn.init.xavier_uniform_(self.time_emb_layer) - self.week_emb_layer = nn.Parameter(torch.empty(self.week_num, time_emb_dim)) - nn.init.xavier_uniform_(self.week_emb_layer) - # embedding layer - self.input_emb_layer = nn.Conv2d(seq_num*in_dim, hid_dim, kernel_size=(1, 1), bias=True) - - self.W_1 = nn.Conv2d(node_emb_dim+time_emb_dim*2, hid_dim, kernel_size=(1, 1), bias=True) - self.W_2 = nn.Conv2d(node_emb_dim+time_emb_dim*2, hid_dim, kernel_size=(1, 1), bias=True) - - self.linear_conv = nn.ModuleList() - self.bn = nn.ModuleList() - - self.supports_len = 0 - if supports is not None: - self.supports_len += len(supports) - - for i in range(self.layer_num): - self.linear_conv.append(linearized_conv(hid_dim*4, hid_dim*4, self.dropout, self.tau, self.random_feature_dim)) - self.bn.append(nn.LayerNorm(hid_dim*4)) - + self.use_long = bigst_args['use_long'] + self.in_dim = bigst_args['in_dim'] + self.out_dim = bigst_args['out_dim'] + self.time_num = bigst_args['time_of_day_size'] + self.bigst = Model(**bigst_args) + if self.use_long: - self.regression_layer = nn.Conv2d(hid_dim*4*2+hid_dim+seq_num, out_dim, kernel_size=(1, 1), bias=True) - else: - self.regression_layer = nn.Conv2d(hid_dim*4*2, out_dim, kernel_size=(1, 1), bias=True) + self.feat_extractor = BigSTPreprocess(**preprocess_args) + self.load_pre_trained_model(preprocess_path) + + def load_pre_trained_model(self, preprocess_path): + """Load pre-trained model""" - # def forward(self, x, feat=None): - def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_seen: int, epoch: int, train: bool, **kwargs) -> torch.Tensor: - x = history_data[:, :, :, range(self.in_dim)] # (batch_size, in_len, data_dim) - x = x.transpose(1,2) - # input: (B, N, T, D) - B, N, T, D = x.size() - - time_emb = self.time_emb_layer[(x[:, :, -1, 1]*self.time_num).type(torch.LongTensor)] - week_emb = self.week_emb_layer[(x[:, :, -1, 2]).type(torch.LongTensor)] - - # input embedding - x = x.contiguous().view(B, N, -1).transpose(1, 2).unsqueeze(-1) # (B, D*T, N, 1) - input_emb = self.input_emb_layer(x) + # load parameters + checkpoint_dict = torch.load(preprocess_path) + self.feat_extractor.load_state_dict(checkpoint_dict["model_state_dict"]) + # freeze parameters + for param in self.feat_extractor.parameters(): + param.requires_grad = False - # node embeddings - node_emb = self.node_emb_layer.unsqueeze(0).expand(B, -1, -1).transpose(1, 2).unsqueeze(-1) # (B, dim, N, 1) + self.feat_extractor.eval() - # time embeddings - time_emb = time_emb.transpose(1, 2).unsqueeze(-1) # (B, dim, N, 1) - week_emb = week_emb.transpose(1, 2).unsqueeze(-1) # (B, dim, N, 1) - - x_g = torch.cat([node_emb, time_emb, week_emb], dim=1) # (B, dim*4, N, 1) - x = torch.cat([input_emb, node_emb, time_emb, week_emb], dim=1) # (B, dim*4, N, 1) - # linearized spatial convolution - x_pool = [x] # (B, dim*4, N, 1) - node_vec1 = self.W_1(x_g) # (B, dim, N, 1) - node_vec2 = self.W_2(x_g) # (B, dim, N, 1) - node_vec1 = node_vec1.permute(0, 2, 3, 1) # (B, N, 1, dim) - node_vec2 = node_vec2.permute(0, 2, 3, 1) # (B, N, 1, dim) - for i in range(self.layer_num): - if self.use_residual: - residual = x - x, node_vec1_prime, node_vec2_prime = self.linear_conv[i](x, node_vec1, node_vec2) - - if self.use_residual: - x = x+residual - - if self.use_bn: - x = x.permute(0, 2, 3, 1) # (B, N, 1, dim*4) - x = self.bn[i](x) - x = x.permute(0, 3, 1, 2) + def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_seen: int, epoch: int, train: bool, **kwargs) -> torch.Tensor: + history_data = history_data.transpose(1,2) # (B, N, T, D) + x = history_data[:, :, -self.out_dim:] # (batch_size, in_len, data_dim) - x_pool.append(x) - x = torch.cat(x_pool, dim=1) # (B, dim*4, N, 1) - - x = self.activation(x) # (B, dim*4, N, 1) - if self.use_long: - feat = feat.permute(0, 2, 1).unsqueeze(-1) # (B, F, N, 1) - x = torch.cat([x, feat], dim=1) - x = self.regression_layer(x) # (B, N, T) - x = x.squeeze(-1).permute(0, 2, 1) - else: - x = self.regression_layer(x) # (B, N, T) - x = x.squeeze(-1).permute(0, 2, 1) - - # if self.use_spatial: + feat = [] + for i in range(history_data.shape[0]): + with torch.no_grad(): + feat_sample = self.feat_extractor(history_data[[i],:,:,:], future_data, batch_seen, epoch, train) + feat.append(feat_sample['feat']) - # supports = [support.to(x.device) for support in self.supports] - # edge_indices = torch.nonzero(supports[0] > 0) + feat = torch.cat(feat, dim=0) + feat_period = sample_period(history_data, self.time_num) + feat = torch.cat([feat, feat_period], dim=2) + + return self.bigst(x, feat) + + else: + return self.bigst(x) - # # s_loss = spatial_loss(node_vec1_prime, node_vec2_prime, supports, edge_indices) - # return x.transpose(1,2).unsqueeze(-1), s_loss - # else: - # return x.transpose(1,2).unsqueeze(-1), 0 - return {"prediction": x.transpose(1,2).unsqueeze(-1) - , "node_vec1": node_vec1_prime - , "node_vec2": node_vec2_prime - , "supports": self.supports - , 'use_spatial': self.use_spatial} \ No newline at end of file + \ No newline at end of file diff --git a/baselines/BigST/arch/model.py b/baselines/BigST/arch/model.py new file mode 100644 index 00000000..31063da6 --- /dev/null +++ b/baselines/BigST/arch/model.py @@ -0,0 +1,122 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from .linear_conv import * +from torch.autograd import Variable +import pdb + +class Model(nn.Module): + def __init__(self, seq_num, in_dim, out_dim, hid_dim, num_nodes, tau, random_feature_dim, node_emb_dim, time_emb_dim, \ + use_residual, use_bn, use_spatial, use_long, dropout, time_of_day_size, day_of_week_size, supports=None, edge_indices=None): + super(Model, self).__init__() + + self.tau = tau + self.layer_num = 3 + self.in_dim = in_dim + self.random_feature_dim = random_feature_dim + + self.use_residual = use_residual + self.use_bn = use_bn + self.use_spatial = use_spatial + self.use_long = use_long + + self.dropout = dropout + self.activation = nn.ReLU() + self.supports = supports + + self.time_num = time_of_day_size + self.week_num = day_of_week_size + + # node embedding layer + self.node_emb_layer = nn.Parameter(torch.empty(num_nodes, node_emb_dim)) + nn.init.xavier_uniform_(self.node_emb_layer) + + # time embedding layer + self.time_emb_layer = nn.Parameter(torch.empty(self.time_num, time_emb_dim)) + nn.init.xavier_uniform_(self.time_emb_layer) + self.week_emb_layer = nn.Parameter(torch.empty(self.week_num, time_emb_dim)) + nn.init.xavier_uniform_(self.week_emb_layer) + + # embedding layer + self.input_emb_layer = nn.Conv2d(seq_num*in_dim, hid_dim, kernel_size=(1, 1), bias=True) + + self.W_1 = nn.Conv2d(node_emb_dim+time_emb_dim*2, hid_dim, kernel_size=(1, 1), bias=True) + self.W_2 = nn.Conv2d(node_emb_dim+time_emb_dim*2, hid_dim, kernel_size=(1, 1), bias=True) + + self.linear_conv = nn.ModuleList() + self.bn = nn.ModuleList() + + self.supports_len = 0 + if supports is not None: + self.supports_len += len(supports) + + for i in range(self.layer_num): + self.linear_conv.append(linearized_conv(hid_dim*4, hid_dim*4, self.dropout, self.tau, self.random_feature_dim)) + self.bn.append(nn.LayerNorm(hid_dim*4)) + + if self.use_long: + self.regression_layer = nn.Conv2d(hid_dim*4*2+hid_dim+seq_num, out_dim, kernel_size=(1, 1), bias=True) + else: + self.regression_layer = nn.Conv2d(hid_dim*4*2, out_dim, kernel_size=(1, 1), bias=True) + + def forward(self, x, feat=None): + + # x: (B, N, T, D) + B, N, T, D = x.size() + + time_emb = self.time_emb_layer[(x[:, :, -1, 1]*self.time_num).type(torch.LongTensor)] + week_emb = self.week_emb_layer[(x[:, :, -1, 2]).type(torch.LongTensor)] + + # input embedding + x = x.contiguous().view(B, N, -1).transpose(1, 2).unsqueeze(-1) # (B, D*T, N, 1) + input_emb = self.input_emb_layer(x) + + # node embeddings + node_emb = self.node_emb_layer.unsqueeze(0).expand(B, -1, -1).transpose(1, 2).unsqueeze(-1) # (B, dim, N, 1) + + # time embeddings + time_emb = time_emb.transpose(1, 2).unsqueeze(-1) # (B, dim, N, 1) + week_emb = week_emb.transpose(1, 2).unsqueeze(-1) # (B, dim, N, 1) + + x_g = torch.cat([node_emb, time_emb, week_emb], dim=1) # (B, dim*4, N, 1) + x = torch.cat([input_emb, node_emb, time_emb, week_emb], dim=1) # (B, dim*4, N, 1) + + # linearized spatial convolution + x_pool = [x] # (B, dim*4, N, 1) + node_vec1 = self.W_1(x_g) # (B, dim, N, 1) + node_vec2 = self.W_2(x_g) # (B, dim, N, 1) + node_vec1 = node_vec1.permute(0, 2, 3, 1) # (B, N, 1, dim) + node_vec2 = node_vec2.permute(0, 2, 3, 1) # (B, N, 1, dim) + for i in range(self.layer_num): + if self.use_residual: + residual = x + x, node_vec1_prime, node_vec2_prime = self.linear_conv[i](x, node_vec1, node_vec2) + + if self.use_residual: + x = x+residual + + if self.use_bn: + x = x.permute(0, 2, 3, 1) # (B, N, 1, dim*4) + x = self.bn[i](x) + x = x.permute(0, 3, 1, 2) + + x_pool.append(x) + x = torch.cat(x_pool, dim=1) # (B, dim*4, N, 1) + + x = self.activation(x) # (B, dim*4, N, 1) + + if self.use_long: + feat = feat.permute(0, 2, 1).unsqueeze(-1) # (B, F, N, 1) + x = torch.cat([x, feat], dim=1) + x = self.regression_layer(x) # (B, N, T) + x = x.squeeze(-1).permute(0, 2, 1) + else: + x = self.regression_layer(x) # (B, N, T) + x = x.squeeze(-1).permute(0, 2, 1) + + return {"prediction": x.transpose(1,2).unsqueeze(-1) + , "node_vec1": node_vec1_prime + , "node_vec2": node_vec2_prime + , "supports": self.supports + , 'use_spatial': self.use_spatial} \ No newline at end of file diff --git a/baselines/BigST/arch/preprocess/model.py b/baselines/BigST/arch/preprocess.py similarity index 91% rename from baselines/BigST/arch/preprocess/model.py rename to baselines/BigST/arch/preprocess.py index 44bd07c1..02926a47 100644 --- a/baselines/BigST/arch/preprocess/model.py +++ b/baselines/BigST/arch/preprocess.py @@ -4,6 +4,8 @@ import torch.nn.functional as F from torch.autograd import Variable import sys +import numpy as np +import pdb def create_projection_matrix(m, d, seed=0, scaling=0, struct_mode=False): nb_full_blocks = int(m/d) @@ -147,9 +149,17 @@ def forward(self, x): return x -class linear_transformer(nn.Module): - def __init__(self, input_length, output_length, in_dim, num_nodes, nhid, dropout=0.3): - super(linear_transformer, self).__init__() + +class BigSTPreprocess(nn.Module): + """ + Paper: BigST: Linear Complexity Spatio-Temporal Graph Neural Network for Traffic Forecasting on Large-Scale Road Networks + Link: https://dl.acm.org/doi/10.14778/3641204.3641217 + Official Code: https://github.com/usail-hkust/BigST?tab=readme-ov-file + Venue: VLDB 2024 + Task: Spatial-Temporal Forecasting + """ + def __init__(self, input_length, output_length, in_dim, num_nodes, nhid, tiny_batch_size, dropout=0.3): + super(BigSTPreprocess, self).__init__() self.tau = 1.0 self.layer_num = 3 self.random_feature_dim = nhid*2 @@ -175,7 +185,10 @@ def __init__(self, input_length, output_length, in_dim, num_nodes, nhid, dropout self.regression_layer = nn.Linear(nhid, output_length) - def forward(self, x): + self.tiny_batch_size = tiny_batch_size + + def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_seen: int, epoch: int, train: bool, **kwargs) -> torch.Tensor: + x = history_data # input: (1, 9638, 2016, 3) (B, N, T, D) B, N, T, D = x.size() pe = self.temporal_embedding.unsqueeze(0).expand(B*N, -1, -1) # (B*N, T/12, nhid) @@ -203,4 +216,4 @@ def forward(self, x): # x = torch.sum(x, dim=1) # (B*N, nhid) feat = x.view(B, N, -1) # (B, N, nhid) x = self.regression_layer(feat) # (B, N, output_length) - return x, feat + return {'prediction': x.transpose(1,2).unsqueeze(-1), 'feat':feat} \ No newline at end of file diff --git a/baselines/BigST/arch/preprocess/metrics.py b/baselines/BigST/arch/preprocess/metrics.py deleted file mode 100644 index aac0af60..00000000 --- a/baselines/BigST/arch/preprocess/metrics.py +++ /dev/null @@ -1,53 +0,0 @@ -import torch -import numpy as np - -def masked_mse(preds, labels, null_val=np.nan): - if np.isnan(null_val): - mask = ~torch.isnan(labels) - else: - mask = (labels!=null_val) - mask = mask.float() - mask /= torch.mean((mask)) - mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) - loss = (preds-labels)**2 - loss = loss * mask - loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) - return torch.mean(loss) - -def masked_rmse(preds, labels, null_val=np.nan): - return torch.sqrt(masked_mse(preds=preds, labels=labels, null_val=null_val)) - -def masked_mae(preds, labels, null_val=np.nan): - if np.isnan(null_val): - mask = ~torch.isnan(labels) - else: - mask = (labels!=null_val) - mask = mask.float() - mask /= torch.mean((mask)) - mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) - loss = torch.abs(preds-labels) - loss = loss * mask - loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) - return torch.mean(loss) - - -def masked_mape(preds, labels, null_val=np.nan): - labels = torch.where(labels<0.01, torch.zeros_like(labels), labels) - if np.isnan(null_val): - mask = ~torch.isnan(labels) - else: - mask = (labels!=null_val) - mask = mask.float() - mask /= torch.mean((mask)) - mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) - loss = torch.abs(preds-labels)/labels - loss = loss * mask - loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) - return torch.mean(loss) - - -def metric(pred, real): - mae = masked_mae(pred,real,0.0).item() - mape = masked_mape(pred,real,0.0).item() - rmse = masked_rmse(pred,real,0.0).item() - return mae,mape,rmse \ No newline at end of file diff --git a/baselines/BigST/arch/preprocess/pipeline.py b/baselines/BigST/arch/preprocess/pipeline.py deleted file mode 100644 index 46499b73..00000000 --- a/baselines/BigST/arch/preprocess/pipeline.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch.optim as optim -from model import * -import metrics - -class train_pipeline(): - def __init__(self, scaler, input_length, output_length, in_dim, num_nodes, nhid, dropout, lrate, wdecay, device): - self.model = linear_transformer(input_length, output_length, in_dim, num_nodes, nhid, dropout) - self.model.to(device) - self.optimizer = optim.Adam(self.model.parameters(), lr=lrate, weight_decay=wdecay) - self.loss = metrics.masked_mae - self.scaler = scaler - self.clip = 5 - - def train(self, input, real_val): - self.model.train() - self.optimizer.zero_grad() - output, _ = self.model(input) - real = self.scaler.inverse_transform(real_val) - predict = self.scaler.inverse_transform(output) - - loss = self.loss(predict, real, 0.0) - loss.backward() - if self.clip is not None: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) - self.optimizer.step() - mape = metrics.masked_mape(predict,real,0.0).item() - rmse = metrics.masked_rmse(predict,real,0.0).item() - return loss.item(), mape, rmse - - def eval(self, input, real_val): - self.model.eval() - output, _ = self.model(input) - real = self.scaler.inverse_transform(real_val) - predict = self.scaler.inverse_transform(output) - loss = self.loss(predict, real, 0.0) - mape = metrics.masked_mape(predict,real,0.0).item() - rmse = metrics.masked_rmse(predict,real,0.0).item() - return loss.item(), mape, rmse diff --git a/baselines/BigST/arch/preprocess/preprocess.py b/baselines/BigST/arch/preprocess/preprocess.py deleted file mode 100644 index feb795e3..00000000 --- a/baselines/BigST/arch/preprocess/preprocess.py +++ /dev/null @@ -1,127 +0,0 @@ -import torch -import numpy as np -import argparse -import time -import util -from pipeline import train_pipeline - -parser = argparse.ArgumentParser() -parser.add_argument('--device',type=str,default='cuda:0',help='') -parser.add_argument('--data',type=str,default='/data/pems_data/pems_vldb/long_term',help='data path') -parser.add_argument('--input_length',type=int,default=2016,help='') -parser.add_argument('--output_length',type=int,default=12,help='') -parser.add_argument('--nhid',type=int,default=32,help='') -parser.add_argument('--in_dim',type=int,default=3,help='inputs dimension') -parser.add_argument('--num_nodes',type=int,default=9638,help='number of nodes') -parser.add_argument('--batch_size',type=int,default=1,help='batch size') -parser.add_argument('--tiny_batch_size',type=int,default=256,help='tiny batch size') -parser.add_argument('--learning_rate',type=float,default=0.001,help='learning rate') -parser.add_argument('--dropout',type=float,default=0.3,help='dropout rate') -parser.add_argument('--weight_decay',type=float,default=0.0001,help='weight decay rate') -parser.add_argument('--epochs',type=int,default=100,help='') -parser.add_argument('--print_every',type=int,default=1,help='') -#parser.add_argument('--seed',type=int,default=99,help='random seed') -parser.add_argument('--save',type=str,default='checkpoint/',help='save path') -parser.add_argument('--expid',type=int,default=1,help='experiment id') - -args = parser.parse_args() - -def main(): - # set seed - # torch.manual_seed(args.seed) - # np.random.seed(args.seed) - # load data - device = torch.device(args.device) - dataloader = util.load_dataset(args.data, args.batch_size, args.batch_size, args.batch_size, - args.input_length, args.output_length) - scaler = dataloader['scaler'] - tiny_batch_size = args.tiny_batch_size - - print(args) - - trainer = train_pipeline(scaler, args.input_length, args.output_length, args.in_dim, args.num_nodes, - args.nhid, args.dropout, args.learning_rate, args.weight_decay, device) - - print("start training...",flush=True) - his_loss =[] - train_time = [] - val_time = [] - - for i in range(1, args.epochs+1): - # train - train_loss = [] - train_mape = [] - train_rmse = [] - t1 = time.time() - dataloader['train_loader'].shuffle() - for iter, (x, y) in enumerate(dataloader['train_loader'].get_iterator()): - B, T, N, F = x.shape - batch_num = int(B * N / tiny_batch_size) - idx_perm = np.random.permutation([i for i in range(B*N)]) - for j in range(batch_num): - if j==batch_num-1: - x_ = x[:, :, idx_perm[(j+1)*tiny_batch_size:], :] - y_ = y[:, :, idx_perm[(j+1)*tiny_batch_size:], :] - else: - x_ = x[:, :, idx_perm[j*tiny_batch_size:(j+1)*tiny_batch_size], :] - y_ = y[:, :, idx_perm[j*tiny_batch_size:(j+1)*tiny_batch_size], :] - - trainx = torch.Tensor(x_).to(device) # (B, T, N, F) - trainx = trainx.transpose(1, 2) # (B, N, T, F) - trainy = torch.Tensor(y_).to(device) # (B, T, N, F) - trainy = trainy.transpose(1, 2) # (B, N, T, F) - metrics = trainer.train(trainx, trainy[:,:,:,0]) - train_loss.append(metrics[0]) - train_mape.append(metrics[1]) - train_rmse.append(metrics[2]) - t2 = time.time() - train_time.append(t2-t1) - - if iter % args.print_every == 0: - log = 'Iter: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}' - print(log.format(iter, train_loss[-1], train_mape[-1], train_rmse[-1]),flush=True) - # Save the model parameters for subsequent preprocessing - torch.save(trainer.model.state_dict(), args.save+"linear_transformer.pth") - - # validation - valid_loss = [] - valid_mape = [] - valid_rmse = [] - - s1 = time.time() - for iter, (x, y) in enumerate(dataloader['val_loader'].get_iterator()): - B, T, N, F = x.shape - batch_num = int(B*N/tiny_batch_size) - for k in range(batch_num): - if k==batch_num-1: - x_ = x[:, :, (k+1)*tiny_batch_size:, :] - y_ = y[:, :, (k+1)*tiny_batch_size:, :] - else: - x_ = x[:, :, k*tiny_batch_size:(k+1)*tiny_batch_size, :] - y_ = y[:, :, k*tiny_batch_size:(k+1)*tiny_batch_size, :] - testx = torch.Tensor(x).to(device) - testx = testx.transpose(1, 2) - testy = torch.Tensor(y).to(device) - testy = testy.transpose(1, 2) - metrics = trainer.eval(testx, testy[:,:,:,0]) - valid_loss.append(metrics[0]) - valid_mape.append(metrics[1]) - valid_rmse.append(metrics[2]) - s2 = time.time() - mvalid_loss = np.mean(valid_loss) - mvalid_mape = np.mean(valid_mape) - mvalid_rmse = np.mean(valid_rmse) - log = 'Epoch: {:03d}, Validation Inference Time: {:.4f} secs' - print(log.format(i,(s2-s1))) - log = 'Valid MAE: {:.4f}, Valid MAPE: {:.4f}, Valid RMSE: {:.4f}' - print(log.format(mvalid_loss, mvalid_mape, mvalid_rmse), flush=True) - val_time.append(s2-s1) - - print("Average Training Time: {:.4f} secs/epoch".format(np.mean(train_time))) - print("Average Inference Time: {:.4f} secs".format(np.mean(val_time))) - -if __name__ == "__main__": - t1 = time.time() - main() - t2 = time.time() - print("Total time spent: {:.4f}".format(t2-t1)) diff --git a/baselines/BigST/arch/preprocess/util.py b/baselines/BigST/arch/preprocess/util.py deleted file mode 100644 index 81bf2cd7..00000000 --- a/baselines/BigST/arch/preprocess/util.py +++ /dev/null @@ -1,147 +0,0 @@ -import pickle -import numpy as np -import os -import scipy.sparse as sp -import torch -from scipy.sparse import linalg - -class DataLoader(object): - def __init__(self, data, batch_size, input_length, output_length): - self.seq_length_x = input_length - self.seq_length_y = output_length - self.y_start = 1 - self.batch_size = batch_size - self.current_ind = 0 - self.x_offsets = np.sort(np.concatenate((np.arange(-(self.seq_length_x - 1), 1, 1),))) - self.y_offsets = np.sort(np.arange(self.y_start, (self.seq_length_y + 1), 1)) - self.min_t = abs(min(self.x_offsets)) - self.max_t = abs(data.shape[0] - abs(max(self.y_offsets))) - mod = (self.max_t-self.min_t) % batch_size - if mod != 0: - self.data = data[:-mod] - else: - self.data = data - self.max_t = abs(self.data.shape[0] - abs(max(self.y_offsets))) - self.permutation = [i for i in range(self.min_t, self.max_t)] - - def shuffle(self): - self.permutation = np.random.permutation([i for i in range(self.min_t, self.max_t)]) - - def get_iterator(self): - self.current_ind = 0 - - def _wrapper(): - while self.current_ind < len(self.permutation): - if self.batch_size > 1: - x_batch = [] - y_batch = [] - for i in range(self.batch_size): - x_i = self.data[self.permutation[self.current_ind+i] + self.x_offsets, ...] - y_i = self.data[self.permutation[self.current_ind+i] + self.y_offsets, ...] - x_batch.append(x_i) - y_batch.append(y_i) - - x_batch = np.stack(x_batch, axis=0) - y_batch = np.stack(y_batch, axis=0) - else: - x_batch = self.data[self.permutation[self.current_ind] + self.x_offsets, ...] - y_batch = self.data[self.permutation[self.current_ind] + self.y_offsets, ...] - x_batch = np.expand_dims(x_batch, axis=0) - y_batch = np.expand_dims(y_batch, axis=0) - yield (x_batch, y_batch) - self.current_ind += self.batch_size - - return _wrapper() - -class StandardScaler(): - """ - Standard the input - """ - - def __init__(self, mean, std): - self.mean = mean - self.std = std - - def transform(self, data): - return (data - self.mean) / self.std - - def inverse_transform(self, data): - return (data * self.std) + self.mean - -def sym_adj(adj): - """Symmetrically normalize adjacency matrix.""" - adj = sp.coo_matrix(adj) - rowsum = np.array(adj.sum(1)) - d_inv_sqrt = np.power(rowsum, -0.5).flatten() - d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. - d_mat_inv_sqrt = sp.diags(d_inv_sqrt) - return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).astype(np.float32).todense() - -def asym_adj(adj): - adj = sp.coo_matrix(adj) - rowsum = np.array(adj.sum(1)).flatten() - d_inv = np.power(rowsum, -1).flatten() - d_inv[np.isinf(d_inv)] = 0. - d_mat= sp.diags(d_inv) - return d_mat.dot(adj).astype(np.float32).todense() - -def calculate_normalized_laplacian(adj): - """ - # L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2 - # D = diag(A 1) - :param adj: - :return: - """ - adj = sp.coo_matrix(adj) - d = np.array(adj.sum(1)) - d_inv_sqrt = np.power(d, -0.5).flatten() - d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. - d_mat_inv_sqrt = sp.diags(d_inv_sqrt) - normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() - return normalized_laplacian - -def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True): - if undirected: - adj_mx = np.maximum.reduce([adj_mx, adj_mx.T]) - L = calculate_normalized_laplacian(adj_mx) - if lambda_max is None: - lambda_max, _ = linalg.eigsh(L, 1, which='LM') - lambda_max = lambda_max[0] - L = sp.csr_matrix(L) - M, _ = L.shape - I = sp.identity(M, format='csr', dtype=L.dtype) - L = (2 / lambda_max * L) - I - return L.astype(np.float32).todense() - -def load_pickle(pickle_file): - try: - with open(pickle_file, 'rb') as f: - pickle_data = pickle.load(f) - except UnicodeDecodeError as e: - with open(pickle_file, 'rb') as f: - pickle_data = pickle.load(f, encoding='latin1') - except Exception as e: - print('Unable to load data ', pickle_file, ':', e) - raise - return pickle_data - -def load_adj(adj_filename, adjtype): - adj_mx = np.load(adj_filename) - print('adj_mx: ', adj_mx.shape) - adj = [asym_adj(adj_mx)] - return adj - -def load_dataset(dataset_dir, batch_size, valid_batch_size, test_batch_size, input_length, output_length): - data = {} - for category in ['train', 'val', 'test']: - data[category] = np.load(os.path.join(dataset_dir, category + '.npy')) - print('*'*10, category, data[category].shape, '*'*10) - scaler = StandardScaler(mean=data['train'][..., 0].mean(), std=data['train'][..., 0].std()) - # Data format - for category in ['train', 'val', 'test']: - data[category][..., 0] = scaler.transform(data[category][..., 0]) - data['train_loader'] = DataLoader(data['train'], batch_size, input_length, output_length) - data['val_loader'] = DataLoader(data['val'], valid_batch_size, input_length, output_length) - data['test_loader'] = DataLoader(data['test'], test_batch_size, input_length, output_length) - data['scaler'] = scaler - return data diff --git a/baselines/BigST/runner/__init__.py b/baselines/BigST/runner/__init__.py new file mode 100644 index 00000000..2a0ecce8 --- /dev/null +++ b/baselines/BigST/runner/__init__.py @@ -0,0 +1 @@ +from .bigstpreprocess_runner import BigSTPreprocessRunner \ No newline at end of file diff --git a/baselines/BigST/runner/bigstpreprocess_runner.py b/baselines/BigST/runner/bigstpreprocess_runner.py new file mode 100644 index 00000000..fbff8c45 --- /dev/null +++ b/baselines/BigST/runner/bigstpreprocess_runner.py @@ -0,0 +1,48 @@ +from typing import Tuple, Union, Dict +import torch +import numpy as np +import wandb +import pdb +import os + +from basicts.runners import SimpleTimeSeriesForecastingRunner + + +class BigSTPreprocessRunner(SimpleTimeSeriesForecastingRunner): + def __init__(self, cfg: dict): + super().__init__(cfg) + + self.tiny_batch_size = cfg.MODEL.PARAM['tiny_batch_size'] + + def preprocessing(self, input_data: Dict) -> Dict: + """Preprocess data. + + Args: + input_data (Dict): Dictionary containing data to be processed. + + Returns: + Dict: Processed data. + """ + + input_data = super().preprocessing(input_data) + + x = input_data['inputs'] + y = input_data['target'] + + B, T, N, F = x.shape + batch_num = int(B * N / self.tiny_batch_size) # 似乎要确保不能等于0 + idx_perm = np.random.permutation([i for i in range(B*N)]) + + for j in range(batch_num): + if j==batch_num-1: + x_ = x[:, :, idx_perm[(j+1)*self.tiny_batch_size:], :] + y_ = y[:, :, idx_perm[(j+1)*self.tiny_batch_size:], :] + else: + x_ = x[:, :, idx_perm[j*self.tiny_batch_size:(j+1)*self.tiny_batch_size], :] + y_ = y[:, :, idx_perm[j*self.tiny_batch_size:(j+1)*self.tiny_batch_size], :] + + input_data['inputs'] = x_.transpose(1,2) + input_data['target'] = y_ + return input_data + + \ No newline at end of file