Skip to content

Commit

Permalink
fix: change OnmtTrainCommand cli config argument pass for explicit yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
irinaespejo committed Apr 16, 2024
1 parent bda7384 commit 1ab01a9
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions src/rxn/onmt_utils/train_command.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
from enum import Flag
from pathlib import Path
from typing import Any, List, Optional, Tuple

from onmt.trainer import Trainer
import yaml
from rxn.utilities.files import PathLike

#from .model_introspection import get_model_rnn_size
Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(self, key: str, default: Any, needed_for: RxnCommand):
self.default = default
self.needed_for = needed_for


# See https://opennmt.net/OpenNMT-py/options/train.html
ONMT_TRAIN_ARGS: List[Arg] = [
Arg("accum_count", "4", RxnCommand.TCF),
Arg("adam_beta1", "0.9", RxnCommand.TF),
Expand Down Expand Up @@ -74,12 +75,14 @@ def __init__(self, key: str, default: Any, needed_for: RxnCommand):
Arg("position_encoding", "", RxnCommand.T), # note: empty means "nothing"
Arg("report_every", "1000", RxnCommand.TCF),
Arg("reset_optim", None, RxnCommand.CF),
Arg("rnn_size", None, RxnCommand.TF),
Arg("hidden_size", None, RxnCommand.TF),
Arg("save_checkpoint_steps", "5000", RxnCommand.TCF),
Arg("save_model", None, RxnCommand.TCF),
Arg("seed", None, RxnCommand.TCF),
Arg("self_attn_type", "scaled-dot", RxnCommand.T),
Arg("share_embeddings", "", RxnCommand.T), # note: empty means "nothing"
Arg("src_vocab", None, RxnCommand.T),
Arg("tgt_vocab", None, RxnCommand.T),
Arg("train_from", None, RxnCommand.CF),
Arg("train_steps", None, RxnCommand.TCF),
Arg("transformer_ff", None, RxnCommand.T),
Expand All @@ -89,7 +92,7 @@ def __init__(self, key: str, default: Any, needed_for: RxnCommand):
]


class OnmtTrainCommand(Trainer):
class OnmtTrainCommand:
"""
Class to build the onmt_command for training models, continuing the
training, or finetuning.
Expand All @@ -112,6 +115,7 @@ def _build_cmd(self) -> List[str]:
Build the base command.
"""
command = ["onmt_train"]

for arg in ONMT_TRAIN_ARGS:
arg_given = arg.key in self._kwargs

Expand Down Expand Up @@ -168,11 +172,18 @@ def cmd(self) -> List[str]:
"""
return self._build_cmd()

def save_to_config_cmd(self, config_file: PathLike) -> List[str]:
def save_to_config_cmd(self, config_file_path: PathLike) -> None:
"""
Return the command for saving the config to a file.
Save the training config to a file.
See https://opennmt.net/OpenNMT-py/quickstart.html part 2
"""
return self._build_cmd() + ["-save_config", str(config_file)]
# Build dictionary with build vocab config content
# See structure https://opennmt.net/OpenNMT-py/quickstart.html (Step 1: Prepare the data)
train_config: Dict[str, Any] = {}


with open(config_file_path, "w+") as file:
yaml.dump(train_config, file)

@staticmethod
def execute_from_config_cmd(config_file: PathLike) -> List[str]:
Expand All @@ -186,11 +197,13 @@ def train(
cls,
batch_size: int,
data: PathLike,
src_vocab: Path,
tgt_vocab: Path,
dropout: float,
heads: int,
layers: int,
learning_rate: float,
rnn_size: int,
hidden_size: int,
save_model: PathLike,
seed: int,
train_steps: int,
Expand All @@ -207,12 +220,14 @@ def train(
data_weights=data_weights,
batch_size=batch_size,
data=data,
src_vocab=src_vocab,
tgt_vocab=tgt_vocab,
dropout=dropout,
heads=heads,
keep_checkpoint=keep_checkpoint,
layers=layers,
learning_rate=learning_rate,
rnn_size=rnn_size,
hidden_size=hidden_size,
save_model=save_model,
seed=seed,
train_steps=train_steps,
Expand Down Expand Up @@ -263,18 +278,17 @@ def finetune(
train_steps: int,
warmup_steps: int,
no_gpu: bool,
data_weights: Tuple[int, ...],
report_every: int,
save_checkpoint_steps: int,
keep_checkpoint: int = -1,
rnn_size: Optional[int] = None,
hidden_size: Optional[int] = None,
) -> "OnmtTrainCommand":
if rnn_size is None:
if hidden_size is None:
# In principle, the rnn_size should not be needed for finetuning. However,
# when resetting the decay algorithm for the learning rate, this value
# is necessary - and does not get it from the model checkpoint (OpenNMT bug).
#rnn_size = get_model_rnn_size(train_from)
logger.info(f"Loaded the value of rnn_size from the model: {rnn_size}.")
logger.info(f"Loaded the value of hidden_size from the model: {hidden_size}.")

return cls(
command_type=RxnCommand.F,
Expand All @@ -286,7 +300,7 @@ def finetune(
keep_checkpoint=keep_checkpoint,
learning_rate=learning_rate,
reset_optim="all",
#rnn_size=rnn_size,
hidden_size=hidden_size,
save_model=save_model,
seed=seed,
train_from=train_from,
Expand Down

0 comments on commit 1ab01a9

Please sign in to comment.