From 9fed131b2b6d57012574e2358c3dc8ac98071558 Mon Sep 17 00:00:00 2001 From: Zezhi Shao <864453277@qq.com> Date: Tue, 5 Dec 2023 01:24:59 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=F0=9F=92=A1=20update=20the=20`resc?= =?UTF-8?q?ale=5Fdata`=20function?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- basicts/losses/losses.py | 6 +-- basicts/runners/base_tsf_runner.py | 85 ++++++++++++++++++++---------- basicts/utils/xformer.py | 2 +- 3 files changed, 61 insertions(+), 32 deletions(-) diff --git a/basicts/losses/losses.py b/basicts/losses/losses.py index 5910d28b..95ef6304 100644 --- a/basicts/losses/losses.py +++ b/basicts/losses/losses.py @@ -1,17 +1,17 @@ -import torch import numpy as np +import torch import torch.nn.functional as F from ..utils import check_nan_inf -def l1_loss(input_data, target_data, **kwargs): +def l1_loss(input_data, target_data): """unmasked mae.""" return F.l1_loss(input_data, target_data) -def l2_loss(input_data, target_data, **kwargs): +def l2_loss(input_data, target_data): """unmasked mse""" check_nan_inf(input_data) diff --git a/basicts/runners/base_tsf_runner.py b/basicts/runners/base_tsf_runner.py index 5b7a2e77..e63aed2e 100644 --- a/basicts/runners/base_tsf_runner.py +++ b/basicts/runners/base_tsf_runner.py @@ -1,6 +1,6 @@ import math import functools -from typing import Tuple, Union, Optional +from typing import Tuple, Union, Optional, List import torch import numpy as np @@ -37,7 +37,11 @@ def __init__(self, cfg: dict): self.need_setup_graph = cfg["MODEL"].get("SETUP_GRAPH", False) # read scaler for re-normalization - self.scaler = load_pkl("{0}/scaler_in_{1}_out_{2}_rescale_{3}.pkl".format(cfg["TRAIN"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"], cfg.get("RESCALE", True))) + self.scaler = load_pkl("{0}/scaler_in_{1}_out_{2}_rescale_{3}.pkl".format( + cfg["TRAIN"]["DATA"]["DIR"], + cfg["DATASET_INPUT_LEN"], + cfg["DATASET_OUTPUT_LEN"], + cfg.get("RESCALE", True))) # define loss self.loss = cfg["TRAIN"]["LOSS"] # define metric @@ -130,7 +134,7 @@ def build_train_dataset(self, cfg: dict): 2. Normalize on EACH channel (i.e., calculate the mean and std of each channel). The reason why there are two different preprocessing methods is that each channel of the dataset may have a different value range. - 1. Normalizing the WHOLE data set will preserve the relative size relationship between channels. + 1. Normalizing the WHOLE data set will preserve the relative size relationship between channels. Larger channels usually produce larger loss values, so more attention will be paid to these channels when optimizing the model. Therefore, this approach will achieve better performance when we evaluate on the rescaled dataset. For example, when evaluating rescaled data for two channels with values in the range [0, 1], [9000, 10000], the prediction on channel [0,1] is trivial. @@ -143,7 +147,8 @@ def build_train_dataset(self, cfg: dict): For example, the first approach is often adopted in the field of Spatial-Temporal Forecasting (STF). The second approach is often adopted in the field of Long-term Time Series Forecasting (LTSF). - To avoid confusion for users and facilitate them to obtain results comparable to existing studies, we automatically select data based on the cfg.get("RESCALE") flag (default to True). + To avoid confusion for users and facilitate them to obtain results comparable to existing studies, we + automatically select data based on the cfg.get("RESCALE") flag (default to True). if_rescale == True: use the data that is normalized across the WHOLE dataset if_rescale == False: use the data that is normalized on EACH channel @@ -153,8 +158,16 @@ def build_train_dataset(self, cfg: dict): Returns: train dataset (Dataset) """ - data_file_path = "{0}/data_in_{1}_out_{2}_rescale_{3}.pkl".format(cfg["TRAIN"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"], cfg.get("RESCALE", True)) - index_file_path = "{0}/index_in_{1}_out_{2}_rescale_{3}.pkl".format(cfg["TRAIN"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"], cfg.get("RESCALE", True)) + data_file_path = "{0}/data_in_{1}_out_{2}_rescale_{3}.pkl".format( + cfg["TRAIN"]["DATA"]["DIR"], + cfg["DATASET_INPUT_LEN"], + cfg["DATASET_OUTPUT_LEN"], + cfg.get("RESCALE", True)) + index_file_path = "{0}/index_in_{1}_out_{2}_rescale_{3}.pkl".format( + cfg["TRAIN"]["DATA"]["DIR"], + cfg["DATASET_INPUT_LEN"], + cfg["DATASET_OUTPUT_LEN"], + cfg.get("RESCALE", True)) # build dataset args dataset_args = cfg.get("DATASET_ARGS", {}) @@ -182,8 +195,16 @@ def build_val_dataset(cfg: dict): validation dataset (Dataset) """ # see build_train_dataset for details - data_file_path = "{0}/data_in_{1}_out_{2}_rescale_{3}.pkl".format(cfg["VAL"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"], cfg.get("RESCALE", True)) - index_file_path = "{0}/index_in_{1}_out_{2}_rescale_{3}.pkl".format(cfg["VAL"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"], cfg.get("RESCALE", True)) + data_file_path = "{0}/data_in_{1}_out_{2}_rescale_{3}.pkl".format( + cfg["VAL"]["DATA"]["DIR"], + cfg["DATASET_INPUT_LEN"], + cfg["DATASET_OUTPUT_LEN"], + cfg.get("RESCALE", True)) + index_file_path = "{0}/index_in_{1}_out_{2}_rescale_{3}.pkl".format( + cfg["VAL"]["DATA"]["DIR"], + cfg["DATASET_INPUT_LEN"], + cfg["DATASET_OUTPUT_LEN"], + cfg.get("RESCALE", True)) # build dataset args dataset_args = cfg.get("DATASET_ARGS", {}) @@ -207,8 +228,16 @@ def build_test_dataset(cfg: dict): Returns: train dataset (Dataset) """ - data_file_path = "{0}/data_in_{1}_out_{2}_rescale_{3}.pkl".format(cfg["TEST"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"], cfg.get("RESCALE", True)) - index_file_path = "{0}/index_in_{1}_out_{2}_rescale_{3}.pkl".format(cfg["TEST"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"], cfg.get("RESCALE", True)) + data_file_path = "{0}/data_in_{1}_out_{2}_rescale_{3}.pkl".format( + cfg["TEST"]["DATA"]["DIR"], + cfg["DATASET_INPUT_LEN"], + cfg["DATASET_OUTPUT_LEN"], + cfg.get("RESCALE", True)) + index_file_path = "{0}/index_in_{1}_out_{2}_rescale_{3}.pkl".format( + cfg["TEST"]["DATA"]["DIR"], + cfg["DATASET_INPUT_LEN"], + cfg["DATASET_OUTPUT_LEN"], + cfg.get("RESCALE", True)) # build dataset args dataset_args = cfg.get("DATASET_ARGS", {}) @@ -277,17 +306,21 @@ def metric_forward(self, metric_func, args): raise TypeError("Unknown metric type: {0}".format(type(metric_func))) return metric_item - def rescale_data(self, data: torch.Tensor) -> torch.Tensor: + def rescale_data(self, input_data: List[torch.Tensor]) -> List[torch.Tensor]: """Rescale data. Args: - data (torch.Tensor): data to be re-scaled. + data (List[torch.Tensor]): list of data to be re-scaled. Returns: - torch.Tensor: re-scaled data. + List[torch.Tensor]: list of re-scaled data. """ - return SCALER_REGISTRY.get(self.scaler["func"])(data, **self.scaler["args"]) + # prediction, real_value = input_data[:2] + if self.if_rescale: + input_data[0] = SCALER_REGISTRY.get(self.scaler["func"])(input_data[0], **self.scaler["args"]) + input_data[1] = SCALER_REGISTRY.get(self.scaler["func"])(input_data[1], **self.scaler["args"]) + return input_data def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tuple]) -> torch.Tensor: """Training details. @@ -304,20 +337,16 @@ def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tup iter_num = (epoch-1) * self.iter_per_epoch + iter_index forward_return = list(self.forward(data=data, epoch=epoch, iter_num=iter_num, train=True)) # re-scale data - prediction = self.rescale_data(forward_return[0]) if self.if_rescale else forward_return[0] - real_value = self.rescale_data(forward_return[1]) if self.if_rescale else forward_return[1] + forward_return = self.rescale_data(forward_return) # loss if self.cl_param: cl_length = self.curriculum_learning(epoch=epoch) - forward_return[0] = prediction[:, :cl_length, :, :] - forward_return[1] = real_value[:, :cl_length, :, :] - else: - forward_return[0] = prediction - forward_return[1] = real_value + forward_return[0] = forward_return[0][:, :cl_length, :, :] # prediction + forward_return[1] = forward_return[1][:, :cl_length, :, :] # real_value loss = self.metric_forward(self.loss, forward_return) # metrics for metric_name, metric_func in self.metrics.items(): - metric_item = self.metric_forward(metric_func, [prediction, real_value]) + metric_item = self.metric_forward(metric_func, forward_return[:2]) self.update_epoch_meter("train_"+metric_name, metric_item.item()) return loss @@ -329,13 +358,12 @@ def val_iters(self, iter_index: int, data: Union[torch.Tensor, Tuple]): data (Union[torch.Tensor, Tuple]): Data provided by DataLoader """ - forward_return = self.forward(data=data, epoch=None, iter_num=iter_index, train=False) + forward_return = list(self.forward(data=data, epoch=None, iter_num=iter_index, train=False)) # re-scale data - prediction = self.rescale_data(forward_return[0]) if self.if_rescale else forward_return[0] - real_value = self.rescale_data(forward_return[1]) if self.if_rescale else forward_return[1] + forward_return = self.rescale_data(forward_return) # metrics for metric_name, metric_func in self.metrics.items(): - metric_item = self.metric_forward(metric_func, [prediction, real_value]) + metric_item = self.metric_forward(metric_func, forward_return[:2]) self.update_epoch_meter("val_"+metric_name, metric_item.item()) def evaluate(self, prediction, real_value): @@ -385,8 +413,9 @@ def test(self): prediction = torch.cat(prediction, dim=0) real_value = torch.cat(real_value, dim=0) # re-scale data - prediction = self.rescale_data(prediction) if self.if_rescale else prediction - real_value = self.rescale_data(real_value) if self.if_rescale else real_value + if self.if_rescale: + prediction = SCALER_REGISTRY.get(self.scaler["func"])(prediction, **self.scaler["args"]) + real_value = SCALER_REGISTRY.get(self.scaler["func"])(real_value, **self.scaler["args"]) # evaluate self.evaluate(prediction, real_value) diff --git a/basicts/utils/xformer.py b/basicts/utils/xformer.py index 2b8f9237..b0fb7398 100644 --- a/basicts/utils/xformer.py +++ b/basicts/utils/xformer.py @@ -6,7 +6,7 @@ def data_transformation_4_xformer(history_data: torch.Tensor, future_data: torch Args: history_data (torch.Tensor): history data with shape: [B, L1, N, C]. - future_data (torch.Tensor): future data with shape: [B, L2, N, C]. + future_data (torch.Tensor): future data with shape: [B, L2, N, C]. L1 and L2 are input sequence length and output sequence length, respectively. start_token_length (int): length of the decoder start token. Ref: Informer paper.