From f1a0b970411aac308a3cba36c942297933a4dd91 Mon Sep 17 00:00:00 2001 From: Irina Espejo Morales Date: Tue, 30 Apr 2024 14:59:22 +0200 Subject: [PATCH] fix: added updated defaults arguments --- src/rxn/onmt_utils/train_command.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/rxn/onmt_utils/train_command.py b/src/rxn/onmt_utils/train_command.py index 8039f5e..bfebdee 100644 --- a/src/rxn/onmt_utils/train_command.py +++ b/src/rxn/onmt_utils/train_command.py @@ -73,8 +73,8 @@ def __init__(self, key: str, default: Any, needed_for: RxnCommand): Arg("normalization", "tokens", RxnCommand.TCF), Arg("optim", "adam", RxnCommand.TF), Arg("param_init", "0", RxnCommand.T), - Arg("param_init_glorot", "", RxnCommand.T), # note: empty means "nothing" - Arg("position_encoding", "", RxnCommand.T), # note: empty means "nothing" + Arg("param_init_glorot", "true", RxnCommand.T), # note: empty means "nothing" + Arg("position_encoding", "true", RxnCommand.T), # note: empty means "nothing" Arg("report_every", "1000", RxnCommand.TCF), Arg("reset_optim", None, RxnCommand.CF), Arg("hidden_size", None, RxnCommand.TF), @@ -82,7 +82,7 @@ def __init__(self, key: str, default: Any, needed_for: RxnCommand): Arg("save_model", None, RxnCommand.TCF), Arg("seed", None, RxnCommand.TCF), Arg("self_attn_type", "scaled-dot", RxnCommand.T), - Arg("share_embeddings", "", RxnCommand.T), # note: empty means "nothing" + Arg("share_embeddings", "true", RxnCommand.T), # note: empty means "nothing" Arg("src_vocab", None, RxnCommand.T), Arg("tgt_vocab", None, RxnCommand.T), Arg("train_from", None, RxnCommand.CF), @@ -217,12 +217,20 @@ def save_to_config_cmd(self, config_file_path: PathLike) -> None: if torch.cuda.is_available() and self._no_gpu is False: train_config["gpu_ranks"] = [0] - # Dump all cli arguments to dict - for kwarg, value in self._kwargs.items(): - if self.is_valid_kwarg_value(kwarg, value): - train_config[kwarg] = value + for arg in ONMT_TRAIN_ARGS: + arg_given = arg.key in self._kwargs + + if arg_given: + train_config[arg.key] = self._kwargs[arg.key] else: - raise ValueError(f'"Value {value}" for argument {kwarg} is invalid') + train_config[arg.key] = arg.default + + # Dump all cli arguments to dict + # for kwarg, value in self._kwargs.items(): + # if self.is_valid_kwarg_value(kwarg, value): + # train_config[kwarg] = value + # else: + # raise ValueError(f'"Value {value}" for argument {kwarg} is invalid') # Reformat "data" argument as in ONMT-py v.3.5.0 path_save_prepr_data = train_config["data"] @@ -255,6 +263,9 @@ def save_to_config_cmd(self, config_file_path: PathLike) -> None: train_config["save_model"] ) # avoid posix bad format in yaml + # share_vocab + train_config["share_vocab"] = str(True) + # Dump to config.yaml with open(config_file_path, "w+") as file: yaml.dump(train_config, file)