diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py index b0a617424e..d3d7edd498 100644 --- a/fairseq/data/multilingual/sampled_multi_dataset.py +++ b/fairseq/data/multilingual/sampled_multi_dataset.py @@ -14,6 +14,7 @@ import numpy as np import torch + from fairseq.data import FairseqDataset, data_utils from fairseq.distributed import utils as distributed_utils @@ -253,7 +254,7 @@ def collater(self, samples, **extra_args): """Merge a list of samples to form a mini-batch.""" if len(samples) == 0: return None - if self.collate_format == "ordered_dict": + if self.collate_format == CollateFormat.ordered_dict: collect_samples = [[] for _ in range(len(self.datasets))] for (i, sample) in samples: collect_samples[i].append(sample) diff --git a/fairseq/data/multilingual/sampled_multilingual_dataset.py b/fairseq/data/multilingual/sampled_multilingual_dataset.py new file mode 100644 index 0000000000..95c7e128df --- /dev/null +++ b/fairseq/data/multilingual/sampled_multilingual_dataset.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np + +from fairseq.data import SampledMultiDataset + +logger = logging.getLogger(__name__) + + +class SampledMultilingualDataset(SampledMultiDataset): + """ + Ensures that each batch contains samples only from one dataset. + Otherwise functions like SampledMultiDataset. + """ + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + dataset_indices = [[] for _ in range(len(self.datasets))] + for i in indices: + ds_idx, _ = self._get_dataset_and_index(i) + dataset_indices[ds_idx].append(i) + + batches = [] + for ds_idx, indices in enumerate(dataset_indices): + cur_batches = super().batch_by_size( + np.array(indices, dtype=np.int64), + max_tokens, + max_sentences, + required_batch_size_multiple, + ) + logger.info( + f"Created {len(cur_batches)} batches for dataset {self.keys[ds_idx]}" + ) + batches += cur_batches + + return batches + + def filter_indices_by_size(self, indices, max_sizes): + if isinstance(max_sizes, dict): + max_sizes = next(iter(max_sizes.values())) + return super().filter_indices_by_size(indices, max_sizes) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index b081e6cabf..354e663311 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -892,7 +892,7 @@ class GenerationConfig(FairseqDataclass): default=None, metadata={ "help": "if set, uses attention feedback to compute and print alignment to source tokens " - "(valid options are: hard, soft, otherwise treated as hard alignment)", + "(valid options are: hard, hard_shifted, soft, otherwise treated as hard alignment)", "argparse_const": "hard", }, ) diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 7e5aef7067..f521f503e2 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -51,4 +51,4 @@ def ChoiceEnum(choices: List[str]): ) ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) -PRINT_ALIGNMENT_CHOICES = ChoiceEnum(["hard", "soft"]) +PRINT_ALIGNMENT_CHOICES = ChoiceEnum(["hard", "soft", "hard_shifted"]) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index e55c7ba1ad..6d31fb5839 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -13,6 +13,9 @@ import torch import torch.nn as nn import torch.nn.functional as F +from omegaconf import DictConfig +from torch import Tensor + from fairseq import utils from fairseq.data import Dictionary from fairseq.dataclass.utils import ( @@ -20,17 +23,15 @@ gen_parser_from_dataclass, ) from fairseq.models import FairseqDecoder, FairseqEncoder -from omegaconf import DictConfig -from torch import Tensor - logger = logging.getLogger(__name__) def check_type(module, expected_type): if hasattr(module, "unwrapped_module"): - assert isinstance(module.unwrapped_module, expected_type), \ - f"{type(module.unwrapped_module)} != {expected_type}" + assert isinstance( + module.unwrapped_module, expected_type + ), f"{type(module.unwrapped_module)} != {expected_type}" else: assert isinstance(module, expected_type), f"{type(module)} != {expected_type}" @@ -114,7 +115,9 @@ def load_state_dict( """ if model_cfg is None and args is not None: - logger.warn("using 'args' is deprecated, please update your code to use dataclass config") + logger.warn( + "using 'args' is deprecated, please update your code to use dataclass config" + ) model_cfg = convert_namespace_to_omegaconf(args).model self.upgrade_state_dict(state_dict) @@ -366,7 +369,7 @@ def __init__(self, *args, **kwargs): class FairseqMultiModel(BaseFairseqModel): """Base class for combining multiple encoder-decoder models.""" - def __init__(self, encoders, decoders): + def __init__(self, encoders, decoders, args): super().__init__() assert encoders.keys() == decoders.keys() self.keys = list(encoders.keys()) @@ -376,11 +379,15 @@ def __init__(self, encoders, decoders): self.models = nn.ModuleDict( { - key: FairseqEncoderDecoderModel(encoders[key], decoders[key]) + key: self.build_submodel(encoders[key], decoders[key], args) for key in self.keys } ) + @staticmethod + def build_submodel(encoder, decoder, args): + return FairseqEncoderDecoderModel(encoder, decoder) + @staticmethod def build_shared_embeddings( dicts: Dict[str, Dictionary], @@ -454,7 +461,9 @@ def load_state_dict( """ if model_cfg is None and args is not None: - logger.warn("using 'args' is deprecated, please update your code to use dataclass config") + logger.warn( + "using 'args' is deprecated, please update your code to use dataclass config" + ) model_cfg = convert_namespace_to_omegaconf(args).model self.upgrade_state_dict(state_dict) diff --git a/fairseq/models/multilingual_transformer.py b/fairseq/models/multilingual_transformer.py index e722b647ed..64becee17e 100644 --- a/fairseq/models/multilingual_transformer.py +++ b/fairseq/models/multilingual_transformer.py @@ -34,17 +34,22 @@ class MultilingualTransformerModel(FairseqMultiModel): Args: --share-encoder-embeddings: share encoder embeddings across all source languages --share-decoder-embeddings: share decoder embeddings across all target languages + --share-language-specific-embeddings: share encoder and decoder embeddings language-specifically --share-encoders: share all encoder params (incl. embeddings) across all source languages --share-decoders: share all decoder params (incl. embeddings) across all target languages """ - def __init__(self, encoders, decoders): - super().__init__(encoders, decoders) + def __init__(self, encoders, decoders, args): + super().__init__(encoders, decoders, args) @staticmethod - def add_args(parser): + def get_encoder_decoder_model_class(): + return TransformerModel + + @classmethod + def add_args(cls, parser): """Add model-specific arguments to the parser.""" - TransformerModel.add_args(parser) + cls.get_encoder_decoder_model_class().add_args(parser) parser.add_argument( "--share-encoder-embeddings", action="store_true", @@ -65,6 +70,11 @@ def add_args(parser): action="store_true", help="share decoders across languages", ) + parser.add_argument( + "--share-language-specific-embeddings", + action="store_true", + help="share encoder and decoder embeddings between the encoder and the decoder of the same language", + ) @classmethod def build_model(cls, args, task): @@ -101,7 +111,10 @@ def build_embedding(dictionary, embed_dim, path=None): # build shared embeddings (if applicable) shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None - if args.share_all_embeddings: + language_specific_embeddings = None + + # checks if the encoder and decoder embed dims and paths are equal (for shared embeddings) + def check_encoder_decoder_embed_args_equal(): if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" @@ -112,6 +125,9 @@ def build_embedding(dictionary, embed_dim, path=None): raise ValueError( "--share-all-embeddings not compatible with --decoder-embed-path" ) + + if args.share_all_embeddings: + check_encoder_decoder_embed_args_equal() shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( dicts=task.dicts, langs=task.langs, @@ -139,6 +155,22 @@ def build_embedding(dictionary, embed_dim, path=None): pretrained_embed_path=args.decoder_embed_path, ) + if args.share_language_specific_embeddings: + if args.share_encoder_embeddings or args.share_decoder_embeddings: + raise ValueError( + "--share-language-specific-embeddings is not compatible with " + "--share-encoder-embeddings or --share-decoder-embeddings." + ) + check_encoder_decoder_embed_args_equal() + language_specific_embeddings = { + lang: build_embedding( + task.dicts[lang], + args.encoder_embed_dim, + args.encoder_embed_path, + ) + for lang in task.langs + } + # encoders/decoders for each language lang_encoders, lang_decoders = {}, {} @@ -146,6 +178,8 @@ def get_encoder(lang): if lang not in lang_encoders: if shared_encoder_embed_tokens is not None: encoder_embed_tokens = shared_encoder_embed_tokens + elif language_specific_embeddings is not None: + encoder_embed_tokens = language_specific_embeddings[lang] else: encoder_embed_tokens = build_embedding( task.dicts[lang], @@ -161,6 +195,8 @@ def get_decoder(lang): if lang not in lang_decoders: if shared_decoder_embed_tokens is not None: decoder_embed_tokens = shared_decoder_embed_tokens + elif language_specific_embeddings is not None: + decoder_embed_tokens = language_specific_embeddings[lang] else: decoder_embed_tokens = build_embedding( task.dicts[lang], @@ -188,7 +224,7 @@ def get_decoder(lang): shared_decoder if shared_decoder is not None else get_decoder(tgt) ) - return MultilingualTransformerModel(encoders, decoders) + return cls(encoders, decoders, args) @classmethod def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): @@ -210,6 +246,9 @@ def base_multilingual_architecture(args): base_architecture(args) args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False) args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False) + args.share_language_specific_embeddings = getattr( + args, "share_language_specific_embeddings", False + ) args.share_encoders = getattr(args, "share_encoders", False) args.share_decoders = getattr(args, "share_decoders", False) diff --git a/fairseq/models/multilingual_transformer_align.py b/fairseq/models/multilingual_transformer_align.py new file mode 100644 index 0000000000..e99a32d403 --- /dev/null +++ b/fairseq/models/multilingual_transformer_align.py @@ -0,0 +1,28 @@ +from fairseq.models import register_model, register_model_architecture +from fairseq.models.multilingual_transformer import ( + MultilingualTransformerModel, + base_multilingual_architecture, +) +from fairseq.models.transformer_align import TransformerAlignModel, transformer_align + + +@register_model("multilingual_transformer_align") +class MultilingualTransformerAlignModel(MultilingualTransformerModel): + @staticmethod + def build_submodel(encoder, decoder, args): + return TransformerAlignModel(encoder, decoder, args) + + @staticmethod + def get_encoder_decoder_model_class(): + return TransformerAlignModel + + def forward_decoder(self, prev_output_tokens, **kwargs): + return self.models[self.keys[0]].forward_decoder(prev_output_tokens, **kwargs) + + +@register_model_architecture( + "multilingual_transformer_align", "multilingual_transformer_align" +) +def multilingual_transformer_aligned(args): + base_multilingual_architecture(args) + transformer_align(args) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 2e61140dd8..793a44461a 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -775,14 +775,14 @@ def forward_decoder( encoder_out = encoder_outs[i] # decode each model if self.has_incremental_states(): - decoder_out = model.decoder.forward( + decoder_out = model.forward_decoder( tokens, encoder_out=encoder_out, incremental_state=incremental_states[i], ) else: if hasattr(model, "decoder"): - decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) + decoder_out = model.forward_decoder(tokens, encoder_out=encoder_out) else: decoder_out = model.forward(tokens) @@ -883,6 +883,8 @@ def __init__( if print_alignment == "hard": self.extract_alignment = utils.extract_hard_alignment + elif print_alignment == "hard_shifted": + self.extract_alignment = utils.extract_hard_alignment_shifted elif print_alignment == "soft": self.extract_alignment = utils.extract_soft_alignment diff --git a/fairseq/tasks/multilingual_translation_sampled.py b/fairseq/tasks/multilingual_translation_sampled.py new file mode 100644 index 0000000000..a4a349c377 --- /dev/null +++ b/fairseq/tasks/multilingual_translation_sampled.py @@ -0,0 +1,318 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import time +from collections import OrderedDict + +from fairseq import utils +from fairseq.tasks import register_task +from fairseq.tasks.multilingual_translation import MultilingualTranslationTask +from fairseq.tasks.translation import load_langpair_dataset +from fairseq.utils import csv_str_list + +from ..data import FairseqDataset, data_utils, iterators +from ..data.multilingual.sampled_multi_dataset import CollateFormat +from ..data.multilingual.sampled_multilingual_dataset import SampledMultilingualDataset +from ..data.multilingual.sampling_method import SamplingMethod +from .translation_multi_simple_epoch import get_time_gap + +logger = logging.getLogger(__name__) + + +@register_task("multilingual_translation_sampled") +class SampledMultilingualTranslationTask(MultilingualTranslationTask): + @staticmethod + def add_args(parser): + MultilingualTranslationTask.add_args(parser) + SamplingMethod.add_arguments(parser) + parser.add_argument( + "--model-lang-pairs", + default=None, + type=csv_str_list, + help="language pairs that will be used for building the model. --lang-pairs are used by default.", + ) + parser.add_argument( + "--eval-lang-pairs", + default=None, + type=csv_str_list, + help="language pairs that will be used for evaluating the model. --lang-pairs are used by default.", + ) + parser.add_argument( + "--load-alignments", + action="store_true", + help="load the binarized alignments", + ) + + def __init__(self, args, dicts, training): + super().__init__(args, dicts, training) + self.model_lang_pairs = ( + self.lang_pairs if args.model_lang_pairs is None else args.model_lang_pairs + ) + self.eval_lang_pairs = ( + self.lang_pairs if args.eval_lang_pairs is None else args.eval_lang_pairs + ) + + def load_dataset(self, split, epoch=1, **kwargs): + """Load a dataset split.""" + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + def language_pair_dataset(lang_pair): + src, tgt = lang_pair.split("-") + langpair_dataset = load_langpair_dataset( + data_path, + split, + src, + self.dicts[src], + tgt, + self.dicts[tgt], + combine=True, + dataset_impl=self.args.dataset_impl, + upsample_primary=self.args.upsample_primary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + max_source_positions=self.args.max_source_positions, + max_target_positions=self.args.max_target_positions, + load_alignments=self.args.load_alignments, + ) + return self.alter_dataset_langtok( + langpair_dataset, + src_eos=self.dicts[src].eos(), + src_lang=src, + tgt_eos=self.dicts[tgt].eos(), + tgt_lang=tgt, + ) + + datasets = OrderedDict( + [ + (lang_pair, language_pair_dataset(lang_pair)) + for lang_pair in self.lang_pairs + ] + ) + + sampling_method = SamplingMethod(self.args, self).sampling_method_selector() + ratios = ( + None + if sampling_method is None + else sampling_method([len(dataset) for dataset in datasets.values()]) + ) + + self.datasets[split] = SampledMultilingualDataset( + datasets, + epoch=epoch, + sampling_ratios=ratios, + seed=self.args.seed, + collate_format=CollateFormat.ordered_dict, + eval_key=None + if self.training + else f"{self.args.source_lang}-{self.args.target_lang}", + ) + + # needs to be overridden to work with SampledMultiDataset + def max_positions(self): + """Return the max sentence length allowed by the task.""" + if len(self.datasets.values()) == 0: + return { + f"{self.args.source_lang}-{self.args.target_lang}": ( + self.args.max_source_positions, + self.args.max_target_positions, + ) + } + return OrderedDict( + [ + (key, (self.args.max_source_positions, self.args.max_target_positions)) + for split in self.datasets.keys() + for key in self.datasets[split].keys + ] + ) + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + # each sample contains one language-pair + assert len(sample) == 1 + model.train() + + lang_pair = next(iter(sample.keys())) + loss, sample_size, logging_output = criterion( + model.models[lang_pair], sample[lang_pair] + ) + if ignore_grad: + loss *= 0 + optimizer.backward(loss) + return loss, sample_size, logging_output + + # from translation_multi_simple_epoch + def create_batch_sampler_func( + self, + max_positions, + ignore_invalid_inputs, + max_tokens, + max_sentences, + required_batch_size_multiple=1, + seed=1, + ): + def construct_batch_sampler(dataset, epoch): + splits = [ + s for s, _ in self.datasets.items() if self.datasets[s] == dataset + ] + split = splits[0] if len(splits) > 0 else None + # NEW implementation + if epoch is not None: + # initialize the dataset with the correct starting epoch + dataset.set_epoch(epoch) + + # get indices ordered by example size + start_time = time.time() + logger.info(f"start batch sampler: mem usage: {data_utils.get_mem_usage()}") + + with data_utils.numpy_seed(seed): + indices = dataset.ordered_indices() + logger.info( + f"[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}" + ) + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + + # filter examples that are too large + if max_positions is not None: + my_time = time.time() + indices = self.filter_indices_by_size( + indices, dataset, max_positions, ignore_invalid_inputs + ) + logger.info( + f"[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}" + ) + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + + # create mini-batches with given size constraints + my_time = time.time() + batch_sampler = dataset.batch_by_size( + indices, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + + logger.info( + f"[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}" + ) + logger.info( + f"[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}" + ) + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + + return batch_sampler + + return construct_batch_sampler + + # from translation_multi_simple_epoch + # we need to override get_batch_iterator because we want to reset the epoch iterator each time + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + skip_remainder_batch=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, + ): + """ + Get an iterator that yields batches of data from the given dataset. + + Args: + dataset (~fairseq.data.FairseqDataset): dataset to batch + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + max_positions (optional): max sentence length supported by the + model (default: None). + ignore_invalid_inputs (bool, optional): don't raise Exception for + sentences that are too long (default: False). + required_batch_size_multiple (int, optional): require batch size to + be a multiple of N (default: 1). + seed (int, optional): seed for random number generator for + reproducibility (default: 1). + num_shards (int, optional): shard the data iterator into N + shards (default: 1). + shard_id (int, optional): which shard of the data iterator to + return (default: 0). + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 0). + data_buffer_size (int, optional): number of batches to + preload (default: 0). + disable_iterator_cache (bool, optional): don't cache the + EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) + (default: False). + grouped_shuffling (bool, optional): group batches with each groups + containing num_shards batches and shuffle groups. Reduces difference + between sequence lengths among workers for batches sorted by length. + update_epoch_batch_itr (bool optional): if true then donot use the cached + batch iterator for the epoch + Returns: + ~fairseq.iterators.EpochBatchIterator: a batched iterator over the + given dataset split + """ + # initialize the dataset with the correct starting epoch + assert isinstance(dataset, FairseqDataset) + if dataset in self.dataset_to_epoch_iter: + return self.dataset_to_epoch_iter[dataset] + if self.args.sampling_method == "RoundRobin": + batch_iter = super().get_batch_iterator( + dataset, + max_tokens=max_tokens, + max_sentences=max_sentences, + max_positions=max_positions, + ignore_invalid_inputs=ignore_invalid_inputs, + required_batch_size_multiple=required_batch_size_multiple, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + data_buffer_size=data_buffer_size, + disable_iterator_cache=disable_iterator_cache, + skip_remainder_batch=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, + ) + self.dataset_to_epoch_iter[dataset] = batch_iter + return batch_iter + + construct_batch_sampler = self.create_batch_sampler_func( + max_positions, + ignore_invalid_inputs, + max_tokens, + max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + seed=seed, + ) + + epoch_iter = iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=construct_batch_sampler, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + ) + return epoch_iter diff --git a/fairseq/utils.py b/fairseq/utils.py index 94114ce15c..0f3f3bfe86 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -680,6 +680,33 @@ def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): return alignment +def extract_hard_alignment_shifted(attn, src_sent, tgt_sent, pad, eos): + tgt_idxs = ((tgt_sent != pad) & (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1) + tgt_valid = ( + ((tgt_sent != pad)).nonzero(as_tuple=False).squeeze(dim=-1) + ) + src_invalid = ( + ((src_sent == pad) | (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1) + ) + src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) + tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [pad]) + alignment = [] + if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent): + attn_valid = attn[tgt_valid] + attn_valid[:, src_invalid] = float("-inf") + + attn_valid = attn_valid[1:, :] + _, src_indices = attn_valid.max(dim=1) + for tgt_idx, src_idx in zip(tgt_idxs, src_indices): + alignment.append( + ( + src_token_to_word[src_idx.item()] - 1, + tgt_token_to_word[tgt_idx.item()] - 1, + ) + ) + return alignment + + def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos): tgt_valid = ((tgt_sent != pad)).nonzero(as_tuple=False) src_valid = ((src_sent != pad)).nonzero(as_tuple=False).squeeze(dim=-1) diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 7e887e8864..6377c6eb52 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -298,7 +298,7 @@ def decode_fn(x): file=output_file, ) - if cfg.generation.print_alignment == "hard": + if cfg.generation.print_alignment in ["hard", "hard_shifted"]: print( "A-{}\t{}".format( sample_id,