Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Upgrade onmt to v.3.5.1 #7

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ jobs:
build:
runs-on: ubuntu-latest
name: Build the Sphinx docs
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.8
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: 3.8
python-version: ${{ matrix.python-version }}
- name: Install package dependencies
run: pip install -e .
- name: Install sphinx dependencies
Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/pypi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ jobs:
build-and-publish:
name: Build and publish rxn-onmt-utils on PyPI
runs-on: ubuntu-latest

strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@master
- name: Python setup 3.9
- name: Python setup ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: 3.9
python-version: ${{ matrix.python-version }}
- name: Install build package (for packaging)
run: pip install --upgrade build
- name: Build dist
Expand Down
7 changes: 5 additions & 2 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ jobs:
tests:
runs-on: ubuntu-latest
name: Style, mypy
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.7
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: 3.7
python-version: ${{ matrix.python-version }}
- name: Install Dependencies
run: pip install -e .[dev]
- name: Check black
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ check_untyped_defs = true
[[tool.mypy.overrides]]
module = [
"onmt.*",
"yaml.*",
"torch.*",
]
ignore_missing_imports = true

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ include_package_data = True
install_requires =
attrs>=21.2.0
click>=8.0
rxn-opennmt-py>=1.1.1
#rxn-opennmt-py>=1.1.1 # Remove opennmt-py fork dependence
rxn-utils>=1.6.0

[options.packages.find]
Expand Down
117 changes: 92 additions & 25 deletions src/rxn/onmt_utils/internal_translation_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,47 @@
import copy
import os
from argparse import Namespace
from itertools import repeat
from itertools import islice, repeat
from typing import Any, Iterable, Iterator, List, Optional

import attr
import onmt.opts as opts
import torch
from onmt.constants import CorpusTask
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.translate.translator import build_translator
from onmt.utils.misc import split_corpus

# from onmt.utils.misc import split_corpus
from onmt.utils.parse import ArgumentParser
from rxn.utilities.files import named_temporary_path


# Introduced back _split_corpus and split_corpus originally in onmt.utils.misc
# This commit gets rid of it: https://github.com/OpenNMT/OpenNMT-py/commit/4dcb2b9478eba32a480364e595f5fff7bd8ca887
# Since dependencies of split_corpus and _split_corpus are only itertools, it's easier to add them to source code
def _split_corpus(path, shard_size):
"""Yield a `list` containing `shard_size` line of `path`."""
with open(path, "rb") as f:
if shard_size <= 0:
yield f.readlines()
else:
while True:
shard = list(islice(f, shard_size))
if not shard:
break
yield shard


def split_corpus(path, shard_size=0, default=None):
"""yield a `list` containing `shard_size` line of `path`,
or repeatly generate `default` if `path` is None.
"""
if path is not None:
return _split_corpus(path, shard_size)
else:
return repeat(default)


