diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 0774f3a..db0b0ac 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -9,12 +9,15 @@ jobs: build: runs-on: ubuntu-latest name: Build the Sphinx docs + strategy: + matrix: + python-version: ["3.11"] steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 with: - python-version: 3.8 + python-version: ${{ matrix.python-version }} - name: Install package dependencies run: pip install -e . - name: Install sphinx dependencies diff --git a/.github/workflows/pypi.yaml b/.github/workflows/pypi.yaml index e5415f1..aa4f2e4 100644 --- a/.github/workflows/pypi.yaml +++ b/.github/workflows/pypi.yaml @@ -9,13 +9,15 @@ jobs: build-and-publish: name: Build and publish rxn-onmt-utils on PyPI runs-on: ubuntu-latest - + strategy: + matrix: + python-version: ["3.11"] steps: - uses: actions/checkout@master - - name: Python setup 3.9 + - name: Python setup ${{ matrix.python-version }} uses: actions/setup-python@v1 with: - python-version: 3.9 + python-version: ${{ matrix.python-version }} - name: Install build package (for packaging) run: pip install --upgrade build - name: Build dist diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 2c92ef9..af548a2 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -6,12 +6,15 @@ jobs: tests: runs-on: ubuntu-latest name: Style, mypy + strategy: + matrix: + python-version: ["3.11"] steps: - uses: actions/checkout@v3 - - name: Set up Python 3.7 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 with: - python-version: 3.7 + python-version: ${{ matrix.python-version }} - name: Install Dependencies run: pip install -e .[dev] - name: Check black diff --git a/pyproject.toml b/pyproject.toml index 740ff57..9087f7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,8 @@ check_untyped_defs = true [[tool.mypy.overrides]] module = [ "onmt.*", + "yaml.*", + "torch.*", ] ignore_missing_imports = true diff --git a/setup.cfg b/setup.cfg index 1922784..65d33f0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,7 +25,7 @@ include_package_data = True install_requires = attrs>=21.2.0 click>=8.0 - rxn-opennmt-py>=1.1.1 + #rxn-opennmt-py>=1.1.1 # Remove opennmt-py fork dependence rxn-utils>=1.6.0 [options.packages.find] diff --git a/src/rxn/onmt_utils/internal_translation_utils.py b/src/rxn/onmt_utils/internal_translation_utils.py index 6eba198..cc106ba 100644 --- a/src/rxn/onmt_utils/internal_translation_utils.py +++ b/src/rxn/onmt_utils/internal_translation_utils.py @@ -1,17 +1,47 @@ import copy import os from argparse import Namespace -from itertools import repeat +from itertools import islice, repeat from typing import Any, Iterable, Iterator, List, Optional import attr import onmt.opts as opts +import torch +from onmt.constants import CorpusTask +from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter from onmt.translate.translator import build_translator -from onmt.utils.misc import split_corpus + +# from onmt.utils.misc import split_corpus from onmt.utils.parse import ArgumentParser from rxn.utilities.files import named_temporary_path +# Introduced back _split_corpus and split_corpus originally in onmt.utils.misc +# This commit gets rid of it: https://github.com/OpenNMT/OpenNMT-py/commit/4dcb2b9478eba32a480364e595f5fff7bd8ca887 +# Since dependencies of split_corpus and _split_corpus are only itertools, it's easier to add them to source code +def _split_corpus(path, shard_size): + """Yield a `list` containing `shard_size` line of `path`.""" + with open(path, "rb") as f: + if shard_size <= 0: + yield f.readlines() + else: + while True: + shard = list(islice(f, shard_size)) + if not shard: + break + yield shard + + +def split_corpus(path, shard_size=0, default=None): + """yield a `list` containing `shard_size` line of `path`, + or repeatly generate `default` if `path` is None. + """ + if path is not None: + return _split_corpus(path, shard_size) + else: + return repeat(default) + + @attr.s(auto_attribs=True) class TranslationResult: """ @@ -88,6 +118,7 @@ def translate_sentences_with_onmt( else: yield translation_results + @torch.no_grad() def translate_with_onmt(self, opt) -> Iterator[List[TranslationResult]]: """ Do the translation (in tokenized format) with OpenNMT. @@ -101,29 +132,60 @@ def translate_with_onmt(self, opt) -> Iterator[List[TranslationResult]]: """ # for some versions, it seems that n_best is not updated, we therefore do it manually here self.internal_translator.n_best = opt.n_best - - src_shards = split_corpus(opt.src, opt.shard_size) - tgt_shards = ( - split_corpus(opt.tgt, opt.shard_size) - if opt.tgt is not None - else repeat(None) + + opt.src_dir = opt.src.parent + + #pprint.pprint(opt) + + + infer_iter = build_dynamic_dataset_iter( + opt=opt, + transforms_cls=opt.transforms, + vocabs=self.internal_translator.vocabs, + task=CorpusTask.INFER, + device_id=opt.gpu, ) - shard_pairs = zip(src_shards, tgt_shards) - - for i, (src_shard, tgt_shard) in enumerate(shard_pairs): - l1, l2 = self.internal_translator.translate( - src=src_shard, - tgt=tgt_shard, - src_dir=opt.src_dir, - batch_size=opt.batch_size, - batch_type=opt.batch_type, + + l1_total, l2_total = self.internal_translator._translate( # IRINA + infer_iter=infer_iter, attn_debug=opt.attn_debug, - ) - for score_list, translation_list in zip(l1, l2): - yield [ - TranslationResult(text=t, score=s.item()) - for s, t in zip(score_list, translation_list) - ] + ) + + del infer_iter + for score_list, translation_list in zip(l1_total, l2_total): + yield [ + TranslationResult(text=t, score=s) + for s, t in zip(score_list, translation_list) + ] + del l1_total, l2_total + + # for i, (src_shard, tgt_shard) in enumerate(shard_pairs): + # #import ipdb + # #ipdb.set_trace() + # infer_iter = build_dynamic_dataset_iter( + # opt=opt, + # transforms_cls=opt.transforms, + # vocabs=self.internal_translator.vocabs, + # task=CorpusTask.INFER, + # #device_id=self.device_id, + # src=src_shard, + # tgt=tgt_shard, + # ) + # l1, l2 = self.internal_translator._translate( # IRINA + # infer_iter=infer_iter, + # attn_debug=opt.attn_debug, + # # src=src_shard, + # # tgt=tgt_shard, + # # src_dir=opt.src_dir, + # # batch_size=opt.batch_size, + # # batch_type=opt.batch_type, + # # attn_debug=opt.attn_debug, + # ) + # for score_list, translation_list in zip(l1, l2): + # yield [ + # TranslationResult(text=t, score=s.item()) + # for s, t in zip(score_list, translation_list) + # ] def get_onmt_opt( @@ -155,6 +217,10 @@ def get_onmt_opt( setattr(opt, key, value) ArgumentParser.validate_translate_opts(opt) + #opt.random_sampling_topk = 1.0 + #opt.length_penalty = "none" + #opt.alpha = 0 + return opt @@ -165,7 +231,8 @@ def onmt_parser() -> ArgumentParser: parser = ArgumentParser(description="translate.py") - opts.config_opts(parser) + #opts.config_opts(parser) # IRINA + opts.translate_opts(parser) - return parser + return parser \ No newline at end of file diff --git a/src/rxn/onmt_utils/model_introspection.py b/src/rxn/onmt_utils/model_introspection.py index e6e9cd2..d07b6dc 100644 --- a/src/rxn/onmt_utils/model_introspection.py +++ b/src/rxn/onmt_utils/model_introspection.py @@ -2,7 +2,8 @@ from typing import Any, Dict, List import torch -from onmt.inputters.text_dataset import TextMultiField + +#from onmt.inputters.text_dataset import TextMultiField from rxn.utilities.files import PathLike @@ -29,6 +30,14 @@ def get_preprocessed_vocab(vocab_path: PathLike) -> List[str]: return _torch_vocab_to_list(vocab) +def read_vocab_file(file_path): + vocab = [] + with open(file_path, 'r') as file: + for line in file: + vocab.append(line.split()[0]) # Split each line and take the first element + return vocab + + def model_vocab_is_compatible(model_pt: PathLike, vocab_pt: PathLike) -> bool: """ Determine whether the vocabulary contained in a model checkpoint contains @@ -39,20 +48,20 @@ def model_vocab_is_compatible(model_pt: PathLike, vocab_pt: PathLike) -> bool: vocab_pt: vocab file, such as `preprocessed.vocab.pt`. """ model_vocab = set(get_model_vocab(model_pt)) - data_vocab = set(get_preprocessed_vocab(vocab_pt)) + data_vocab = set(read_vocab_file(vocab_pt)) return data_vocab.issubset(model_vocab) def _torch_vocab_to_list(vocab: Dict[str, Any]) -> List[str]: - src_vocab = _multifield_vocab_to_list(vocab["src"]) - tgt_vocab = _multifield_vocab_to_list(vocab["tgt"]) + src_vocab = vocab["src"] #_multifield_vocab_to_list(vocab["src"]) + tgt_vocab = vocab["tgt"] #_multifield_vocab_to_list(vocab["tgt"]) if src_vocab != tgt_vocab: raise RuntimeError("Handling of different src/tgt vocab not implemented") return src_vocab -def _multifield_vocab_to_list(multifield: TextMultiField) -> List[str]: - return multifield.base_field.vocab.itos[:] +#def _multifield_vocab_to_list(multifield: TextMultiField) -> List[str]: +# return multifield.base_field.vocab.itos[:] def get_model_opt(model_path: PathLike) -> Namespace: diff --git a/src/rxn/onmt_utils/train_command.py b/src/rxn/onmt_utils/train_command.py index 0a1e3dd..4298543 100644 --- a/src/rxn/onmt_utils/train_command.py +++ b/src/rxn/onmt_utils/train_command.py @@ -1,15 +1,40 @@ import logging from enum import Flag -from typing import Any, List, Optional, Tuple +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +import torch +import yaml from rxn.utilities.files import PathLike -from .model_introspection import get_model_rnn_size +# from .model_introspection import get_model_rnn_size logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) +def get_paths_src_tgt(dir: PathLike, model_task: str) -> Tuple[str, str, str, str]: + # reaction is A -> B irregardless of retro or forward + if model_task == "forward": + A = "precursors" + B = "products" + elif model_task == "retro": + A = "products" + B = "precursors" + pass + else: + raise ValueError( + f"Argument model_task can only be 'forward' or 'retro' but received {model_task}" + ) + + corpus_path_src = f"{dir}/data.processed.train.{A}_tokens" + corpus_path_tgt = f"{dir}/data.processed.train.{B}_tokens" + valid_path_src = f"{dir}/data.processed.validation.{A}_tokens" + valid_path_tgt = f"{dir}/data.processed.validation.{B}_tokens" + + return corpus_path_src, corpus_path_tgt, valid_path_src, valid_path_tgt + + class RxnCommand(Flag): """ Flag indicating which command(s) the parameters relate to. @@ -46,6 +71,7 @@ def __init__(self, key: str, default: Any, needed_for: RxnCommand): self.needed_for = needed_for +# See https://opennmt.net/OpenNMT-py/options/train.html ONMT_TRAIN_ARGS: List[Arg] = [ Arg("accum_count", "4", RxnCommand.TCF), Arg("adam_beta1", "0.9", RxnCommand.TF), @@ -69,23 +95,27 @@ def __init__(self, key: str, default: Any, needed_for: RxnCommand): Arg("normalization", "tokens", RxnCommand.TCF), Arg("optim", "adam", RxnCommand.TF), Arg("param_init", "0", RxnCommand.T), - Arg("param_init_glorot", "", RxnCommand.T), # note: empty means "nothing" - Arg("position_encoding", "", RxnCommand.T), # note: empty means "nothing" + Arg("param_init_glorot", "true", RxnCommand.T), # note: empty means "nothing" + Arg("position_encoding", "true", RxnCommand.T), # note: empty means "nothing" Arg("report_every", "1000", RxnCommand.TCF), Arg("reset_optim", None, RxnCommand.CF), - Arg("rnn_size", None, RxnCommand.TF), + Arg("hidden_size", None, RxnCommand.TF), Arg("save_checkpoint_steps", "5000", RxnCommand.TCF), Arg("save_model", None, RxnCommand.TCF), Arg("seed", None, RxnCommand.TCF), Arg("self_attn_type", "scaled-dot", RxnCommand.T), - Arg("share_embeddings", "", RxnCommand.T), # note: empty means "nothing" + Arg("share_embeddings", "true", RxnCommand.T), # note: empty means "nothing" + Arg("src_vocab", None, RxnCommand.T), + Arg("tgt_vocab", None, RxnCommand.T), Arg("train_from", None, RxnCommand.CF), Arg("train_steps", None, RxnCommand.TCF), Arg("transformer_ff", None, RxnCommand.T), Arg("valid_batch_size", "8", RxnCommand.TCF), Arg("warmup_steps", None, RxnCommand.TF), Arg("word_vec_size", None, RxnCommand.T), + Arg("pos_ffn_activation_fn", "relu", RxnCommand.T), ] +# TODO: (Irina) Add new v.3.5.1 arguments like lora_layers, quant_layers if necessary class OnmtTrainCommand: @@ -99,18 +129,21 @@ def __init__( command_type: RxnCommand, no_gpu: bool, data_weights: Tuple[int, ...], + model_task: str, **kwargs: Any, ): self._command_type = command_type self._no_gpu = no_gpu self._data_weights = data_weights self._kwargs = kwargs + self.model_task = model_task def _build_cmd(self) -> List[str]: """ Build the base command. """ command = ["onmt_train"] + for arg in ONMT_TRAIN_ARGS: arg_given = arg.key in self._kwargs @@ -167,11 +200,97 @@ def cmd(self) -> List[str]: """ return self._build_cmd() - def save_to_config_cmd(self, config_file: PathLike) -> List[str]: + def is_valid_kwarg_value(self, kwarg, value) -> bool: + # NOTE: upgrade to v.3.5.1 + # A lot of the code below is from self._build_cmd() + # In theory, self._build_cmd() could be deprecated but to avoid breaking something, + # it will stay until 100% sure + # Here we jsut need the checks and not construct a command + # TODO: assess deprecation of self._build_cmd() + + # Check if argument is in ONMT_TRAIN_ARGS + for arg in ONMT_TRAIN_ARGS: + if arg.key == kwarg: + onmt_train_kwarg = arg + + try: + onmt_train_kwarg + except NameError: + NameError(f"Argument {kwarg} doesn't exist in ONMT_TRAIN_ARGS.") + + # Check argument is needed for command + if self._command_type not in onmt_train_kwarg.needed_for: + raise ValueError( + f'"{value}" value given for arg {kwarg}, but not necessary for command {self._command_type}' + ) + # Check if argument has no default and needs a value + if onmt_train_kwarg.default is None and value is None: + raise ValueError(f"No value given for {kwarg} and needs one.") + + return True + + def save_to_config_cmd(self, config_file_path: PathLike) -> None: """ - Return the command for saving the config to a file. + Save the training config to a file. + See https://opennmt.net/OpenNMT-py/quickstart.html part 2 """ - return self._build_cmd() + ["-save_config", str(config_file)] + # Build train config content, it will not include defaults not specified in cli + # See structure https://opennmt.net/OpenNMT-py/quickstart.html (Step 2: Train) + train_config: Dict[str, Any] = {} + + # GPUs + if torch.cuda.is_available() and self._no_gpu is False: + train_config["gpu_ranks"] = [0] + + for arg in ONMT_TRAIN_ARGS: + arg_given = arg.key in self._kwargs + + if arg_given: + train_config[arg.key] = self._kwargs[arg.key] + else: + train_config[arg.key] = arg.default + + # Dump all cli arguments to dict + # for kwarg, value in self._kwargs.items(): + # if self.is_valid_kwarg_value(kwarg, value): + # train_config[kwarg] = value + # else: + # raise ValueError(f'"Value {value}" for argument {kwarg} is invalid') + + # Reformat "data" argument as in ONMT-py v.3.5.0 + path_save_prepr_data = train_config["data"] + train_config["save_data"] = str(path_save_prepr_data) + # TODO: update to > 1 corpus + train_config["data"] = {"corpus_1": {}, "valid": {}} + + # get data files path, caution depends on task because ONMT preprocessed files in v.3.5.1 aren't fully processed as with earlier versions + corpus_path_src, corpus_path_tgt, valid_path_src, valid_path_tgt = ( + get_paths_src_tgt( + dir=path_save_prepr_data.parent.parent, model_task=self.model_task + ) + ) + + train_config["data"]["corpus_1"]["path_src"] = corpus_path_src + train_config["data"]["corpus_1"]["path_tgt"] = corpus_path_tgt + train_config["data"]["valid"]["path_src"] = valid_path_src + train_config["data"]["valid"]["path_tgt"] = valid_path_tgt + + train_config["src_vocab"] = str( + train_config["src_vocab"] + ) # avoid posix bad format in yaml + train_config["tgt_vocab"] = str( + train_config["tgt_vocab"] + ) # avoid posix bad format in yaml + train_config["save_model"] = str( + train_config["save_model"] + ) # avoid posix bad format in yaml + + # share_vocab + train_config["share_vocab"] = str(True) + + # Dump to config.yaml + with open(config_file_path, "w+") as file: + yaml.dump(train_config, file) @staticmethod def execute_from_config_cmd(config_file: PathLike) -> List[str]: @@ -185,11 +304,13 @@ def train( cls, batch_size: int, data: PathLike, + src_vocab: Path, + tgt_vocab: Path, dropout: float, heads: int, layers: int, learning_rate: float, - rnn_size: int, + hidden_size: int, save_model: PathLike, seed: int, train_steps: int, @@ -198,6 +319,7 @@ def train( word_vec_size: int, no_gpu: bool, data_weights: Tuple[int, ...], + model_task: str, keep_checkpoint: int = -1, ) -> "OnmtTrainCommand": return cls( @@ -206,18 +328,21 @@ def train( data_weights=data_weights, batch_size=batch_size, data=data, + src_vocab=src_vocab, + tgt_vocab=tgt_vocab, dropout=dropout, heads=heads, keep_checkpoint=keep_checkpoint, layers=layers, learning_rate=learning_rate, - rnn_size=rnn_size, + hidden_size=hidden_size, save_model=save_model, seed=seed, train_steps=train_steps, transformer_ff=transformer_ff, warmup_steps=warmup_steps, word_vec_size=word_vec_size, + model_task=model_task, ) @classmethod @@ -225,6 +350,8 @@ def continue_training( cls, batch_size: int, data: PathLike, + src_vocab: Path, + tgt_vocab: Path, dropout: float, save_model: PathLike, seed: int, @@ -232,6 +359,7 @@ def continue_training( train_steps: int, no_gpu: bool, data_weights: Tuple[int, ...], + model_task: str, keep_checkpoint: int = -1, ) -> "OnmtTrainCommand": return cls( @@ -240,6 +368,8 @@ def continue_training( data_weights=data_weights, batch_size=batch_size, data=data, + src_vocab=src_vocab, + tgt_vocab=tgt_vocab, dropout=dropout, keep_checkpoint=keep_checkpoint, reset_optim="none", @@ -247,6 +377,7 @@ def continue_training( seed=seed, train_from=train_from, train_steps=train_steps, + model_task=model_task, ) @classmethod @@ -265,15 +396,18 @@ def finetune( data_weights: Tuple[int, ...], report_every: int, save_checkpoint_steps: int, + model_task: str, keep_checkpoint: int = -1, - rnn_size: Optional[int] = None, + hidden_size: Optional[int] = None, ) -> "OnmtTrainCommand": - if rnn_size is None: + if hidden_size is None: # In principle, the rnn_size should not be needed for finetuning. However, # when resetting the decay algorithm for the learning rate, this value # is necessary - and does not get it from the model checkpoint (OpenNMT bug). - rnn_size = get_model_rnn_size(train_from) - logger.info(f"Loaded the value of rnn_size from the model: {rnn_size}.") + # rnn_size = get_model_rnn_size(train_from) + logger.info( + f"Loaded the value of hidden_size from the model: {hidden_size}." + ) return cls( command_type=RxnCommand.F, @@ -285,7 +419,7 @@ def finetune( keep_checkpoint=keep_checkpoint, learning_rate=learning_rate, reset_optim="all", - rnn_size=rnn_size, + hidden_size=hidden_size, save_model=save_model, seed=seed, train_from=train_from, @@ -293,6 +427,7 @@ def finetune( warmup_steps=warmup_steps, report_every=report_every, save_checkpoint_steps=save_checkpoint_steps, + model_task=model_task, ) diff --git a/src/rxn/onmt_utils/translate.py b/src/rxn/onmt_utils/translate.py index 3c3a689..8c0d1ad 100644 --- a/src/rxn/onmt_utils/translate.py +++ b/src/rxn/onmt_utils/translate.py @@ -158,7 +158,7 @@ def translate_as_external_command( str(src), "-output", str(output), - "-log_probs", + #"-log_probs", "-n_best", str(n_best), "-beam_size", diff --git a/src/rxn/onmt_utils/translator.py b/src/rxn/onmt_utils/translator.py index cc97173..415043d 100644 --- a/src/rxn/onmt_utils/translator.py +++ b/src/rxn/onmt_utils/translator.py @@ -1,6 +1,8 @@ from argparse import Namespace from typing import Any, Iterable, Iterator, List, Optional, Union +import torch + from .internal_translation_utils import RawTranslator, TranslationResult, get_onmt_opt @@ -27,6 +29,7 @@ def translate_single(self, sentence: str) -> str: assert len(translations) == 1 return translations[0] + @torch.no_grad() def translate_sentences(self, sentences: Iterable[str]) -> List[str]: """ Translate multiple sentences. @@ -34,6 +37,7 @@ def translate_sentences(self, sentences: Iterable[str]) -> List[str]: translations = self.translate_multiple_with_scores(sentences) return [t[0].text for t in translations] + @torch.no_grad() def translate_multiple_with_scores( self, sentences: Iterable[str], n_best: Optional[int] = None ) -> Iterator[List[TranslationResult]]: @@ -51,7 +55,6 @@ def translate_multiple_with_scores( translations = self.onmt_translator.translate_sentences_with_onmt( sentences, **additional_opt_kwargs ) - yield from translations @classmethod