From 1ab01a9519c42abed8052ae26d4e43c11076d685 Mon Sep 17 00:00:00 2001 From: Irina Espejo Morales Date: Tue, 16 Apr 2024 09:28:03 +0200 Subject: [PATCH] fix: change OnmtTrainCommand cli config argument pass for explicit yaml --- src/rxn/onmt_utils/train_command.py | 42 +++++++++++++++++++---------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/src/rxn/onmt_utils/train_command.py b/src/rxn/onmt_utils/train_command.py index 3d628f9..701c3bc 100644 --- a/src/rxn/onmt_utils/train_command.py +++ b/src/rxn/onmt_utils/train_command.py @@ -1,8 +1,9 @@ import logging from enum import Flag +from pathlib import Path from typing import Any, List, Optional, Tuple -from onmt.trainer import Trainer +import yaml from rxn.utilities.files import PathLike #from .model_introspection import get_model_rnn_size @@ -46,7 +47,7 @@ def __init__(self, key: str, default: Any, needed_for: RxnCommand): self.default = default self.needed_for = needed_for - +# See https://opennmt.net/OpenNMT-py/options/train.html ONMT_TRAIN_ARGS: List[Arg] = [ Arg("accum_count", "4", RxnCommand.TCF), Arg("adam_beta1", "0.9", RxnCommand.TF), @@ -74,12 +75,14 @@ def __init__(self, key: str, default: Any, needed_for: RxnCommand): Arg("position_encoding", "", RxnCommand.T), # note: empty means "nothing" Arg("report_every", "1000", RxnCommand.TCF), Arg("reset_optim", None, RxnCommand.CF), - Arg("rnn_size", None, RxnCommand.TF), + Arg("hidden_size", None, RxnCommand.TF), Arg("save_checkpoint_steps", "5000", RxnCommand.TCF), Arg("save_model", None, RxnCommand.TCF), Arg("seed", None, RxnCommand.TCF), Arg("self_attn_type", "scaled-dot", RxnCommand.T), Arg("share_embeddings", "", RxnCommand.T), # note: empty means "nothing" + Arg("src_vocab", None, RxnCommand.T), + Arg("tgt_vocab", None, RxnCommand.T), Arg("train_from", None, RxnCommand.CF), Arg("train_steps", None, RxnCommand.TCF), Arg("transformer_ff", None, RxnCommand.T), @@ -89,7 +92,7 @@ def __init__(self, key: str, default: Any, needed_for: RxnCommand): ] -class OnmtTrainCommand(Trainer): +class OnmtTrainCommand: """ Class to build the onmt_command for training models, continuing the training, or finetuning. @@ -112,6 +115,7 @@ def _build_cmd(self) -> List[str]: Build the base command. """ command = ["onmt_train"] + for arg in ONMT_TRAIN_ARGS: arg_given = arg.key in self._kwargs @@ -168,11 +172,18 @@ def cmd(self) -> List[str]: """ return self._build_cmd() - def save_to_config_cmd(self, config_file: PathLike) -> List[str]: + def save_to_config_cmd(self, config_file_path: PathLike) -> None: """ - Return the command for saving the config to a file. + Save the training config to a file. + See https://opennmt.net/OpenNMT-py/quickstart.html part 2 """ - return self._build_cmd() + ["-save_config", str(config_file)] + # Build dictionary with build vocab config content + # See structure https://opennmt.net/OpenNMT-py/quickstart.html (Step 1: Prepare the data) + train_config: Dict[str, Any] = {} + + + with open(config_file_path, "w+") as file: + yaml.dump(train_config, file) @staticmethod def execute_from_config_cmd(config_file: PathLike) -> List[str]: @@ -186,11 +197,13 @@ def train( cls, batch_size: int, data: PathLike, + src_vocab: Path, + tgt_vocab: Path, dropout: float, heads: int, layers: int, learning_rate: float, - rnn_size: int, + hidden_size: int, save_model: PathLike, seed: int, train_steps: int, @@ -207,12 +220,14 @@ def train( data_weights=data_weights, batch_size=batch_size, data=data, + src_vocab=src_vocab, + tgt_vocab=tgt_vocab, dropout=dropout, heads=heads, keep_checkpoint=keep_checkpoint, layers=layers, learning_rate=learning_rate, - rnn_size=rnn_size, + hidden_size=hidden_size, save_model=save_model, seed=seed, train_steps=train_steps, @@ -263,18 +278,17 @@ def finetune( train_steps: int, warmup_steps: int, no_gpu: bool, - data_weights: Tuple[int, ...], report_every: int, save_checkpoint_steps: int, keep_checkpoint: int = -1, - rnn_size: Optional[int] = None, + hidden_size: Optional[int] = None, ) -> "OnmtTrainCommand": - if rnn_size is None: + if hidden_size is None: # In principle, the rnn_size should not be needed for finetuning. However, # when resetting the decay algorithm for the learning rate, this value # is necessary - and does not get it from the model checkpoint (OpenNMT bug). #rnn_size = get_model_rnn_size(train_from) - logger.info(f"Loaded the value of rnn_size from the model: {rnn_size}.") + logger.info(f"Loaded the value of hidden_size from the model: {hidden_size}.") return cls( command_type=RxnCommand.F, @@ -286,7 +300,7 @@ def finetune( keep_checkpoint=keep_checkpoint, learning_rate=learning_rate, reset_optim="all", - #rnn_size=rnn_size, + hidden_size=hidden_size, save_model=save_model, seed=seed, train_from=train_from,