From 3761512915d4b26653dd553a2d92d7ad67b6d083 Mon Sep 17 00:00:00 2001 From: duyifan Date: Thu, 28 Nov 2024 17:15:14 +0800 Subject: [PATCH] =?UTF-8?q?style:=20=F0=9F=92=84=20lint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix lint matched pylint --- .pylintrc | 9 +++- basicts/runners/base_tsf_runner.py | 2 +- basicts/runners/optim/optimizers.py | 1 - basicts/runners/runner_zoo/no_bp_runner.py | 1 - .../runners/runner_zoo/simple_tsf_runner.py | 4 +- basicts/scaler/min_max_scaler.py | 2 +- basicts/utils/misc.py | 1 - examples/arch.py | 1 + examples/complete_config.py | 54 ++++++++++++------- examples/complete_config_cn.py | 46 +++++++++------- examples/regular_config.py | 8 ++- experiments/evaluate.py | 1 + experiments/train.py | 3 +- .../data_test/test_simple_tsf_dataset.py | 42 ++++++++------- tests/basicts_test/metrics_test/test_mae.py | 7 ++- tests/basicts_test/metrics_test/test_mape.py | 7 ++- tests/basicts_test/metrics_test/test_mse.py | 8 ++- tests/basicts_test/metrics_test/test_wape.py | 7 ++- .../scaler_test/test_min_max_scaler.py | 3 ++ .../scaler_test/test_z_score_scaler.py | 7 ++- tests/experiments_test/test_evaluate.py | 21 ++++---- tests/experiments_test/test_train.py | 10 ++-- tests/run_all_test.py | 2 +- 23 files changed, 152 insertions(+), 95 deletions(-) diff --git a/.pylintrc b/.pylintrc index 093ba6ab..b6388c40 100644 --- a/.pylintrc +++ b/.pylintrc @@ -8,11 +8,11 @@ [MASTER] # Files or directories to be skipped. They should be base names, not paths. -ignore=third_party +ignore=baselines,assets # Files or directories matching the regex patterns are skipped. The regex # matches against base names, not paths. -ignore-patterns= +ignore-patterns=^\.|^_|^.*\.md|^.*\.txt|^.*\.CFF|^LICENSE # Pickle collected data for later comparisons. persistent=no @@ -85,6 +85,7 @@ disable=abstract-method, input-builtin, intern-builtin, invalid-str-codec, + invalid-name, locally-disabled, logging-format-interpolation, logging-fstring-interpolation, @@ -433,3 +434,7 @@ valid-metaclass-classmethod-first-arg=mcs overgeneral-exceptions=builtins.StandardError, builtins.Exception, builtins.BaseException + +[DESIGN] +# https://pylint.pycqa.org/en/latest/user_guide/messages/refactor/too-many-positional-arguments.html +max-positional-arguments=10 diff --git a/basicts/runners/base_tsf_runner.py b/basicts/runners/base_tsf_runner.py index a366df1c..1388900d 100644 --- a/basicts/runners/base_tsf_runner.py +++ b/basicts/runners/base_tsf_runner.py @@ -343,7 +343,7 @@ def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tup forward_return['prediction'] = forward_return['prediction'][:, :cl_length, :, :] forward_return['target'] = forward_return['target'][:, :cl_length, :, :] loss = self.metric_forward(self.loss, forward_return) - self.update_epoch_meter(f'train/loss', loss.item()) + self.update_epoch_meter('train/loss', loss.item()) for metric_name, metric_func in self.metrics.items(): metric_item = self.metric_forward(metric_func, forward_return) diff --git a/basicts/runners/optim/optimizers.py b/basicts/runners/optim/optimizers.py index 380bf922..4ca59f31 100644 --- a/basicts/runners/optim/optimizers.py +++ b/basicts/runners/optim/optimizers.py @@ -1,5 +1,4 @@ # define more optimizers here -import os import inspect from typing import Union, Tuple, Optional diff --git a/basicts/runners/runner_zoo/no_bp_runner.py b/basicts/runners/runner_zoo/no_bp_runner.py index 937b2a3d..c4407df3 100644 --- a/basicts/runners/runner_zoo/no_bp_runner.py +++ b/basicts/runners/runner_zoo/no_bp_runner.py @@ -7,4 +7,3 @@ class NoBPRunner(SimpleTimeSeriesForecastingRunner): def backward(self, loss: torch.Tensor): pass - return diff --git a/basicts/runners/runner_zoo/simple_tsf_runner.py b/basicts/runners/runner_zoo/simple_tsf_runner.py index f8b5e4e8..617f976e 100644 --- a/basicts/runners/runner_zoo/simple_tsf_runner.py +++ b/basicts/runners/runner_zoo/simple_tsf_runner.py @@ -90,13 +90,13 @@ def forward(self, data: Dict, epoch: int = None, iter_num: int = None, train: bo # Select input features history_data = self.select_input_features(history_data) future_data_4_dec = self.select_input_features(future_data) - + if not train: # For non-training phases, use only temporal features future_data_4_dec[..., 0] = torch.empty_like(future_data_4_dec[..., 0]) # Forward pass through the model - model_return = self.model(history_data=history_data, future_data=future_data_4_dec, + model_return = self.model(history_data=history_data, future_data=future_data_4_dec, batch_seen=iter_num, epoch=epoch, train=train) # Parse model return diff --git a/basicts/scaler/min_max_scaler.py b/basicts/scaler/min_max_scaler.py index e0294a55..d6ade22d 100644 --- a/basicts/scaler/min_max_scaler.py +++ b/basicts/scaler/min_max_scaler.py @@ -72,7 +72,7 @@ def transform(self, input_data: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: The normalized data with the same shape as the input. """ - + _min = self.min.to(input_data.device) _max = self.max.to(input_data.device) input_data[..., self.target_channel] = (input_data[..., self.target_channel] - _min) / (_max - _min) diff --git a/basicts/utils/misc.py b/basicts/utils/misc.py index 530da7f3..dc74959a 100644 --- a/basicts/utils/misc.py +++ b/basicts/utils/misc.py @@ -1,5 +1,4 @@ import time -from typing import Dict from functools import partial import torch diff --git a/examples/arch.py b/examples/arch.py index e9af348a..7b7dac86 100644 --- a/examples/arch.py +++ b/examples/arch.py @@ -1,3 +1,4 @@ +# pylint: disable=unused-argument import torch from torch import nn diff --git a/examples/complete_config.py b/examples/complete_config.py index 12edc8b5..282ed560 100644 --- a/examples/complete_config.py +++ b/examples/complete_config.py @@ -1,12 +1,4 @@ ############################## Import Dependencies ############################## - -import os -import sys -from easydict import EasyDict - -# TODO: Remove this when basicts can be installed via pip -sys.path.append(os.path.abspath(__file__ + '/../../..')) - # Import metrics & loss functions from basicts.metrics import masked_mae, masked_mape, masked_rmse # Import dataset class @@ -15,10 +7,14 @@ from basicts.runners import SimpleTimeSeriesForecastingRunner # Import scaler class from basicts.scaler import ZScoreScaler +# Import dataset settings +from basicts.utils import get_regular_settings # Import model architecture from .arch import MultiLayerPerceptron as MLP -from basicts.utils import get_regular_settings + +import os +from easydict import EasyDict ############################## Hot Parameters ############################## @@ -82,7 +78,9 @@ ############################## Scaler Configuration ############################## -CFG.SCALER = EasyDict() # Scaler settings. Default: None. If not specified, the data will not be normalized, i.e., the data will be used directly for training, validation, and test. +# Scaler settings. Default: None. +# If not specified, the data will not be normalized, i.e., the data will be used directly for training, validation, and test. +CFG.SCALER = EasyDict() # Scaler settings CFG.SCALER.TYPE = ZScoreScaler # Scaler class @@ -101,11 +99,22 @@ CFG.MODEL.NAME = MODEL_ARCH.__name__ # Model name, must be specified, used for saving checkpoints and set the process title. CFG.MODEL.ARCH = MODEL_ARCH # Model architecture, must be specified. CFG.MODEL.PARAM = MODEL_PARAM # Model parameters -CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] # Features used as input. The size of input data `history_data` is usually [B, L, N, C], this parameter specifies the index of the last dimension, i.e., history_data[:, :, :, CFG.MODEL.FORWARD_FEATURES]. -CFG.MODEL.TARGET_FEATURES = [0] # Features used as output. The size of target data `future_data` is usually [B, L, N, C], this parameter specifies the index of the last dimension, i.e., future_data[:, :, :, CFG.MODEL.TARGET_FEATURES]. -CFG.MODEL.TARGET_TIME_SERIES = [5, 6] # The index of the time series to be predicted, default is None. This setting is particularly useful in a Multivariate-to-Univariate setup. For example, if 7 time series are input and the last two need to be predicted, you can set `CFG.MODEL.TARGET_TIME_SERIES=[5, 6]` to achieve this. -CFG.MODEL.SETUP_GRAPH = False # Whether to set up the computation graph. Default: False. Implementation of many works (e.g., DCRNN, GTS) acts like TensorFlow, which creates parameters in the first feedforward process. -CFG.MODEL.DDP_FIND_UNUSED_PARAMETERS = False # Controls the `find_unused_parameters parameter` of `torch.nn.parallel.DistributedDataParallel`. In distributed computing, if there are unused parameters in the forward process, PyTorch usually raises a RuntimeError. In such cases, this parameter should be set to True. +# Features used as input. The size of input data `history_data` is usually [B, L, N, C], +# this parameter specifies the index of the last dimension, i.e., history_data[:, :, :, CFG.MODEL.FORWARD_FEATURES]. +CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] +# Features used as output. The size of target data `future_data` is usually [B, L, N, C], +# this parameter specifies the index of the last dimension, i.e., future_data[:, :, :, CFG.MODEL.TARGET_FEATURES]. +CFG.MODEL.TARGET_FEATURES = [0] +# The index of the time series to be predicted, default is None. This setting is particularly useful in a Multivariate-to-Univariate setup. +# For example, if 7 time series are input and the last two need to be predicted, you can set `CFG.MODEL.TARGET_TIME_SERIES=[5, 6]` to achieve this. +CFG.MODEL.TARGET_TIME_SERIES = [5, 6] +# Whether to set up the computation graph. Default: False. +# Implementation of many works (e.g., DCRNN, GTS) acts like TensorFlow, which creates parameters in the first feedforward process. +CFG.MODEL.SETUP_GRAPH = False +# Controls the `find_unused_parameters parameter` of `torch.nn.parallel.DistributedDataParallel`. +# In distributed computing, if there are unused parameters in the forward process, PyTorch usually raises a RuntimeError. +# In such cases, this parameter should be set to True. +CFG.MODEL.DDP_FIND_UNUSED_PARAMETERS = False ############################## Metrics Configuration ############################## @@ -132,7 +141,12 @@ MODEL_ARCH.__name__, '_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)]) ) # Directory to save checkpoints. Default: 'checkpoints/{MODEL_NAME}/{DATASET_NAME}_{NUM_EPOCHS}_{INPUT_LEN}_{OUTPUT_LEN}' -CFG.TRAIN.CKPT_SAVE_STRATEGY = None # Checkpoint save strategy. `CFG.TRAIN.CKPT_SAVE_STRATEGY` should be None, an int value, a list or a tuple. None: remove last checkpoint file every epoch. Default: None. Int: save checkpoint every `CFG.TRAIN.CKPT_SAVE_STRATEGY` epoch. List or Tuple: save checkpoint when epoch in `CFG.TRAIN.CKPT_SAVE_STRATEGY, remove last checkpoint file when last_epoch not in ckpt_save_strategy +# Checkpoint save strategy. `CFG.TRAIN.CKPT_SAVE_STRATEGY` should be None, an int value, a list or a tuple. +# Default: None. +# None: remove last checkpoint file every epoch. +# Int: save checkpoint every `CFG.TRAIN.CKPT_SAVE_STRATEGY` epoch. +# List or Tuple: save checkpoint when epoch in `CFG.TRAIN.CKPT_SAVE_STRATEGY, remove last checkpoint file when last_epoch not in ckpt_save_strategy +CFG.TRAIN.CKPT_SAVE_STRATEGY = None CFG.TRAIN.FINETUNE_FROM = None # Checkpoint path for fine-tuning. Default: None. If not specified, the model will be trained from scratch. CFG.TRAIN.STRICT_LOAD = True # Whether to strictly load the checkpoint. Default: True. @@ -205,10 +219,12 @@ CFG.TEST.DATA.NUM_WORKERS = 0 CFG.TEST.DATA.PIN_MEMORY = False -############################## Evaluation Configuration ############################## - +########################### Evaluation Configuration ########################## CFG.EVAL = EasyDict() # Evaluation parameters -CFG.EVAL.HORIZONS = [3, 6, 12] # The prediction horizons for evaluation. Default value: []. NOTE: HORIZONS[i] refers to testing **on the i-th** time step, representing the loss for that specific time step. This is a common setting in spatiotemporal forecasting. For long-sequence predictions, it is recommended to keep HORIZONS set to the default value [] to avoid confusion. +# The prediction horizons for evaluation. Default value: []. +# NOTE: HORIZONS[i] refers to testing **on the i-th** time step, representing the loss for that specific time step. +# This is a common setting in spatiotemporal forecasting. For long-sequence predictions, it is recommended to keep HORIZONS set to the default value [] to avoid confusion. +CFG.EVAL.HORIZONS = [] CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True diff --git a/examples/complete_config_cn.py b/examples/complete_config_cn.py index dd758bec..75cceb45 100644 --- a/examples/complete_config_cn.py +++ b/examples/complete_config_cn.py @@ -1,13 +1,4 @@ - ############################## 导入依赖 ############################## - -import os -import sys -from easydict import EasyDict - -# TODO: 当 basicts 可以通过 pip 安装时,移除这行代码 -sys.path.append(os.path.abspath(__file__ + '/../../..')) - # 导入指标和损失函数 from basicts.metrics import masked_mae, masked_mape, masked_rmse # 导入数据集类 @@ -16,10 +7,14 @@ from basicts.runners import SimpleTimeSeriesForecastingRunner # 导入缩放器类 from basicts.scaler import ZScoreScaler +# 导入数据集配置 +from basicts.utils import get_regular_settings # 导入模型架构 from .arch import MultiLayerPerceptron as MLP -from basicts.utils import get_regular_settings + +import os +from easydict import EasyDict ############################## 热门参数 ############################## @@ -102,11 +97,19 @@ CFG.MODEL.NAME = MODEL_ARCH.__name__ # 模型名称,必须指定,用于保存检查点和设置进程标题。 CFG.MODEL.ARCH = MODEL_ARCH # 模型架构,必须指定。 CFG.MODEL.PARAM = MODEL_PARAM # 模型参数,必须指定。 -CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] # 作为输入使用的特征。输入数据的大小通常为 [B, L, N, C],此参数指定最后一个维度的索引,即 history_data[:, :, :, CFG.MODEL.FORWARD_FEATURES]。 -CFG.MODEL.TARGET_FEATURES = [0] # 作为输出使用的特征。目标数据的大小通常为 [B, L, N, C],此参数指定最后一个维度的索引,即 future_data[:, :, :, CFG.MODEL.TARGET_FEATURES]。 -CFG.MODEL.TARGET_TIME_SERIES = None # 待预测的时间序列索引,默认为None。该参数在多变量到单变量预测(Multivariate-to-Univariate)的场景下特别有用。例如,当输入7条时间序列时,若需要预测最后两条序列,可以通过设置`CFG.MODEL.TARGET_TIME_SERIES=[5, 6]`来实现。 -CFG.MODEL.SETUP_GRAPH = False # 是否设置计算图。默认值:False。许多论文的实现(如 DCRNN,GTS)类似于 TensorFlow,需要第一次前向传播时建立计算图并创建参数。 -CFG.MODEL.DDP_FIND_UNUSED_PARAMETERS = False # 控制 torch.nn.parallel.DistributedDataParallel 的 `find_unused_parameters` 参数。在分布式计算中,如果前向传播过程中存在未使用的参数,PyTorch 通常会抛出 RuntimeError。在这种情况下,应将此参数设置为 True。 +# 作为输入使用的特征。输入数据的大小通常为 [B, L, N, C], +# 此参数指定最后一个维度的索引,即 history_data[:, :, :, CFG.MODEL.FORWARD_FEATURES]。 +CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] +# 作为输出使用的特征。目标数据的大小通常为 [B, L, N, C],此参数指定最后一个维度的索引,即 future_data[:, :, :, CFG.MODEL.TARGET_FEATURES]。 +CFG.MODEL.TARGET_FEATURES = [0] +# 待预测的时间序列索引,默认为None。该参数在多变量到单变量预测(Multivariate-to-Univariate)的场景下特别有用。 +# 例如,当输入7条时间序列时,若需要预测最后两条序列,可以通过设置`CFG.MODEL.TARGET_TIME_SERIES=[5, 6]`来实现。 +CFG.MODEL.TARGET_TIME_SERIES = None +# 是否设置计算图。默认值:False。许多论文的实现(如 DCRNN,GTS)类似于 TensorFlow,需要第一次前向传播时建立计算图并创建参数。 +CFG.MODEL.SETUP_GRAPH = False +# 控制 torch.nn.parallel.DistributedDataParallel 的 `find_unused_parameters` 参数。 +# 在分布式计算中,如果前向传播过程中存在未使用的参数,PyTorch 通常会抛出 RuntimeError。在这种情况下,应将此参数设置为 True。 +CFG.MODEL.DDP_FIND_UNUSED_PARAMETERS = False ############################## 指标配置 ############################## @@ -128,12 +131,17 @@ # 训练参数 CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS + # 保存检查点的目录。默认值:'checkpoints/{MODEL_NAME}/{DATASET_NAME}_{NUM_EPOCHS}_{INPUT_LEN}_{OUTPUT_LEN}' CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 'checkpoints', MODEL_ARCH.__name__, '_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)]) -) # 保存检查点的目录。默认值:'checkpoints/{MODEL_NAME}/{DATASET_NAME}_{NUM_EPOCHS}_{INPUT_LEN}_{OUTPUT_LEN}' -CFG.TRAIN.CKPT_SAVE_STRATEGY = None # 检查点保存策略。`CFG.TRAIN.CKPT_SAVE_STRATEGY` 可以是 None、整数值、列表或元组。默认值:None。None:每个 epoch 移除最后一个检查点文件。整数:每隔 `CFG.TRAIN.CKPT_SAVE_STRATEGY` 个 epoch 保存一次检查点。列表或元组:当 epoch 在 `CFG.TRAIN.CKPT_SAVE_STRATEGY` 中时保存检查点,当 last_epoch 不在 ckpt_save_strategy 中时移除最后一个检查点文件。“移除”操作是将最后一个检查点文件重命名为 .bak 文件,BasicTS会每个10个epoch清空一次.bak文件。 +) +# 检查点保存策略。`CFG.TRAIN.CKPT_SAVE_STRATEGY` 可以是 None、整数值、列表或元组。默认值:None。 +# None:每个 epoch 移除最后一个检查点文件。整数:每隔 `CFG.TRAIN.CKPT_SAVE_STRATEGY` 个 epoch 保存一次检查点。 +# 列表或元组:当 epoch 在 `CFG.TRAIN.CKPT_SAVE_STRATEGY` 中时保存检查点,当 last_epoch 不在 ckpt_save_strategy 中时移除最后一个检查点文件。 +# “移除”操作是将最后一个检查点文件重命名为 .bak 文件,BasicTS会每个10个epoch清空一次.bak文件。 +CFG.TRAIN.CKPT_SAVE_STRATEGY = None CFG.TRAIN.FINETUNE_FROM = None # 微调的检查点路径。默认值:None。如果未指定,模型将从头开始训练。 CFG.TRAIN.STRICT_LOAD = True # 是否严格加载检查点。默认值:True。 @@ -213,5 +221,7 @@ CFG.EVAL = EasyDict() # 评估参数 -CFG.EVAL.HORIZONS = [3, 6, 12] # 评估时的预测时间范围。默认值为 []。注意:HORIZONS[i] 指的是在 ”第 i 个时间片“ 上进行测试,表示该时间片的损失(Loss)。这是时空预测中的常见配置。对于长序列预测,建议将 HORIZONS 保持为默认值 [],以避免引发误解。 +# 评估时的预测时间范围。默认值为 []。注意:HORIZONS[i] 指的是在 ”第 i 个时间片“ 上进行测试,表示该时间片的损失(Loss)。 +# 这是时空预测中的常见配置。对于长序列预测,建议将 HORIZONS 保持为默认值 [],以避免引发误解。 +CFG.EVAL.HORIZONS = [] CFG.EVAL.USE_GPU = True # 是否在评估时使用 GPU。默认值:True diff --git a/examples/regular_config.py b/examples/regular_config.py index 6557544e..d71756bb 100644 --- a/examples/regular_config.py +++ b/examples/regular_config.py @@ -1,8 +1,3 @@ -import os -import sys -from easydict import EasyDict -sys.path.append(os.path.abspath(__file__ + '/../../..')) - from basicts.metrics import masked_mae, masked_mape, masked_rmse from basicts.data import TimeSeriesForecastingDataset from basicts.runners import SimpleTimeSeriesForecastingRunner @@ -10,6 +5,9 @@ from basicts.utils import get_regular_settings from .arch import MultiLayerPerceptron as MLP +import os +from easydict import EasyDict + ############################## Hot Parameters ############################## # Dataset & Metrics configuration DATA_NAME = 'PEMS08' # Dataset name diff --git a/experiments/evaluate.py b/experiments/evaluate.py index 78f2f72a..99a9815c 100644 --- a/experiments/evaluate.py +++ b/experiments/evaluate.py @@ -1,3 +1,4 @@ +# pylint: disable=wrong-import-position import os import sys from argparse import ArgumentParser diff --git a/experiments/train.py b/experiments/train.py index 83fa9db6..76ff60c5 100644 --- a/experiments/train.py +++ b/experiments/train.py @@ -1,4 +1,5 @@ # Run a baseline model in BasicTS framework. +# pylint: disable=wrong-import-position import os import sys from argparse import ArgumentParser @@ -23,5 +24,5 @@ def main(): basicts.launch_training(args.cfg, args.gpus, node_rank=0) -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/tests/basicts_test/data_test/test_simple_tsf_dataset.py b/tests/basicts_test/data_test/test_simple_tsf_dataset.py index 777caeed..23d7e8b4 100644 --- a/tests/basicts_test/data_test/test_simple_tsf_dataset.py +++ b/tests/basicts_test/data_test/test_simple_tsf_dataset.py @@ -1,11 +1,14 @@ +# pylint: disable=unused-argument import unittest import numpy as np import json -import os from unittest.mock import patch, mock_open from basicts.data.simple_tsf_dataset import TimeSeriesForecastingDataset class TestTimeSeriesForecastingDataset(unittest.TestCase): + """ + Test the TimeSeriesForecastingDataset class. + """ def setUp(self): self.dataset_name = 'test_dataset' @@ -27,14 +30,14 @@ def setUp(self): @patch('builtins.open', new_callable=mock_open, read_data=json.dumps({'shape': [100,]})) @patch('numpy.memmap') - def test_load_description(self, mock_memmap, mock_open): + def test_load_description(self, mock_memmap, mocked_open): mock_memmap.return_value = self.data dataset = TimeSeriesForecastingDataset( dataset_name=self.dataset_name, train_val_test_ratio=self.train_val_test_ratio, mode=self.mode, - input_len=self.input_len, + input_len=self.input_len, output_len=self.output_len, overlap=self.overlap, logger=self.logger @@ -43,7 +46,7 @@ def test_load_description(self, mock_memmap, mock_open): self.assertEqual(dataset.description, self.description) @patch('builtins.open', side_effect=FileNotFoundError) - def test_load_description_file_not_found(self, mock_open): + def test_load_description_file_not_found(self, mocked_open): with self.assertRaises(FileNotFoundError): TimeSeriesForecastingDataset( dataset_name=self.dataset_name+'nonexistent', @@ -56,7 +59,7 @@ def test_load_description_file_not_found(self, mock_open): ) @patch('builtins.open', new_callable=mock_open, read_data='not a json') - def test_load_description_json_decode_error(self, mock_open): + def test_load_description_json_decode_error(self, mocked_open): with self.assertRaises(ValueError): TimeSeriesForecastingDataset( dataset_name=self.dataset_name, @@ -70,7 +73,7 @@ def test_load_description_json_decode_error(self, mock_open): @patch('builtins.open', new_callable=mock_open, read_data=json.dumps({'shape': [100,]})) @patch('numpy.memmap', side_effect=FileNotFoundError) - def test_load_data_file_not_found(self, mock_memmap, mock_open): + def test_load_data_file_not_found(self, mock_memmap, mocked_open): with self.assertRaises(ValueError): TimeSeriesForecastingDataset( dataset_name=self.dataset_name, @@ -84,7 +87,7 @@ def test_load_data_file_not_found(self, mock_memmap, mock_open): @patch('builtins.open', new_callable=mock_open, read_data=json.dumps({'shape': [100,]})) @patch('numpy.memmap', side_effect=ValueError) - def test_load_data_value_error(self, mock_memmap, mock_open): + def test_load_data_value_error(self, mock_memmap, mocked_open): with self.assertRaises(ValueError): TimeSeriesForecastingDataset( dataset_name=self.dataset_name, @@ -99,7 +102,7 @@ def test_load_data_value_error(self, mock_memmap, mock_open): @patch('builtins.open', new_callable=mock_open, read_data=json.dumps({'shape': [100,]})) @patch('numpy.memmap') - def test_load_data_train_mode(self, mock_memmap, mock_open): + def test_load_data_train_mode(self, mock_memmap, mocked_open): mock_memmap.return_value = self.data dataset = TimeSeriesForecastingDataset( @@ -117,10 +120,10 @@ def test_load_data_train_mode(self, mock_memmap, mock_open): test_len = int(total_len * self.train_val_test_ratio[2]) expected_data_len = total_len - valid_len - test_len self.assertEqual(len(dataset.data), expected_data_len) - + @patch('builtins.open', new_callable=mock_open, read_data=json.dumps({'shape': [100,]})) @patch('numpy.memmap') - def test_load_data_train_mode_overlap(self, mock_memmap, mock_open): + def test_load_data_train_mode_overlap(self, mock_memmap, mocked_open): mock_memmap.return_value = self.data dataset = TimeSeriesForecastingDataset( @@ -141,7 +144,7 @@ def test_load_data_train_mode_overlap(self, mock_memmap, mock_open): @patch('builtins.open', new_callable=mock_open, read_data=json.dumps({'shape': [100,]})) @patch('numpy.memmap') - def test_load_data_valid_mode(self, mock_memmap, mock_open): + def test_load_data_valid_mode(self, mock_memmap, mocked_open): mock_memmap.return_value = self.data dataset = TimeSeriesForecastingDataset( @@ -157,12 +160,12 @@ def test_load_data_valid_mode(self, mock_memmap, mock_open): valid_len = int(len(self.data) * self.train_val_test_ratio[1]) expected_data_len = valid_len self.assertEqual(len(dataset.data), expected_data_len) - + @patch('builtins.open', new_callable=mock_open, read_data=json.dumps({'shape': [100,]})) @patch('numpy.memmap') - def test_load_data_valid_mode_overlap(self, mock_memmap, mock_open): + def test_load_data_valid_mode_overlap(self, mock_memmap, mocked_open): mock_memmap.return_value = self.data - + dataset = TimeSeriesForecastingDataset( dataset_name=self.dataset_name, train_val_test_ratio=self.train_val_test_ratio, @@ -176,11 +179,11 @@ def test_load_data_valid_mode_overlap(self, mock_memmap, mock_open): valid_len = int(len(self.data) * self.train_val_test_ratio[1]) expected_data_len = valid_len + self.input_len - 1 + self.output_len self.assertEqual(len(dataset.data), expected_data_len) - + @patch('builtins.open', new_callable=mock_open, read_data=json.dumps({'shape': [100,]})) @patch('numpy.memmap') - def test_load_data_test_mode(self, mock_memmap, mock_open): + def test_load_data_test_mode(self, mock_memmap, mocked_open): mock_memmap.return_value = self.data dataset = TimeSeriesForecastingDataset( @@ -200,7 +203,7 @@ def test_load_data_test_mode(self, mock_memmap, mock_open): @patch('builtins.open', new_callable=mock_open, read_data=json.dumps({'shape': [100,]})) @patch('numpy.memmap') - def test_load_data_test_mode_overlap(self, mock_memmap, mock_open): + def test_load_data_test_mode_overlap(self, mock_memmap, mocked_open): mock_memmap.return_value = self.data dataset = TimeSeriesForecastingDataset( @@ -221,7 +224,7 @@ def test_load_data_test_mode_overlap(self, mock_memmap, mock_open): @patch('builtins.open', new_callable=mock_open, read_data=json.dumps({'shape': [100,]})) @patch('numpy.memmap') - def test_getitem(self, mock_memmap, mock_open): + def test_getitem(self, mock_memmap, mocked_open): mock_memmap.return_value = self.data dataset = TimeSeriesForecastingDataset( @@ -243,7 +246,7 @@ def test_getitem(self, mock_memmap, mock_open): @patch('builtins.open', new_callable=mock_open, read_data=json.dumps({'shape': [100,]})) @patch('numpy.memmap') - def test_len(self, mock_memmap, mock_open): + def test_len(self, mock_memmap, mocked_open): mock_memmap.return_value = self.data dataset = TimeSeriesForecastingDataset( @@ -260,7 +263,6 @@ def test_len(self, mock_memmap, mock_open): expected_len = len(self.data)*self.train_val_test_ratio[0] - self.input_len - self.output_len + 1 self.assertEqual(len(dataset), expected_len) - if __name__ == '__main__': unittest.main() diff --git a/tests/basicts_test/metrics_test/test_mae.py b/tests/basicts_test/metrics_test/test_mae.py index d17597ad..c71c7f68 100644 --- a/tests/basicts_test/metrics_test/test_mae.py +++ b/tests/basicts_test/metrics_test/test_mae.py @@ -4,6 +4,9 @@ from basicts.metrics.mae import masked_mae class TestMaskedMAE(unittest.TestCase): + """ + Test the masked_mae function from basicts.metrics.mae. + """ def test_masked_mae_no_nulls(self): prediction = torch.tensor([1.0, 2.0, 3.0]) @@ -40,5 +43,5 @@ def test_masked_mae_all_nulls(self): expected = torch.tensor(0.0) # Since all are nulls, the MAE should be zero self.assertTrue(torch.allclose(result, expected), f"Expected {expected}, but got {result}") -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() diff --git a/tests/basicts_test/metrics_test/test_mape.py b/tests/basicts_test/metrics_test/test_mape.py index 389b2667..79bff641 100644 --- a/tests/basicts_test/metrics_test/test_mape.py +++ b/tests/basicts_test/metrics_test/test_mape.py @@ -4,6 +4,9 @@ from basicts.metrics.mape import masked_mape class TestMaskedMAPE(unittest.TestCase): + """ + Test the masked MAPE function. + """ def test_basic_functionality(self): prediction = torch.tensor([2.0, 3.0, 3.0]) @@ -40,5 +43,5 @@ def test_all_zeros_in_target(self): expected = torch.tensor(0.0) # No valid entries, should return 0 self.assertTrue(torch.allclose(result, expected), f"Expected {expected}, but got {result}") -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() diff --git a/tests/basicts_test/metrics_test/test_mse.py b/tests/basicts_test/metrics_test/test_mse.py index 16268b87..85a4a186 100644 --- a/tests/basicts_test/metrics_test/test_mse.py +++ b/tests/basicts_test/metrics_test/test_mse.py @@ -4,6 +4,10 @@ from basicts.metrics.mse import masked_mse class TestMaskedMSE(unittest.TestCase): + """ + Test the masked MSE function. + """ + def test_masked_mse_no_nulls(self): prediction = torch.tensor([1.0, 3.0, 3.0, 5.0]) target = torch.tensor([1.0, 2.0, 3.0, 4.0]) @@ -32,5 +36,5 @@ def test_masked_mse_with_all_nulls(self): expected = torch.tensor(0.0) self.assertTrue(torch.allclose(result, expected), f"Expected {expected}, but got {result}") -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() diff --git a/tests/basicts_test/metrics_test/test_wape.py b/tests/basicts_test/metrics_test/test_wape.py index 08d3de24..61dcd738 100644 --- a/tests/basicts_test/metrics_test/test_wape.py +++ b/tests/basicts_test/metrics_test/test_wape.py @@ -4,6 +4,9 @@ from basicts.metrics.wape import masked_wape class TestMaskedWape(unittest.TestCase): + """ + Test the masked WAPE function. + """ def test_masked_wape_basic(self): prediction = torch.tensor([[2.0, 2.0, 3.0], [6.0, 5.0, 7.0]]) @@ -33,5 +36,5 @@ def test_masked_wape_with_all_null_vals(self): expected = torch.tensor(0.0) # No valid entries, should return 0 self.assertTrue(torch.allclose(result, expected), f"Expected {expected}, but got {result}") -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() diff --git a/tests/basicts_test/scaler_test/test_min_max_scaler.py b/tests/basicts_test/scaler_test/test_min_max_scaler.py index 2ab683e1..a833dd5d 100644 --- a/tests/basicts_test/scaler_test/test_min_max_scaler.py +++ b/tests/basicts_test/scaler_test/test_min_max_scaler.py @@ -7,6 +7,9 @@ import os class TestMinMaxScaler(unittest.TestCase): + """ + Test the MinMaxScaler class. + """ def setUp(self): # Mock dataset description and data diff --git a/tests/basicts_test/scaler_test/test_z_score_scaler.py b/tests/basicts_test/scaler_test/test_z_score_scaler.py index d82def34..7afce9fe 100644 --- a/tests/basicts_test/scaler_test/test_z_score_scaler.py +++ b/tests/basicts_test/scaler_test/test_z_score_scaler.py @@ -6,6 +6,9 @@ from basicts.scaler.z_score_scaler import ZScoreScaler class TestZScoreScaler(unittest.TestCase): + """ + Test the ZScoreScaler class. + """ def setUp(self): # Create a mock dataset description and data @@ -64,7 +67,7 @@ def test_inverse_transform(self): # Check if the inverse transformed data is approximately equal to the original data self.assertTrue(torch.allclose(inverse_transformed_data, raw_data, atol=1e-6)) - + def tearDown(self): # Remove the mock dataset directory os.remove(f'datasets/{self.dataset_name}/desc.json') @@ -72,4 +75,4 @@ def tearDown(self): os.rmdir(f'datasets/{self.dataset_name}') if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/experiments_test/test_evaluate.py b/tests/experiments_test/test_evaluate.py index 43faa189..7e192cd7 100644 --- a/tests/experiments_test/test_evaluate.py +++ b/tests/experiments_test/test_evaluate.py @@ -3,26 +3,29 @@ from experiments.evaluate import parse_args class TestEvaluate(unittest.TestCase): + """ + Test the evaluate.py script. + """ @patch('sys.argv', ['evaluate.py']) def test_default_args(self): args = parse_args() - self.assertEqual(args.config, "baselines/STID/PEMS08_LTSF.py") - self.assertEqual(args.checkpoint, "checkpoints/STID/PEMS08_100_336_336/97d131cadc14bd2b9ffa892d59d55129/STID_best_val_MAE.pt") - self.assertEqual(args.gpus, "5") - self.assertEqual(args.device_type, "gpu") + self.assertEqual(args.config, 'baselines/STID/PEMS08_LTSF.py') + self.assertEqual(args.checkpoint, 'checkpoints/STID/PEMS08_100_336_336/97d131cadc14bd2b9ffa892d59d55129/STID_best_val_MAE.pt') + self.assertEqual(args.gpus, '5') + self.assertEqual(args.device_type, 'gpu') self.assertIsNone(args.batch_size) @patch('sys.argv', ['evaluate.py', '-cfg', 'custom_config.py', '-ckpt', 'custom_checkpoint.pt', '-g', '0', '-d', 'cpu', '-b', '32']) def test_custom_args(self): args = parse_args() - self.assertEqual(args.config, "custom_config.py") - self.assertEqual(args.checkpoint, "custom_checkpoint.pt") - self.assertEqual(args.gpus, "0") - self.assertEqual(args.device_type, "cpu") + self.assertEqual(args.config, 'custom_config.py') + self.assertEqual(args.checkpoint, 'custom_checkpoint.pt') + self.assertEqual(args.gpus, '0') + self.assertEqual(args.device_type, 'cpu') self.assertEqual(args.batch_size, '32') if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/experiments_test/test_train.py b/tests/experiments_test/test_train.py index a6579a78..a3ac8ad1 100644 --- a/tests/experiments_test/test_train.py +++ b/tests/experiments_test/test_train.py @@ -1,15 +1,19 @@ import unittest -from unittest.mock import patch, call +from unittest.mock import patch import sys import os from experiments.train import parse_args from experiments.train import main # Add the path to the train.py file -sys.path.append(os.path.abspath(__file__ + "/../..")) +sys.path.append(os.path.abspath(__file__ + '/../..')) class TestTrain(unittest.TestCase): + """ + Test the train.py script. + """ + @patch('experiments.train.basicts.launch_training') @patch('sys.argv', ['train.py', '-c', 'baselines/STID/PEMS04.py', '-g', '0']) def test_launch_training_called_with_correct_args(self, mock_launch_training): @@ -28,4 +32,4 @@ def test_launch_training_called_with_correct_args(self, mock_launch_training): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/run_all_test.py b/tests/run_all_test.py index cc66ddc4..27271374 100644 --- a/tests/run_all_test.py +++ b/tests/run_all_test.py @@ -9,4 +9,4 @@ # run all tests test_runner = unittest.TextTestRunner() -test_runner.run(test_suite) \ No newline at end of file +test_runner.run(test_suite)