From bb9a2168aeabbfc265c4da4c3ef118a510cf130e Mon Sep 17 00:00:00 2001 From: Irina Espejo Morales Date: Wed, 29 May 2024 15:05:23 +0200 Subject: [PATCH] temptative fix: translation utils upgrade --- .../onmt_utils/internal_translation_utils.py | 86 ++++++++++++++----- 1 file changed, 63 insertions(+), 23 deletions(-) diff --git a/src/rxn/onmt_utils/internal_translation_utils.py b/src/rxn/onmt_utils/internal_translation_utils.py index 513908d..8675cd5 100644 --- a/src/rxn/onmt_utils/internal_translation_utils.py +++ b/src/rxn/onmt_utils/internal_translation_utils.py @@ -6,6 +6,8 @@ import attr import onmt.opts as opts +from onmt.constants import CorpusTask +from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter from onmt.translate.translator import build_translator # from onmt.utils.misc import split_corpus @@ -29,7 +31,7 @@ def _split_corpus(path, shard_size): yield shard -def split_corpus(path, shard_size, default=None): +def split_corpus(path, shard_size=0, default=None): """yield a `list` containing `shard_size` line of `path`, or repeatly generate `default` if `path` is None. """ @@ -128,29 +130,64 @@ def translate_with_onmt(self, opt) -> Iterator[List[TranslationResult]]: """ # for some versions, it seems that n_best is not updated, we therefore do it manually here self.internal_translator.n_best = opt.n_best - - src_shards = split_corpus(opt.src, opt.shard_size) - tgt_shards = ( - split_corpus(opt.tgt, opt.shard_size) - if opt.tgt is not None - else repeat(None) + opt.shard_size = 0 # IRINA + + opt.src_dir = opt.src.parent + #import ipdb + #ipdb.set_trace() + #src_shards = split_corpus(opt.src, opt.shard_size) + #tgt_shards = ( + # split_corpus(opt.tgt, opt.shard_size) + # if opt.tgt is not None + # else repeat(None) + #) + #shard_pairs = zip(src_shards, tgt_shards) + + infer_iter = build_dynamic_dataset_iter( + opt=opt, + transforms_cls=opt.transforms, + vocabs=self.internal_translator.vocabs, + task=CorpusTask.INFER, + device_id=opt.gpu, ) - shard_pairs = zip(src_shards, tgt_shards) - - for i, (src_shard, tgt_shard) in enumerate(shard_pairs): - l1, l2 = self.internal_translator.translate( - src=src_shard, - tgt=tgt_shard, - src_dir=opt.src_dir, - batch_size=opt.batch_size, - batch_type=opt.batch_type, + l1_total, l2_total = self.internal_translator._translate( # IRINA + infer_iter=infer_iter, attn_debug=opt.attn_debug, - ) - for score_list, translation_list in zip(l1, l2): - yield [ - TranslationResult(text=t, score=s.item()) - for s, t in zip(score_list, translation_list) - ] + + ) + for score_list, translation_list in zip(l1_total, l2_total): + yield [ + TranslationResult(text=t, score=s) + for s, t in zip(score_list, translation_list) + ] + + # for i, (src_shard, tgt_shard) in enumerate(shard_pairs): + # #import ipdb + # #ipdb.set_trace() + # infer_iter = build_dynamic_dataset_iter( + # opt=opt, + # transforms_cls=opt.transforms, + # vocabs=self.internal_translator.vocabs, + # task=CorpusTask.INFER, + # #device_id=self.device_id, + # src=src_shard, + # tgt=tgt_shard, + # ) + # l1, l2 = self.internal_translator._translate( # IRINA + # infer_iter=infer_iter, + # attn_debug=opt.attn_debug, + # # src=src_shard, + # # tgt=tgt_shard, + # # src_dir=opt.src_dir, + # # batch_size=opt.batch_size, + # # batch_type=opt.batch_type, + # # attn_debug=opt.attn_debug, + # ) + # for score_list, translation_list in zip(l1, l2): + # yield [ + # TranslationResult(text=t, score=s.item()) + # for s, t in zip(score_list, translation_list) + # ] def get_onmt_opt( @@ -192,7 +229,10 @@ def onmt_parser() -> ArgumentParser: parser = ArgumentParser(description="translate.py") - opts.config_opts(parser) + #opts.config_opts(parser) # IRINA + + #import ipdb + #ipdb.set_trace() opts.translate_opts(parser) return parser