From 59054f2b59f680c9ab327d5da934de658d4429eb Mon Sep 17 00:00:00 2001 From: Taido Purason Date: Fri, 26 Nov 2021 14:52:31 +0200 Subject: [PATCH] Making the sampled multilingual translation changes compliant with the linter. --- .../multilingual/sampled_multi_dataset.py | 1 + .../sampled_multilingual_dataset.py | 15 +-- fairseq/models/multilingual_transformer.py | 14 ++- .../tasks/multilingual_translation_sampled.py | 100 ++++++++++-------- 4 files changed, 77 insertions(+), 53 deletions(-) diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py index 4d1e707d6d..d3d7edd498 100644 --- a/fairseq/data/multilingual/sampled_multi_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_dataset.py @@ -14,6 +14,7 @@ import numpy as np import torch + from fairseq.data import FairseqDataset, data_utils from fairseq.distributed import utils as distributed_utils diff --git a/fairseq/data/multilingual/sampled_multilingual_dataset.py b/fairseq/data/multilingual/sampled_multilingual_dataset.py index 1acbddc0a0..95c7e128df 100644 --- a/fairseq/data/multilingual/sampled_multilingual_dataset.py +++ b/fairseq/data/multilingual/sampled_multilingual_dataset.py @@ -6,6 +6,7 @@ import logging import numpy as np + from fairseq.data import SampledMultiDataset logger = logging.getLogger(__name__) @@ -18,11 +19,11 @@ class SampledMultilingualDataset(SampledMultiDataset): """ def batch_by_size( - self, - indices, - max_tokens=None, - max_sentences=None, - required_batch_size_multiple=1, + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, ): dataset_indices = [[] for _ in range(len(self.datasets))] for i in indices: @@ -37,7 +38,9 @@ def batch_by_size( max_sentences, required_batch_size_multiple, ) - logger.info(f"Created {len(cur_batches)} batches for dataset {self.keys[ds_idx]}") + logger.info( + f"Created {len(cur_batches)} batches for dataset {self.keys[ds_idx]}" + ) batches += cur_batches return batches diff --git a/fairseq/models/multilingual_transformer.py b/fairseq/models/multilingual_transformer.py index 54d75f8f2b..6ba7d14664 100644 --- a/fairseq/models/multilingual_transformer.py +++ b/fairseq/models/multilingual_transformer.py @@ -69,7 +69,7 @@ def add_args(parser): parser.add_argument( "--share-language-specific-embeddings", action="store_true", - help="share encoder and decoder embeddings between the encoder and the decoder of the same language" + help="share encoder and decoder embeddings between the encoder and the decoder of the same language", ) @classmethod @@ -116,7 +116,7 @@ def check_encoder_decoder_embed_args_equal(): "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" ) if args.decoder_embed_path and ( - args.decoder_embed_path != args.encoder_embed_path + args.decoder_embed_path != args.encoder_embed_path ): raise ValueError( "--share-all-embeddings not compatible with --decoder-embed-path" @@ -159,7 +159,11 @@ def check_encoder_decoder_embed_args_equal(): ) check_encoder_decoder_embed_args_equal() language_specific_embeddings = { - lang: build_embedding(task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path) + lang: build_embedding( + task.dicts[lang], + args.encoder_embed_dim, + args.encoder_embed_path, + ) for lang in task.langs } @@ -238,7 +242,9 @@ def base_multilingual_architecture(args): base_architecture(args) args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False) args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False) - args.share_language_specific_embeddings = getattr(args, "share_language_specific_embeddings", False) + args.share_language_specific_embeddings = getattr( + args, "share_language_specific_embeddings", False + ) args.share_encoders = getattr(args, "share_encoders", False) args.share_decoders = getattr(args, "share_decoders", False) diff --git a/fairseq/tasks/multilingual_translation_sampled.py b/fairseq/tasks/multilingual_translation_sampled.py index fc8f84139c..150d5d6361 100644 --- a/fairseq/tasks/multilingual_translation_sampled.py +++ b/fairseq/tasks/multilingual_translation_sampled.py @@ -8,41 +8,47 @@ from collections import OrderedDict from fairseq import utils -from fairseq.tasks.translation import load_langpair_dataset - from fairseq.tasks import register_task from fairseq.tasks.multilingual_translation import MultilingualTranslationTask +from fairseq.tasks.translation import load_langpair_dataset from fairseq.utils import csv_str_list -from .translation_multi_simple_epoch import get_time_gap -from ..data import iterators, FairseqDataset, data_utils +from ..data import FairseqDataset, data_utils, iterators from ..data.multilingual.sampled_multi_dataset import CollateFormat from ..data.multilingual.sampled_multilingual_dataset import SampledMultilingualDataset from ..data.multilingual.sampling_method import SamplingMethod +from .translation_multi_simple_epoch import get_time_gap logger = logging.getLogger(__name__) @register_task("multilingual_translation_sampled") class SampledMultilingualTranslationTask(MultilingualTranslationTask): - @staticmethod def add_args(parser): MultilingualTranslationTask.add_args(parser) SamplingMethod.add_arguments(parser) - parser.add_argument('--model-lang-pairs', - default=None, - type=csv_str_list, - help='language pairs that will be used for building the model. --lang-pairs are used by default.') - parser.add_argument('--eval-lang-pairs', - default=None, - type=csv_str_list, - help='language pairs that will be used for evaluating the model. --lang-pairs are used by default.') + parser.add_argument( + "--model-lang-pairs", + default=None, + type=csv_str_list, + help="language pairs that will be used for building the model. --lang-pairs are used by default.", + ) + parser.add_argument( + "--eval-lang-pairs", + default=None, + type=csv_str_list, + help="language pairs that will be used for evaluating the model. --lang-pairs are used by default.", + ) def __init__(self, args, dicts, training): super().__init__(args, dicts, training) - self.model_lang_pairs = self.lang_pairs if args.model_lang_pairs is None else args.model_lang_pairs - self.eval_lang_pairs = self.lang_pairs if args.eval_lang_pairs is None else args.eval_lang_pairs + self.model_lang_pairs = ( + self.lang_pairs if args.model_lang_pairs is None else args.model_lang_pairs + ) + self.eval_lang_pairs = ( + self.lang_pairs if args.eval_lang_pairs is None else args.eval_lang_pairs + ) def load_dataset(self, split, epoch=1, **kwargs): """Load a dataset split.""" @@ -75,13 +81,19 @@ def language_pair_dataset(lang_pair): tgt_lang=tgt, ) - datasets = OrderedDict([ - (lang_pair, language_pair_dataset(lang_pair)) - for lang_pair in self.lang_pairs - ]) + datasets = OrderedDict( + [ + (lang_pair, language_pair_dataset(lang_pair)) + for lang_pair in self.lang_pairs + ] + ) sampling_method = SamplingMethod(self.args, self).sampling_method_selector() - ratios = None if sampling_method is None else sampling_method([len(dataset) for dataset in datasets.values()]) + ratios = ( + None + if sampling_method is None + else sampling_method([len(dataset) for dataset in datasets.values()]) + ) self.datasets[split] = SampledMultilingualDataset( datasets, @@ -89,7 +101,9 @@ def language_pair_dataset(lang_pair): sampling_ratios=ratios, seed=self.args.seed, collate_format=CollateFormat.ordered_dict, - eval_key=None if self.training else f"{self.args.source_lang}-{self.args.target_lang}" + eval_key=None + if self.training + else f"{self.args.source_lang}-{self.args.target_lang}", ) # needs to be overridden to work with SampledMultiDataset @@ -111,7 +125,7 @@ def max_positions(self): ) def train_step( - self, sample, model, criterion, optimizer, update_num, ignore_grad=False + self, sample, model, criterion, optimizer, update_num, ignore_grad=False ): # each sample contains one language-pair assert len(sample) == 1 @@ -128,13 +142,13 @@ def train_step( # from translation_multi_simple_epoch def create_batch_sampler_func( - self, - max_positions, - ignore_invalid_inputs, - max_tokens, - max_sentences, - required_batch_size_multiple=1, - seed=1, + self, + max_positions, + ignore_invalid_inputs, + max_tokens, + max_sentences, + required_batch_size_multiple=1, + seed=1, ): def construct_batch_sampler(dataset, epoch): splits = [ @@ -192,20 +206,20 @@ def construct_batch_sampler(dataset, epoch): # from translation_multi_simple_epoch # we need to override get_batch_iterator because we want to reset the epoch iterator each time def get_batch_iterator( - self, - dataset, - max_tokens=None, - max_sentences=None, - max_positions=None, - ignore_invalid_inputs=False, - required_batch_size_multiple=1, - seed=1, - num_shards=1, - shard_id=0, - num_workers=0, - epoch=1, - data_buffer_size=0, - disable_iterator_cache=False, + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, ): """ Get an iterator that yields batches of data from the given dataset.