@attr.s(auto_attribs=True)
class TranslationResult:
"""
Expand Down Expand Up @@ -88,6 +118,7 @@ def translate_sentences_with_onmt(
else:
yield translation_results

@torch.no_grad()
def translate_with_onmt(self, opt) -> Iterator[List[TranslationResult]]:
"""
Do the translation (in tokenized format) with OpenNMT.
Expand All @@ -101,29 +132,60 @@ def translate_with_onmt(self, opt) -> Iterator[List[TranslationResult]]:
"""
# for some versions, it seems that n_best is not updated, we therefore do it manually here
self.internal_translator.n_best = opt.n_best

src_shards = split_corpus(opt.src, opt.shard_size)
tgt_shards = (
split_corpus(opt.tgt, opt.shard_size)
if opt.tgt is not None
else repeat(None)

opt.src_dir = opt.src.parent

#pprint.pprint(opt)


infer_iter = build_dynamic_dataset_iter(
opt=opt,
transforms_cls=opt.transforms,
vocabs=self.internal_translator.vocabs,
task=CorpusTask.INFER,
device_id=opt.gpu,
)
shard_pairs = zip(src_shards, tgt_shards)

for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
l1, l2 = self.internal_translator.translate(
src=src_shard,
tgt=tgt_shard,
src_dir=opt.src_dir,
batch_size=opt.batch_size,
batch_type=opt.batch_type,

l1_total, l2_total = self.internal_translator._translate( # IRINA
infer_iter=infer_iter,
attn_debug=opt.attn_debug,
)
for score_list, translation_list in zip(l1, l2):
yield [
TranslationResult(text=t, score=s.item())
for s, t in zip(score_list, translation_list)
]
)

del infer_iter
for score_list, translation_list in zip(l1_total, l2_total):
yield [
TranslationResult(text=t, score=s)
for s, t in zip(score_list, translation_list)
]
del l1_total, l2_total

# for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
# #import ipdb
# #ipdb.set_trace()
# infer_iter = build_dynamic_dataset_iter(
# opt=opt,
# transforms_cls=opt.transforms,
# vocabs=self.internal_translator.vocabs,
# task=CorpusTask.INFER,
# #device_id=self.device_id,
# src=src_shard,
# tgt=tgt_shard,
# )
# l1, l2 = self.internal_translator._translate( # IRINA
# infer_iter=infer_iter,
# attn_debug=opt.attn_debug,
# # src=src_shard,
# # tgt=tgt_shard,
# # src_dir=opt.src_dir,
# # batch_size=opt.batch_size,
# # batch_type=opt.batch_type,
# # attn_debug=opt.attn_debug,
# )
# for score_list, translation_list in zip(l1, l2):
# yield [
# TranslationResult(text=t, score=s.item())
# for s, t in zip(score_list, translation_list)
# ]


def get_onmt_opt(
Expand Down Expand Up @@ -155,6 +217,10 @@ def get_onmt_opt(
setattr(opt, key, value)
ArgumentParser.validate_translate_opts(opt)

#opt.random_sampling_topk = 1.0
#opt.length_penalty = "none"
#opt.alpha = 0

return opt


Expand All @@ -165,7 +231,8 @@ def onmt_parser() -> ArgumentParser:

parser = ArgumentParser(description="translate.py")

opts.config_opts(parser)
#opts.config_opts(parser) # IRINA

opts.translate_opts(parser)

return parser
return parser
21 changes: 15 additions & 6 deletions src/rxn/onmt_utils/model_introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Any, Dict, List

import torch
from onmt.inputters.text_dataset import TextMultiField

#from onmt.inputters.text_dataset import TextMultiField
from rxn.utilities.files import PathLike


Expand All @@ -29,6 +30,14 @@ def get_preprocessed_vocab(vocab_path: PathLike) -> List[str]:
return _torch_vocab_to_list(vocab)


def read_vocab_file(file_path):
vocab = []
with open(file_path, 'r') as file:
for line in file:
vocab.append(line.split()[0]) # Split each line and take the first element
return vocab


def model_vocab_is_compatible(model_pt: PathLike, vocab_pt: PathLike) -> bool:
"""
Determine whether the vocabulary contained in a model checkpoint contains
Expand All @@ -39,20 +48,20 @@ def model_vocab_is_compatible(model_pt: PathLike, vocab_pt: PathLike) -> bool:
vocab_pt: vocab file, such as `preprocessed.vocab.pt`.
"""
model_vocab = set(get_model_vocab(model_pt))
data_vocab = set(get_preprocessed_vocab(vocab_pt))
data_vocab = set(read_vocab_file(vocab_pt))
return data_vocab.issubset(model_vocab)


def _torch_vocab_to_list(vocab: Dict[str, Any]) -> List[str]:
src_vocab = _multifield_vocab_to_list(vocab["src"])
tgt_vocab = _multifield_vocab_to_list(vocab["tgt"])
src_vocab = vocab["src"] #_multifield_vocab_to_list(vocab["src"])
tgt_vocab = vocab["tgt"] #_multifield_vocab_to_list(vocab["tgt"])
if src_vocab != tgt_vocab:
raise RuntimeError("Handling of different src/tgt vocab not implemented")
return src_vocab


def _multifield_vocab_to_list(multifield: TextMultiField) -> List[str]:
return multifield.base_field.vocab.itos[:]
#def _multifield_vocab_to_list(multifield: TextMultiField) -> List[str]:
# return multifield.base_field.vocab.itos[:]


def get_model_opt(model_path: PathLike) -> Namespace:
Expand Down
Loading
Loading