diff --git a/data/data_features/src-test-with-feats.txt b/data/data_features/src-test-with-feats.txt new file mode 100644 index 0000000000..4a41985c0d --- /dev/null +++ b/data/data_features/src-test-with-feats.txt @@ -0,0 +1 @@ +she│C is│B a│A hard-working.│B \ No newline at end of file diff --git a/data/data_features/src-test.feat0 b/data/data_features/src-test.feat0 deleted file mode 100644 index 4ab4a9e651..0000000000 --- a/data/data_features/src-test.feat0 +++ /dev/null @@ -1 +0,0 @@ -C B A B \ No newline at end of file diff --git a/data/data_features/src-train-with-feats.txt b/data/data_features/src-train-with-feats.txt new file mode 100644 index 0000000000..42cd4995a4 --- /dev/null +++ b/data/data_features/src-train-with-feats.txt @@ -0,0 +1,3 @@ +however,│A according│A to│A the│A logs,│B she│A is│A a│A hard-working.│C +however,│A according│B to│C the│D logs,│E +she│C is│B a│A hard-working.│B \ No newline at end of file diff --git a/data/data_features/src-train.feat0 b/data/data_features/src-train.feat0 deleted file mode 100644 index 7e189f2c33..0000000000 --- a/data/data_features/src-train.feat0 +++ /dev/null @@ -1,3 +0,0 @@ -A A A A B A A A C -A B C D E -C B A B \ No newline at end of file diff --git a/data/data_features/src-val-with-feats.txt b/data/data_features/src-val-with-feats.txt new file mode 100644 index 0000000000..4a41985c0d --- /dev/null +++ b/data/data_features/src-val-with-feats.txt @@ -0,0 +1 @@ +she│C is│B a│A hard-working.│B \ No newline at end of file diff --git a/data/data_features/src-val.feat0 b/data/data_features/src-val.feat0 deleted file mode 100644 index 4ab4a9e651..0000000000 --- a/data/data_features/src-val.feat0 +++ /dev/null @@ -1 +0,0 @@ -C B A B \ No newline at end of file diff --git a/data/data_features/tgt-test-with-feats.txt b/data/data_features/tgt-test-with-feats.txt new file mode 100644 index 0000000000..4a41985c0d --- /dev/null +++ b/data/data_features/tgt-test-with-feats.txt @@ -0,0 +1 @@ +she│C is│B a│A hard-working.│B \ No newline at end of file diff --git a/data/data_features/tgt-test.txt b/data/data_features/tgt-test.txt new file mode 100644 index 0000000000..0cc723ce39 --- /dev/null +++ b/data/data_features/tgt-test.txt @@ -0,0 +1 @@ +she is a hard-working. \ No newline at end of file diff --git a/data/data_features/tgt-train-with-feats.txt b/data/data_features/tgt-train-with-feats.txt new file mode 100644 index 0000000000..42cd4995a4 --- /dev/null +++ b/data/data_features/tgt-train-with-feats.txt @@ -0,0 +1,3 @@ +however,│A according│A to│A the│A logs,│B she│A is│A a│A hard-working.│C +however,│A according│B to│C the│D logs,│E +she│C is│B a│A hard-working.│B \ No newline at end of file diff --git a/data/data_features/tgt-val-with-feats.txt b/data/data_features/tgt-val-with-feats.txt new file mode 100644 index 0000000000..4a41985c0d --- /dev/null +++ b/data/data_features/tgt-val-with-feats.txt @@ -0,0 +1 @@ +she│C is│B a│A hard-working.│B \ No newline at end of file diff --git a/data/features_configs/source_and_target_features.yaml b/data/features_configs/source_and_target_features.yaml new file mode 100644 index 0000000000..53f331f5f0 --- /dev/null +++ b/data/features_configs/source_and_target_features.yaml @@ -0,0 +1,20 @@ +# Corpus opts: +data: + corpus_1: + path_src: data/data_features/src-train-with-feats.txt + path_tgt: data/data_features/tgt-train-with-feats.txt + transforms: [inferfeats] + corpus_2: + path_src: data/data_features/src-train.txt + path_tgt: data/data_features/tgt-train.txt + transforms: [inferfeats] + valid: + path_src: data/data_features/src-val-with-feats.txt + path_tgt: data/data_features/tgt-val-with-feats.txt + transforms: [inferfeats] + +# # Feats options +n_src_feats: 1 +n_tgt_feats: 1 +src_feats_defaults: "0" +tgt_feats_defaults: "1" diff --git a/data/features_configs/source_features_only.yaml b/data/features_configs/source_features_only.yaml new file mode 100644 index 0000000000..087d8dae24 --- /dev/null +++ b/data/features_configs/source_features_only.yaml @@ -0,0 +1,18 @@ +# Corpus opts: +data: + corpus_1: + path_src: data/data_features/src-train-with-feats.txt + path_tgt: data/data_features/tgt-train.txt + transforms: [inferfeats] + corpus_2: + path_src: data/data_features/src-train.txt + path_tgt: data/data_features/tgt-train.txt + transforms: [inferfeats] + valid: + path_src: data/data_features/src-val-with-feats.txt + path_tgt: data/data_features/tgt-val.txt + transforms: [inferfeats] + +# # Feats options +n_src_feats: 1 +src_feats_defaults: "0" diff --git a/data/features_configs/target_features_only.yaml b/data/features_configs/target_features_only.yaml new file mode 100644 index 0000000000..2212d593ab --- /dev/null +++ b/data/features_configs/target_features_only.yaml @@ -0,0 +1,18 @@ +# Corpus opts: +data: + corpus_1: + path_src: data/data_features/src-train.txt + path_tgt: data/data_features/tgt-train-with-feats.txt + transforms: [inferfeats] + corpus_2: + path_src: data/data_features/src-train.txt + path_tgt: data/data_features/tgt-train.txt + transforms: [inferfeats] + valid: + path_src: data/data_features/src-val.txt + path_tgt: data/data_features/tgt-val-with-feats.txt + transforms: [inferfeats] + +# # Feats options +n_tgt_feats: 1 +tgt_feats_defaults: "0" diff --git a/data/features_data.yaml b/data/features_data.yaml deleted file mode 100644 index fa9b665f9c..0000000000 --- a/data/features_data.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# Corpus opts: -data: - corpus_1: - path_src: data/data_features/src-train.txt - path_tgt: data/data_features/tgt-train.txt - src_feats: - feat0: data/data_features/src-train.feat0 - transforms: [filterfeats, inferfeats] - valid: - path_src: data/data_features/src-val.txt - path_tgt: data/data_features/tgt-val.txt diff --git a/onmt/bin/build_vocab.py b/onmt/bin/build_vocab.py index 63adf02908..4c2facac51 100644 --- a/onmt/bin/build_vocab.py +++ b/onmt/bin/build_vocab.py @@ -9,7 +9,7 @@ from onmt.inputters.text_utils import process from onmt.transforms import make_transforms, get_transforms_cls from onmt.constants import CorpusName, CorpusTask -from collections import Counter, defaultdict +from collections import Counter import multiprocessing as mp @@ -42,19 +42,26 @@ def write_files_from_queues(sample_path, queues): # Just for debugging purposes # It appends features to subwords when dumping to file -def append_features_to_example(example, features): - ex_toks = example.split(' ') - feat_toks = features.split(' ') - toks = [f"{subword}│{feat}" for subword, feat in - zip(ex_toks, feat_toks)] - return " ".join(toks) +def append_features_to_text(text, features): + text_tok = text.split(' ') + feats_tok = [x.split(' ') for x in features] + + pretty_toks = [] + for tok, *feats in zip(text_tok, *feats_tok): + feats = '│'.join(feats) + if feats: + pretty_toks.append(f"{tok}│{feats}") + else: + pretty_toks.append(tok) + return " ".join(pretty_toks) def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): """Build vocab on (strided) subpart of the data.""" sub_counter_src = Counter() sub_counter_tgt = Counter() - sub_counter_src_feats = defaultdict(Counter) + sub_counter_src_feats = [Counter() for _ in range(opts.n_src_feats)] + sub_counter_tgt_feats = [Counter() for _ in range(opts.n_tgt_feats)] datasets_iterables = build_corpora_iters( corpora, transforms, opts.data, skip_empty_level=opts.skip_empty_level, @@ -68,28 +75,35 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): if opts.dump_samples: build_sub_vocab.queues[c_name][offset].put("blank") continue - src_line, tgt_line = (maybe_example['src']['src'], - maybe_example['tgt']['tgt']) - src_line_pretty = src_line - for feat_name, feat_line in maybe_example["src"].items(): - if feat_name not in ["src", "src_original"]: - sub_counter_src_feats[feat_name].update( - feat_line.split(' ')) - if opts.dump_samples: - src_line_pretty = append_features_to_example( - src_line_pretty, feat_line) + src_line = maybe_example.src + tgt_line = maybe_example.tgt + src_feats_lines = maybe_example.src_feats + tgt_feats_lines = maybe_example.tgt_feats + sub_counter_src.update(src_line.split(' ')) sub_counter_tgt.update(tgt_line.split(' ')) + for i in range(opts.n_src_feats): + sub_counter_src_feats[i].update(src_feats_lines[i].split(' ')) + for i in range(opts.n_tgt_feats): + sub_counter_tgt_feats[i].update(tgt_feats_lines[i].split(' ')) + if opts.dump_samples: + src_pretty_line = append_features_to_text( + src_line, src_feats_lines) + tgt_pretty_line = append_features_to_text( + tgt_line, tgt_feats_lines) build_sub_vocab.queues[c_name][offset].put( - (i, src_line_pretty, tgt_line)) + (i, src_pretty_line, tgt_pretty_line)) if n_sample > 0 and ((i+1) * stride + offset) >= n_sample: if opts.dump_samples: build_sub_vocab.queues[c_name][offset].put("break") break if opts.dump_samples: build_sub_vocab.queues[c_name][offset].put("break") - return sub_counter_src, sub_counter_tgt, sub_counter_src_feats + return (sub_counter_src, + sub_counter_tgt, + sub_counter_src_feats, + sub_counter_tgt_feats) def init_pool(queues): @@ -113,7 +127,8 @@ def build_vocab(opts, transforms, n_sample=3): corpora = get_corpora(opts, task=CorpusTask.TRAIN) counter_src = Counter() counter_tgt = Counter() - counter_src_feats = defaultdict(Counter) + counter_src_feats = [Counter() for _ in range(opts.n_src_feats)] + counter_tgt_feats = [Counter() for _ in range(opts.n_tgt_feats)] from functools import partial queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size) for i in range(opts.num_threads)] @@ -130,14 +145,18 @@ def build_vocab(opts, transforms, n_sample=3): func = partial( build_sub_vocab, corpora, transforms, opts, n_sample, opts.num_threads) - for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap( - func, range(0, opts.num_threads)): + for (sub_counter_src, sub_counter_tgt, + sub_counter_src_feats, sub_counter_tgt_feats) \ + in p.imap(func, range(0, opts.num_threads)): counter_src.update(sub_counter_src) counter_tgt.update(sub_counter_tgt) - counter_src_feats.update(sub_counter_src_feats) + for i in range(opts.n_src_feats): + counter_src_feats[i].update(sub_counter_src_feats[i]) + for i in range(opts.n_tgt_feats): + counter_tgt_feats[i].update(sub_counter_tgt_feats[i]) if opts.dump_samples: write_process.join() - return counter_src, counter_tgt, counter_src_feats + return counter_src, counter_tgt, counter_src_feats, counter_tgt_feats def build_vocab_main(opts): @@ -163,13 +182,16 @@ def build_vocab_main(opts): transforms = make_transforms(opts, transforms_cls, None) logger.info(f"Counter vocab from {opts.n_sample} samples.") - src_counter, tgt_counter, src_feats_counter = build_vocab( + (src_counter, tgt_counter, + src_feats_counter, tgt_feats_counter) = build_vocab( opts, transforms, n_sample=opts.n_sample) - logger.info(f"Counters src:{len(src_counter)}") - logger.info(f"Counters tgt:{len(tgt_counter)}") - for feat_name, feat_counter in src_feats_counter.items(): - logger.info(f"Counters {feat_name}:{len(feat_counter)}") + logger.info(f"Counters src: {len(src_counter)}") + logger.info(f"Counters tgt: {len(tgt_counter)}") + for i, feat_counter in enumerate(src_feats_counter): + logger.info(f"Counters src feat_{i}: {len(feat_counter)}") + for i, feat_counter in enumerate(tgt_feats_counter): + logger.info(f"Counters tgt feat_{i}: {len(feat_counter)}") def save_counter(counter, save_path): check_path(save_path, exist_ok=opts.overwrite, log=logger.warning) @@ -186,8 +208,10 @@ def save_counter(counter, save_path): save_counter(src_counter, opts.src_vocab) save_counter(tgt_counter, opts.tgt_vocab) - for k, v in src_feats_counter.items(): - save_counter(v, opts.src_feats_vocab[k]) + for i, c in enumerate(src_feats_counter): + save_counter(c, f"{opts.src_vocab}_feat{i}") + for i, c in enumerate(tgt_feats_counter): + save_counter(c, f"{opts.tgt_vocab}_feat{i}") def _get_parser(): diff --git a/onmt/decoders/ensemble.py b/onmt/decoders/ensemble.py index 31ad6f2509..c13ab5c67a 100644 --- a/onmt/decoders/ensemble.py +++ b/onmt/decoders/ensemble.py @@ -99,14 +99,31 @@ def forward(self, hidden, attn=None, src_map=None): by averaging distributions from models in the ensemble. All models in the ensemble must share a target vocabulary. """ - distributions = torch.stack( - [mg(h) if attn is None else mg(h, attn, src_map) - for h, mg in zip(hidden, self.model_generators)] - ) + + distributions, feats_distributions = [], [] + n_feats = len(self.model_generators[0].feats_generators) + for h, mg in zip(hidden, self.model_generators): + scores, feats_scores = \ + (mg(h) if attn is None else mg(h, attn, src_map)) + distributions.append(scores) + feats_distributions.append(feats_scores) + + distributions = torch.stack(distributions) + + stacked_feats_distributions = [] + for i in range(n_feats): + stacked_feats_distributions.append( + torch.stack([feat_distribution[i] + for feat_distribution in feats_distributions + for i in range(n_feats)])) + if self._raw_probs: - return torch.log(torch.exp(distributions).mean(0)) + return (torch.log(torch.exp(distributions).mean(0)), + [torch.log(torch.exp(d).mean(0)) + for d in stacked_feats_distributions]) else: - return distributions.mean(0) + return (distributions.mean(0), + [d.mean(0) for d in stacked_feats_distributions]) class EnsembleModel(NMTModel): diff --git a/onmt/inputters/__init__.py b/onmt/inputters/__init__.py index 9ad1cdcec9..79e4dff1e2 100644 --- a/onmt/inputters/__init__.py +++ b/onmt/inputters/__init__.py @@ -4,15 +4,14 @@ e.g., from a line of text to a sequence of vectors. """ from onmt.inputters.inputter import build_vocab, IterOnDevice -from onmt.inputters.text_utils import text_sort_key, process,\ - numericalize, tensorify +from onmt.inputters.text_utils import text_sort_key, process, tensorify from onmt.inputters.text_corpus import ParallelCorpus, ParallelCorpusIterator from onmt.inputters.dynamic_iterator import MixingStrategy, SequentialMixer,\ WeightedMixer, DynamicDatasetIter __all__ = ['IterOnDevice', 'build_vocab', 'text_sort_key', - 'process', 'numericalize', 'tensorify', + 'process', 'tensorify', 'ParallelCorpus', 'ParallelCorpusIterator', 'MixingStrategy', 'SequentialMixer', 'WeightedMixer', 'DynamicDatasetIter'] diff --git a/onmt/inputters/dynamic_iterator.py b/onmt/inputters/dynamic_iterator.py index 5bca1f71d6..5fbe76f7d1 100644 --- a/onmt/inputters/dynamic_iterator.py +++ b/onmt/inputters/dynamic_iterator.py @@ -4,7 +4,7 @@ from onmt.constants import CorpusTask from onmt.inputters.text_corpus import get_corpora, build_corpora_iters from onmt.inputters.text_utils import text_sort_key, process,\ - numericalize, tensorify, _addcopykeys + tensorify from onmt.transforms import make_transforms from onmt.utils.logging import init_logger, logger from onmt.utils.misc import RandomShuffler @@ -209,8 +209,9 @@ def _tuple_to_json_with_tokIDs(self, tuple_bucket): for example in tuple_bucket: if example is not None: if self.copy: - example = _addcopykeys(self.vocabs, example) - bucket.append(numericalize(self.vocabs, example)) + example.addcopykeys(self.vocabs) + example.numericalize(self.vocabs) + bucket.append(example) return bucket def _bucketing(self): @@ -252,11 +253,8 @@ def batch_size_fn(nbsents, maxlen): minibatch, maxlen, size_so_far, seen = [], 0, 0, [] for ex in data: - if ( - (ex['src']['src'] not in seen) or - (self.task != CorpusTask.TRAIN) - ): - seen.append(ex['src']['src']) + if ((ex.src not in seen) or (self.task != CorpusTask.TRAIN)): + seen.append(ex.src) minibatch.append(ex) nbsents = len(minibatch) maxlen = max(text_sort_key(ex), maxlen) diff --git a/onmt/inputters/example.py b/onmt/inputters/example.py new file mode 100644 index 0000000000..e3a696e113 --- /dev/null +++ b/onmt/inputters/example.py @@ -0,0 +1,157 @@ +from onmt.constants import DefaultTokens, ModelTask +import pyonmttok +from collections import Counter + + +class Example(object): + """Container for each of the parallel corpus examples""" + + def __init__(self, src, src_original, src_feats=None, + tgt=None, tgt_original=None, tgt_feats=None, + align=None): + # 'src_original' and 'tgt_original' store the + # original line before tokenization. These + # fields are used later on in the feature + # transforms. + self.src = src + self.src_ids = None + self.src_original = src_original + self.src_feats = src_feats + self.src_feats_ids = None + + self.tgt = tgt + self.tgt_ids = None + self.tgt_original = tgt_original + self.tgt_feats = tgt_feats + self.tgt_feats_ids = None + + # Alignments + self.align = align + + # Copy mechanism + self.src_map = None + self.src_ex_vocab = None + self.alignment = None + + def tokenize(self): + self.src = self.src.strip('\n').split() + self.src_original = \ + self.src_original.strip("\n").split() + if self.src_feats is not None: + self.src_feats = \ + [feat.split() for feat in self.src_feats] + if self.tgt is not None: + self.tgt = self.tgt.strip('\n').split() + self.tgt_original = \ + self.tgt_original.strip("\n").split() + if self.tgt_feats is not None: + self.tgt_feats = \ + [feat.split() for feat in self.tgt_feats] + if self.align is not None: + self.align = self.align.strip('\n').split() + + def add_index(self, idx): + self.index = idx + + def is_empty(self): + if len(self.src) == 0: + return True + if self.tgt is not None: + if len(self.tgt) == 0: + return True + if self.align is not None: + if len(self.align) == 0: + return True + return False + + def clean(self): + self.src = ' '.join(self.src) + if self.src_feats is not None: + self.src_feats = [' '.join(x) for x in self.src_feats] + if self.tgt is not None: + self.tgt = ' '.join(self.tgt) + if self.tgt_feats is not None: + self.tgt_feats = [' '.join(x) for x in self.tgt_feats] + if self.align is not None: + self.align = ' '.join(self.align) + + def numericalize(self, vocabs): + data_task = vocabs['data_task'] + assert data_task in [ModelTask.SEQ2SEQ, ModelTask.LANGUAGE_MODEL], \ + f"Something went wrong with task {vocabs['data_task']}" + + src_toks = self.src.split() + if data_task == ModelTask.SEQ2SEQ: + self.src_ids = vocabs['src'](src_toks) + elif data_task == ModelTask.LANGUAGE_MODEL: + self.src_ids = vocabs['src']([DefaultTokens.BOS] + src_toks) + + if self.src_feats is not None: + self.src_feats_ids = [] + for fv, feat in zip(vocabs['src_feats'], self.src_feats): + feat_toks = feat.split() + if data_task == ModelTask.SEQ2SEQ: + self.src_feats_ids.append(fv(feat_toks)) + else: + self.src_feats_ids.append( + fv([DefaultTokens.BOS] + feat_toks)) + + if self.tgt is not None: + tgt_toks = self.tgt.split() + if data_task == ModelTask.SEQ2SEQ: + self.tgt_ids = vocabs['tgt']([DefaultTokens.BOS] + + tgt_toks + + [DefaultTokens.EOS]) + elif data_task == ModelTask.LANGUAGE_MODEL: + self.tgt_ids = vocabs['tgt'](tgt_toks + [DefaultTokens.EOS]) + + if self.tgt_feats is not None: + self.tgt_feats_ids = [] + for fv, feat in zip(vocabs['tgt_feats'], self.tgt_feats): + feat_toks = feat.split() + if data_task == ModelTask.SEQ2SEQ: + self.tgt_feats_ids.append( + fv([DefaultTokens.BOS] + feat_toks + + [DefaultTokens.EOS])) + else: + self.tgt_feats_ids.append( + fv(feat_toks + [DefaultTokens.EOS])) + + def addcopykeys(self, vocabs): + """Create copy-vocab and numericalize with it. + In-place adds ``"src_map"`` to ``example``. That is the copy-vocab + numericalization of the tokenized ``example["src"]``. If ``example`` + has a ``"tgt"`` key, adds ``"alignment"`` to example. That is the + copy-vocab numericalization of the tokenized ``example["tgt"]``. The + alignment has an initial and final UNK token to match the BOS and EOS + tokens. + Args: + vocabs + example (dict): An example dictionary with a ``"src"`` key and + maybe a ``"tgt"`` key. (This argument changes in place!) + Returns: + ``example``, changed as described. + """ + src = self.src.split() + src_ex_vocab = pyonmttok.build_vocab_from_tokens( + Counter(src), + maximum_size=0, + minimum_frequency=1, + special_tokens=[DefaultTokens.UNK, + DefaultTokens.PAD, + DefaultTokens.BOS, + DefaultTokens.EOS]) + src_ex_vocab.default_id = src_ex_vocab[DefaultTokens.UNK] + # make a small vocab containing just the tokens in the source sequence + + # Map source tokens to indices in the dynamic dict. + self.src_map = src_ex_vocab(src) + self.src_ex_vocab = src_ex_vocab + + if self.tgt is not None: + if vocabs['data_task'] == ModelTask.SEQ2SEQ: + tgt = [DefaultTokens.UNK] + self.tgt.split() \ + + [DefaultTokens.UNK] + elif vocabs['data_task'] == ModelTask.LANGUAGE_MODEL: + tgt = self.tgt.split() + [DefaultTokens.UNK] + self.alignment = src_ex_vocab(tgt) diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index 0bcf80c2f0..d97f676734 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -34,11 +34,11 @@ def build_vocab(opt, specials): """ Build vocabs dict to be stored in the checkpoint based on vocab files having each line [token, count] Args: - opt: src_vocab, tgt_vocab, src_feats_vocab + opt: src_vocab, tgt_vocab, n_src_feats, n_tgt_feats Return: vocabs: {'src': pyonmttok.Vocab, 'tgt': pyonmttok.Vocab, - 'src_feats' : {'feat0': pyonmttok.Vocab, - 'feat1': pyonmttok.Vocab, ...}, + 'src_feats' : [pyonmttok.Vocab, pyonmttok.Vocab, ...], + 'tgt_feats' : [pyonmttok.Vocab, pyonmttok.Vocab, ...], 'data_task': seq2seq or lm } """ @@ -85,10 +85,10 @@ def _pad_vocab_to_multiple(vocab, multiple): opt.vocab_size_multiple) vocabs['tgt'] = tgt_vocab - if opt.src_feats_vocab: - src_feats = {} - for feat_name, filepath in opt.src_feats_vocab.items(): - src_f_vocab = _read_vocab_file(filepath, 1) + if opt.n_src_feats > 0: + src_feats_vocabs = [] + for i in range(opt.n_src_feats): + src_f_vocab = _read_vocab_file(f"{opt.src_vocab}_feat{i}", 1) src_f_vocab = pyonmttok.build_vocab_from_tokens( src_f_vocab, maximum_size=0, @@ -101,8 +101,31 @@ def _pad_vocab_to_multiple(vocab, multiple): if opt.vocab_size_multiple > 1: src_f_vocab = _pad_vocab_to_multiple(src_f_vocab, opt.vocab_size_multiple) - src_feats[feat_name] = src_f_vocab - vocabs['src_feats'] = src_feats + src_feats_vocabs.append(src_f_vocab) + vocabs["src_feats"] = src_feats_vocabs + else: + vocabs["src_feats"] = [] + + if opt.n_tgt_feats > 0: + tgt_feats_vocabs = [] + for i in range(opt.n_tgt_feats): + tgt_f_vocab = _read_vocab_file(f"{opt.tgt_vocab}_feat{i}", 1) + tgt_f_vocab = pyonmttok.build_vocab_from_tokens( + tgt_f_vocab, + maximum_size=0, + minimum_frequency=1, + special_tokens=[DefaultTokens.UNK, + DefaultTokens.PAD, + DefaultTokens.BOS, + DefaultTokens.EOS]) + tgt_f_vocab.default_id = tgt_f_vocab[DefaultTokens.UNK] + if opt.vocab_size_multiple > 1: + tgt_f_vocab = _pad_vocab_to_multiple(tgt_f_vocab, + opt.vocab_size_multiple) + tgt_feats_vocabs.append(tgt_f_vocab) + vocabs["tgt_feats"] = tgt_feats_vocabs + else: + vocabs["tgt_feats"] = [] vocabs['data_task'] = opt.data_task @@ -146,10 +169,11 @@ def vocabs_to_dict(vocabs): vocabs_dict['src'] = vocabs['src'].ids_to_tokens vocabs_dict['tgt'] = vocabs['tgt'].ids_to_tokens if 'src_feats' in vocabs.keys(): - vocabs_dict['src_feats'] = {} - for feat in vocabs['src_feats'].keys(): - vocabs_dict['src_feats'][feat] = \ - vocabs['src_feats'][feat].ids_to_tokens + vocabs_dict['src_feats'] = [feat_vocab.ids_to_tokens + for feat_vocab in vocabs['src_feats']] + if 'tgt_feats' in vocabs.keys(): + vocabs_dict['tgt_feats'] = [feat_vocab.ids_to_tokens + for feat_vocab in vocabs['tgt_feats']] vocabs_dict['data_task'] = vocabs['data_task'] return vocabs_dict @@ -167,9 +191,13 @@ def dict_to_vocabs(vocabs_dict): else: vocabs['tgt'] = pyonmttok.build_vocab_from_tokens(vocabs_dict['tgt']) if 'src_feats' in vocabs_dict.keys(): - vocabs['src_feats'] = {} - for feat in vocabs_dict['src_feats'].keys(): - vocabs['src_feats'][feat] = \ - pyonmttok.build_vocab_from_tokens( - vocabs_dict['src_feats'][feat]) + vocabs['src_feats'] = [] + for feat_vocab in vocabs_dict['src_feats']: + vocabs['src_feats'].append( + pyonmttok.build_vocab_from_tokens(feat_vocab)) + if 'tgt_feats' in vocabs_dict.keys(): + vocabs['tgt_feats'] = [] + for feat_vocab in vocabs_dict['tgt_feats']: + vocabs['tgt_feats'].append( + pyonmttok.build_vocab_from_tokens(feat_vocab)) return vocabs diff --git a/onmt/inputters/text_corpus.py b/onmt/inputters/text_corpus.py index 442c6b074d..83936d85c5 100644 --- a/onmt/inputters/text_corpus.py +++ b/onmt/inputters/text_corpus.py @@ -3,8 +3,9 @@ from onmt.utils.logging import logger from onmt.constants import CorpusName, CorpusTask from onmt.transforms import TransformPipe -from onmt.inputters.text_utils import process +from onmt.inputters.text_utils import process, parse_features from contextlib import contextmanager +from onmt.inputters.example import Example @contextmanager @@ -38,13 +39,18 @@ def exfile_open(filename, *args, **kwargs): class ParallelCorpus(object): """A parallel corpus file pair that can be loaded to iterate.""" - def __init__(self, name, src, tgt, align=None, src_feats=None): + def __init__(self, name, src, tgt, align=None, + n_src_feats=0, n_tgt_feats=0, + src_feats_defaults=None, tgt_feats_defaults=None): """Initialize src & tgt side file path.""" self.id = name self.src = src self.tgt = tgt self.align = align - self.src_feats = src_feats + self.n_src_feats = n_src_feats + self.n_tgt_feats = n_tgt_feats + self.src_feats_defaults = src_feats_defaults + self.tgt_feats_defaults = tgt_feats_defaults def load(self, offset=0, stride=1): """ @@ -52,48 +58,47 @@ def load(self, offset=0, stride=1): `offset` and `stride` allow to iterate only on every `stride` example, starting from `offset`. """ - if self.src_feats: - features_names = [] - features_files = [] - for feat_name, feat_path in self.src_feats.items(): - features_names.append(feat_name) - features_files.append(open(feat_path, mode='rb')) - else: - features_files = [] with exfile_open(self.src, mode='rb') as fs,\ - exfile_open(self.tgt, mode='rb') as ft,\ - exfile_open(self.align, mode='rb') as fa: - for i, (sline, tline, align, *features) in \ - enumerate(zip(fs, ft, fa, *features_files)): + exfile_open(self.tgt, mode='rb') as ft,\ + exfile_open(self.align, mode='rb') as fa: + for i, (sline, tline, align) in \ + enumerate(zip(fs, ft, fa)): if (i % stride) == offset: sline = sline.decode('utf-8') + sline, sfeats = parse_features( + sline, + n_feats=self.n_src_feats, + defaults=self.src_feats_defaults) if tline is not None: tline = tline.decode('utf-8') - # 'src_original' and 'tgt_original' store the - # original line before tokenization. These - # fields are used later on in the feature - # transforms. - example = { - 'src': sline, - 'tgt': tline, - 'src_original': sline, - 'tgt_original': tline - } + tline, tfeats = parse_features( + tline, + n_feats=self.n_tgt_feats, + defaults=self.tgt_feats_defaults) + else: + tfeats = None + if align is not None: - example['align'] = align.decode('utf-8') - if features: - example['src_feats'] = dict() - for j, feat in enumerate(features): - example['src_feats'][features_names[j]] = \ - feat.decode("utf-8") + align = align.decode('utf-8') + + example = Example( + src=sline, + src_original=sline, + src_feats=sfeats, + tgt=tline, + tgt_original=tline, + tgt_feats=tfeats, + align=align) yield example - for f in features_files: - f.close() def __str__(self): cls_name = type(self).__name__ - return '{}({}, {}, align={}, src_feats={})'.format( - cls_name, self.src, self.tgt, self.align, self.src_feats) + return f'{cls_name}({self.id}, {self.src}, {self.tgt}, ' \ + f'align={self.align}, ' \ + f'n_src_feats={self.n_src_feats}, ' \ + f'n_tgt_feats={self.n_tgt_feats}, ' \ + f'src_feats_defaults="{self.src_feats_defaults}", ' \ + f'tgt_feats_defaults="{self.tgt_feats_defaults}")' def get_corpora(opts, task=CorpusTask.TRAIN): @@ -106,7 +111,10 @@ def get_corpora(opts, task=CorpusTask.TRAIN): corpus_dict["path_src"], corpus_dict["path_tgt"], corpus_dict["path_align"], - corpus_dict["src_feats"]) + n_src_feats=opts.n_src_feats, + n_tgt_feats=opts.n_tgt_feats, + src_feats_defaults=opts.src_feats_defaults, + tgt_feats_defaults=opts.tgt_feats_defaults) elif task == CorpusTask.VALID: if CorpusName.VALID in opts.data.keys(): corpora_dict[CorpusName.VALID] = ParallelCorpus( @@ -114,7 +122,10 @@ def get_corpora(opts, task=CorpusTask.TRAIN): opts.data[CorpusName.VALID]["path_src"], opts.data[CorpusName.VALID]["path_tgt"], opts.data[CorpusName.VALID]["path_align"], - opts.data[CorpusName.VALID]["src_feats"]) + n_src_feats=opts.n_src_feats, + n_tgt_feats=opts.n_tgt_feats, + src_feats_defaults=opts.src_feats_defaults, + tgt_feats_defaults=opts.tgt_feats_defaults) else: return None else: @@ -122,7 +133,10 @@ def get_corpora(opts, task=CorpusTask.TRAIN): CorpusName.INFER, opts.src, opts.tgt, - src_feats=opts.src_feats) + n_src_feats=opts.n_src_feats, + n_tgt_feats=opts.n_tgt_feats, + src_feats_defaults=opts.src_feats_defaults, + tgt_feats_defaults=opts.tgt_feats_defaults) return corpora_dict @@ -151,19 +165,7 @@ def __init__(self, corpus, transform, def _tokenize(self, stream): for example in stream: - example['src'] = example['src'].strip('\n').split() - example['src_original'] = \ - example['src_original'].strip("\n").split() - if example['tgt'] is not None: - example['tgt'] = example['tgt'].strip('\n').split() - example['tgt_original'] = \ - example['tgt_original'].strip("\n").split() - if 'align' in example: - example['align'] = example['align'].strip('\n').split() - if 'src_feats' in example: - for k in example['src_feats'].keys(): - example['src_feats'][k] = \ - example['src_feats'][k].strip('\n').split() + example.tokenize() yield example def _transform(self, stream): @@ -186,17 +188,15 @@ def _add_index(self, stream): for i, item in enumerate(stream): example = item[0] line_number = i * self.stride + self.offset - example['indices'] = line_number - if example['tgt'] is not None: - if (len(example['src']) == 0 or len(example['tgt']) == 0 or - ('align' in example and example['align'] == 0)): - # empty example: skip - empty_msg = f"Empty line in {self.cid}#{line_number}." - if self.skip_empty_level == 'error': - raise IOError(empty_msg) - elif self.skip_empty_level == 'warning': - logger.warning(empty_msg) - continue + example.add_index(line_number) + if example.is_empty(): + # empty example: skip + empty_msg = f"Empty line in {self.cid}#{line_number}." + if self.skip_empty_level == 'error': + raise IOError(empty_msg) + elif self.skip_empty_level == 'warning': + logger.warning(empty_msg) + continue yield item def __iter__(self): diff --git a/onmt/inputters/text_utils.py b/onmt/inputters/text_utils.py index 1728c44e39..9567f43e94 100644 --- a/onmt/inputters/text_utils.py +++ b/onmt/inputters/text_utils.py @@ -1,94 +1,39 @@ import torch -import pyonmttok -from onmt.constants import DefaultTokens, CorpusTask, ModelTask +from onmt.constants import DefaultTokens, CorpusTask from torch.nn.utils.rnn import pad_sequence from onmt.utils.logging import logger -from collections import Counter +from onmt.inputters.example import Example -def text_sort_key(ex): - """Sort using the number of tokens in the sequence.""" - if ex['tgt']: - return max(len(ex['src']['src_ids']), len(ex['tgt']['tgt_ids'])) - return len(ex['src']['src_ids']) - - -def clean_example(maybe_example): - maybe_example['src'] = {"src": ' '.join(maybe_example['src'])} - # Make features part of src like - # {'src': {'src': ..., 'feat1': ...., 'feat2': ....}} - if 'src_feats' in maybe_example: - for feat_name, feat_value in maybe_example['src_feats'].items(): - maybe_example['src'][feat_name] = ' '.join(feat_value) - del maybe_example['src_feats'] - if maybe_example['tgt'] is not None: - maybe_example['tgt'] = {'tgt': ' '.join(maybe_example['tgt'])} - if 'align' in maybe_example: - maybe_example['align'] = ' '.join(maybe_example['align']) - return maybe_example - - -def process(task, bucket, **kwargs): - """Returns valid transformed bucket from bucket.""" - _, transform, cid = bucket[0] - # We apply the same TransformPipe to all the bucket - processed_bucket = transform.batch_apply( - bucket, is_train=(task == CorpusTask.TRAIN), corpus_name=cid) - if processed_bucket: - for i in range(len(processed_bucket)): - (example, transform, cid) = processed_bucket[i] - example = clean_example(example) - processed_bucket[i] = example - # at this point an example looks like: - # {'src': {'src': ..., 'feat1': ...., 'feat2': ....}, - # 'tgt': {'tgt': ...}, - # 'src_original': ['tok1', ...'tokn'], - # 'tgt_original': ['tok1', ...'tokm'], - # 'indices' : seq in bucket - # 'align': ..., - # } - return processed_bucket - else: - return None - - -def numericalize(vocabs, example): +def parse_features(line, n_feats=0, defaults=None): """ + Parses text lines with features appended to each token. + Ex.: This│A│B is│A│A a│C│A test│A│B """ - numeric = example - numeric['src']['src_ids'] = [] - if vocabs['data_task'] == ModelTask.SEQ2SEQ: - src_text = example['src']['src'].split() - numeric['src']['src_ids'] = vocabs['src'](src_text) - if example['tgt'] is not None: - numeric['tgt']['tgt_ids'] = [] - tgt_text = example['tgt']['tgt'].split() - numeric['tgt']['tgt_ids'] = \ - vocabs['tgt']([DefaultTokens.BOS] + tgt_text - + [DefaultTokens.EOS]) - - elif vocabs['data_task'] == ModelTask.LANGUAGE_MODEL: - src_text = example['src']['src'].split() - numeric['src']['src_ids'] = \ - vocabs['src']([DefaultTokens.BOS] + src_text) - if example['tgt'] is not None: - numeric['tgt']['tgt_ids'] = [] - tgt_text = example['tgt']['tgt'].split() - numeric['tgt']['tgt_ids'] = \ - vocabs['tgt'](tgt_text + [DefaultTokens.EOS]) - else: - raise ValueError( - f"Something went wrong with task {vocabs['data_task']}" - ) - - if 'src_feats' in vocabs.keys(): - for featname in vocabs['src_feats'].keys(): - src_feat = example['src'][featname].split() - vf = vocabs['src_feats'][featname] - # we'll need to change this if we introduce tgt feat - numeric['src'][featname] = vf(src_feat) - - return numeric + text, feats = [], [[] for _ in range(n_feats)] + check, count = 0, 0 + for token in line.split(' '): + tok, *fts = token.strip().split("│") + check += len(fts) + count += 1 + if not fts and defaults is not None: + if isinstance(defaults, str): + defaults = defaults.split("│") + assert len(defaults) == n_feats, \ + "The number of provided defaults does not " \ + "match the number of feats" + fts = defaults + assert len(fts) == n_feats, \ + f"The number of fetures does not match the " \ + f"expected number of features. Found {len(fts)} " \ + f"features in the data but {n_feats} were expected" + text.append(tok) + for i in range(n_feats): + feats[i].append(fts[i]) + # Check if all tokens have features or none at all + assert check == 0 or check == count*n_feats, "Some features are missing" + feats = [" ".join(x) for x in feats] if n_feats > 0 else None + return " ".join(text), feats def parse_align_idx(align_pharaoh): @@ -108,20 +53,29 @@ def parse_align_idx(align_pharaoh): return flatten_align_idx +def process(task, bucket, **kwargs): + """Returns valid transformed bucket from bucket.""" + _, transform, cid = bucket[0] + # We apply the same TransformPipe to all the bucket + processed_bucket = transform.batch_apply( + bucket, is_train=(task == CorpusTask.TRAIN), corpus_name=cid) + if processed_bucket: + for i in range(len(processed_bucket)): + (example, transform, cid) = processed_bucket[i] + example.clean() + processed_bucket[i] = example + return processed_bucket + else: + return None + + def tensorify(vocabs, minibatch): """ - This function transforms a batch of example in tensors - Each example looks like - {'src': {'src': ..., 'feat1': ..., 'feat2': ..., 'src_ids': ...}, - 'tgt': {'tgt': ..., 'tgt_ids': ...}, - 'src_original': ['tok1', ...'tokn'], - 'tgt_original': ['tok1', ...'tokm'], - 'indices' : seq in bucket - 'align': ..., - } + This function transforms a batch of Examples in tensors + Returns Dict of batch Tensors - {'src': [seqlen, batchsize, n_feats], - 'tgt' : [seqlen, batchsize, n_feats=1], + {'src': [batchsize, seq_len, n_feats+1], + 'tgt' : [batchsize, seq_len, n_feats+1], 'indices' : [batchsize], 'srclen': [batchsize], 'tgtlen': [batchsize], @@ -129,76 +83,88 @@ def tensorify(vocabs, minibatch): } """ tensor_batch = {} - tbatchsrc = [torch.LongTensor(ex['src']['src_ids']) for ex in minibatch] + tbatchsrc = [torch.LongTensor(ex.src_ids) for ex in minibatch] padidx = vocabs['src'][DefaultTokens.PAD] tbatchsrc = pad_sequence(tbatchsrc, batch_first=True, padding_value=padidx) - if len(minibatch[0]['src'].keys()) > 2: - tbatchfs = [tbatchsrc] - for feat in minibatch[0]['src'].keys(): - if feat not in ['src', 'src_ids']: - tbatchfeat = [torch.LongTensor(ex['src'][feat]) - for ex in minibatch] - padidx = vocabs['src_feats'][feat][DefaultTokens.PAD] - tbatchfeat = pad_sequence(tbatchfeat, batch_first=True, - padding_value=padidx) - tbatchfs.append(tbatchfeat) - tbatchsrc = torch.stack(tbatchfs, dim=2) - else: - tbatchsrc = tbatchsrc[:, :, None] - # Need to add features in last dimensions + tbatchfs = [tbatchsrc] + if minibatch[0].src_feats is not None: + for feat_id in range(len(minibatch[0].src_feats_ids)): + tbatchfeat = [torch.LongTensor(ex.src_feats_ids[feat_id]) + for ex in minibatch] + padidx = vocabs['src_feats'][feat_id][DefaultTokens.PAD] + tbatchfeat = pad_sequence(tbatchfeat, batch_first=True, + padding_value=padidx) + tbatchfs.append(tbatchfeat) + tbatchsrc = torch.stack(tbatchfs, dim=2) tensor_batch['src'] = tbatchsrc - tensor_batch['indices'] = torch.LongTensor([ex['indices'] - for ex in minibatch]) - tensor_batch['srclen'] = torch.LongTensor([len(ex['src']['src_ids']) - for ex in minibatch]) - if minibatch[0]['tgt'] is not None: - tbatchtgt = [torch.LongTensor(ex['tgt']['tgt_ids']) - for ex in minibatch] + tensor_batch['indices'] = \ + torch.LongTensor([ex.index for ex in minibatch]) + tensor_batch['srclen'] = \ + torch.LongTensor([len(ex.src_ids) for ex in minibatch]) + + if minibatch[0].tgt is not None: + tbatchtgt = [torch.LongTensor(ex.tgt_ids) for ex in minibatch] padidx = vocabs['tgt'][DefaultTokens.PAD] tbatchtgt = pad_sequence(tbatchtgt, batch_first=True, padding_value=padidx) - tbatchtgt = tbatchtgt[:, :, None] - tbatchtgtlen = torch.LongTensor([len(ex['tgt']['tgt_ids']) - for ex in minibatch]) + tensor_batch['tgtlen'] = \ + torch.LongTensor([len(ex.tgt_ids) for ex in minibatch]) + + tbatchfs = [tbatchtgt] + if minibatch[0].tgt_feats is not None: + for feat_id in range(len(minibatch[0].tgt_feats_ids)): + tbatchfeat = [torch.LongTensor(ex.tgt_feats_ids[feat_id]) + for ex in minibatch] + padidx = vocabs['tgt_feats'][feat_id][DefaultTokens.PAD] + tbatchfeat = pad_sequence(tbatchfeat, batch_first=True, + padding_value=padidx) + tbatchfs.append(tbatchfeat) + tbatchtgt = torch.stack(tbatchfs, dim=2) tensor_batch['tgt'] = tbatchtgt - tensor_batch['tgtlen'] = tbatchtgtlen - if 'align' in minibatch[0].keys() and minibatch[0]['align'] is not None: + if minibatch[0].align is not None: sparse_idx = [] for i, ex in enumerate(minibatch): - for src, tgt in parse_align_idx(ex['align']): + for src, tgt in parse_align_idx(ex.align): sparse_idx.append([i, tgt + 1, src]) tbatchalign = torch.LongTensor(sparse_idx) tensor_batch['align'] = tbatchalign - if 'src_map' in minibatch[0].keys(): - src_vocab_size = max([max(ex['src_map']) for ex in minibatch]) + 1 + if minibatch[0].src_map is not None: + src_vocab_size = max([max(ex.src_map) for ex in minibatch]) + 1 src_map = torch.zeros(len(tensor_batch['srclen']), tbatchsrc.size(1), src_vocab_size) for i, ex in enumerate(minibatch): - for j, t in enumerate(ex['src_map']): + for j, t in enumerate(ex.src_map): src_map[i, j, t] = 1 tensor_batch['src_map'] = src_map - if 'alignment' in minibatch[0].keys(): + if minibatch[0].alignment is not None: alignment = torch.zeros(len(tensor_batch['srclen']), tbatchtgt.size(1)).long() for i, ex in enumerate(minibatch): - alignment[i, :len(ex['alignment'])] = \ - torch.LongTensor(ex['alignment']) + alignment[i, :len(ex.alignment)] = \ + torch.LongTensor(ex.alignment) tensor_batch['alignment'] = alignment - if 'src_ex_vocab' in minibatch[0].keys(): - tensor_batch['src_ex_vocab'] = [ex['src_ex_vocab'] - for ex in minibatch] + if minibatch[0].src_ex_vocab: + tensor_batch['src_ex_vocab'] = \ + [ex.src_ex_vocab for ex in minibatch] return tensor_batch +def text_sort_key(ex): + """Sort using the number of tokens in the sequence.""" + if ex.tgt is not None: + return max(len(ex.src_ids), len(ex.tgt_ids)) + return len(ex.src_ids) + + def textbatch_to_tensor(vocabs, batch, is_train=False): """ This is a hack to transform a simple batch of texts @@ -213,56 +179,11 @@ def textbatch_to_tensor(vocabs, batch, is_train=False): toks = ex else: toks = ex.strip("\n").split() - idxs = vocabs['src'](toks) - # Need to add features also in 'src' - numeric.append({'src': {'src': toks, - 'src_ids': idxs}, - 'srclen': len(toks), - 'tgt': None, - 'indices': i, - 'align': None}) + example = Example(toks, toks) + example.add_index(i) + example.numericalize(vocabs) + numeric.append(example) + numeric.sort(key=text_sort_key, reverse=True) infer_iter = [tensorify(vocabs, numeric)] return infer_iter - - -def _addcopykeys(vocabs, example): - """Create copy-vocab and numericalize with it. - In-place adds ``"src_map"`` to ``example``. That is the copy-vocab - numericalization of the tokenized ``example["src"]``. If ``example`` - has a ``"tgt"`` key, adds ``"alignment"`` to example. That is the - copy-vocab numericalization of the tokenized ``example["tgt"]``. The - alignment has an initial and final UNK token to match the BOS and EOS - tokens. - Args: - vocabs - example (dict): An example dictionary with a ``"src"`` key and - maybe a ``"tgt"`` key. (This argument changes in place!) - Returns: - ``example``, changed as described. - """ - src = example['src']['src'].split() - src_ex_vocab = pyonmttok.build_vocab_from_tokens( - Counter(src), - maximum_size=0, - minimum_frequency=1, - special_tokens=[DefaultTokens.UNK, - DefaultTokens.PAD, - DefaultTokens.BOS, - DefaultTokens.EOS]) - src_ex_vocab.default_id = src_ex_vocab[DefaultTokens.UNK] - # make a small vocab containing just the tokens in the source sequence - - # Map source tokens to indices in the dynamic dict. - example['src_map'] = src_ex_vocab(src) - example['src_ex_vocab'] = src_ex_vocab - - if example['tgt'] is not None: - if vocabs['data_task'] == ModelTask.SEQ2SEQ: - tgt = [DefaultTokens.UNK] + example['tgt']['tgt'].split() \ - + [DefaultTokens.UNK] - elif vocabs['data_task'] == ModelTask.LANGUAGE_MODEL: - tgt = example['tgt']['tgt'].split() \ - + [DefaultTokens.UNK] - example['alignment'] = src_ex_vocab(tgt) - return example diff --git a/onmt/model_builder.py b/onmt/model_builder.py index e261709781..52eb5ba59d 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -4,14 +4,13 @@ """ import re import torch -import torch.nn as nn from torch.nn.init import xavier_uniform_ import onmt.modules from onmt.encoders import str2enc from onmt.decoders import str2dec from onmt.inputters.inputter import dict_to_vocabs -from onmt.modules import Embeddings, CopyGenerator +from onmt.modules import Embeddings, Generator from onmt.utils.misc import use_gpu from onmt.utils.logging import logger from onmt.utils.parse import ArgumentParser @@ -31,14 +30,13 @@ def build_embeddings(opt, vocabs, for_encoder=True): emb_dim = opt.src_word_vec_size word_padding_idx = vocabs['src'][DefaultTokens.PAD] num_word_embeddings = len(vocabs['src']) - if 'src_feats' in vocabs.keys(): - feat_pad_indices = [vocabs['src_feats'][feat][DefaultTokens.PAD] - for feat in vocabs['src_feats'].keys()] - num_feat_embeddings = [len(vocabs['src_feats'][feat]) - for feat in vocabs['src_feats'].keys()] + if 'src_feats' in vocabs: + feat_pad_indices = [feat_vocab[DefaultTokens.PAD] + for feat_vocab in vocabs['src_feats']] + num_feat_embeddings = [len(feat_vocab) + for feat_vocab in vocabs['src_feats']] freeze_word_vecs = opt.freeze_word_vecs_enc else: - emb_dim = opt.tgt_word_vec_size word_padding_idx = vocabs['tgt'][DefaultTokens.PAD] num_word_embeddings = len(vocabs['tgt']) @@ -194,12 +192,11 @@ def use_embeddings_from_checkpoint(vocabs, model, generator, checkpoint): emb_name ][old_i] if side == 'tgt': - generator.state_dict()['weight'][i] = checkpoint[ - 'generator' - ]['weight'][old_i] - generator.state_dict()['bias'][i] = checkpoint[ - 'generator' - ]['bias'][old_i] + # TODO: check feats generators + generator.state_dict()['tgt_generator.weight'][i] = \ + checkpoint['generator']['tgt_generator.weight'][old_i] + generator.state_dict()['tgt_generator.bias'][i] = \ + checkpoint['generator']['tgt_generator.bias'][old_i] else: # Just for debugging purposes new_tokens.append(tok) @@ -207,7 +204,37 @@ def use_embeddings_from_checkpoint(vocabs, model, generator, checkpoint): # Remove old vocabulary associated embeddings del checkpoint['model'][emb_name] - del checkpoint['generator']['weight'], checkpoint['generator']['bias'] + del checkpoint['generator']['tgt_generator.weight'] + del checkpoint['generator']['tgt_generator.bias'] + + +def build_generator(model_opt, vocabs, decoder): + gen_sizes = [len(vocabs['tgt'])] + if 'tgt_feats' in vocabs: + gen_sizes += [len(feat_vocab) for feat_vocab in vocabs['tgt_feats']] + + if model_opt.share_decoder_embeddings: + hid_sizes = ([model_opt.dec_hid_size - + (model_opt.feat_vec_size * (len(gen_sizes) - 1))] + + [model_opt.feat_vec_size] * (len(gen_sizes) - 1)) + else: + hid_sizes = [model_opt.dec_hid_size] * len(gen_sizes) + + pad_idx = vocabs['tgt'][DefaultTokens.PAD] + generator = Generator(hid_sizes, gen_sizes, + shared=model_opt.share_decoder_embeddings, + copy_attn=model_opt.copy_attn, + pad_idx=pad_idx) + + if model_opt.share_decoder_embeddings: + if not model_opt.share_decoder_embeddings: + generator.generators[0].weight = \ + decoder.embeddings.word_lut.weight + else: + generator.generators[0].linear.weight = \ + decoder.embeddings.word_lut.weight + + return generator def build_base_model(model_opt, vocabs, gpu, checkpoint=None, gpu_id=None): @@ -244,18 +271,9 @@ def build_base_model(model_opt, vocabs, gpu, checkpoint=None, gpu_id=None): model = build_task_specific_model(model_opt, vocabs) - # Build Generator. - if not model_opt.copy_attn: - generator = nn.Linear(model_opt.dec_hid_size, - len(vocabs['tgt'])) - if model_opt.share_decoder_embeddings: - generator.weight = model.decoder.embeddings.word_lut.weight - else: - vocab_size = len(vocabs['tgt']) - pad_idx = vocabs['tgt'][DefaultTokens.PAD] - generator = CopyGenerator(model_opt.dec_hid_size, vocab_size, pad_idx) - if model_opt.share_decoder_embeddings: - generator.linear.weight = model.decoder.embeddings.word_lut.weight + # Build Generators + # Next token prediction and possibly target features generators + generator = build_generator(model_opt, vocabs, model.decoder) # Load the model states from checkpoint or initialize them. if checkpoint is None or model_opt.update_vocab: diff --git a/onmt/models/model.py b/onmt/models/model.py index 7246b2f9ee..91f6839aa6 100644 --- a/onmt/models/model.py +++ b/onmt/models/model.py @@ -66,7 +66,7 @@ def forward(self, src, tgt, src_len, bptt=False, with_align=False): * enc_out + enc_final_hs in the case of CNNs * src in the case of Transformer """ - dec_in = tgt[:, :-1, :] + dec_in = tgt[:, :-1, :1] enc_out, enc_final_hs, src_len = self.encoder(src, src_len) if not bptt: self.decoder.init_state(src, enc_out, enc_final_hs) diff --git a/onmt/modules/__init__.py b/onmt/modules/__init__.py index 44f3d51c9d..25f18578b9 100644 --- a/onmt/modules/__init__.py +++ b/onmt/modules/__init__.py @@ -3,6 +3,7 @@ from onmt.modules.gate import context_gate_factory, ContextGate from onmt.modules.global_attention import GlobalAttention from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention +from onmt.modules.generator import Generator from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss from onmt.modules.multi_headed_attn import MultiHeadedAttention from onmt.modules.embeddings import Embeddings, PositionalEncoding @@ -11,7 +12,7 @@ __all__ = ["Elementwise", "context_gate_factory", "ContextGate", "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", - "CopyGeneratorLoss", + "CopyGeneratorLoss", "Generator", "MultiHeadedAttention", "Embeddings", "PositionalEncoding", "WeightNormConv2d", "AverageAttention", "CopyGeneratorLMLossCompute"] diff --git a/onmt/modules/criterions.py b/onmt/modules/criterions.py new file mode 100644 index 0000000000..11fed50357 --- /dev/null +++ b/onmt/modules/criterions.py @@ -0,0 +1,39 @@ +import onmt +from onmt.constants import DefaultTokens +from onmt.modules.sparse_losses import SparsemaxLoss +import torch.nn as nn + + +class Criterions: + + def __init__(self, opt, vocabs): + tgt_vocab = vocabs['tgt'] + padding_idx = tgt_vocab[DefaultTokens.PAD] + unk_idx = tgt_vocab[DefaultTokens.UNK] + + if opt.copy_attn: + self.tgt_criterion = onmt.modules.CopyGeneratorLoss( + len(tgt_vocab), opt.copy_attn_force, + unk_index=unk_idx, ignore_index=padding_idx + ) + else: + if opt.generator_function == 'sparsemax': + self.tgt_criterion = SparsemaxLoss( + ignore_index=padding_idx, + reduction='sum') + else: + self.tgt_criterion = nn.CrossEntropyLoss( + ignore_index=padding_idx, + reduction='sum', + label_smoothing=opt.label_smoothing) + + # Add as many criterios as tgt features we have + self.feats_criterions = [] + if 'tgt_feats' in vocabs: + for feat_vocab in vocabs["tgt_feats"]: + padding_idx = feat_vocab[DefaultTokens.PAD] + self.feats_criterions.append( + nn.CrossEntropyLoss( + ignore_index=padding_idx, + reduction='sum') + ) diff --git a/onmt/modules/generator.py b/onmt/modules/generator.py new file mode 100644 index 0000000000..035740fdf8 --- /dev/null +++ b/onmt/modules/generator.py @@ -0,0 +1,40 @@ +""" Onmt NMT Model base class definition """ +import torch.nn as nn + +from onmt.modules.copy_generator import CopyGenerator + + +class Generator(nn.Module): + + def __init__(self, hid_sizes, gen_sizes, + shared=False, copy_attn=False, pad_idx=None): + super(Generator, self).__init__() + self.feats_generators = nn.ModuleList() + self.shared = shared + self.hid_sizes = hid_sizes + self.gen_sizes = gen_sizes + + def simple_generator(hid_size, gen_size): + return nn.Linear(hid_size, gen_size) + + # First generator: next token prediction + if copy_attn: + self.tgt_generator = \ + CopyGenerator(hid_sizes[0], gen_sizes[0], pad_idx) + else: + self.tgt_generator = \ + simple_generator(hid_sizes[0], gen_sizes[0]) + + # Additional generators: target features + for hid_size, gen_size in zip(hid_sizes[1:], gen_sizes[1:]): + self.feats_generators.append( + simple_generator(hid_size, gen_size)) + + def forward(self, dec_out, *args): + scores = self.tgt_generator(dec_out, *args) + + feats_scores = [] + for generator in self.feats_generators: + feats_scores.append(generator(dec_out)) + + return scores, feats_scores diff --git a/onmt/opts.py b/onmt/opts.py index 9e1c2b1838..d561e1b2ab 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -128,6 +128,20 @@ def _add_dynamic_corpus_opts(parser, build_vocab_only=False): help="Size of queues used in the build_vocab dump path.") +def _add_features_opts(parser): + group = parser.add_argument_group("Features") + group.add("-n_src_feats", "--n_src_feats", type=int, + default=0, help="Number of source feats.") + group.add("-n_tgt_feats", "--n_tgt_feats", type=int, + default=0, help="Number of target feats.") + group.add("-src_feats_defaults", "--src_feats_defaults", + help="Default features to apply in source in case " + "there are not annotated") + group.add("-tgt_feats_defaults", "--tgt_feats_defaults", + help="Default features to apply in target in case " + "there are not annotated") + + def _add_dynamic_vocab_opts(parser, build_vocab_only=False): """Options related to vocabulary and features. @@ -145,12 +159,7 @@ def _add_dynamic_vocab_opts(parser, build_vocab_only=False): group.add("-share_vocab", "--share_vocab", action="store_true", help="Share source and target vocabulary.") - group.add("-src_feats_vocab", "--src_feats_vocab", - help=("List of paths to save" - if build_vocab_only - else "List of paths to") - + " src features vocabulary files. " - "Files format: one or \t per line.") + _add_features_opts(parser) if not build_vocab_only: group.add("-src_vocab_size", "--src_vocab_size", @@ -448,7 +457,7 @@ def model_opts(parser): help="For FP16 training, the static loss scale to use. If not " "set, the loss scale is dynamically computed.") group.add('--apex_opt_level', '-apex_opt_level', type=str, default="", - choices=["", "O0", "O1", "O2", "O3"], + choices=["O0", "O1", "O2", "O3"], help="For FP16 training, the opt_level to use." "See https://nvidia.github.io/apex/amp.html#opt-levels.") @@ -791,9 +800,6 @@ def translate_opts(parser, dynamic=False): group.add('--src', '-src', required=True, help="Source sequence to decode (one line per " "sequence)") - group.add("-src_feats", "--src_feats", required=False, - help="Source sequence features (dict format). " - "Ex: {'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}") # noqa: E501 group.add('--tgt', '-tgt', help='True target sequence (optional)') group.add('--tgt_file_prefix', '-tgt_file_prefix', action='store_true', @@ -806,6 +812,9 @@ def translate_opts(parser, dynamic=False): group.add('--report_time', '-report_time', action='store_true', help="Report some translation time metrics") + # Adding options related to source and target features + _add_features_opts(parser) + # Adding options relate to decoding strategy _add_decoding_opts(parser) diff --git a/onmt/tests/test_events.py b/onmt/tests/test_events.py index cc3cd78048..62004ab5ed 100644 --- a/onmt/tests/test_events.py +++ b/onmt/tests/test_events.py @@ -49,5 +49,5 @@ def check_scalars(self, scalars, logdir): args = parser.parse_args() test_event = TestEvents() scalars = test_event.scalars[args.tensorboard_checks] - print("looking for scalars: ", scalars) + print("\nlooking for scalars: ", scalars) test_event.check_scalars(scalars, args.logdir) diff --git a/onmt/train_single.py b/onmt/train_single.py index 243a53e57b..4e3a4d9c68 100644 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -153,8 +153,15 @@ def main(opt, device_id): # Build model. model = build_model(model_opt, opt, vocabs, checkpoint) model.count_parameters(log=logger.info) - logger.info(' * src vocab size = %d' % len(vocabs['src'])) - logger.info(' * tgt vocab size = %d' % len(vocabs['tgt'])) + logger.info('* src vocab size = %d' % len(vocabs['src'])) + logger.info('* tgt vocab size = %d' % len(vocabs['tgt'])) + if "src_feats" in vocabs: + for i, feat_vocab in enumerate(vocabs["src_feats"]): + logger.info(f'* src_feat {i} vocab size = {len(feat_vocab)}') + if "tgt_feats" in vocabs: + for i, feat_vocab in enumerate(vocabs["tgt_feats"]): + logger.info(f'* tgt_feat {i} vocab size = {len(feat_vocab)}') + # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) diff --git a/onmt/trainer.py b/onmt/trainer.py index 516fffadd8..71ad5297e7 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -35,8 +35,8 @@ def build_trainer(opt, device_id, model, vocabs, optim, model_saver=None): used to save the model """ - train_loss = LossCompute.from_opts(opt, model, vocabs['tgt']) - valid_loss = LossCompute.from_opts(opt, model, vocabs['tgt'], train=False) + train_loss = LossCompute.from_opts(opt, model, vocabs) + valid_loss = LossCompute.from_opts(opt, model, vocabs, train=False) scoring_preparator = ScoringPreparator(vocabs=vocabs, opt=opt) validset_transforms = opt.data.get("valid", {}).get("transforms", None) diff --git a/onmt/transforms/bart.py b/onmt/transforms/bart.py index 02b9647d48..4cea480df7 100644 --- a/onmt/transforms/bart.py +++ b/onmt/transforms/bart.py @@ -407,8 +407,8 @@ def warm_up(self, vocabs): def apply(self, example, is_train=False, stats=None, **kwargs): """Apply BART noise to src side tokens.""" if is_train: - src = self.bart_noise.apply(example['src']) - example['src'] = src + src = self.bart_noise.apply(example.src) + example.src = src return example def _repr_args(self): diff --git a/onmt/transforms/features.py b/onmt/transforms/features.py index 3c2ad09f25..b1ff5822b8 100644 --- a/onmt/transforms/features.py +++ b/onmt/transforms/features.py @@ -1,42 +1,7 @@ -from onmt.utils.logging import logger from onmt.transforms import register_transform from .transform import Transform from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer import re -from collections import defaultdict - - -@register_transform(name='filterfeats') -class FilterFeatsTransform(Transform): - """Filter out examples with a mismatch between source and features.""" - - def __init__(self, opts): - super().__init__(opts) - - @classmethod - def add_options(cls, parser): - pass - - def _parse_opts(self): - pass - - def apply(self, example, is_train=False, stats=None, **kwargs): - """Return None if mismatch""" - - if 'src_feats' not in example: - # Do nothing - return example - - for feat_name, feat_values in example['src_feats'].items(): - if len(example['src']) != len(feat_values): - logger.warning( - f"Skipping example due to mismatch " - f"between source and feature {feat_name}") - return None - return example - - def _repr_args(self): - return '' @register_transform(name='inferfeats') @@ -63,34 +28,41 @@ def _parse_opts(self): self.reversible_tokenization = self.opts.reversible_tokenization self.prior_tokenization = self.opts.prior_tokenization - def apply(self, example, is_train=False, stats=None, **kwargs): - - if "src_feats" not in example: - # Do nothing - return example - + def _infer(self, example, side): if self.reversible_tokenization == "joiner": - original_src = example["src_original"] \ + original_text = getattr(example, f"{side}_original") \ if self.prior_tokenization else None word_to_subword_mapping = subword_map_by_joiner( - example["src"], original_subwords=original_src) + getattr(example, side), original_subwords=original_text) else: # Spacer - word_to_subword_mapping = subword_map_by_spacer(example["src"]) + word_to_subword_mapping = subword_map_by_spacer( + getattr(example, side)) - inferred_feats = defaultdict(list) - for subword, word_id in zip(example["src"], word_to_subword_mapping): - for feat_name, feat_values in example["src_feats"].items(): + new_feats = [[] for _ in range(len(getattr(example, f"{side}_feats")))] + for subword, word_id in zip( + getattr(example, side), word_to_subword_mapping): + for i, feat_values in enumerate(getattr(example, f"{side}_feats")): # Punctuation only if not re.sub(r'(\W)+', '', subword).strip() \ and not self.prior_tokenization: inferred_feat = "" else: inferred_feat = feat_values[word_id] + new_feats[i].append(inferred_feat) + setattr(example, f"{side}_feats", new_feats) + + # Security checks + for feat in getattr(example, f"{side}_feats"): + assert len(getattr(example, side)) == len(feat) - inferred_feats[feat_name].append(inferred_feat) + return example + + def apply(self, example, is_train=False, stats=None, **kwargs): + if example.src_feats is not None: + example = self._infer(example, "src") - for feat_name, feat_values in inferred_feats.items(): - example["src_feats"][feat_name] = inferred_feats[feat_name] + if example.tgt_feats is not None: + example = self._infer(example, "tgt") return example diff --git a/onmt/transforms/misc.py b/onmt/transforms/misc.py index 04dc309178..24fe92e7b5 100644 --- a/onmt/transforms/misc.py +++ b/onmt/transforms/misc.py @@ -40,8 +40,8 @@ def _parse_opts(self): def apply(self, example, is_train=False, stats=None, **kwargs): """Return None if too long else return as is.""" - if (len(example['src']) > self.src_seq_length or - len(example['tgt']) > self.tgt_seq_length - 2): + if (len(example.src) > self.src_seq_length or + len(example.tgt) > self.tgt_seq_length - 2): if stats is not None: stats.update(FilterTooLongStats()) return None @@ -124,10 +124,11 @@ def warm_up(self, vocabs=None): def _prepend(self, example, prefix): """Prepend `prefix` to `tokens`.""" for side, side_prefix in prefix.items(): - if example.get(side) is not None: - example[side] = side_prefix.split() + example[side] + if getattr(example, side) is not None: + setattr(example, side, + side_prefix.split() + getattr(example, side)) elif len(side_prefix) > 0: - example[side] = side_prefix.split() + setattr(example, side, side_prefix.split()) return example def apply(self, example, is_train=False, stats=None, **kwargs): @@ -218,10 +219,11 @@ def warm_up(self, vocabs=None): def _append(self, example, suffix): """Prepend `suffix` to `tokens`.""" for side, side_suffix in suffix.items(): - if example.get(side) is not None: - example[side] = example[side] + side_suffix.split() + if getattr(example, side) is not None: + setattr(example, side, + getattr(example, side) + side_suffix.split()) elif len(side_suffix) > 0: - example[side] = side_suffix.split() + setattr(example, side, side_suffix.split()) return example def apply(self, example, is_train=False, stats=None, **kwargs): diff --git a/onmt/transforms/normalize.py b/onmt/transforms/normalize.py index 2332389478..126e09935a 100644 --- a/onmt/transforms/normalize.py +++ b/onmt/transforms/normalize.py @@ -272,11 +272,11 @@ def warm_up(self, vocabs=None): def apply(self, example, is_train=False, stats=None, **kwargs): """Normalize source and target examples.""" - src_str = self.src_mpn.normalize(' '.join(example['src'])) - example['src'] = src_str.split() + src_str = self.src_mpn.normalize(' '.join(example.src)) + example.src = src_str.split() - if example['tgt'] is not None: - tgt_str = self.tgt_mpn.normalize(' '.join(example['tgt'])) - example['tgt'] = tgt_str.split() + if example.tgt is not None: + tgt_str = self.tgt_mpn.normalize(' '.join(example.tgt)) + example.tgt = tgt_str.split() return example diff --git a/onmt/transforms/sampling.py b/onmt/transforms/sampling.py index 232256e379..39c417dff2 100644 --- a/onmt/transforms/sampling.py +++ b/onmt/transforms/sampling.py @@ -101,10 +101,10 @@ def _switchout(self, tokens, vocab, stats=None): def apply(self, example, is_train=False, stats=None, **kwargs): """Apply switchout to both src and tgt side tokens.""" if is_train: - example['src'] = self._switchout( - example['src'], self.vocabs['src'].ids_to_tokens, stats) - example['tgt'] = self._switchout( - example['tgt'], self.vocabs['tgt'].ids_to_tokens, stats) + example.src = self._switchout( + example.src, self.vocabs['src'].ids_to_tokens, stats) + example.tgt = self._switchout( + example.tgt, self.vocabs['tgt'].ids_to_tokens, stats) return example def _repr_args(self): @@ -160,8 +160,8 @@ def _token_drop(self, tokens, stats=None): def apply(self, example, is_train=False, stats=None, **kwargs): """Apply token drop to both src and tgt side tokens.""" if is_train: - example['src'] = self._token_drop(example['src'], stats) - example['tgt'] = self._token_drop(example['tgt'], stats) + example.src = self._token_drop(example.src, stats) + example.tgt = self._token_drop(example.tgt, stats) return example def _repr_args(self): @@ -223,7 +223,7 @@ def _token_mask(self, tokens, stats=None): def apply(self, example, is_train=False, stats=None, **kwargs): """Apply word drop to both src and tgt side tokens.""" if is_train: - example['src'] = self._token_mask(example['src'], stats) + example.src = self._token_mask(example.src, stats) return example def _repr_args(self): diff --git a/onmt/transforms/tokenize.py b/onmt/transforms/tokenize.py index 590630a9e1..21e617efa0 100644 --- a/onmt/transforms/tokenize.py +++ b/onmt/transforms/tokenize.py @@ -193,20 +193,20 @@ def _detokenize(self, tokens, side="src"): def apply(self, example, is_train=False, stats=None, **kwargs): """Apply sentencepiece subword encode to src & tgt.""" - src_out = self._tokenize(example['src'], 'src', is_train) - if example['tgt'] is not None: - tgt_out = self._tokenize(example['tgt'], 'tgt', is_train) + src_out = self._tokenize(example.src, 'src', is_train) + if example.tgt is not None: + tgt_out = self._tokenize(example.tgt, 'tgt', is_train) if stats is not None: - n_words = len(example['src']) + len(example['tgt']) + n_words = len(example.src) + len(example.tgt) n_subwords = len(src_out) + len(tgt_out) stats.update(SubwordStats(n_subwords, n_words)) else: tgt_out = None if stats is not None: - n_words = len(example['src']) + n_words = len(example.src) n_subwords = len(src_out) stats.update(SubwordStats(n_subwords, n_words)) - example['src'], example['tgt'] = src_out, tgt_out + example.src, example.tgt = src_out, tgt_out return example def apply_reverse(self, translated): @@ -285,20 +285,20 @@ def _detokenize(self, tokens, side="src", is_train=False): def apply(self, example, is_train=False, stats=None, **kwargs): """Apply bpe subword encode to src & tgt.""" - src_out = self._tokenize(example['src'], 'src', is_train) - if example['tgt'] is not None: - tgt_out = self._tokenize(example['tgt'], 'tgt', is_train) + src_out = self._tokenize(example.src, 'src', is_train) + if example.tgt is not None: + tgt_out = self._tokenize(example.tgt, 'tgt', is_train) if stats is not None: - n_words = len(example['src']) + len(example['tgt']) + n_words = len(example.src) + len(example.tgt) n_subwords = len(src_out) + len(tgt_out) stats.update(SubwordStats(n_subwords, n_words)) else: tgt_out = None if stats is not None: - n_words = len(example['src']) + n_words = len(example.src) n_subwords = len(src_out) stats.update(SubwordStats(n_subwords, n_words)) - example['src'], example['tgt'] = src_out, tgt_out + example.src, example.tgt = src_out, tgt_out return example def apply_reverse(self, translated): @@ -459,20 +459,20 @@ def _detokenize(self, tokens, side='src', is_train=False): def apply(self, example, is_train=False, stats=None, **kwargs): """Apply OpenNMT Tokenizer to src & tgt.""" - src_out = self._tokenize(example['src'], 'src') - if example['tgt'] is not None: - tgt_out = self._tokenize(example['tgt'], 'tgt') + src_out = self._tokenize(example.src, 'src') + if example.tgt is not None: + tgt_out = self._tokenize(example.tgt, 'tgt') if stats is not None: - n_words = len(example['src']) + len(example['tgt']) + n_words = len(example.src) + len(example.tgt) n_subwords = len(src_out) + len(tgt_out) stats.update(SubwordStats(n_subwords, n_words)) else: tgt_out = None if stats is not None: - n_words = len(example['src']) + n_words = len(example.src) n_subwords = len(src_out) stats.update(SubwordStats(n_subwords, n_words)) - example['src'], example['tgt'] = src_out, tgt_out + example.src, example.tgt = src_out, tgt_out return example def apply_reverse(self, translated): diff --git a/onmt/transforms/uppercase.py b/onmt/transforms/uppercase.py index 22ca1eea38..6c47fc395b 100644 --- a/onmt/transforms/uppercase.py +++ b/onmt/transforms/uppercase.py @@ -36,16 +36,16 @@ def apply(self, example, is_train=False, stats=None, **kwargs): if random.random() > self.upper_corpus_ratio: return example - src_str = ' '.join(example['src']) + src_str = ' '.join(example.src) src_str = ''.join(c for c in unicodedata.normalize('NFD', src_str.upper()) if unicodedata.category(c) != 'Mn') - example['src'] = src_str.split() + example.src = src_str.split() - if example['tgt'] is not None: - tgt_str = ' '.join(example['tgt']) + if example.tgt is not None: + tgt_str = ' '.join(example.tgt) tgt_str = ''.join(c for c in unicodedata.normalize('NFD', tgt_str.upper()) if unicodedata.category(c) != 'Mn') - example['tgt'] = tgt_str.split() + example.tgt = tgt_str.split() return example diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py index 4f253ff2f3..8431d41297 100644 --- a/onmt/translate/beam_search.py +++ b/onmt/translate/beam_search.py @@ -58,11 +58,11 @@ class BeamSearchBase(DecodeStrategy): def __init__(self, beam_size, batch_size, pad, bos, eos, unk, start, n_best, global_scorer, min_length, max_length, return_attention, block_ngram_repeat, exclusion_tokens, - stepwise_penalty, ratio, ban_unk_token): + stepwise_penalty, ratio, ban_unk_token, n_tgt_feats): super(BeamSearchBase, self).__init__( pad, bos, eos, unk, start, batch_size, beam_size, global_scorer, min_length, block_ngram_repeat, exclusion_tokens, - return_attention, max_length, ban_unk_token) + return_attention, max_length, ban_unk_token, n_tgt_feats) # beam parameters self.beam_size = beam_size self.n_best = n_best @@ -117,7 +117,7 @@ def initialize_(self, enc_out, src_len, src_map, device, @property def current_predictions(self): - return self.alive_seq[:, -1] + return self.alive_seq[:, 0, -1] @property def current_backptr(self): @@ -152,6 +152,17 @@ def _pick(self, log_probs, out=None): topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1) return topk_scores, topk_ids + def _pick_features(self, log_probs): + if len(log_probs) > 0: + features_id = [] + for probs in log_probs: + _, topk_ids = probs.topk(1, dim=-1) + features_id.append(topk_ids) + features_id = torch.cat(features_id, dim=-1) + return features_id + else: + return None + def update_finished(self): # Penalize beams that finished. _B_old = self.topk_log_probs.shape[0] @@ -161,7 +172,7 @@ def update_finished(self): # it's faster to not move this back to the original device self.is_finished = self.is_finished.to('cpu') self.top_beam_finished |= self.is_finished[:, 0].eq(1) - predictions = self.alive_seq.view(_B_old, self.beam_size, step) + predictions = self.alive_seq.view(_B_old, self.beam_size, -1, step) attention = ( self.alive_attn.view( step - 1, _B_old, self.beam_size, self.alive_attn.size(-1)) @@ -178,9 +189,12 @@ def update_finished(self): self.best_scores[b] = s self.hypotheses[b].append(( self.topk_scores[i, j], - predictions[i, j, 1:], # Ignore start_token. + predictions[i, j, 0, 1:], # Ignore start_token. attention[:, i, j, :self.src_len[i]] - if attention is not None else None)) + if attention is not None else None, + [predictions[i, 0, 1+k, 1:] + for k in range(self.n_tgt_feats)] + if predictions.size(-2) > 1 else None)) # End condition is the top beam finished and we can return # n_best hypotheses. if self.ratio > 0: @@ -194,11 +208,12 @@ def update_finished(self): best_hyp = sorted( self.hypotheses[b], key=lambda x: x[0], reverse=True)[:self.n_best] - for n, (score, pred, attn) in enumerate(best_hyp): + for n, (score, pred, attn, feats) in enumerate(best_hyp): self.scores[b].append(score) self.predictions[b].append(pred) # ``(batch, n_best,)`` self.attention[b].append( attn if attn is not None else []) + self.features[b].append(feats if feats is not None else []) else: non_finished_batch.append(i) @@ -224,7 +239,7 @@ def remove_finished_batches(self, _B_new, _B_old, non_finished, self._batch_index = self._batch_index.index_select(0, non_finished) self.select_indices = self._batch_index.view(_B_new * self.beam_size) self.alive_seq = predictions.index_select(0, non_finished) \ - .view(-1, self.alive_seq.size(-1)) + .view(-1, self.alive_seq.size(-2), self.alive_seq.size(-1)) self.topk_scores = self.topk_scores.index_select(0, non_finished) self.topk_ids = self.topk_ids.index_select(0, non_finished) self.maybe_update_target_prefix(self.select_indices) @@ -241,7 +256,11 @@ def remove_finished_batches(self, _B_new, _B_old, non_finished, self._prev_penalty = self._prev_penalty.index_select( 0, non_finished) - def advance(self, log_probs, attn): + def advance(self, log_probs, attn, feats_log_probs): + # Pick up candidates for target features + # we take top 1 for feats + features_ids = self._pick_features(feats_log_probs) + vocab_size = log_probs.size(-1) # using integer division to get an integer _B without casting @@ -285,10 +304,18 @@ def advance(self, log_probs, attn): self.select_indices = self._batch_index.view(_B * self.beam_size) self.topk_ids.fmod_(vocab_size) # resolve true word ids + # Concatenate topk_ids for tokens and feats. + if features_ids is not None: + topk_ids = torch.cat(( + self.topk_ids.view(_B * self.beam_size, 1), + features_ids), dim=1) + else: + topk_ids = self.topk_ids.view(_B * self.beam_size, 1) + # Append last prediction. self.alive_seq = torch.cat( [self.alive_seq.index_select(0, self.select_indices), - self.topk_ids.view(_B * self.beam_size, 1)], -1) + topk_ids.unsqueeze(-1)], -1) self.maybe_update_forbidden_tokens() diff --git a/onmt/translate/decode_strategy.py b/onmt/translate/decode_strategy.py index 7eeb4bb365..586d4dcbb6 100644 --- a/onmt/translate/decode_strategy.py +++ b/onmt/translate/decode_strategy.py @@ -68,7 +68,7 @@ class DecodeStrategy(object): def __init__(self, pad, bos, eos, unk, start, batch_size, parallel_paths, global_scorer, min_length, block_ngram_repeat, exclusion_tokens, return_attention, max_length, - ban_unk_token): + ban_unk_token, n_tgt_feats): # magic indices self.pad = pad @@ -86,6 +86,7 @@ def __init__(self, pad, bos, eos, unk, start, batch_size, parallel_paths, self.scores = [[] for _ in range(batch_size)] self.attention = [[] for _ in range(batch_size)] self.hypotheses = [[] for _ in range(batch_size)] + self.features = [[] for _ in range(batch_size)] self.alive_attn = None @@ -102,6 +103,8 @@ def __init__(self, pad, bos, eos, unk, start, batch_size, parallel_paths, self.done = False + self.n_tgt_feats = n_tgt_feats + def get_device_from_enc_out(self, enc_out): if isinstance(enc_out, tuple): mb_device = enc_out[0].device @@ -140,8 +143,8 @@ def initialize(self, enc_out, src_len, src_map=None, device=None, device = torch.device('cpu') # Here we set the decoder to start with self.start (BOS or EOS) self.alive_seq = torch.full( - [self.batch_size * self.parallel_paths, 1], self.start, - dtype=torch.long, device=device) + [self.batch_size * self.parallel_paths, self.n_tgt_feats+1, 1], + self.start, dtype=torch.long, device=device) self.is_finished = torch.zeros( [self.batch_size, self.parallel_paths], dtype=torch.uint8, device=device) @@ -161,13 +164,15 @@ def initialize(self, enc_out, src_len, src_map=None, device=None, return None, enc_out, src_len, src_map def __len__(self): - return self.alive_seq.shape[1] + return self.alive_seq.shape[-1] def ensure_min_length(self, log_probs): + # TODO: check if need in target_features probs if len(self) <= self.min_length: log_probs[:, self.eos] = -1e20 def ensure_unk_removed(self, log_probs): + # TODO: check if need in target_features probs if self.ban_unk_token: log_probs[:, self.unk] = -1e20 @@ -296,7 +301,7 @@ def maybe_update_target_prefix(self, select_index): return self.target_prefix = self.target_prefix.index_select(0, select_index) - def advance(self, log_probs, attn): + def advance(self, log_probs, attn, feats_log_probs): """DecodeStrategy subclasses should override :func:`advance()`. Advance is used to update ``self.alive_seq``, ``self.is_finished``, diff --git a/onmt/translate/translation.py b/onmt/translate/translation.py index ba51241f36..07d2e8d6e0 100644 --- a/onmt/translate/translation.py +++ b/onmt/translate/translation.py @@ -43,19 +43,28 @@ def _build_source_tokens(self, src): break return tokens - def _build_target_tokens(self, src, src_raw, pred, attn): + def _build_target_tokens(self, src, src_raw, pred, attn, feats): tokens = [] - for tok in pred: + if feats is not None: + pred_iter = zip(pred, *feats) + else: + pred_iter = [(item,) for item in pred] + + for tok, *tok_feats in pred_iter: if tok < len(self.vocabs['tgt']): - tokens.append(self.vocabs['tgt'].lookup_index(tok)) + token = self.vocabs['tgt'].lookup_index(tok) else: vl = len(self.vocabs['tgt']) - tokens.append(self.vocabs['src'].lookup_index(tok - vl)) - if tokens[-1] == DefaultTokens.EOS: - tokens = tokens[:-1] + token = self.vocabs['src'].lookup_index(tok - vl) + if token == DefaultTokens.EOS: break + if len(tok_feats) > 0: + for feat, fv in zip(tok_feats, self.vocabs['tgt_feats']): + token += "│" + fv.lookup_index(feat) + tokens.append(token) if self.replace_unk and attn is not None and src is not None: + assert False, "TODO" for i in range(len(tokens)): if tokens[i] == DefaultTokens.UNK: _, max_index = attn[i][:len(src_raw)].max(0) @@ -72,8 +81,9 @@ def from_batch(self, translation_batch): len(translation_batch["predictions"])) batch_size = len(batch['srclen']) - preds, pred_score, attn, align, gold_score, indices = list(zip( + preds, feats, pred_score, attn, align, gold_score, indices = list(zip( *sorted(zip(translation_batch["predictions"], + translation_batch["features"], translation_batch["scores"], translation_batch["attention"], translation_batch["alignment"], @@ -104,7 +114,8 @@ def from_batch(self, translation_batch): src[b, :] if src is not None else None, src_raw, preds[b][n], - align[b][n] if align[b] is not None else attn[b][n]) + align[b][n] if align[b] is not None else attn[b][n], + feats[b][n] if len(feats[0]) > 0 else None) for n in range(self.n_best)] gold_sent = None if tgt is not None: diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 895099a2bb..6485b52c7d 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -129,7 +129,8 @@ def __init__( logger=None, seed=-1, with_score=False, - decoder_start_token=DefaultTokens.BOS + decoder_start_token=DefaultTokens.BOS, + n_tgt_feats=0 ): self.model = model self.vocabs = vocabs @@ -208,6 +209,7 @@ def __init__( set_random_seed(seed, self._use_cuda) self.with_score = with_score + self.n_tgt_feats = n_tgt_feats @classmethod def from_opt( @@ -273,7 +275,8 @@ def from_opt( logger=logger, seed=opt.seed, with_score=opt.with_score, - decoder_start_token=opt.decoder_start_token + decoder_start_token=opt.decoder_start_token, + n_tgt_feats=opt.n_tgt_feats ) def _log(self, msg): @@ -551,17 +554,21 @@ def _decode_and_generate( else: attn = None - scores = self.model.generator(dec_out.squeeze(1)) + scores, feats_scores = self.model.generator(dec_out.squeeze(1)) log_probs = F.log_softmax(scores.to(torch.float32), dim=-1) + feats_log_probs = [F.log_softmax(s.to(torch.float32), dim=-1) + for s in feats_scores] # returns [(batch_size x beam_size) , vocab ] when 1 step # or [batch_size, tgt_len, vocab ] when full sentence else: attn = dec_attn["copy"] - scores = self.model.generator( + scores, feats_scores = self.model.generator( dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map, ) + # TODO: allow target feats inference with the copy mechanism + assert not feats_scores # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(-1, len(batch['srclen']), @@ -581,7 +588,7 @@ def _decode_and_generate( log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [batch_size, tgt_len, vocab ] when full sentence - return log_probs, attn + return log_probs, attn, feats_log_probs def translate_batch(self, batch, attn_debug): """Translate a batch of sentences.""" @@ -603,16 +610,14 @@ def report_results( decode_strategy, ): results = { - "predictions": None, - "scores": None, - "attention": None, + "predictions": decode_strategy.predictions, + "scores": decode_strategy.scores, + "attention": decode_strategy.attention, + "features": decode_strategy.features, "batch": batch, "gold_score": gold_score, } - results["scores"] = decode_strategy.scores - results["predictions"] = decode_strategy.predictions - results["attention"] = decode_strategy.attention if self.report_align: results["alignment"] = self._align_forward( batch, decode_strategy.predictions @@ -726,6 +731,7 @@ def translate_batch(self, batch, attn_debug): stepwise_penalty=self.stepwise_penalty, ratio=self.ratio, ban_unk_token=self.ban_unk_token, + n_tgt_feats=self.n_tgt_feats, ) return self._translate_batch_with_strategy( batch, decode_strategy @@ -803,11 +809,9 @@ def _translate_batch_with_strategy( # (3) Begin decoding step by step: for step in range(decode_strategy.max_length): - # decoder_input = decode_strategy.current_predictions.view(1, -1, - # 1) decoder_input = decode_strategy.current_predictions.view(-1, 1, 1) - log_probs, attn = self._decode_and_generate( + log_probs, attn, feats_log_probs = self._decode_and_generate( decoder_input, enc_out, batch, @@ -817,7 +821,7 @@ def _translate_batch_with_strategy( batch_offset=decode_strategy.batch_offset, ) - decode_strategy.advance(log_probs, attn) + decode_strategy.advance(log_probs, attn, feats_log_probs) any_finished = decode_strategy.is_finished.any() if any_finished: decode_strategy.update_finished() @@ -861,7 +865,7 @@ def _score_target( tgt = batch['tgt'] tgt_in = tgt[:, :-1, :] - log_probs, attn = self._decode_and_generate( + log_probs, attn, feats_log_probs = self._decode_and_generate( tgt_in, enc_out, batch, diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index cc72aad538..76a398ecf9 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -11,6 +11,7 @@ from onmt.constants import ModelTask, DefaultTokens from onmt.modules.copy_generator import collapse_copy_scores from onmt.model_builder import load_test_model +from onmt.modules.criterions import Criterions try: import ctranslate2 except ImportError: @@ -36,13 +37,13 @@ class LossCompute(nn.Module): lm_prior_lambda (float): weight of LM model in loss lm_prior_tau (float): scaler for LM loss """ - def __init__(self, criterion, generator, + def __init__(self, criterions, generator, copy_attn=False, lambda_coverage=0.0, lambda_align=0.0, tgt_shift_index=1, vocab=None, lm_generator=None, lm_prior_lambda=None, lm_prior_tau=None, lm_prior_model=None): super(LossCompute, self).__init__() - self.criterion = criterion + self.criterions = criterions self.generator = generator self.lambda_coverage = lambda_coverage self.lambda_align = lambda_align @@ -55,7 +56,7 @@ def __init__(self, criterion, generator, self.lm_prior_model = lm_prior_model @classmethod - def from_opts(cls, opt, model, vocab, train=True): + def from_opts(cls, opt, model, vocabs, train=True): """ Returns a subclass which wraps around an nn.Module subclass (such as nn.NLLLoss) which defines the loss criterion. The LossCompute @@ -66,8 +67,7 @@ def from_opts(cls, opt, model, vocab, train=True): device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") - padding_idx = vocab[DefaultTokens.PAD] - unk_idx = vocab[DefaultTokens.UNK] + tgt_vocab = vocabs['tgt'] if opt.lambda_coverage != 0: assert opt.coverage_attn, "--coverage_attn needs to be set in " \ @@ -75,21 +75,7 @@ def from_opts(cls, opt, model, vocab, train=True): tgt_shift_idx = 1 if opt.model_task == ModelTask.SEQ2SEQ else 0 - if opt.copy_attn: - criterion = onmt.modules.CopyGeneratorLoss( - len(vocab), opt.copy_attn_force, - unk_index=unk_idx, ignore_index=padding_idx - ) - else: - if opt.generator_function == 'sparsemax': - criterion = SparsemaxLoss(ignore_index=padding_idx, - reduction='sum') - else: - criterion = nn.CrossEntropyLoss( - ignore_index=padding_idx, - reduction='sum', - label_smoothing=opt.label_smoothing - ) + criterions = Criterions(opt, vocabs) lm_prior_lambda = opt.lm_prior_lambda lm_prior_tau = opt.lm_prior_tau @@ -116,12 +102,12 @@ def from_opts(cls, opt, model, vocab, train=True): lm_generator = None lm_prior_model = None - compute = cls(criterion, model.generator, + compute = cls(criterions, model.generator, copy_attn=opt.copy_attn, lambda_coverage=opt.lambda_coverage, lambda_align=opt.lambda_align, tgt_shift_index=tgt_shift_idx, - vocab=vocab, lm_generator=lm_generator, + vocab=tgt_vocab, lm_generator=lm_generator, lm_prior_lambda=lm_prior_lambda, lm_prior_tau=lm_prior_tau, lm_prior_model=lm_prior_model) @@ -131,7 +117,7 @@ def from_opts(cls, opt, model, vocab, train=True): @property def padding_idx(self): - return self.criterion.ignore_index + return self.criterions.tgt_criterion.ignore_index def _compute_coverage_loss(self, std_attn, cov_attn, tgt): """compute coverage loss""" @@ -164,12 +150,12 @@ def _compute_copy_loss(self, batch, output, target, align, attns): Returns: A tuple with the loss and raw scores. """ - scores = self.generator(self._bottle(output), - self._bottle(attns['copy']), - batch['src_map']) - loss = self.criterion(scores, align, target).sum() + scores, feats_scores = self.generator(self._bottle(output), + self._bottle(attns['copy']), + batch['src_map']) + loss = self.criterions.tgt_criterion(scores, align, target).sum() - return loss, scores + return loss, scores, feats_scores def _compute_lm_loss_ct2(self, output, target): """ @@ -277,8 +263,8 @@ def forward(self, batch, output, attns, align = batch['alignment'][ :, trunc_range[0]:trunc_range[1] ].contiguous().view(-1) - loss, scores = self._compute_copy_loss(batch, output, flat_tgt, - align, attns) + loss, scores, feats_scores = \ + self._compute_copy_loss(batch, output, flat_tgt, align, attns) scores_data = collapse_copy_scores( self._unbottle(scores.clone(), len(batch['srclen'])), batch, self.vocab, None) @@ -287,7 +273,7 @@ def forward(self, batch, output, attns, # tgt[i] = align[i] + len(tgt_vocab) # for i such that tgt[i] == 0 and align[i] != 0 target_data = flat_tgt.clone() - unk = self.criterion.unk_index + unk = self.criterions.tgt_criterion.unk_index correct_mask = (target_data == unk) & (align != unk) offset_align = align[correct_mask] + len(self.vocab) target_data[correct_mask] += offset_align @@ -296,10 +282,12 @@ def forward(self, batch, output, attns, else: - scores = self.generator(self._bottle(output)) - if isinstance(self.criterion, SparsemaxLoss): + scores, feats_scores = self.generator(self._bottle(output)) + if isinstance(self.criterions.tgt_criterion, SparsemaxLoss): scores = LogSparsemax(scores.to(torch.float32), dim=-1) - loss = self.criterion(scores.to(torch.float32), flat_tgt) + + loss = self.criterions.tgt_criterion( + scores.to(torch.float32), flat_tgt) if self.lambda_align != 0.0: align_head = attns['align'] @@ -319,6 +307,15 @@ def forward(self, batch, output, attns, align_head=align_head, ref_align=ref_align) loss += align_loss + # Compute target features losses + assert len(feats_scores) == \ + len(self.criterions.feats_criterions) # Security check + for i, (feat_scores, criterion) in enumerate( + zip(feats_scores, self.criterions.feats_criterions)): + loss += criterion( + feat_scores.to(torch.float32), + target[:, :, i+1].contiguous().view(-1)) + if self.lambda_coverage != 0.0: coverage_loss = self._compute_coverage_loss( attns['std'], attns['coverage'], flat_tgt) @@ -332,6 +329,7 @@ def forward(self, batch, output, attns, lm_loss = self._compute_lm_loss(output, batch['tgt']) loss = loss + lm_loss * self.lm_prior_lambda + # TODO: pass feat scores to stats stats = self._stats(len(batch['srclen']), loss.sum().item(), scores, flat_tgt) diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 76c4bced4a..9d3f192d05 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -77,21 +77,11 @@ def _validate_data(cls, opt): corpus['weight'] = 1 # Check features - src_feats = corpus.get("src_feats", None) - if src_feats is not None: - for feature_name, feature_file in src_feats.items(): - cls._validate_file( - feature_file, info=f'{cname}/path_{feature_name}') + if opt.n_src_feats > 0 or opt.n_tgt_feats > 0: if 'inferfeats' not in corpus["transforms"]: raise ValueError( "'inferfeats' transform is required " - "when setting source features") - if 'filterfeats' not in corpus["transforms"]: - raise ValueError( - "'filterfeats' transform is required " - "when setting source features") - else: - corpus["src_feats"] = None + "when using source or target features") logger.info(f"Parsed {len(corpora)} corpora from -data.") opt.data = corpora @@ -129,19 +119,6 @@ def _get_all_transform_translate(cls, opt): @classmethod def _validate_vocab_opts(cls, opt, build_vocab_only=False): """Check options relate to vocab.""" - - for cname, corpus in opt.data.items(): - if cname != CorpusName.VALID and corpus["src_feats"] is not None: - assert opt.src_feats_vocab, \ - "-src_feats_vocab is required if using source features." - if isinstance(opt.src_feats_vocab, str): - import yaml - opt.src_feats_vocab = yaml.safe_load(opt.src_feats_vocab) - - for feature in corpus["src_feats"].keys(): - assert feature in opt.src_feats_vocab, \ - f"No vocab file set for feature {feature}" - if build_vocab_only: if not opt.share_vocab: assert opt.tgt_vocab, \ @@ -322,7 +299,7 @@ def validate_train_opts(cls, opt): @classmethod def validate_translate_opts(cls, opt): - opt.src_feats = eval(opt.src_feats) if opt.src_feats else {} + pass @classmethod def validate_translate_opts_dynamic(cls, opt):