Skip to content

Commit

Permalink
Making the sampled multilingual translation changes compliant with th…
Browse files Browse the repository at this point in the history
…e linter.
  • Loading branch information
taidopurason committed Nov 26, 2021
1 parent c60e1a5 commit 59054f2
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 53 deletions.
1 change: 1 addition & 0 deletions fairseq/data/multilingual/sampled_multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import torch

from fairseq.data import FairseqDataset, data_utils
from fairseq.distributed import utils as distributed_utils

Expand Down
15 changes: 9 additions & 6 deletions fairseq/data/multilingual/sampled_multilingual_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging

import numpy as np

from fairseq.data import SampledMultiDataset

logger = logging.getLogger(__name__)
Expand All @@ -18,11 +19,11 @@ class SampledMultilingualDataset(SampledMultiDataset):
"""

def batch_by_size(
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
dataset_indices = [[] for _ in range(len(self.datasets))]
for i in indices:
Expand All @@ -37,7 +38,9 @@ def batch_by_size(
max_sentences,
required_batch_size_multiple,
)
logger.info(f"Created {len(cur_batches)} batches for dataset {self.keys[ds_idx]}")
logger.info(
f"Created {len(cur_batches)} batches for dataset {self.keys[ds_idx]}"
)
batches += cur_batches

return batches
Expand Down
14 changes: 10 additions & 4 deletions fairseq/models/multilingual_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def add_args(parser):
parser.add_argument(
"--share-language-specific-embeddings",
action="store_true",
help="share encoder and decoder embeddings between the encoder and the decoder of the same language"
help="share encoder and decoder embeddings between the encoder and the decoder of the same language",
)

@classmethod
Expand Down Expand Up @@ -116,7 +116,7 @@ def check_encoder_decoder_embed_args_equal():
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path
args.decoder_embed_path != args.encoder_embed_path
):
raise ValueError(
"--share-all-embeddings not compatible with --decoder-embed-path"
Expand Down Expand Up @@ -159,7 +159,11 @@ def check_encoder_decoder_embed_args_equal():
)
check_encoder_decoder_embed_args_equal()
language_specific_embeddings = {
lang: build_embedding(task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path)
lang: build_embedding(
task.dicts[lang],
args.encoder_embed_dim,
args.encoder_embed_path,
)
for lang in task.langs
}

Expand Down Expand Up @@ -238,7 +242,9 @@ def base_multilingual_architecture(args):
base_architecture(args)
args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False)
args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False)
args.share_language_specific_embeddings = getattr(args, "share_language_specific_embeddings", False)
args.share_language_specific_embeddings = getattr(
args, "share_language_specific_embeddings", False
)
args.share_encoders = getattr(args, "share_encoders", False)
args.share_decoders = getattr(args, "share_decoders", False)

Expand Down
100 changes: 57 additions & 43 deletions fairseq/tasks/multilingual_translation_sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,47 @@
from collections import OrderedDict

from fairseq import utils
from fairseq.tasks.translation import load_langpair_dataset

from fairseq.tasks import register_task
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
from fairseq.tasks.translation import load_langpair_dataset
from fairseq.utils import csv_str_list

from .translation_multi_simple_epoch import get_time_gap
from ..data import iterators, FairseqDataset, data_utils
from ..data import FairseqDataset, data_utils, iterators
from ..data.multilingual.sampled_multi_dataset import CollateFormat
from ..data.multilingual.sampled_multilingual_dataset import SampledMultilingualDataset
from ..data.multilingual.sampling_method import SamplingMethod
from .translation_multi_simple_epoch import get_time_gap

logger = logging.getLogger(__name__)


@register_task("multilingual_translation_sampled")
class SampledMultilingualTranslationTask(MultilingualTranslationTask):

@staticmethod
def add_args(parser):
MultilingualTranslationTask.add_args(parser)
SamplingMethod.add_arguments(parser)
parser.add_argument('--model-lang-pairs',
default=None,
type=csv_str_list,
help='language pairs that will be used for building the model. --lang-pairs are used by default.')
parser.add_argument('--eval-lang-pairs',
default=None,
type=csv_str_list,
help='language pairs that will be used for evaluating the model. --lang-pairs are used by default.')
parser.add_argument(
"--model-lang-pairs",
default=None,
type=csv_str_list,
help="language pairs that will be used for building the model. --lang-pairs are used by default.",
)
parser.add_argument(
"--eval-lang-pairs",
default=None,
type=csv_str_list,
help="language pairs that will be used for evaluating the model. --lang-pairs are used by default.",
)

def __init__(self, args, dicts, training):
super().__init__(args, dicts, training)
self.model_lang_pairs = self.lang_pairs if args.model_lang_pairs is None else args.model_lang_pairs
self.eval_lang_pairs = self.lang_pairs if args.eval_lang_pairs is None else args.eval_lang_pairs
self.model_lang_pairs = (
self.lang_pairs if args.model_lang_pairs is None else args.model_lang_pairs
)
self.eval_lang_pairs = (
self.lang_pairs if args.eval_lang_pairs is None else args.eval_lang_pairs
)

def load_dataset(self, split, epoch=1, **kwargs):
"""Load a dataset split."""
Expand Down Expand Up @@ -75,21 +81,29 @@ def language_pair_dataset(lang_pair):
tgt_lang=tgt,
)

datasets = OrderedDict([
(lang_pair, language_pair_dataset(lang_pair))
for lang_pair in self.lang_pairs
])
datasets = OrderedDict(
[
(lang_pair, language_pair_dataset(lang_pair))
for lang_pair in self.lang_pairs
]
)

sampling_method = SamplingMethod(self.args, self).sampling_method_selector()
ratios = None if sampling_method is None else sampling_method([len(dataset) for dataset in datasets.values()])
ratios = (
None
if sampling_method is None
else sampling_method([len(dataset) for dataset in datasets.values()])
)

self.datasets[split] = SampledMultilingualDataset(
datasets,
epoch=epoch,
sampling_ratios=ratios,
seed=self.args.seed,
collate_format=CollateFormat.ordered_dict,
eval_key=None if self.training else f"{self.args.source_lang}-{self.args.target_lang}"
eval_key=None
if self.training
else f"{self.args.source_lang}-{self.args.target_lang}",
)

# needs to be overridden to work with SampledMultiDataset
Expand All @@ -111,7 +125,7 @@ def max_positions(self):
)

def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
# each sample contains one language-pair
assert len(sample) == 1
Expand All @@ -128,13 +142,13 @@ def train_step(

# from translation_multi_simple_epoch
def create_batch_sampler_func(
self,
max_positions,
ignore_invalid_inputs,
max_tokens,
max_sentences,
required_batch_size_multiple=1,
seed=1,
self,
max_positions,
ignore_invalid_inputs,
max_tokens,
max_sentences,
required_batch_size_multiple=1,
seed=1,
):
def construct_batch_sampler(dataset, epoch):
splits = [
Expand Down Expand Up @@ -192,20 +206,20 @@ def construct_batch_sampler(dataset, epoch):
# from translation_multi_simple_epoch
# we need to override get_batch_iterator because we want to reset the epoch iterator each time
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
):
"""
Get an iterator that yields batches of data from the given dataset.
Expand Down

0 comments on commit 59054f2

Please sign in to comment.