Skip to content

Commit

Permalink
fix: added updated defaults arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
irinaespejo committed Apr 30, 2024
1 parent 765d97e commit f1a0b97
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions src/rxn/onmt_utils/train_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,16 @@ def __init__(self, key: str, default: Any, needed_for: RxnCommand):
Arg("normalization", "tokens", RxnCommand.TCF),
Arg("optim", "adam", RxnCommand.TF),
Arg("param_init", "0", RxnCommand.T),
Arg("param_init_glorot", "", RxnCommand.T), # note: empty means "nothing"
Arg("position_encoding", "", RxnCommand.T), # note: empty means "nothing"
Arg("param_init_glorot", "true", RxnCommand.T), # note: empty means "nothing"
Arg("position_encoding", "true", RxnCommand.T), # note: empty means "nothing"
Arg("report_every", "1000", RxnCommand.TCF),
Arg("reset_optim", None, RxnCommand.CF),
Arg("hidden_size", None, RxnCommand.TF),
Arg("save_checkpoint_steps", "5000", RxnCommand.TCF),
Arg("save_model", None, RxnCommand.TCF),
Arg("seed", None, RxnCommand.TCF),
Arg("self_attn_type", "scaled-dot", RxnCommand.T),
Arg("share_embeddings", "", RxnCommand.T), # note: empty means "nothing"
Arg("share_embeddings", "true", RxnCommand.T), # note: empty means "nothing"
Arg("src_vocab", None, RxnCommand.T),
Arg("tgt_vocab", None, RxnCommand.T),
Arg("train_from", None, RxnCommand.CF),
Expand Down Expand Up @@ -217,12 +217,20 @@ def save_to_config_cmd(self, config_file_path: PathLike) -> None:
if torch.cuda.is_available() and self._no_gpu is False:
train_config["gpu_ranks"] = [0]

# Dump all cli arguments to dict
for kwarg, value in self._kwargs.items():
if self.is_valid_kwarg_value(kwarg, value):
train_config[kwarg] = value
for arg in ONMT_TRAIN_ARGS:
arg_given = arg.key in self._kwargs

if arg_given:
train_config[arg.key] = self._kwargs[arg.key]
else:
raise ValueError(f'"Value {value}" for argument {kwarg} is invalid')
train_config[arg.key] = arg.default

# Dump all cli arguments to dict
# for kwarg, value in self._kwargs.items():
# if self.is_valid_kwarg_value(kwarg, value):
# train_config[kwarg] = value
# else:
# raise ValueError(f'"Value {value}" for argument {kwarg} is invalid')

# Reformat "data" argument as in ONMT-py v.3.5.0
path_save_prepr_data = train_config["data"]
Expand Down Expand Up @@ -255,6 +263,9 @@ def save_to_config_cmd(self, config_file_path: PathLike) -> None:
train_config["save_model"]
) # avoid posix bad format in yaml

# share_vocab
train_config["share_vocab"] = str(True)

# Dump to config.yaml
with open(config_file_path, "w+") as file:
yaml.dump(train_config, file)
Expand Down

0 comments on commit f1a0b97

Please sign in to comment.