From e53374eb2ceafd8af97366ed70b9ad7d6425e0cd Mon Sep 17 00:00:00 2001 From: Zezhi Shao <864453277@qq.com> Date: Fri, 15 Dec 2023 10:58:52 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=F0=9F=8E=B8=20support=20empty=20evalua?= =?UTF-8?q?tion=20horizons?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- basicts/runners/base_tsf_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basicts/runners/base_tsf_runner.py b/basicts/runners/base_tsf_runner.py index efba6830..5c567849 100644 --- a/basicts/runners/base_tsf_runner.py +++ b/basicts/runners/base_tsf_runner.py @@ -59,7 +59,7 @@ def __init__(self, cfg: Dict): # evaluation self.if_evaluate_on_gpu = cfg.get("EVAL", EasyDict()).get("USE_GPU", True) # evaluate on gpu or cpu (gpu is faster but may cause OOM) self.evaluation_horizons = [_ - 1 for _ in cfg.get("EVAL", EasyDict()).get("HORIZONS", range(1, 13))] - assert min(self.evaluation_horizons) >= 0, "The horizon should start counting from 1." + assert len(self.evaluation_horizons) == 0 or min(self.evaluation_horizons) >= 0, "The horizon should start counting from 1." def setup_graph(self, cfg: Dict, train: bool): """Setup all parameters and the computation graph.