Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support source and target features #2289

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions data/data_features/src-test-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she│C is│B a│A hard-working.│B
1 change: 0 additions & 1 deletion data/data_features/src-test.feat0

This file was deleted.

3 changes: 3 additions & 0 deletions data/data_features/src-train-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
however,│A according│A to│A the│A logs,│B she│A is│A a│A hard-working.│C
however,│A according│B to│C the│D logs,│E
she│C is│B a│A hard-working.│B
3 changes: 0 additions & 3 deletions data/data_features/src-train.feat0

This file was deleted.

1 change: 1 addition & 0 deletions data/data_features/src-val-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she│C is│B a│A hard-working.│B
1 change: 0 additions & 1 deletion data/data_features/src-val.feat0

This file was deleted.

1 change: 1 addition & 0 deletions data/data_features/tgt-test-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she│C is│B a│A hard-working.│B
1 change: 1 addition & 0 deletions data/data_features/tgt-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she is a hard-working.
3 changes: 3 additions & 0 deletions data/data_features/tgt-train-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
however,│A according│A to│A the│A logs,│B she│A is│A a│A hard-working.│C
however,│A according│B to│C the│D logs,│E
she│C is│B a│A hard-working.│B
1 change: 1 addition & 0 deletions data/data_features/tgt-val-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she│C is│B a│A hard-working.│B
20 changes: 20 additions & 0 deletions data/features_configs/source_and_target_features.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Corpus opts:
data:
corpus_1:
path_src: data/data_features/src-train-with-feats.txt
path_tgt: data/data_features/tgt-train-with-feats.txt
transforms: [inferfeats]
corpus_2:
path_src: data/data_features/src-train.txt
path_tgt: data/data_features/tgt-train.txt
transforms: [inferfeats]
valid:
path_src: data/data_features/src-val-with-feats.txt
path_tgt: data/data_features/tgt-val-with-feats.txt
transforms: [inferfeats]

# # Feats options
n_src_feats: 1
n_tgt_feats: 1
src_feats_defaults: "0"
tgt_feats_defaults: "1"
18 changes: 18 additions & 0 deletions data/features_configs/source_features_only.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Corpus opts:
data:
corpus_1:
path_src: data/data_features/src-train-with-feats.txt
path_tgt: data/data_features/tgt-train.txt
transforms: [inferfeats]
corpus_2:
path_src: data/data_features/src-train.txt
path_tgt: data/data_features/tgt-train.txt
transforms: [inferfeats]
valid:
path_src: data/data_features/src-val-with-feats.txt
path_tgt: data/data_features/tgt-val.txt
transforms: [inferfeats]

# # Feats options
n_src_feats: 1
src_feats_defaults: "0"
18 changes: 18 additions & 0 deletions data/features_configs/target_features_only.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Corpus opts:
data:
corpus_1:
path_src: data/data_features/src-train.txt
path_tgt: data/data_features/tgt-train-with-feats.txt
transforms: [inferfeats]
corpus_2:
path_src: data/data_features/src-train.txt
path_tgt: data/data_features/tgt-train.txt
transforms: [inferfeats]
valid:
path_src: data/data_features/src-val.txt
path_tgt: data/data_features/tgt-val-with-feats.txt
transforms: [inferfeats]

# # Feats options
n_tgt_feats: 1
tgt_feats_defaults: "0"
11 changes: 0 additions & 11 deletions data/features_data.yaml

This file was deleted.

88 changes: 56 additions & 32 deletions onmt/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from onmt.inputters.text_utils import process
from onmt.transforms import make_transforms, get_transforms_cls
from onmt.constants import CorpusName, CorpusTask
from collections import Counter, defaultdict
from collections import Counter
import multiprocessing as mp


Expand Down Expand Up @@ -42,19 +42,26 @@ def write_files_from_queues(sample_path, queues):

# Just for debugging purposes
# It appends features to subwords when dumping to file
def append_features_to_example(example, features):
ex_toks = example.split(' ')
feat_toks = features.split(' ')
toks = [f"{subword}│{feat}" for subword, feat in
zip(ex_toks, feat_toks)]
return " ".join(toks)
def append_features_to_text(text, features):
text_tok = text.split(' ')
feats_tok = [x.split(' ') for x in features]

pretty_toks = []
for tok, *feats in zip(text_tok, *feats_tok):
feats = '│'.join(feats)
if feats:
pretty_toks.append(f"{tok}│{feats}")
else:
pretty_toks.append(tok)
return " ".join(pretty_toks)


