From bb8e7f62f618181ab5284942d74767eb175d38b1 Mon Sep 17 00:00:00 2001 From: co63oc Date: Fri, 20 Oct 2023 11:33:48 +0800 Subject: [PATCH] =?UTF-8?q?hydra=20=E6=94=B9=E9=80=A0=20phylstm=20(#579)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add phylstm hydra * Fix * Fix * Fix * Fix * Fix * Fix * Fix * Fix * ci_trigger --- docs/zh/examples/phylstm.md | 46 +++++-- examples/phylstm/conf/phylstm2.yaml | 49 +++++++ examples/phylstm/conf/phylstm3.yaml | 49 +++++++ examples/phylstm/phylstm2.py | 184 ++++++++++++++++++++++--- examples/phylstm/phylstm3.py | 203 +++++++++++++++++++++++----- 5 files changed, 464 insertions(+), 67 deletions(-) create mode 100644 examples/phylstm/conf/phylstm2.yaml create mode 100644 examples/phylstm/conf/phylstm3.yaml diff --git a/docs/zh/examples/phylstm.md b/docs/zh/examples/phylstm.md index fefb384bb..cc8b44029 100644 --- a/docs/zh/examples/phylstm.md +++ b/docs/zh/examples/phylstm.md @@ -1,5 +1,25 @@ # PhyLSTM +=== "模型训练命令" + + ``` sh + # linux + wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyLSTM/data_boucwen.mat + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyLSTM/data_boucwen.mat --output data_boucwen.mat + python phylstm2.py + ``` + +=== "模型评估命令" + + ``` sh + # linux + wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyLSTM/data_boucwen.mat + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyLSTM/data_boucwen.mat --output data_boucwen.mat + python phylstm2.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/phylstm/phylstm2_pretrained.pdparams + ``` + ## 1. 背景简介 我们引入了一种创新的物理知识LSTM框架,用于对缺乏数据的非线性结构系统进行元建模。基本概念是将可用但尚不完整的物理知识(如物理定律、科学原理)整合到深度长短时记忆(LSTM)网络中,该网络在可行的解决方案空间内限制和促进学习。物理约束嵌入在损失函数中,以强制执行模型训练,即使在可用训练数据集非常有限的情况下,也能准确地捕捉潜在的系统非线性。特别是对于动态结构,考虑运动方程的物理定律、状态依赖性和滞后本构关系来构建物理损失。嵌入式物理可以缓解过拟合问题,减少对大型训练数据集的需求,并提高训练模型的鲁棒性,使其具有外推能力,从而进行更可靠的预测。因此,物理知识指导的深度学习范式优于传统的非物理指导的数据驱动神经网络。 @@ -29,9 +49,9 @@ $$ 在 PhyLSTM 问题中,建立 LSTM 网络 Deep LSTM network,用 PaddleScience 代码表示如下 -``` py linenums="105" +``` py linenums="102" --8<-- -examples/phylstm/phylstm2.py:105:106 +examples/phylstm/phylstm2.py:102:107 --8<-- ``` @@ -47,9 +67,9 @@ wget -P ./ https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyLSTM/data_ 本案例涉及读取数据构建,如下所示 -``` py linenums="38" +``` py linenums="37" --8<-- -examples/phylstm/phylstm2.py:38:105 +examples/phylstm/phylstm2.py:37:100 --8<-- ``` @@ -57,9 +77,9 @@ examples/phylstm/phylstm2.py:38:105 设置训练数据集和损失计算函数,返回字段,代码如下所示: -``` py linenums="118" +``` py linenums="119" --8<-- -examples/phylstm/phylstm2.py:118:136 +examples/phylstm/phylstm2.py:119:137 --8<-- ``` @@ -67,9 +87,9 @@ examples/phylstm/phylstm2.py:118:136 设置评估数据集和损失计算函数,返回字段,代码如下所示: -``` py linenums="139" +``` py linenums="140" --8<-- -examples/phylstm/phylstm2.py:139:158 +examples/phylstm/phylstm2.py:140:159 --8<-- ``` @@ -77,9 +97,9 @@ examples/phylstm/phylstm2.py:139:158 接下来我们需要指定训练轮数,此处我们按实验经验,使用 100 轮训练轮数。 -``` py linenums="36" +``` py linenums="39" --8<-- -examples/phylstm/phylstm2.py:36:36 +examples/phylstm/conf/phylstm2.yaml:39:39 --8<-- ``` @@ -99,15 +119,15 @@ examples/phylstm/phylstm2.py:163:163 ``` py linenums="164" --8<-- -examples/phylstm/phylstm2.py:164:175 +examples/phylstm/phylstm2.py:164:178 --8<-- ``` 最后启动训练、评估即可: -``` py linenums="177" +``` py linenums="180" --8<-- -examples/phylstm/phylstm2.py:177:180 +examples/phylstm/phylstm2.py:180:183 --8<-- ``` diff --git a/examples/phylstm/conf/phylstm2.yaml b/examples/phylstm/conf/phylstm2.yaml new file mode 100644 index 000000000..3fc184d71 --- /dev/null +++ b/examples/phylstm/conf/phylstm2.yaml @@ -0,0 +1,49 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: outputs_PhyLSTM2/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchaned + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - mode + - output_dir + - log_freq + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval +seed: 42 +output_dir: ${hydra:run.dir} +log_freq: 20 + +# set data file path +DATA_FILE_PATH: data_boucwen.mat + +# model settings +MODEL: + input_size: 1 + hidden_size: 100 + model_type: 2 + +# training settings +TRAIN: + epochs: 100 + iters_per_epoch: 1 + save_freq: 50 + learning_rate: 0.001 + pretrained_model_path: null + checkpoint_path: null + +# evaluation settings +EVAL: + pretrained_model_path: null + eval_with_no_grad: true diff --git a/examples/phylstm/conf/phylstm3.yaml b/examples/phylstm/conf/phylstm3.yaml new file mode 100644 index 000000000..0be68339b --- /dev/null +++ b/examples/phylstm/conf/phylstm3.yaml @@ -0,0 +1,49 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: outputs_PhyLSTM3/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchaned + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - mode + - output_dir + - log_freq + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval +seed: 42 +output_dir: ${hydra:run.dir} +log_freq: 20 + +# set data file path +DATA_FILE_PATH: data_boucwen.mat + +# model settings +MODEL: + input_size: 1 + hidden_size: 100 + model_type: 3 + +# training settings +TRAIN: + epochs: 200 + iters_per_epoch: 1 + save_freq: 50 + learning_rate: 0.001 + pretrained_model_path: null + checkpoint_path: null + +# evaluation settings +EVAL: + pretrained_model_path: null + eval_with_no_grad: true diff --git a/examples/phylstm/phylstm2.py b/examples/phylstm/phylstm2.py index 0dcefa957..759cb9a5e 100755 --- a/examples/phylstm/phylstm2.py +++ b/examples/phylstm/phylstm2.py @@ -16,26 +16,25 @@ Reference: https://github.com/zhry10/PhyLSTM.git """ +from os import path as osp + import functions +import hydra import numpy as np import scipy.io +from omegaconf import DictConfig import ppsci -from ppsci.utils import config from ppsci.utils import logger -if __name__ == "__main__": - args = config.parse_args() + +def train(cfg: DictConfig): # set random seed for reproducibility - ppsci.utils.misc.set_random_seed(42) - # set output directory - OUTPUT_DIR = "./output_PhyLSTM2" if not args.output_dir else args.output_dir + ppsci.utils.misc.set_random_seed(cfg.seed) # initialize logger - logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") - # set training hyper-parameters - EPOCHS = 100 if not args.epochs else args.epochs + logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info") - mat = scipy.io.loadmat("data_boucwen.mat") + mat = scipy.io.loadmat(cfg.DATA_FILE_PATH) ag_data = mat["input_tf"] # ag, ad, av u_data = mat["target_X_tf"] ut_data = mat["target_Xd_tf"] @@ -90,9 +89,6 @@ eta_tt_star = u_tt_all[0:10] ag_c_star = ag_all[0:50] lift_star = -ag_c_star - eta_c_star = u_all[0:50] - eta_t_c_star = u_t_all[0:50] - eta_tt_c_star = u_tt_all[0:50] eta = eta_star ag = ag_star @@ -103,7 +99,12 @@ g = -eta_tt - ag phi_t = np.repeat(phi_t0, ag_c_star.shape[0], axis=0) - model = ppsci.arch.DeepPhyLSTM(1, eta.shape[2], 100, 2) + model = ppsci.arch.DeepPhyLSTM( + cfg.MODEL.input_size, + eta.shape[2], + cfg.MODEL.hidden_size, + cfg.MODEL.model_type, + ) model.register_input_transform(functions.transform_in) model.register_output_transform(functions.transform_out) @@ -113,7 +114,7 @@ label_dict_train, input_dict_val, label_dict_val, - ) = dataset_obj.get(EPOCHS) + ) = dataset_obj.get(cfg.TRAIN.epochs) sup_constraint_pde = ppsci.constraint.SupervisedConstraint( { @@ -159,22 +160,161 @@ validator_pde = {sup_validator_pde.name: sup_validator_pde} # initialize solver - ITERS_PER_EPOCH = 1 - optimizer = ppsci.optimizer.Adam(1e-3)(model) + optimizer = ppsci.optimizer.Adam(cfg.TRAIN.learning_rate)(model) solver = ppsci.solver.Solver( model, constraint_pde, - OUTPUT_DIR, + cfg.output_dir, optimizer, None, - EPOCHS, - ITERS_PER_EPOCH, - save_freq=50, + cfg.TRAIN.epochs, + cfg.TRAIN.iters_per_epoch, + save_freq=cfg.TRAIN.save_freq, + log_freq=cfg.log_freq, + seed=cfg.seed, validator=validator_pde, - eval_with_no_grad=True, + checkpoint_path=cfg.TRAIN.checkpoint_path, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, ) # train model solver.train() # evaluate after finished training solver.eval() + + +def evaluate(cfg: DictConfig): + # set random seed for reproducibility + ppsci.utils.misc.set_random_seed(cfg.seed) + # initialize logger + logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info") + + mat = scipy.io.loadmat(cfg.DATA_FILE_PATH) + ag_data = mat["input_tf"] # ag, ad, av + u_data = mat["target_X_tf"] + ut_data = mat["target_Xd_tf"] + utt_data = mat["target_Xdd_tf"] + ag_data = ag_data.reshape([ag_data.shape[0], ag_data.shape[1], 1]) + u_data = u_data.reshape([u_data.shape[0], u_data.shape[1], 1]) + ut_data = ut_data.reshape([ut_data.shape[0], ut_data.shape[1], 1]) + utt_data = utt_data.reshape([utt_data.shape[0], utt_data.shape[1], 1]) + + t = mat["time"] + dt = t[0, 1] - t[0, 0] + + ag_all = ag_data + u_all = u_data + u_t_all = ut_data + u_tt_all = utt_data + + # finite difference + N = u_data.shape[1] + phi1 = np.concatenate( + [ + np.array([-3 / 2, 2, -1 / 2]), + np.zeros([N - 3]), + ] + ) + temp1 = np.concatenate([-1 / 2 * np.identity(N - 2), np.zeros([N - 2, 2])], axis=1) + temp2 = np.concatenate([np.zeros([N - 2, 2]), 1 / 2 * np.identity(N - 2)], axis=1) + phi2 = temp1 + temp2 + phi3 = np.concatenate( + [ + np.zeros([N - 3]), + np.array([1 / 2, -2, 3 / 2]), + ] + ) + phi_t0 = ( + 1 + / dt + * np.concatenate( + [ + np.reshape(phi1, [1, phi1.shape[0]]), + phi2, + np.reshape(phi3, [1, phi3.shape[0]]), + ], + axis=0, + ) + ) + phi_t0 = np.reshape(phi_t0, [1, N, N]) + + ag_star = ag_all[0:10] + eta_star = u_all[0:10] + eta_t_star = u_t_all[0:10] + eta_tt_star = u_tt_all[0:10] + ag_c_star = ag_all[0:50] + lift_star = -ag_c_star + + eta = eta_star + ag = ag_star + lift = lift_star + eta_t = eta_t_star + eta_tt = eta_tt_star + ag_c = ag_c_star + g = -eta_tt - ag + phi_t = np.repeat(phi_t0, ag_c_star.shape[0], axis=0) + + model = ppsci.arch.DeepPhyLSTM( + cfg.MODEL.input_size, + eta.shape[2], + cfg.MODEL.hidden_size, + cfg.MODEL.model_type, + ) + model.register_input_transform(functions.transform_in) + model.register_output_transform(functions.transform_out) + + dataset_obj = functions.Dataset(eta, eta_t, g, ag, ag_c, lift, phi_t) + ( + _, + _, + input_dict_val, + label_dict_val, + ) = dataset_obj.get(1) + + sup_validator_pde = ppsci.validate.SupervisedValidator( + { + "dataset": { + "name": "NamedArrayDataset", + "input": input_dict_val, + "label": label_dict_val, + }, + }, + ppsci.loss.FunctionalLoss(functions.train_loss_func2), + { + "eta_pred": lambda out: out["eta_pred"], + "eta_dot_pred": lambda out: out["eta_dot_pred"], + "g_pred": lambda out: out["g_pred"], + "eta_t_pred_c": lambda out: out["eta_t_pred_c"], + "eta_dot_pred_c": lambda out: out["eta_dot_pred_c"], + "lift_pred_c": lambda out: out["lift_pred_c"], + }, + metric={"metric": ppsci.metric.FunctionalMetric(functions.metric_expr)}, + name="sup_valid", + ) + validator_pde = {sup_validator_pde.name: sup_validator_pde} + + # initialize solver + solver = ppsci.solver.Solver( + model, + output_dir=cfg.output_dir, + seed=cfg.seed, + validator=validator_pde, + pretrained_model_path=cfg.EVAL.pretrained_model_path, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + ) + # evaluate after finished training + solver.eval() + + +@hydra.main(version_base=None, config_path="./conf", config_name="phylstm2.yaml") +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + else: + raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + + +if __name__ == "__main__": + main() diff --git a/examples/phylstm/phylstm3.py b/examples/phylstm/phylstm3.py index 718249f05..4119f7a3d 100755 --- a/examples/phylstm/phylstm3.py +++ b/examples/phylstm/phylstm3.py @@ -16,27 +16,25 @@ Reference: https://github.com/zhry10/PhyLSTM.git """ +from os import path as osp + import functions +import hydra import numpy as np import scipy.io +from omegaconf import DictConfig import ppsci -from ppsci.utils import config from ppsci.utils import logger -if __name__ == "__main__": - args = config.parse_args() + +def train(cfg: DictConfig): # set random seed for reproducibility - ppsci.utils.misc.set_random_seed(42) - # set output directory - OUTPUT_DIR = "./output_PhyLSTM3" if not args.output_dir else args.output_dir + ppsci.utils.misc.set_random_seed(cfg.seed) # initialize logger - logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") - # set training hyper-parameters - EPOCHS = 200 if not args.epochs else args.epochs - - mat = scipy.io.loadmat("data_boucwen.mat") + logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info") + mat = scipy.io.loadmat(cfg.DATA_FILE_PATH) t = mat["time"] dt = 0.02 n1 = int(dt / 0.005) @@ -96,9 +94,6 @@ eta_tt_star = utt_data ag_c_star = np.concatenate([ag_data, ag_pred[0:53]]) lift_star = -ag_c_star - eta_c_star = np.concatenate([u_data, u_pred[0:53]]) - eta_t_c_star = np.concatenate([ut_data, ut_pred[0:53]]) - eta_tt_c_star = np.concatenate([utt_data, utt_pred[0:53]]) eta = eta_star ag = ag_star @@ -108,18 +103,14 @@ g = -eta_tt - ag ag_c = ag_c_star - # Training Data - eta_train = eta - ag_train = ag - lift_train = lift - eta_t_train = eta_t - eta_tt_train = eta_tt - g_train = g - ag_c_train = ag_c - phi_t = np.repeat(phi_t0, ag_c_star.shape[0], axis=0) - model = ppsci.arch.DeepPhyLSTM(1, eta.shape[2], 100, 3) + model = ppsci.arch.DeepPhyLSTM( + cfg.MODEL.input_size, + eta.shape[2], + cfg.MODEL.hidden_size, + cfg.MODEL.model_type, + ) model.register_input_transform(functions.transform_in) model.register_output_transform(functions.transform_out) @@ -129,7 +120,7 @@ label_dict_train, input_dict_val, label_dict_val, - ) = dataset_obj.get(EPOCHS) + ) = dataset_obj.get(cfg.TRAIN.epochs) sup_constraint_pde = ppsci.constraint.SupervisedConstraint( { @@ -179,22 +170,170 @@ validator_pde = {sup_validator_pde.name: sup_validator_pde} # initialize solver - ITERS_PER_EPOCH = 1 - optimizer = ppsci.optimizer.Adam(1e-3)(model) + optimizer = ppsci.optimizer.Adam(cfg.TRAIN.learning_rate)(model) solver = ppsci.solver.Solver( model, constraint_pde, - OUTPUT_DIR, + cfg.output_dir, optimizer, None, - EPOCHS, - ITERS_PER_EPOCH, - save_freq=50, + cfg.TRAIN.epochs, + cfg.TRAIN.iters_per_epoch, + save_freq=cfg.TRAIN.save_freq, + log_freq=cfg.log_freq, + seed=cfg.seed, validator=validator_pde, - eval_with_no_grad=True, + checkpoint_path=cfg.TRAIN.checkpoint_path, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, ) # train model solver.train() # evaluate after finished training solver.eval() + + +def evaluate(cfg: DictConfig): + # set random seed for reproducibility + ppsci.utils.misc.set_random_seed(cfg.seed) + # initialize logger + logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info") + + mat = scipy.io.loadmat(cfg.DATA_FILE_PATH) + t = mat["time"] + dt = 0.02 + n1 = int(dt / 0.005) + t = t[::n1] + + ag_data = mat["input_tf"][:, ::n1] # ag, ad, av + u_data = mat["target_X_tf"][:, ::n1] + ut_data = mat["target_Xd_tf"][:, ::n1] + utt_data = mat["target_Xdd_tf"][:, ::n1] + ag_data = ag_data.reshape([ag_data.shape[0], ag_data.shape[1], 1]) + u_data = u_data.reshape([u_data.shape[0], u_data.shape[1], 1]) + ut_data = ut_data.reshape([ut_data.shape[0], ut_data.shape[1], 1]) + utt_data = utt_data.reshape([utt_data.shape[0], utt_data.shape[1], 1]) + + ag_pred = mat["input_pred_tf"][:, ::n1] # ag, ad, av + u_pred = mat["target_pred_X_tf"][:, ::n1] + ut_pred = mat["target_pred_Xd_tf"][:, ::n1] + utt_pred = mat["target_pred_Xdd_tf"][:, ::n1] + ag_pred = ag_pred.reshape([ag_pred.shape[0], ag_pred.shape[1], 1]) + u_pred = u_pred.reshape([u_pred.shape[0], u_pred.shape[1], 1]) + ut_pred = ut_pred.reshape([ut_pred.shape[0], ut_pred.shape[1], 1]) + utt_pred = utt_pred.reshape([utt_pred.shape[0], utt_pred.shape[1], 1]) + + N = u_data.shape[1] + phi1 = np.concatenate( + [ + np.array([-3 / 2, 2, -1 / 2]), + np.zeros([N - 3]), + ] + ) + temp1 = np.concatenate([-1 / 2 * np.identity(N - 2), np.zeros([N - 2, 2])], axis=1) + temp2 = np.concatenate([np.zeros([N - 2, 2]), 1 / 2 * np.identity(N - 2)], axis=1) + phi2 = temp1 + temp2 + phi3 = np.concatenate( + [ + np.zeros([N - 3]), + np.array([1 / 2, -2, 3 / 2]), + ] + ) + phi_t0 = ( + 1 + / dt + * np.concatenate( + [ + np.reshape(phi1, [1, phi1.shape[0]]), + phi2, + np.reshape(phi3, [1, phi3.shape[0]]), + ], + axis=0, + ) + ) + phi_t0 = np.reshape(phi_t0, [1, N, N]) + + ag_star = ag_data + eta_star = u_data + eta_t_star = ut_data + eta_tt_star = utt_data + ag_c_star = np.concatenate([ag_data, ag_pred[0:53]]) + lift_star = -ag_c_star + + eta = eta_star + ag = ag_star + lift = lift_star + eta_t = eta_t_star + eta_tt = eta_tt_star + g = -eta_tt - ag + ag_c = ag_c_star + + phi_t = np.repeat(phi_t0, ag_c_star.shape[0], axis=0) + + model = ppsci.arch.DeepPhyLSTM( + cfg.MODEL.input_size, + eta.shape[2], + cfg.MODEL.hidden_size, + cfg.MODEL.model_type, + ) + model.register_input_transform(functions.transform_in) + model.register_output_transform(functions.transform_out) + + dataset_obj = functions.Dataset(eta, eta_t, g, ag, ag_c, lift, phi_t) + ( + _, + _, + input_dict_val, + label_dict_val, + ) = dataset_obj.get(1) + + sup_validator_pde = ppsci.validate.SupervisedValidator( + { + "dataset": { + "name": "NamedArrayDataset", + "input": input_dict_val, + "label": label_dict_val, + }, + }, + ppsci.loss.FunctionalLoss(functions.train_loss_func3), + { + "eta_pred": lambda out: out["eta_pred"], + "eta_dot_pred": lambda out: out["eta_dot_pred"], + "g_pred": lambda out: out["g_pred"], + "eta_t_pred_c": lambda out: out["eta_t_pred_c"], + "eta_dot_pred_c": lambda out: out["eta_dot_pred_c"], + "lift_pred_c": lambda out: out["lift_pred_c"], + "g_t_pred_c": lambda out: out["g_t_pred_c"], + "g_dot_pred_c": lambda out: out["g_dot_pred_c"], + }, + metric={"metric": ppsci.metric.FunctionalMetric(functions.metric_expr)}, + name="sup_valid", + ) + validator_pde = {sup_validator_pde.name: sup_validator_pde} + + # initialize solver + solver = ppsci.solver.Solver( + model, + output_dir=cfg.output_dir, + seed=cfg.seed, + validator=validator_pde, + pretrained_model_path=cfg.EVAL.pretrained_model_path, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + ) + + # evaluate after finished training + solver.eval() + + +@hydra.main(version_base=None, config_path="./conf", config_name="phylstm3.yaml") +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + else: + raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + + +if __name__ == "__main__": + main()