Skip to content

Commit

Permalink
temptative fix: translation utils upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
irinaespejo committed May 29, 2024
1 parent 6c7b563 commit bb9a216
Showing 1 changed file with 63 additions and 23 deletions.
86 changes: 63 additions & 23 deletions src/rxn/onmt_utils/internal_translation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import attr
import onmt.opts as opts
from onmt.constants import CorpusTask
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.translate.translator import build_translator

# from onmt.utils.misc import split_corpus
Expand All @@ -29,7 +31,7 @@ def _split_corpus(path, shard_size):
yield shard


def split_corpus(path, shard_size, default=None):
def split_corpus(path, shard_size=0, default=None):
"""yield a `list` containing `shard_size` line of `path`,
or repeatly generate `default` if `path` is None.
"""
Expand Down Expand Up @@ -128,29 +130,64 @@ def translate_with_onmt(self, opt) -> Iterator[List[TranslationResult]]:
"""
# for some versions, it seems that n_best is not updated, we therefore do it manually here
self.internal_translator.n_best = opt.n_best

src_shards = split_corpus(opt.src, opt.shard_size)
tgt_shards = (
split_corpus(opt.tgt, opt.shard_size)
if opt.tgt is not None
else repeat(None)
opt.shard_size = 0 # IRINA

opt.src_dir = opt.src.parent
#import ipdb
#ipdb.set_trace()
#src_shards = split_corpus(opt.src, opt.shard_size)
#tgt_shards = (
# split_corpus(opt.tgt, opt.shard_size)
# if opt.tgt is not None
# else repeat(None)
#)
#shard_pairs = zip(src_shards, tgt_shards)

infer_iter = build_dynamic_dataset_iter(
opt=opt,
transforms_cls=opt.transforms,
vocabs=self.internal_translator.vocabs,
task=CorpusTask.INFER,
device_id=opt.gpu,
)
shard_pairs = zip(src_shards, tgt_shards)

for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
l1, l2 = self.internal_translator.translate(
src=src_shard,
tgt=tgt_shard,
src_dir=opt.src_dir,
batch_size=opt.batch_size,
batch_type=opt.batch_type,
l1_total, l2_total = self.internal_translator._translate( # IRINA
infer_iter=infer_iter,
attn_debug=opt.attn_debug,
)
for score_list, translation_list in zip(l1, l2):
yield [
TranslationResult(text=t, score=s.item())
for s, t in zip(score_list, translation_list)
]

)
for score_list, translation_list in zip(l1_total, l2_total):
yield [
TranslationResult(text=t, score=s)
for s, t in zip(score_list, translation_list)
]

# for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
# #import ipdb
# #ipdb.set_trace()
# infer_iter = build_dynamic_dataset_iter(
# opt=opt,
# transforms_cls=opt.transforms,
# vocabs=self.internal_translator.vocabs,
# task=CorpusTask.INFER,
# #device_id=self.device_id,
# src=src_shard,
# tgt=tgt_shard,
# )
# l1, l2 = self.internal_translator._translate( # IRINA
# infer_iter=infer_iter,
# attn_debug=opt.attn_debug,
# # src=src_shard,
# # tgt=tgt_shard,
# # src_dir=opt.src_dir,
# # batch_size=opt.batch_size,
# # batch_type=opt.batch_type,
# # attn_debug=opt.attn_debug,
# )
# for score_list, translation_list in zip(l1, l2):
# yield [
# TranslationResult(text=t, score=s.item())
# for s, t in zip(score_list, translation_list)
# ]


def get_onmt_opt(
Expand Down Expand Up @@ -192,7 +229,10 @@ def onmt_parser() -> ArgumentParser:

parser = ArgumentParser(description="translate.py")

opts.config_opts(parser)
#opts.config_opts(parser) # IRINA

#import ipdb
#ipdb.set_trace()
opts.translate_opts(parser)

return parser

0 comments on commit bb9a216

Please sign in to comment.