def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
"""Build vocab on (strided) subpart of the data."""
sub_counter_src = Counter()
sub_counter_tgt = Counter()
sub_counter_src_feats = defaultdict(Counter)
sub_counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
sub_counter_tgt_feats = [Counter() for _ in range(opts.n_tgt_feats)]
datasets_iterables = build_corpora_iters(
corpora, transforms, opts.data,
skip_empty_level=opts.skip_empty_level,
Expand All @@ -68,28 +75,35 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("blank")
continue
src_line, tgt_line = (maybe_example['src']['src'],
maybe_example['tgt']['tgt'])
src_line_pretty = src_line
for feat_name, feat_line in maybe_example["src"].items():
if feat_name not in ["src", "src_original"]:
sub_counter_src_feats[feat_name].update(
feat_line.split(' '))
if opts.dump_samples:
src_line_pretty = append_features_to_example(
src_line_pretty, feat_line)
src_line = maybe_example.src
tgt_line = maybe_example.tgt
src_feats_lines = maybe_example.src_feats
tgt_feats_lines = maybe_example.tgt_feats

sub_counter_src.update(src_line.split(' '))
sub_counter_tgt.update(tgt_line.split(' '))
for i in range(opts.n_src_feats):
sub_counter_src_feats[i].update(src_feats_lines[i].split(' '))
for i in range(opts.n_tgt_feats):
sub_counter_tgt_feats[i].update(tgt_feats_lines[i].split(' '))

if opts.dump_samples:
src_pretty_line = append_features_to_text(
src_line, src_feats_lines)
tgt_pretty_line = append_features_to_text(
tgt_line, tgt_feats_lines)
build_sub_vocab.queues[c_name][offset].put(
(i, src_line_pretty, tgt_line))
(i, src_pretty_line, tgt_pretty_line))
if n_sample > 0 and ((i+1) * stride + offset) >= n_sample:
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("break")
break
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("break")
return sub_counter_src, sub_counter_tgt, sub_counter_src_feats
return (sub_counter_src,
sub_counter_tgt,
sub_counter_src_feats,
sub_counter_tgt_feats)


