diff --git a/fairseq/tasks/multilingual_translation_sampled.py b/fairseq/tasks/multilingual_translation_sampled.py index 034fec01de..a4a349c377 100644 --- a/fairseq/tasks/multilingual_translation_sampled.py +++ b/fairseq/tasks/multilingual_translation_sampled.py @@ -226,6 +226,9 @@ def get_batch_iterator( epoch=1, data_buffer_size=0, disable_iterator_cache=False, + skip_remainder_batch=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, ): """ Get an iterator that yields batches of data from the given dataset. @@ -258,6 +261,11 @@ def get_batch_iterator( disable_iterator_cache (bool, optional): don't cache the EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) (default: False). + grouped_shuffling (bool, optional): group batches with each groups + containing num_shards batches and shuffle groups. Reduces difference + between sequence lengths among workers for batches sorted by length. + update_epoch_batch_itr (bool optional): if true then donot use the cached + batch iterator for the epoch Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split @@ -281,6 +289,9 @@ def get_batch_iterator( epoch=epoch, data_buffer_size=data_buffer_size, disable_iterator_cache=disable_iterator_cache, + skip_remainder_batch=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, ) self.dataset_to_epoch_iter[dataset] = batch_iter return batch_iter