def init_pool(queues):
Expand All @@ -113,7 +127,8 @@ def build_vocab(opts, transforms, n_sample=3):
corpora = get_corpora(opts, task=CorpusTask.TRAIN)
counter_src = Counter()
counter_tgt = Counter()
counter_src_feats = defaultdict(Counter)
counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
counter_tgt_feats = [Counter() for _ in range(opts.n_tgt_feats)]
from functools import partial
queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
for i in range(opts.num_threads)]
Expand All @@ -130,14 +145,18 @@ def build_vocab(opts, transforms, n_sample=3):
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
func, range(0, opts.num_threads)):
for (sub_counter_src, sub_counter_tgt,
sub_counter_src_feats, sub_counter_tgt_feats) \
in p.imap(func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
counter_src_feats.update(sub_counter_src_feats)
for i in range(opts.n_src_feats):
counter_src_feats[i].update(sub_counter_src_feats[i])
for i in range(opts.n_tgt_feats):
counter_tgt_feats[i].update(sub_counter_tgt_feats[i])
if opts.dump_samples:
write_process.join()
return counter_src, counter_tgt, counter_src_feats
return counter_src, counter_tgt, counter_src_feats, counter_tgt_feats


def build_vocab_main(opts):
Expand All @@ -163,13 +182,16 @@ def build_vocab_main(opts):
transforms = make_transforms(opts, transforms_cls, None)

logger.info(f"Counter vocab from {opts.n_sample} samples.")
src_counter, tgt_counter, src_feats_counter = build_vocab(
(src_counter, tgt_counter,
src_feats_counter, tgt_feats_counter) = build_vocab(
opts, transforms, n_sample=opts.n_sample)

logger.info(f"Counters src:{len(src_counter)}")
logger.info(f"Counters tgt:{len(tgt_counter)}")
for feat_name, feat_counter in src_feats_counter.items():
logger.info(f"Counters {feat_name}:{len(feat_counter)}")
logger.info(f"Counters src: {len(src_counter)}")
logger.info(f"Counters tgt: {len(tgt_counter)}")
for i, feat_counter in enumerate(src_feats_counter):
logger.info(f"Counters src feat_{i}: {len(feat_counter)}")
for i, feat_counter in enumerate(tgt_feats_counter):
logger.info(f"Counters tgt feat_{i}: {len(feat_counter)}")

def save_counter(counter, save_path):
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
Expand All @@ -186,8 +208,10 @@ def save_counter(counter, save_path):
save_counter(src_counter, opts.src_vocab)
save_counter(tgt_counter, opts.tgt_vocab)

for k, v in src_feats_counter.items():
save_counter(v, opts.src_feats_vocab[k])
for i, c in enumerate(src_feats_counter):
save_counter(c, f"{opts.src_vocab}_feat{i}")
for i, c in enumerate(tgt_feats_counter):
save_counter(c, f"{opts.tgt_vocab}_feat{i}")


def _get_parser():
Expand Down
29 changes: 23 additions & 6 deletions onmt/decoders/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,31 @@ def forward(self, hidden, attn=None, src_map=None):
by averaging distributions from models in the ensemble.
All models in the ensemble must share a target vocabulary.
"""
distributions = torch.stack(
[mg(h) if attn is None else mg(h, attn, src_map)
for h, mg in zip(hidden, self.model_generators)]
)

distributions, feats_distributions = [], []
n_feats = len(self.model_generators[0].feats_generators)
for h, mg in zip(hidden, self.model_generators):
scores, feats_scores = \
(mg(h) if attn is None else mg(h, attn, src_map))
distributions.append(scores)
feats_distributions.append(feats_scores)

distributions = torch.stack(distributions)

stacked_feats_distributions = []
for i in range(n_feats):
stacked_feats_distributions.append(
torch.stack([feat_distribution[i]
for feat_distribution in feats_distributions
for i in range(n_feats)]))

if self._raw_probs:
return torch.log(torch.exp(distributions).mean(0))
return (torch.log(torch.exp(distributions).mean(0)),
[torch.log(torch.exp(d).mean(0))
for d in stacked_feats_distributions])
else:
return distributions.mean(0)
return (distributions.mean(0),
[d.mean(0) for d in stacked_feats_distributions])


class EnsembleModel(NMTModel):
Expand Down
5 changes: 2 additions & 3 deletions onmt/inputters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
e.g., from a line of text to a sequence of vectors.
"""
from onmt.inputters.inputter import build_vocab, IterOnDevice
from onmt.inputters.text_utils import text_sort_key, process,\
numericalize, tensorify
from onmt.inputters.text_utils import text_sort_key, process, tensorify
from onmt.inputters.text_corpus import ParallelCorpus, ParallelCorpusIterator
from onmt.inputters.dynamic_iterator import MixingStrategy, SequentialMixer,\
WeightedMixer, DynamicDatasetIter


__all__ = ['IterOnDevice', 'build_vocab', 'text_sort_key',
'process', 'numericalize', 'tensorify',
'process', 'tensorify',
'ParallelCorpus', 'ParallelCorpusIterator',
'MixingStrategy', 'SequentialMixer', 'WeightedMixer',
'DynamicDatasetIter']
14 changes: 6 additions & 8 deletions onmt/inputters/dynamic_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from onmt.constants import CorpusTask
from onmt.inputters.text_corpus import get_corpora, build_corpora_iters
from onmt.inputters.text_utils import text_sort_key, process,\
numericalize, tensorify, _addcopykeys
tensorify
from onmt.transforms import make_transforms
from onmt.utils.logging import init_logger, logger
from onmt.utils.misc import RandomShuffler
Expand Down Expand Up @@ -209,8 +209,9 @@ def _tuple_to_json_with_tokIDs(self, tuple_bucket):
for example in tuple_bucket:
if example is not None:
if self.copy:
example = _addcopykeys(self.vocabs, example)
bucket.append(numericalize(self.vocabs, example))
example.addcopykeys(self.vocabs)
example.numericalize(self.vocabs)
bucket.append(example)
return bucket

def _bucketing(self):
Expand Down Expand Up @@ -252,11 +253,8 @@ def batch_size_fn(nbsents, maxlen):

minibatch, maxlen, size_so_far, seen = [], 0, 0, []
for ex in data:
if (
(ex['src']['src'] not in seen) or
(self.task != CorpusTask.TRAIN)
):
seen.append(ex['src']['src'])
if ((ex.src not in seen) or (self.task != CorpusTask.TRAIN)):
seen.append(ex.src)
minibatch.append(ex)
nbsents = len(minibatch)
maxlen = max(text_sort_key(ex), maxlen)
Expand Down
Loading