From cc247396becf3dc8ea7ffa2c293632e34550faea Mon Sep 17 00:00:00 2001 From: Felix Hieber Date: Tue, 21 Nov 2017 15:03:35 +0100 Subject: [PATCH] Remove RNN parameter packing, FusedRNN support; refactored core model components (#189) * Removed RNN parameter packing and FusedRNN support * Refactor embedding and output layers (#196) * Removed RNN parameter packing and FusedRNN support * Refactoring of sockeye model: source embed/target embed/output layers are now separate components in model * Make training and inference work. Remove lexical biasing code. --- CHANGELOG.md | 23 +- sockeye/arguments.py | 17 +- sockeye/constants.py | 1 + sockeye/decoder.py | 482 +++++++++----------------- sockeye/encoder.py | 218 +++++------- sockeye/inference.py | 80 +++-- sockeye/initializer.py | 12 +- sockeye/layers.py | 34 +- sockeye/model.py | 131 ++++--- sockeye/train.py | 79 ++--- sockeye/training.py | 38 +- sockeye/transformer.py | 8 - test/integration/test_seq_copy_int.py | 8 +- test/system/test_seq_copy_sys.py | 34 +- test/unit/test_arguments.py | 3 - test/unit/test_decoder.py | 4 +- test/unit/test_encoder.py | 5 +- 17 files changed, 497 insertions(+), 680 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c8e47185..bd4666b70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,19 +1,28 @@ # Changelog -All notable changes to this project will be documented in this file. +All notable changes to the project are documented in this file. -We use version numbers with three digits such as 1.0.0. +Version numbers are of the form `1.0.0`. Any version bump in the last digit is backwards-compatible, in that a model trained with the previous version can still be used for translation with the new version. -Any bump in the second digit indicates potential backwards incompatibilities, e.g. due to changing the architecture or -simply modifying weight names. +Any bump in the second digit indicates a backwards-incompatible change, +e.g. due to changing the architecture or simply modifying model parameter names. Note that Sockeye has checks in place to not translate with an old model that was trained with an incompatible version. -For each item we will potentially have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. + ## [1.13.0] ### Fixed + - Transformer models do not ignore `--num-embed` anymore as they did silently before. + As a result there is an error thrown if `--num-embed` != `--transformer-model-size`. - Fixed the attention in upper layers (`--rnn-attention-in-upper-layers`), which was previously not passed correctly - to the decoder. + to the decoder. +### Removed + - Removed RNN parameter (un-)packing and support for FusedRNNCells (removed `--use-fused-rnns` flag). + These were not used, not correctly initialized, and performed worse than regular RNN cells. Moreover, + they made the code much more complex. RNN models trained with previous versions are no longer compatible. +- Removed the lexical biasing functionality (Arthur ETAL'16) (removed arguments `--lexical-bias` + and `--learn-lexical-bias`). ## [1.12.2] ### Changed @@ -120,7 +129,7 @@ For each item we will potentially have subsections for: _Added_, _Changed_, _Rem - Convolutional decoder. - Weight normalization (for CNN only so far). - Learned positional embeddings for the transformer. - + ### Changed - `--attention-*` CLI params renamed to `--rnn-attention-*`. - `--transformer-no-positional-encodings` generalized to `--transformer-positional-embedding-type`. diff --git a/sockeye/arguments.py b/sockeye/arguments.py index 270ce47e4..e61fad876 100644 --- a/sockeye/arguments.py +++ b/sockeye/arguments.py @@ -489,15 +489,6 @@ def add_model_parameters(params): type=int, default=None, help='Number of heads for Multi-head dot attention. Default: %(default)s.') - model_params.add_argument('--lexical-bias', - default=None, - type=str, - help="Specify probabilistic lexicon (fast_align format) for lexical biasing (Arthur " - "ETAL'16). Set smoothing value epsilon by appending :") - model_params.add_argument('--learn-lexical-bias', - action='store_true', - help='Adjust lexicon probabilities during training. Default: %(default)s') - model_params.add_argument('--weight-tying', action='store_true', help='Turn on weight tying (see arxiv.org/abs/1608.05859). ' @@ -690,7 +681,8 @@ def add_training_args(params): default=C.EMBED_INIT_DEFAULT, choices=C.EMBED_INIT_TYPES, help='Type of embedding matrix weight initialization. If normal, initializes embedding ' - 'weights using a normal distribution with std=vocab_size. Default: %(default)s.') + 'weights using a normal distribution with std=1/srqt(vocab_size). ' + 'Default: %(default)s.') train_params.add_argument('--initial-learning-rate', type=float, default=0.0003, @@ -750,11 +742,6 @@ def add_training_args(params): "reduced due to the value of --learning-rate-reduce-num-not-improved. " "Default: %(default)s.") - train_params.add_argument('--use-fused-rnn', - default=False, - action="store_true", - help='Use FusedRNNCell in encoder (requires GPU device). Speeds up training.') - train_params.add_argument('--rnn-forget-bias', default=0.0, type=float, diff --git a/sockeye/constants.py b/sockeye/constants.py index 5c352f73b..10e807066 100644 --- a/sockeye/constants.py +++ b/sockeye/constants.py @@ -40,6 +40,7 @@ TRANSFORMER_ENCODER_PREFIX = ENCODER_PREFIX + "transformer_" CNN_ENCODER_PREFIX = ENCODER_PREFIX + "cnn_" CHAR_SEQ_ENCODER_PREFIX = ENCODER_PREFIX + "char_" +DEFAULT_OUTPUT_LAYER_PREFIX = "target_output_" # embedding prefixes SOURCE_EMBEDDING_PREFIX = "source_embed_" diff --git a/sockeye/decoder.py b/sockeye/decoder.py index 22e39200c..42359ead4 100644 --- a/sockeye/decoder.py +++ b/sockeye/decoder.py @@ -36,15 +36,13 @@ logger = logging.getLogger(__name__) -def get_decoder(config: Config, - lexicon: Optional[lexicons.Lexicon] = None, - embed_weight: Optional[mx.sym.Symbol] = None) -> 'Decoder': +def get_decoder(config: Config) -> 'Decoder': if isinstance(config, RecurrentDecoderConfig): - return RecurrentDecoder(config=config, lexicon=lexicon, embed_weight=embed_weight, prefix=C.RNN_DECODER_PREFIX) + return RecurrentDecoder(config=config, prefix=C.RNN_DECODER_PREFIX) elif isinstance(config, ConvolutionalDecoderConfig): - return ConvolutionalDecoder(config=config, embed_weight=embed_weight, prefix=C.CNN_DECODER_PREFIX) + return ConvolutionalDecoder(config=config, prefix=C.CNN_DECODER_PREFIX) elif isinstance(config, transformer.TransformerConfig): - return TransformerDecoder(config=config, embed_weight=embed_weight, prefix=C.TRANSFORMER_DECODER_PREFIX) + return TransformerDecoder(config=config, prefix=C.TRANSFORMER_DECODER_PREFIX) else: raise ValueError("Unsupported decoder configuration") @@ -58,54 +56,49 @@ class Decoder(ABC): For the inference module to be able to keep track of decoder's states a decoder provides methods to return initial states (init_states), state variables and their shapes. """ - def __init__(self) -> None: - # Tracked to find params for logit computation - self.output_layer = None # type: layers.OutputLayer @abstractmethod def decode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, source_encoded_max_length: int, - target: mx.sym.Symbol, - target_lengths: mx.sym.Symbol, - target_max_length: int, - source_lexicon: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: + target_embed: mx.sym.Symbol, + target_embed_lengths: mx.sym.Symbol, + target_embed_max_length: int) -> mx.sym.Symbol: """ - Decodes given a known target sequence and returns logits - with batch size and target length dimensions collapsed. - Used for training. + Decodes a sequence of embedded target words and returns sequence of last decoder + representations for each time step. :param source_encoded: Encoded source: (source_encoded_max_length, batch_size, encoder_depth). :param source_encoded_lengths: Lengths of encoded source sequences. Shape: (batch_size,). :param source_encoded_max_length: Size of encoder time dimension. - :param target: Target sequence. Shape: (batch_size, target_max_length). - :param target_lengths: Lengths of target sequences. Shape: (batch_size,). - :param target_max_length: Size of target sequence dimension. - :param source_lexicon: Lexical biases for current sentence. - Shape: (batch_size, target_vocab_size, source_seq_len) - :return: Logits of next-word predictions for target sequence. - Shape: (batch_size * target_max_length, target_vocab_size) + :param target_embed: Embedded target sequence. Shape: (batch_size, target_embed_max_length, target_num_embed). + :param target_embed_lengths: Lengths of embedded target sequences. Shape: (batch_size,). + :param target_embed_max_length: Dimension of the embedded target sequence. + :return: Decoder data. Shape: (batch_size, target_embed_max_length, decoder_depth). """ pass @abstractmethod def decode_step(self, - target: mx.sym.Symbol, - target_max_length: int, + target_embed: mx.sym.Symbol, + target_embed_lengths: mx.sym.Symbol, + target_embed_max_length: int, + target_embed_prev: mx.sym.Symbol, source_encoded_max_length: int, - *states: mx.sym.Symbol) \ - -> Tuple[mx.sym.Symbol, mx.sym.Symbol, mx.sym.Symbol, List[mx.sym.Symbol]]: + *states: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, List[mx.sym.Symbol]]: """ - Decodes a single time step given the previous word ids in target and previous decoder states. - Returns logit inputs, logits, attention probabilities, and next decoder states. + Decodes a single time step given the embedded target sequence and previous decoder states. + Returns decoder representation for the next prediction, attention probabilities, and next decoder states. Implementations can maintain an arbitrary number of states. - :param target: Previous target word ids. Shape: (batch_size, target_max_length). - :param target_max_length: Size of time dimension in prev_word_ids. + :param target_embed: Embedded target sequence. Shape: (batch_size, target_embed_max_length, target_num_embed). + :param target_embed_lengths: Lengths of embedded target sequences. Shape: (batch_size,). + :param target_embed_max_length: Size of embedded target sequence dimension. + :param target_embed_prev: Previous target word embedding. Shape: (batch_size, target_num_embed). :param source_encoded_max_length: Length of encoded source time dimension. :param states: Arbitrary list of decoder states. - :return: logit inputs, logits, attention probabilities, next decoder states. + :return: logit inputs, attention probabilities, next decoder states. """ pass @@ -116,6 +109,13 @@ def reset(self): """ pass + @abstractmethod + def get_num_hidden(self) -> int: + """ + :return: The representation size of this decoder. + """ + pass + @abstractmethod def init_states(self, source_encoded: mx.sym.Symbol, @@ -157,13 +157,6 @@ def state_shapes(self, """ pass - def get_rnn_cells(self) -> List[mx.rnn.BaseRNNCell]: - """ - Returns a list of RNNCells used by this decoder. - - """ - return [] - def get_max_seq_len(self) -> Optional[int]: """ :return: The maximum length supported by the decoder if such a restriction exists. @@ -181,13 +174,11 @@ class TransformerDecoder(Decoder): time-step ensures correct self-attention scores and is updated with every step. :param config: Transformer configuration. - :param embed_weight: Optionally use an existing embedding matrix instead of creating a new target embedding. :param prefix: Name prefix for symbols of this decoder. """ def __init__(self, config: transformer.TransformerConfig, - embed_weight: Optional[mx.sym.Symbol] = None, prefix: str = C.TRANSFORMER_DECODER_PREFIX) -> None: self.config = config self.prefix = prefix @@ -198,90 +189,63 @@ def __init__(self, dropout=config.dropout_prepost, prefix="%sfinal_process_" % prefix) - # Embedding & output parameters - if embed_weight is None: - embed_weight = encoder.Embedding.get_embed_weight(config.vocab_size, - config.model_size, - C.TARGET_EMBEDDING_PREFIX) - # Note: Transformers use model_size as embedding size - self.embedding = encoder.Embedding(num_embed=config.model_size, - vocab_size=config.vocab_size, - prefix=C.TARGET_EMBEDDING_PREFIX, - dropout=config.dropout_embed, - embed_weight=embed_weight, - embed_scale=config.model_size ** 0.5) self.pos_embedding = encoder.get_positional_embedding(config.positional_embedding_type, config.model_size, max_seq_len=config.max_seq_len_target, + fixed_pos_embed_scale_up_input=True, + fixed_pos_embed_scale_down_positions=False, prefix=C.TARGET_POSITIONAL_EMBEDDING_PREFIX) - self.output_layer = layers.OutputLayer(num_hidden=self.config.model_size, - num_embed=self.config.model_size, - vocab_size=self.config.vocab_size, - weight_tying=self.config.weight_tying, - embed_weight=embed_weight, - weight_normalization=self.config.weight_normalization, - prefix=prefix) - def decode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, source_encoded_max_length: int, - target: mx.sym.Symbol, - target_lengths: mx.sym.Symbol, - target_max_length: int, - source_lexicon: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: + target_embed: mx.sym.Symbol, + target_embed_lengths: mx.sym.Symbol, + target_embed_max_length: int) -> mx.sym.Symbol: """ - Decodes given a known target sequence and returns logits - with batch size and target length dimensions collapsed. - Used for training. + Decodes a sequence of embedded target words and returns sequence of last decoder + representations for each time step. :param source_encoded: Encoded source: (source_encoded_max_length, batch_size, encoder_depth). :param source_encoded_lengths: Lengths of encoded source sequences. Shape: (batch_size,). :param source_encoded_max_length: Size of encoder time dimension. - :param target: Target sequence. Shape: (batch_size, target_max_length). - :param target_lengths: Lengths of target sequences. Shape: (batch_size,). - :param target_max_length: Size of target sequence dimension. - :param source_lexicon: Lexical biases for current sentence. - Shape: (batch_size, target_vocab_size, source_seq_len) - :return: Logits of next-word predictions for target sequence. - Shape: (batch_size * target_max_length, target_vocab_size) + :param target_embed: Embedded target sequence. Shape: (batch_size, target_embed_max_length, target_num_embed). + :param target_embed_lengths: Lengths of embedded target sequences. Shape: (batch_size,). + :param target_embed_max_length: Dimension of the embedded target sequence. + :return: Decoder data. Shape: (batch_size, target_embed_max_length, decoder_depth). """ # (batch_size, source_max_length, num_source_embed) source_encoded = mx.sym.swapaxes(source_encoded, dim1=0, dim2=1) # (batch_size, target_max_length, model_size) target = self._decode(source_encoded, source_encoded_lengths, source_encoded_max_length, - target, target_lengths, target_max_length) - - # (batch_size * target_max_length, model_size) - target = mx.sym.reshape(data=target, shape=(-3, -1)) + target_embed, target_embed_lengths, target_embed_max_length) - # (batch_size * target_max_length, vocab_size) - logits = self.output_layer(target) - return logits + return target def _decode(self, source_encoded, source_encoded_lengths, source_encoded_max_length, - target, target_lengths, target_max_length): + target_embed, target_embed_lengths, target_embed_max_length): """ Runs stacked decoder transformer blocks. :param source_encoded: Batch-major encoded source: (batch_size, source_encoded_max_length, encoder_depth). :param source_encoded_lengths: Lengths of encoded source sequences. Shape: (batch_size,). :param source_encoded_max_length: Size of encoder time dimension. - :param target: Target sequence. Shape: (batch_size, target_max_length). - :param target_lengths: Lengths of target sequences. Shape: (batch_size,). - :param target_max_length: Size of target sequence dimension. + :param target_embed: Embedded target sequence. Shape: (batch_size, target_embed_max_length). + :param target_embed_lengths: Lengths of embedded target sequences. Shape: (batch_size,). + :param target_embed_max_length: Size of embedded target sequence dimension. :return: Result of stacked transformer blocks. """ # (1, target_max_length, target_max_length) - target_bias = transformer.get_autoregressive_bias(target_max_length, name="%sbias" % self.prefix) + target_bias = transformer.get_autoregressive_bias(target_embed_max_length, name="%sbias" % self.prefix) # target: (batch_size, target_max_length, model_size) - target, target_lengths, target_max_length = self.embedding.encode(target, target_lengths, target_max_length) - target, target_lengths, target_max_length = self.pos_embedding.encode(target, target_lengths, target_max_length) + target, target_lengths, target_max_length = self.pos_embedding.encode(target_embed, + target_embed_lengths, + target_embed_max_length) if self.config.dropout_prepost > 0.0: target = mx.sym.Dropout(data=target, p=self.config.dropout_prepost) @@ -294,56 +258,61 @@ def _decode(self, return target def decode_step(self, - target: mx.sym.Symbol, - target_max_length: int, + target_embed: mx.sym.Symbol, + target_embed_lengths: mx.sym.Symbol, + target_embed_max_length: int, + target_embed_prev: mx.sym.Symbol, source_encoded_max_length: int, - *states: mx.sym.Symbol) \ - -> Tuple[mx.sym.Symbol, mx.sym.Symbol, mx.sym.Symbol, List[mx.sym.Symbol]]: + *states: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, List[mx.sym.Symbol]]: """ - Decodes a single time step given the previous word ids in target and previous decoder states. - Returns logit inputs, logits, attention probabilities, and next decoder states. + Decodes a single time step given the embedded target sequence and previous decoder states. + Returns decoder representation for the next prediction, attention probabilities, and next decoder states. Implementations can maintain an arbitrary number of states. - :param target: Previous target word ids. Shape: (batch_size, target_max_length). - :param target_max_length: Size of time dimension in prev_word_ids. + :param target_embed: Embedded target sequence. Shape: (batch_size, target_embed_max_length, target_num_embed). + :param target_embed_lengths: Lengths of embedded target sequences. Shape: (batch_size,). + :param target_embed_max_length: Size of embedded target sequence dimension. + :param target_embed_prev: Previous target word embedding. Shape: (batch_size, target_num_embed). :param source_encoded_max_length: Length of encoded source time dimension. :param states: Arbitrary list of decoder states. - :return: logit inputs, logits, attention probabilities, next decoder states. + :return: logit inputs, attention probabilities, next decoder states. """ source_encoded, source_encoded_lengths = states - # lengths: (batch_size,) - target_lengths = utils.compute_lengths(target) - indices = target_lengths - 1 # type: mx.sym.Symbol + # indices: (batch_size,) + indices = target_embed_lengths - 1 # type: mx.sym.Symbol # (batch_size, target_max_length, 1) mask = mx.sym.expand_dims(mx.sym.one_hot(indices=indices, - depth=target_max_length, + depth=target_embed_max_length, on_value=1, off_value=0), axis=2) # (batch_size, target_max_length, model_size) target = self._decode(source_encoded, source_encoded_lengths, source_encoded_max_length, - target, target_lengths, target_max_length) + target_embed, target_embed_lengths, target_embed_max_length) # set all target positions to zero except for current time-step # target: (batch_size, target_max_length, model_size) target = mx.sym.broadcast_mul(target, mask) # reduce to single prediction # target: (batch_size, model_size) - target = mx.sym.sum(target, axis=1, keepdims=False, name=C.LOGIT_INPUTS_NAME) - - # logits: (batch_size, vocab_size) - logits = self.output_layer(target) + target = mx.sym.sum(target, axis=1, keepdims=False) # TODO(fhieber): no attention probs for now attention_probs = mx.sym.sum(mx.sym.zeros_like(source_encoded), axis=2, keepdims=False) new_states = [source_encoded, source_encoded_lengths] - return target, logits, attention_probs, new_states + return target, attention_probs, new_states def reset(self): pass + def get_num_hidden(self) -> int: + """ + :return: The representation size of this decoder. + """ + return self.config.model_size + def init_states(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, @@ -407,49 +376,34 @@ class RecurrentDecoderConfig(Config): """ Recurrent decoder configuration. - :param vocab_size: Target vocabulary size. :param max_seq_len_source: Maximum source sequence length - :param num_embed: Target word embedding size. :param rnn_config: RNN configuration. :param attention_config: Attention configuration. - :param embed_dropout: Dropout probability for target embeddings. :param hidden_dropout: Dropout probability on next decoder hidden state. - :param weight_tying: Whether to share embedding and prediction parameter matrices. :param state_init: Type of RNN decoder state initialization: zero, last, average. :param context_gating: Whether to use context gating. :param layer_normalization: Apply layer normalization. :param attention_in_upper_layers: Pass the attention value to all layers in the decoder. - :param weight_normalization: Weight normalization. """ def __init__(self, - vocab_size: int, max_seq_len_source: int, - num_embed: int, rnn_config: rnn.RNNConfig, attention_config: rnn_attention.AttentionConfig, - embed_dropout: float = .0, - hidden_dropout: float = .0, - weight_tying: bool = False, + hidden_dropout: float = .0, # TODO: move this dropout functionality to OutputLayer state_init: str = C.RNN_DEC_INIT_LAST, context_gating: bool = False, layer_normalization: bool = False, - attention_in_upper_layers: bool = False, - weight_normalization: bool = False) -> None: + attention_in_upper_layers: bool = False) -> None: super().__init__() - self.vocab_size = vocab_size self.max_seq_len_source = max_seq_len_source - self.num_embed = num_embed self.rnn_config = rnn_config self.attention_config = attention_config - self.embed_dropout = embed_dropout self.hidden_dropout = hidden_dropout - self.weight_tying = weight_tying self.state_init = state_init self.context_gating = context_gating self.layer_normalization = layer_normalization self.attention_in_upper_layers = attention_in_upper_layers - self.weight_normalization = weight_normalization class RecurrentDecoder(Decoder): @@ -458,21 +412,16 @@ class RecurrentDecoder(Decoder): The architecture is based on Luong et al, 2015: Effective Approaches to Attention-based Neural Machine Translation. :param config: Configuration for recurrent decoder. - :param lexicon: Optional Lexicon. - :param embed_weight: Optionally use an existing embedding matrix instead of creating a new target embedding. :param prefix: Decoder symbol prefix. """ def __init__(self, config: RecurrentDecoderConfig, - lexicon: Optional[lexicons.Lexicon] = None, - embed_weight: Optional[mx.sym.Symbol] = None, prefix: str = C.RNN_DECODER_PREFIX) -> None: # TODO: implement variant without input feeding self.config = config self.rnn_config = config.rnn_config self.attention = rnn_attention.get_attention(config.attention_config, config.max_seq_len_source) - self.lexicon = lexicon self.prefix = prefix self.num_hidden = self.rnn_config.num_hidden @@ -512,25 +461,6 @@ def __init__(self, prefix="%shidden_norm" % prefix) \ if self.config.layer_normalization else None - # Embedding & output parameters - if embed_weight is None: - embed_weight = encoder.Embedding.get_embed_weight(self.config.vocab_size, - self.config.num_embed, - C.TARGET_EMBEDDING_PREFIX) - self.embedding = encoder.Embedding(self.config.num_embed, - self.config.vocab_size, - prefix=C.TARGET_EMBEDDING_PREFIX, - dropout=config.embed_dropout, - embed_weight=embed_weight) - - self.output_layer = layers.OutputLayer(num_hidden=self.num_hidden, - num_embed=self.config.num_embed, - vocab_size=self.config.vocab_size, - weight_tying=self.config.weight_tying, - embed_weight=embed_weight, - weight_normalization=self.config.weight_normalization, - prefix=prefix) - def _create_state_init_parameters(self): """ Creates parameters for encoder last state transformation into decoder layer initial states. @@ -552,32 +482,23 @@ def decode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, source_encoded_max_length: int, - target: mx.sym.Symbol, - target_lengths: mx.sym.Symbol, - target_max_length: int, - source_lexicon: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: + target_embed: mx.sym.Symbol, + target_embed_lengths: mx.sym.Symbol, + target_embed_max_length: int) -> mx.sym.Symbol: """ - Decodes given a known target sequence and returns logits - with batch size and target length dimensions collapsed. - Used for training. + Decodes a sequence of embedded target words and returns sequence of last decoder + representations for each time step. :param source_encoded: Encoded source: (source_encoded_max_length, batch_size, encoder_depth). :param source_encoded_lengths: Lengths of encoded source sequences. Shape: (batch_size,). :param source_encoded_max_length: Size of encoder time dimension. - :param target: Target sequence. Shape: (batch_size, target_max_length). - :param target_lengths: Lengths of target sequences. Shape: (batch_size,). - :param target_max_length: Size of target sequence dimension. - :param source_lexicon: Lexical biases for current sentence. - Shape: (batch_size, target_vocab_size, source_seq_len) - :return: Logits of next-word predictions for target sequence. - Shape: (batch_size * target_max_length, target_vocab_size) - """ - # embed and slice target words - # target_embed: (batch_size, target_seq_len, num_target_embed) - target_embed, target_lengths, target_max_length = self.embedding.encode(target, target_lengths, - target_max_length) + :param target_embed: Embedded target sequence. Shape: (batch_size, target_embed_max_length, target_num_embed). + :param target_embed_lengths: Lengths of embedded target sequences. Shape: (batch_size,). + :param target_embed_max_length: Dimension of the embedded target sequence. + :return: Decoder data. Shape: (batch_size, target_embed_max_length, decoder_depth). + """ # target_embed: target_seq_len * (batch_size, num_target_embed) - target_embed = mx.sym.split(data=target_embed, num_outputs=target_max_length, axis=1, squeeze_axis=True) + target_embed = mx.sym.split(data=target_embed, num_outputs=target_embed_max_length, axis=1, squeeze_axis=True) # get recurrent attention function conditioned on source source_encoded_batch_major = mx.sym.swapaxes(source_encoded, dim1=0, dim2=1, name='source_encoded_batch_major') @@ -592,72 +513,45 @@ def decode_sequence(self, # hidden_all: target_seq_len * (batch_size, 1, rnn_num_hidden) hidden_all = [] - # TODO: possible alternative: feed back the context vector instead of the hidden (see lamtram) - - lexical_biases = [] - self.reset() - - for seq_idx in range(target_max_length): + for seq_idx in range(target_embed_max_length): # hidden: (batch_size, rnn_num_hidden) state, attention_state = self._step(target_embed[seq_idx], state, attention_func, attention_state, seq_idx) - # hidden_expanded: (batch_size, 1, rnn_num_hidden) hidden_all.append(mx.sym.expand_dims(data=state.hidden, axis=1)) - if source_lexicon is not None: - assert self.lexicon is not None, "source_lexicon should not be None if no lexicon available" - lexical_biases.append(self.lexicon.calculate_lex_bias(source_lexicon, attention_state.probs)) - # concatenate along time axis # hidden_concat: (batch_size, target_seq_len, rnn_num_hidden) hidden_concat = mx.sym.concat(*hidden_all, dim=1, name="%shidden_concat" % self.prefix) - # hidden_concat: (batch_size * target_seq_len, rnn_num_hidden) - hidden_concat = mx.sym.reshape(data=hidden_concat, shape=(-1, self.num_hidden)) - - # logits: (batch_size * target_seq_len, target_vocab_size) - logits = self.output_layer(hidden_concat) - - if source_lexicon is not None: - # lexical_biases_concat: (batch_size, target_seq_len, target_vocab_size) - lexical_biases_concat = mx.sym.concat(*lexical_biases, dim=1, name='lex_bias_concat') - # lexical_biases_concat: (batch_size * target_seq_len, target_vocab_size) - lexical_biases_concat = mx.sym.reshape(data=lexical_biases_concat, shape=(-1, self.config.vocab_size)) - logits = mx.sym.broadcast_add(lhs=logits, rhs=lexical_biases_concat, - name='%s_plus_lex_bias' % C.LOGITS_NAME) - - return logits + return hidden_concat def decode_step(self, - target: mx.sym.Symbol, - target_max_length: int, + target_embed: mx.sym.Symbol, + target_embed_lengths: mx.sym.Symbol, + target_embed_max_length: int, + target_embed_prev: mx.sym.Symbol, source_encoded_max_length: int, - *states: mx.sym.Symbol) \ - -> Tuple[mx.sym.Symbol, mx.sym.Symbol, mx.sym.Symbol, List[mx.sym.Symbol]]: + *states: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, List[mx.sym.Symbol]]: """ - Decodes a single time step given the previous word ids in target and previous decoder states. - Returns logit inputs, logits, attention probabilities, and next decoder states. + Decodes a single time step given the embedded target sequence and previous decoder states. + Returns decoder representation for the next prediction, attention probabilities, and next decoder states. Implementations can maintain an arbitrary number of states. - :param target: Previous target word ids. Shape: (batch_size, target_max_length). - :param target_max_length: Size of time dimension in prev_word_ids. + :param target_embed: Embedded target sequence. Shape: (batch_size, target_embed_max_length, target_num_embed). + :param target_embed_lengths: Lengths of embedded target sequences. Shape: (batch_size,). + :param target_embed_max_length: Size of embedded target sequence dimension. + :param target_embed_prev: Previous target word embedding. Shape: (batch_size, target_num_embed). :param source_encoded_max_length: Length of encoded source time dimension. :param states: Arbitrary list of decoder states. - :return: logit inputs, logits, attention probabilities, next decoder states. + :return: logit inputs, attention probabilities, next decoder states. """ source_encoded, prev_dynamic_source, source_encoded_length, prev_hidden, *layer_states = states - # indices: (batch_size,) - indices = utils.compute_lengths(target) - 1 # type: mx.sym.Symbol - prev_word_id = mx.sym.pick(target, indices, axis=1) - - word_vec_prev, _, _ = self.embedding.encode(prev_word_id, None, 1) - attention_func = self.attention.on(source_encoded, source_encoded_length, source_encoded_max_length) prev_state = RecurrentDecoderState(prev_hidden, list(layer_states)) @@ -667,23 +561,17 @@ def decode_step(self, # state.hidden: (batch_size, rnn_num_hidden) # attention_state.dynamic_source: (batch_size, source_seq_len, coverage_num_hidden) # attention_state.probs: (batch_size, source_seq_len) - state, attention_state = self._step(word_vec_prev, + state, attention_state = self._step(target_embed_prev, prev_state, attention_func, prev_attention_state) - # logit inputs aka state.hidden: (batch_size, rnn_num_hidden) - logit_inputs = mx.sym.identity(state.hidden, name=C.LOGIT_INPUTS_NAME) - - # logits: (batch_size, target_vocab_size) - logits = self.output_layer(state.hidden) - new_states = [source_encoded, attention_state.dynamic_source, source_encoded_length, state.hidden] + state.layer_states - return logit_inputs, logits, attention_state.probs, new_states + return state.hidden, attention_state.probs, new_states def reset(self): """ @@ -701,6 +589,12 @@ def reset(self): cell.base_cell.reset() cell.reset() + def get_num_hidden(self) -> int: + """ + :return: The representation size of this decoder. + """ + return self.num_hidden + def init_states(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, @@ -942,44 +836,29 @@ class ConvolutionalDecoderConfig(Config): Convolutional decoder configuration. :param cnn_config: Configuration for the convolution block. - :param vocab_size: Target vocabulary size. :param max_seq_len_target: Maximum target sequence length. :param num_embed: Target word embedding size. :param encoder_num_hidden: Number of hidden units of the encoder. :param num_layers: The number of convolutional layers. :param positional_embedding_type: The type of positional embedding. - :param weight_tying: Whether to share embedding and prediction parameter matrices. - :param weight_normalization: Weight normalization. - :param embed_dropout: Dropout probability for target embeddings. :param hidden_dropout: Dropout probability on next decoder hidden state. """ def __init__(self, cnn_config: convolution.ConvolutionConfig, - vocab_size: int, max_seq_len_target: int, num_embed: int, encoder_num_hidden: int, num_layers: int, positional_embedding_type: str, - weight_tying: bool, - weight_normalization: bool = False, - embed_dropout: float = .0, hidden_dropout: float = .0) -> None: super().__init__() - if embed_dropout > 0 and hidden_dropout > 0: - logger.warning("Setting cnn encoder dropout AND hidden dropout > 0 leads to " - "two dropout layers on top of each other.") self.cnn_config = cnn_config - self.vocab_size = vocab_size self.max_seq_len_target = max_seq_len_target self.num_embed = num_embed self.encoder_num_hidden = encoder_num_hidden self.num_layers = num_layers self.positional_embedding_type = positional_embedding_type - self.weight_tying = weight_tying - self.weight_normalization = weight_normalization - self.embed_dropout = embed_dropout self.hidden_dropout = hidden_dropout @@ -999,14 +878,13 @@ class ConvolutionalDecoder(Decoder): several projection matrices) :param config: Configuration for convolutional decoder. - :param embed_weight: Optionally use an existing embedding matrix instead of creating a new target embedding. :param prefix: Name prefix for symbols of this decoder. """ def __init__(self, config: ConvolutionalDecoderConfig, - embed_weight: Optional[mx.sym.Symbol] = None, prefix: str = C.DECODER_PREFIX) -> None: + super().__init__() self.config = config self.prefix = prefix @@ -1015,19 +893,11 @@ def __init__(self, "We need to have the same number of hidden units in the decoder " "as we have in the encoder") - if embed_weight is None: - embed_weight = encoder.Embedding.get_embed_weight(self.config.vocab_size, - self.config.num_embed, - C.TARGET_EMBEDDING_PREFIX) - - self.embedding = encoder.Embedding(self.config.num_embed, - self.config.vocab_size, - prefix=C.TARGET_EMBEDDING_PREFIX, - embed_weight=embed_weight, - dropout=config.embed_dropout) self.pos_embedding = encoder.get_positional_embedding(config.positional_embedding_type, - config.num_embed, + num_embed=config.num_embed, max_seq_len=config.max_seq_len_target, + fixed_pos_embed_scale_up_input=False, + fixed_pos_embed_scale_down_positions=False, prefix=C.TARGET_POSITIONAL_EMBEDDING_PREFIX) self.layers = [convolution.ConvolutionBlock( @@ -1037,78 +907,56 @@ def __init__(self, self.i2h_weight = mx.sym.Variable('%si2h_weight' % prefix) - self.output_layer = layers.OutputLayer(num_hidden=self.config.cnn_config.num_hidden, - num_embed=self.config.num_embed, - vocab_size=self.config.vocab_size, - weight_tying=self.config.weight_tying, - embed_weight=embed_weight, - weight_normalization=self.config.weight_normalization, - prefix=prefix) - def decode_sequence(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, source_encoded_max_length: int, - target: mx.sym.Symbol, - target_lengths: mx.sym.Symbol, - target_max_length: int, - source_lexicon: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: + target_embed: mx.sym.Symbol, + target_embed_lengths: mx.sym.Symbol, + target_embed_max_length: int) -> mx.sym.Symbol: """ - Decodes given a known target sequence and returns logits - with batch size and target length dimensions collapsed. - Used for training. + Decodes a sequence of embedded target words and returns sequence of last decoder + representations for each time step. :param source_encoded: Encoded source: (source_encoded_max_length, batch_size, encoder_depth). :param source_encoded_lengths: Lengths of encoded source sequences. Shape: (batch_size,). :param source_encoded_max_length: Size of encoder time dimension. - :param target: Target sequence. Shape: (batch_size, target_max_length). - :param target_lengths: Lengths of target sequences. Shape: (batch_size,). - :param target_max_length: Size of target sequence dimension. - :param source_lexicon: Lexical biases for current sentence. - Shape: (batch_size, target_vocab_size, source_seq_len) - :return: Logits of next-word predictions for target sequence. - Shape: (batch_size * target_max_length, target_vocab_size) + :param target_embed: Embedded target sequence. Shape: (batch_size, target_embed_max_length, target_num_embed). + :param target_embed_lengths: Lengths of embedded target sequences. Shape: (batch_size,). + :param target_embed_max_length: Dimension of the embedded target sequence. + :return: Decoder data. Shape: (batch_size, target_embed_max_length, decoder_depth). """ - - check_condition(source_lexicon is None, "Source lexicon not supported.") - # (batch_size, source_encoded_max_length, encoder_depth). source_encoded_batch_major = mx.sym.swapaxes(source_encoded, dim1=0, dim2=1, name='source_encoded_batch_major') + # (batch_size, target_seq_len, num_hidden) target_hidden = self._decode(source_encoded=source_encoded_batch_major, source_encoded_lengths=source_encoded_lengths, - target=target, - target_lengths=target_lengths, - target_max_length=target_max_length) + target_embed=target_embed, + target_embed_lengths=target_embed_lengths, + target_embed_max_length=target_embed_max_length) - # (batch_size * target_seq_len, num_hidden) - target_hidden = mx.sym.reshape(data=target_hidden, shape=(-3, 0)) - # (batch_size * target_seq_len, target_vocab_size) - logits = self.output_layer(target_hidden) - return logits + return target_hidden def _decode(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, - target: mx.sym.Symbol, - target_lengths: mx.sym.Symbol, - target_max_length: int) -> mx.sym.Symbol: + target_embed: mx.sym.Symbol, + target_embed_lengths: mx.sym.Symbol, + target_embed_max_length: int) -> mx.sym.Symbol: """ Decode the target and produce a sequence of hidden states. :param source_encoded: Shape: (batch_size, source_encoded_max_length, encoder_depth). :param source_encoded_lengths: Shape: (batch_size,). - :param target: Target sequence. Shape: (batch_size, target_max_length). - :param target_lengths: Lengths of target sequences. Shape: (batch_size,). - :param target_max_length: Size of target sequence dimension. + :param target_embed: Embedded target sequence. Shape: (batch_size, target_embed_max_length). + :param target_embed_lengths: Lengths of embedded target sequences. Shape: (batch_size,). + :param target_embed_max_length: Size of embedded target sequence dimension. :return: The target hidden states. Shape: (batch_size, target_seq_len, num_hidden). """ - # target_embed: (batch_size, target_seq_len, num_target_embed) - target_embed, target_lengths, target_max_length = self.embedding.encode(target, target_lengths, - target_max_length) - target_embed, target_lengths, target_max_length = self.pos_embedding.encode(target_embed, - target_lengths, - target_max_length) + target_embed, target_embed_lengths, target_embed_max_length = self.pos_embedding.encode(target_embed, + target_embed_lengths, + target_embed_max_length) # target_hidden: (batch_size, target_seq_len, num_hidden) target_hidden = mx.sym.FullyConnected(data=target_embed, num_hidden=self.config.cnn_config.num_hidden, @@ -1122,7 +970,7 @@ def _decode(self, for layer in self.layers: # (batch_size, target_seq_len, num_hidden) target_hidden = layer(mx.sym.Dropout(target_hidden, p=drop_prob) if drop_prob > 0 else target_hidden, - target_lengths, target_max_length) + target_embed_lengths, target_embed_max_length) # (batch_size, target_seq_len, num_embed) context = layers.dot_attention(queries=target_hidden, @@ -1136,26 +984,26 @@ def _decode(self, return target_hidden def decode_step(self, - target: mx.sym.Symbol, - target_max_length: int, + target_embed: mx.sym.Symbol, + target_embed_lengths: mx.sym.Symbol, + target_embed_max_length: int, + target_embed_prev: mx.sym.Symbol, source_encoded_max_length: int, - *states: mx.sym.Symbol) \ - -> Tuple[mx.sym.Symbol, mx.sym.Symbol, mx.sym.Symbol, List[mx.sym.Symbol]]: + *states: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, List[mx.sym.Symbol]]: """ - Decodes a single time step given the previous word ids in target and previous decoder states. - Returns logit inputs, logits, attention probabilities, and next decoder states. + Decodes a single time step given the embedded target sequence and previous decoder states. + Returns decoder representation for the next prediction, attention probabilities, and next decoder states. Implementations can maintain an arbitrary number of states. - :param target: Previous target word ids. Shape: (batch_size, target_max_length). - :param target_max_length: Size of time dimension in prev_word_ids. + :param target_embed: Embedded target sequence. Shape: (batch_size, target_embed_max_length, target_num_embed). + :param target_embed_lengths: Lengths of embedded target sequences. Shape: (batch_size,). + :param target_embed_max_length: Size of embedded target sequence dimension. + :param target_embed_prev: Previous target word embedding. Shape: (batch_size, target_num_embed). :param source_encoded_max_length: Length of encoded source time dimension. :param states: Arbitrary list of decoder states. - :return: logit inputs, logits, attention probabilities, next decoder states. + :return: logit inputs, attention probabilities, next decoder states. """ - - # (batch_size,) - target_lengths = utils.compute_lengths(target) - indices = target_lengths - 1 # type: mx.sym.Symbol + indices = target_embed_lengths - 1 # type: mx.sym.Symbol # Source_encoded: (batch_size, source_encoded_max_length, encoder_depth) source_encoded, source_encoded_lengths, *layer_states = states @@ -1170,18 +1018,11 @@ def decode_step(self, new_layer_states = [] - # (batch_size,) - prev_word_id = mx.sym.pick(target, indices, axis=1) - - # (batch_size, num_embed) - target_embed, _, target_max_length = self.embedding.encode(prev_word_id, - None, - target_max_length) # (batch_size, num_embed) - target_embed = self.pos_embedding.encode_positions(indices, target_embed) + target_embed_prev = self.pos_embedding.encode_positions(indices, target_embed_prev) # (batch_size, num_hidden) - target_hidden_step = mx.sym.FullyConnected(data=target_embed, + target_hidden_step = mx.sym.FullyConnected(data=target_embed_prev, num_hidden=self.config.cnn_config.num_hidden, no_bias=True, weight=self.i2h_weight) @@ -1227,16 +1068,17 @@ def decode_step(self, axis=2, begin=0, end=1), shape=(0, -1)) - # logit inputs aka target_hidden - logit_inputs = mx.sym.identity(target_hidden, name=C.LOGIT_INPUTS_NAME) - - # (batch_size, vocab_size) - logits = self.output_layer(target_hidden) - return logit_inputs, logits, attention_probs, [source_encoded, source_encoded_lengths] + new_layer_states + return target_hidden, attention_probs, [source_encoded, source_encoded_lengths] + new_layer_states def reset(self): pass + def get_num_hidden(self) -> int: + """ + :return: The representation size of this decoder. + """ + return self.config.cnn_config.num_hidden + def init_states(self, source_encoded: mx.sym.Symbol, source_encoded_lengths: mx.sym.Symbol, diff --git a/sockeye/encoder.py b/sockeye/encoder.py index b4c2e33c8..2f6aa6b22 100644 --- a/sockeye/encoder.py +++ b/sockeye/encoder.py @@ -32,13 +32,13 @@ logger = logging.getLogger(__name__) -def get_encoder(config: Config, fused: bool, embed_weight: Optional[mx.sym.Symbol] = None): +def get_encoder(config: Config): if isinstance(config, RecurrentEncoderConfig): - return get_recurrent_encoder(config, fused, embed_weight) + return get_recurrent_encoder(config) elif isinstance(config, transformer.TransformerConfig): - return get_transformer_encoder(config, embed_weight) + return get_transformer_encoder(config) elif isinstance(config, ConvolutionalEncoderConfig): - return get_convolutional_encoder(config, embed_weight) + return get_convolutional_encoder(config) else: raise ValueError("Unsupported encoder configuration") @@ -47,25 +47,16 @@ class RecurrentEncoderConfig(Config): """ Recurrent encoder configuration. - :param vocab_size: Source vocabulary size. - :param num_embed: Size of embedding layer. - :param embed_dropout: Dropout probability on embedding layer. :param rnn_config: RNN configuration. :param conv_config: Optional configuration for convolutional embedding. :param reverse_input: Reverse embedding sequence before feeding into RNN. """ def __init__(self, - vocab_size: int, - num_embed: int, - embed_dropout: float, rnn_config: rnn.RNNConfig, conv_config: Optional['ConvolutionalEmbeddingConfig'] = None, reverse_input: bool = False) -> None: super().__init__() - self.vocab_size = vocab_size - self.num_embed = num_embed - self.embed_dropout = embed_dropout self.rnn_config = rnn_config self.conv_config = conv_config self.reverse_input = reverse_input @@ -75,58 +66,42 @@ class ConvolutionalEncoderConfig(Config): """ Convolutional encoder configuration. - :param vocab_size: Source vocabulary size. - :param num_embed: Size of embedding layer. - :param embed_dropout: Dropout probability on embedding layer. :param cnn_config: CNN configuration. :param num_layers: The number of convolutional layers on top of the embeddings. :param positional_embedding_type: The type of positional embedding. """ def __init__(self, - vocab_size: int, num_embed: int, - embed_dropout: float, max_seq_len_source: int, cnn_config: convolution.ConvolutionConfig, num_layers: int, positional_embedding_type: str) -> None: super().__init__() - self.vocab_size = vocab_size self.num_embed = num_embed - self.embed_dropout = embed_dropout self.num_layers = num_layers self.cnn_config = cnn_config self.max_seq_len_source = max_seq_len_source self.positional_embedding_type = positional_embedding_type -def get_recurrent_encoder(config: RecurrentEncoderConfig, fused: bool, - embed_weight: Optional[mx.sym.Symbol] = None) -> 'Encoder': +def get_recurrent_encoder(config: RecurrentEncoderConfig) -> 'Encoder': """ - Returns a recurrent encoder with embedding, batch2time-major conversion, and bidirectional RNN. - If num_layers > 1, adds additional uni-directional RNNs. + Returns an encoder stack with a bi-directional RNN, and a variable number of uni-directional forward RNNs. :param config: Configuration for recurrent encoder. - :param fused: Whether to use FusedRNNCell (CuDNN). Only works with GPU context. - :param embed_weight: Optionally use an existing embedding matrix instead of creating a new one. :return: Encoder instance. """ # TODO give more control on encoder architecture encoders = list() # type: List[Encoder] - encoders.append(Embedding(num_embed=config.num_embed, - vocab_size=config.vocab_size, - prefix=C.SOURCE_EMBEDDING_PREFIX, - dropout=config.embed_dropout, - embed_weight=embed_weight)) - if config.conv_config is not None: - encoders.append(ConvolutionalEmbeddingEncoder(config.conv_config, - prefix=C.CHAR_SEQ_ENCODER_PREFIX)) + encoders.append(ConvolutionalEmbeddingEncoder(config.conv_config, prefix=C.CHAR_SEQ_ENCODER_PREFIX)) if config.conv_config.add_positional_encoding: # If specified, add positional encodings to segment embeddings - encoders.append(AddSinCosPositionalEmbeddings(num_embed=config.num_embed, + encoders.append(AddSinCosPositionalEmbeddings(num_embed=config.conv_config.num_embed, + scale_up_input=False, + scale_down_positions=False, prefix="%sadd_positional_encodings" % C.CHAR_SEQ_ENCODER_PREFIX)) encoders.append(BatchMajor2TimeMajor()) @@ -138,7 +113,6 @@ def get_recurrent_encoder(config: RecurrentEncoderConfig, fused: bool, utils.check_condition(config.rnn_config.first_residual_layer >= 2, "Residual connections on the first encoder layer are not supported") - encoder_class = FusedRecurrentEncoder if fused else RecurrentEncoder # One layer bi-directional RNN: encoders.append(BiDirectionalRNNEncoder(rnn_config=config.rnn_config.copy(num_layers=1), prefix=C.BIDIRECTIONALRNN_PREFIX, @@ -149,15 +123,14 @@ def get_recurrent_encoder(config: RecurrentEncoderConfig, fused: bool, # Because we already have a one layer bi-rnn we reduce the num_layers as well as the first_residual_layer. remaining_rnn_config = config.rnn_config.copy(num_layers=config.rnn_config.num_layers - 1, first_residual_layer=config.rnn_config.first_residual_layer - 1) - encoders.append(encoder_class(rnn_config=remaining_rnn_config, + encoders.append(RecurrentEncoder(rnn_config=remaining_rnn_config, prefix=C.STACKEDRNN_PREFIX, layout=C.TIME_MAJOR)) return EncoderSequence(encoders) -def get_convolutional_encoder(config: ConvolutionalEncoderConfig, - embed_weight: Optional[mx.sym.Symbol] = None) -> 'Encoder': +def get_convolutional_encoder(config: ConvolutionalEncoderConfig) -> 'Encoder': """ Creates a convolutional encoder. @@ -166,46 +139,32 @@ def get_convolutional_encoder(config: ConvolutionalEncoderConfig, :return: Encoder instance. """ encoders = list() # type: List[Encoder] - encoders.append(Embedding(num_embed=config.num_embed, - vocab_size=config.vocab_size, - prefix=C.SOURCE_EMBEDDING_PREFIX, - dropout=config.embed_dropout, - embed_weight=embed_weight)) - encoders.append(get_positional_embedding(config.positional_embedding_type, config.num_embed, max_seq_len=config.max_seq_len_source, + fixed_pos_embed_scale_up_input=False, + fixed_pos_embed_scale_down_positions=False, prefix=C.SOURCE_POSITIONAL_EMBEDDING_PREFIX)) - encoders.append(ConvolutionalEncoder(config=config)) encoders.append(BatchMajor2TimeMajor()) - return EncoderSequence(encoders) -def get_transformer_encoder(config: transformer.TransformerConfig, - embed_weight: Optional[mx.sym.Symbol] = None) -> 'Encoder': +def get_transformer_encoder(config: transformer.TransformerConfig) -> 'Encoder': """ Returns a Transformer encoder, consisting of an embedding layer with positional encodings and a TransformerEncoder instance. :param config: Configuration for transformer encoder. - :param embed_weight: Optionally use an existing embedding matrix instead of creating a new one. :return: Encoder instance. """ encoders = list() # type: List[Encoder] - # Note: Transformers use model_size as embedding size - # Note: Embedding vectors are scaled by transformer model size. - encoders.append(Embedding(num_embed=config.model_size, - vocab_size=config.vocab_size, - prefix=C.SOURCE_EMBEDDING_PREFIX, - dropout=config.dropout_embed, - embed_weight=embed_weight, - embed_scale=config.model_size ** 0.5)) encoders.append(get_positional_embedding(config.positional_embedding_type, config.model_size, config.max_seq_len_source, - C.SOURCE_POSITIONAL_EMBEDDING_PREFIX)) + fixed_pos_embed_scale_up_input=True, + fixed_pos_embed_scale_down_positions=False, + prefix=C.SOURCE_POSITIONAL_EMBEDDING_PREFIX)) if config.conv_config is not None: encoders.append(ConvolutionalEmbeddingEncoder(config.conv_config)) @@ -241,12 +200,6 @@ def get_num_hidden(self) -> int: """ raise NotImplementedError() - def get_rnn_cells(self) -> List[mx.rnn.BaseRNNCell]: - """ - :return: A list of RNNCells used by this encoder. - """ - return [] - def get_encoded_seq_len(self, seq_len: int) -> int: """ :return: The size of the encoded sequence. @@ -294,34 +247,37 @@ def encode(self, return data, data_length, seq_len +class EmbeddingConfig(Config): + + def __init__(self, + vocab_size: int, + num_embed: int, + dropout: float) -> None: + super().__init__() + self.vocab_size = vocab_size + self.num_embed = num_embed + self.dropout = dropout + + class Embedding(Encoder): """ Thin wrapper around MXNet's Embedding symbol. Works with both time- and batch-major data layouts. - :param num_embed: Embedding size. - :param vocab_size: Source vocabulary size. + :param config: Embedding config. :param prefix: Name prefix for symbols of this encoder. - :param dropout: Dropout probability. :param embed_weight: Optionally use an existing embedding matrix instead of creating a new one. - :param embed_scale: Optional fixed scaling factor for embeddings. """ def __init__(self, - num_embed: int, - vocab_size: int, + config: EmbeddingConfig, prefix: str, - dropout: float, - embed_weight: Optional[mx.sym.Symbol] = None, - embed_scale: Optional[float] = None) -> None: - self.num_embed = num_embed - self.vocab_size = vocab_size + embed_weight: Optional[mx.sym.Symbol] = None) -> None: + self.config = config self.prefix = prefix - self.dropout = dropout - if embed_weight is not None: - self.embed_weight = embed_weight - else: - self.embed_weight = self.get_embed_weight(vocab_size, num_embed, prefix) - self.embed_scale = embed_scale + self.embed_weight = embed_weight + if self.embed_weight is None: + self.embed_weight = mx.sym.Variable(prefix + "weight", + shape=(self.config.vocab_size, self.config.num_embed)) def encode(self, data: mx.sym.Symbol, @@ -336,28 +292,20 @@ def encode(self, :return: Encoded versions of input data (data, data_length, seq_len). """ embedding = mx.sym.Embedding(data=data, - input_dim=self.vocab_size, + input_dim=self.config.vocab_size, weight=self.embed_weight, - output_dim=self.num_embed, + output_dim=self.config.num_embed, name=self.prefix + "embed") - if self.dropout > 0: - embedding = mx.sym.Dropout(data=embedding, p=self.dropout, name="source_embed_dropout") - if self.embed_scale is not None and self.embed_scale != 1.0: - embedding = embedding * self.embed_scale + if self.config.dropout > 0: + embedding = mx.sym.Dropout(data=embedding, p=self.config.dropout, name="source_embed_dropout") + return embedding, data_length, seq_len def get_num_hidden(self) -> int: """ Return the representation size of this encoder. """ - return self.num_embed - - @staticmethod - def get_embed_weight(vocab_size: int, embed_size: int, prefix: str) -> mx.sym.Variable: - """ - Creates a variable for the embedding matrix. - """ - return mx.sym.Variable(prefix + "weight", shape=(vocab_size, embed_size)) + return self.config.num_embed class PositionalEncoder(Encoder): @@ -380,13 +328,19 @@ class AddSinCosPositionalEmbeddings(PositionalEncoder): :param num_embed: Embedding size. :param prefix: Name prefix for symbols of this encoder. + :param scale_up_input: If True, scales input data up by num_embed ** 0.5. + :param scale_down_positions: If True, scales positional embeddings down by num_embed ** -0.5. """ def __init__(self, num_embed: int, - prefix: str) -> None: + prefix: str, + scale_up_input: bool, + scale_down_positions: bool) -> None: utils.check_condition(num_embed % 2 == 0, "Positional embeddings require an even embedding size it " "is however %d." % num_embed) + self.scale_up_input = scale_up_input + self.scale_down_positions = scale_down_positions self.num_embed = num_embed self.prefix = prefix @@ -400,11 +354,19 @@ def encode(self, :param seq_len: sequence length. :return: (batch_size, source_seq_len, num_embed) """ - embedding = mx.sym.broadcast_add(data, - mx.sym.BlockGrad(mx.symbol.Custom(length=seq_len, - depth=self.num_embed, - name="%spositional_encodings" % self.prefix, - op_type='positional_encodings'))) + # add positional embeddings to data + if self.scale_up_input: + data = data * (self.num_embed ** 0.5) + + positions = mx.sym.BlockGrad(mx.symbol.Custom(length=seq_len, + depth=self.num_embed, + name="%spositional_encodings" % self.prefix, + op_type='positional_encodings')) + + if self.scale_down_positions: + positions = positions * (self.num_embed ** -0.5) + + embedding = mx.sym.broadcast_add(data, positions) return embedding, data_length, seq_len def encode_positions(self, @@ -431,6 +393,12 @@ def encode_positions(self, # (batch_size, num_embed/2) pos_embedding = mx.sym.concat(sin, cos, dim=1) + if self.scale_up_input: + data = data * (self.num_embed ** 0.5) + + if self.scale_down_positions: + pos_embedding = pos_embedding * (self.num_embed ** -0.5) + return mx.sym.broadcast_add(data, pos_embedding, name="%s_add" % self.prefix) def get_num_hidden(self) -> int: @@ -531,9 +499,17 @@ def get_num_hidden(self) -> int: return self.num_embed -def get_positional_embedding(positional_embedding_type, num_embed, max_seq_len, prefix) -> PositionalEncoder: +def get_positional_embedding(positional_embedding_type: str, + num_embed: int, + max_seq_len: int, + fixed_pos_embed_scale_up_input: bool = False, + fixed_pos_embed_scale_down_positions: bool = False, + prefix: str = '') -> PositionalEncoder: if positional_embedding_type == C.FIXED_POSITIONAL_EMBEDDING: - return AddSinCosPositionalEmbeddings(num_embed=num_embed, prefix=prefix) + return AddSinCosPositionalEmbeddings(num_embed=num_embed, + scale_up_input=fixed_pos_embed_scale_up_input, + scale_down_positions=fixed_pos_embed_scale_down_positions, + prefix=prefix) elif positional_embedding_type == C.LEARNED_POSITIONAL_EMBEDDING: return AddLearnedPositionalEmbeddings(num_embed=num_embed, max_seq_len=max_seq_len, @@ -581,16 +557,6 @@ def get_num_hidden(self) -> int: else: return self.encoders[-1].get_num_hidden() - def get_rnn_cells(self) -> List[mx.rnn.BaseRNNCell]: - """ - Returns a list of RNNCells used by this encoder. - """ - cells = [] - for encoder in self.encoders: - for cell in encoder.get_rnn_cells(): - cells.append(cell) - return cells - def get_encoded_seq_len(self, seq_len: int) -> int: """ Returns the size of the encoded sequence. @@ -654,30 +620,6 @@ def get_num_hidden(self): return self.rnn_config.num_hidden -class FusedRecurrentEncoder(RecurrentEncoder): - """ - Uni-directional (multi-layered) recurrent encoder. - - :param rnn_config: RNN configuration. - :param prefix: Prefix. - :param layout: Data layout. - """ - - def __init__(self, - rnn_config: rnn.RNNConfig, - prefix: str = C.STACKEDRNN_PREFIX, - layout: str = C.TIME_MAJOR) -> None: - super().__init__(rnn_config, prefix, layout) - logger.warning("%s: FusedRNNCell uses standard MXNet Orthogonal initializer w/ rand_type=uniform", prefix) - self.rnn = mx.rnn.FusedRNNCell(self.rnn_config.num_hidden, - num_layers=self.rnn_config.num_layers, - mode=self.rnn_config.cell_type, - bidirectional=False, - dropout=self.rnn_config.dropout_inputs, - forget_bias=self.rnn_config.forget_bias, - prefix=prefix) - - class BiDirectionalRNNEncoder(Encoder): """ An encoder that runs a forward and a reverse RNN over input data. diff --git a/sockeye/inference.py b/sockeye/inference.py index 71d050617..48727aea2 100644 --- a/sockeye/inference.py +++ b/sockeye/inference.py @@ -41,7 +41,6 @@ class InferenceModel(model.SockeyeModel): :param model_folder: Folder to load model from. :param context: MXNet context to bind modules to. - :param fused: Whether to use FusedRNNCell (CuDNN). Only works with GPU context. :param beam_size: Beam size. :param batch_size: Batch size. :param checkpoint: Checkpoint to load. If None, finds best parameters in model_folder. @@ -55,7 +54,6 @@ class InferenceModel(model.SockeyeModel): def __init__(self, model_folder: str, context: mx.context.Context, - fused: bool, beam_size: int, batch_size: int, checkpoint: Optional[int] = None, @@ -80,7 +78,7 @@ def __init__(self, self.batch_size = batch_size self.context = context - self._build_model_components(fused) + self._build_model_components() self.max_input_length, self.get_max_output_length = get_max_input_output_length([self], max_output_length_num_stds) @@ -137,15 +135,15 @@ def initialize(self, max_input_length: int, get_max_output_length_function: Call self.decoder_module.init_params(arg_params=self.params, allow_missing=False) if self.cache_output_layer_w_b: - if self.decoder.output_layer.weight_normalization: + if self.output_layer.weight_normalization: # precompute normalized output layer weight imperatively - assert self.decoder.output_layer.weight_norm is not None - weight = self.params[self.decoder.output_layer.weight_norm.weight.name].as_in_context(self.context) - scale = self.params[self.decoder.output_layer.weight_norm.scale.name].as_in_context(self.context) - self.output_layer_w = self.decoder.output_layer.weight_norm(weight, scale) + assert self.output_layer.weight_norm is not None + weight = self.params[self.output_layer.weight_norm.weight.name].as_in_context(self.context) + scale = self.params[self.output_layer.weight_norm.scale.name].as_in_context(self.context) + self.output_layer_w = self.output_layer.weight_norm(weight, scale) else: - self.output_layer_w = self.params[self.decoder.output_layer.w.name].as_in_context(self.context) - self.output_layer_b = self.params[self.decoder.output_layer.b.name].as_in_context(self.context) + self.output_layer_w = self.params[self.output_layer.w.name].as_in_context(self.context) + self.output_layer_b = self.params[self.output_layer.b.name].as_in_context(self.context) def _get_encoder_module(self) -> Tuple[mx.mod.BucketingModule, int]: """ @@ -160,9 +158,19 @@ def sym_gen(source_seq_len: int): source = mx.sym.Variable(C.SOURCE_NAME) source_length = utils.compute_lengths(source) + # source embedding + (source_embed, + source_embed_length, + source_embed_seq_len) = self.embedding_source.encode(source, source_length, source_seq_len) + + # encoder + # source_encoded: (source_encoded_length, batch_size, encoder_depth) (source_encoded, source_encoded_length, - source_encoded_seq_len) = self.encoder.encode(source, source_length, source_seq_len) + source_encoded_seq_len) = self.encoder.encode(source_embed, + source_embed_length, + source_embed_seq_len) + # source_encoded: (batch_size, source_encoded_length, encoder_depth) # TODO(fhieber): Consider standardizing encoders to return batch-major data to avoid this line. source_encoded = mx.sym.swapaxes(source_encoded, dim1=0, dim2=1) @@ -197,28 +205,47 @@ def sym_gen(bucket_key: Tuple[int, int]): Returns either softmax output (probs over target vocabulary) or inputs to logit computation, controlled by decoder_return_logit_inputs """ - source_max_len, target_max_len = bucket_key - source_encoded_seq_len = self.encoder.get_encoded_seq_len(source_max_len) + source_seq_len, target_seq_len = bucket_key + source_embed_seq_len = self.embedding_source.get_encoded_seq_len(source_seq_len) + source_encoded_seq_len = self.encoder.get_encoded_seq_len(source_embed_seq_len) self.decoder.reset() - prev_word_ids = mx.sym.Variable(C.TARGET_NAME) + target = mx.sym.Variable(C.TARGET_NAME) + target_lengths = utils.compute_lengths(target) states = self.decoder.state_variables() state_names = [state.name for state in states] - (logit_inputs, - logits, + # target embedding + # target_embed: (batch_size, target_embed_seq_len). + # TODO target embedding. Possible optimization: only embed the last and cache the previous target embed vectors by returning them in the states list. + (target_embed, + targed_embed_length, + target_embed_seq_len) = self.embedding_target.encode(target, target_lengths, target_seq_len) + + # embedding vector for previous word id + indices = target_lengths - 1 # type: mx.sym.Symbol + target_prev = mx.sym.pick(target, indices, axis=1) + target_embed_prev, _, _ = self.embedding_target.encode(target_prev, None, 1) + + # decoder + # target_decoded: (batch_size, decoder_depth) + (target_decoded, attention_probs, - states) = self.decoder.decode_step(prev_word_ids, - target_max_len, + states) = self.decoder.decode_step(target_embed, + targed_embed_length, + target_embed_seq_len, + target_embed_prev, source_encoded_seq_len, *states) - if self.softmax_temperature is not None: - logits /= self.softmax_temperature - # Return logit inputs or softmax over target vocab if self.decoder_return_logit_inputs: - outputs = logit_inputs + # skip output layer in graph + outputs = mx.sym.identity(target_decoded, name=C.LOGIT_INPUTS_NAME) else: + # logits: (batch_size, target_vocab_size) + logits = self.output_layer(target_decoded) + if self.softmax_temperature is not None: + logits /= self.softmax_temperature outputs = mx.sym.softmax(data=logits, name=C.SOFTMAX_NAME) data_names = [C.TARGET_NAME] + state_names @@ -291,7 +318,7 @@ def run_decoder(self, """ Runs forward pass of the single-step decoder. - :return: Probability distribution over next word, attention scores, updated model state. + :return: Decoder stack output (logit inputs or probability distribution), attention scores, updated model state. """ batch = mx.io.DataBatch( data=[sequences.as_in_context(self.context)] + model_state.states, @@ -299,8 +326,8 @@ def run_decoder(self, bucket_key=bucket_key, provide_data=self._get_decoder_data_shapes(bucket_key)) self.decoder_module.forward(data_batch=batch, is_train=False) - probs, attention_probs, *model_state.states = self.decoder_module.get_outputs() - return probs, attention_probs, model_state + out, attention_probs, *model_state.states = self.decoder_module.get_outputs() + return out, attention_probs, model_state @property def training_max_seq_len_source(self) -> int: @@ -372,7 +399,6 @@ def load_models(context: mx.context.Context, target_vocabs.append(vocab.vocab_from_json_or_pickle(os.path.join(model_folder, C.VOCAB_TRG_NAME))) model = InferenceModel(model_folder=model_folder, context=context, - fused=False, beam_size=beam_size, batch_size=batch_size, softmax_temperature=softmax_temperature, @@ -859,7 +885,7 @@ def _decode_step(self, decoder_outputs, attention_probs, state = model.run_decoder(sequences, bucket_key, state) # Compute logits and softmax with restricted vocabulary if self.restrict_lexicon: - logits = model.decoder.output_layer(decoder_outputs, out_w, out_b) + logits = model.output_layer(decoder_outputs, out_w, out_b) probs = mx.nd.softmax(logits) else: # Otherwise decoder outputs are already target vocab probs diff --git a/sockeye/initializer.py b/sockeye/initializer.py index 4e553ef74..0e0b1e196 100644 --- a/sockeye/initializer.py +++ b/sockeye/initializer.py @@ -18,17 +18,15 @@ import numpy as np import sockeye.constants as C -from sockeye.lexicon import LexiconInitializer logger = logging.getLogger(__name__) def get_initializer(default_init_type: str, default_init_scale: float, default_init_xavier_factor_type: str, embed_init_type: str, embed_init_sigma: float, - rnn_init_type: str, - lexicon: Optional[mx.nd.NDArray] = None) -> mx.initializer.Initializer: + rnn_init_type: str) -> mx.initializer.Initializer: """ - Returns a mixed MXNet initializer given rnn_init_type and optional lexicon. + Returns a mixed MXNet initializer. :param default_init_type: The default weight initializer type. :param default_init_scale: The scale used for default weight initialization (only used with uniform initialization). @@ -36,7 +34,6 @@ def get_initializer(default_init_type: str, default_init_scale: float, default_i :param embed_init_type: Embedding matrix initialization type. :param embed_init_sigma: Sigma for normal initialization of embedding matrix. :param rnn_init_type: Initialization type for RNN h2h matrices. - :param lexicon: Optional lexicon. :return: Mixed initializer. """ # default initializer @@ -68,10 +65,7 @@ def get_initializer(default_init_type: str, default_init_scale: float, default_i else: raise ValueError('Unknown RNN initializer: %s' % rnn_init_type) - # lexicon initializer - lexicon_init = [(C.LEXICON_NAME, LexiconInitializer(lexicon))] if lexicon is not None else [] - - params_init_pairs = embed_init + rnn_init + lexicon_init + default_init + params_init_pairs = embed_init + rnn_init + default_init return mx.initializer.Mixed(*zip(*params_init_pairs)) diff --git a/sockeye/layers.py b/sockeye/layers.py index 7df28c5b8..e1ff7ff56 100644 --- a/sockeye/layers.py +++ b/sockeye/layers.py @@ -13,11 +13,11 @@ import logging from typing import Optional, Tuple, Union -from . import constants as C import mxnet as mx import numpy as np +from . import constants as C from . import utils logger = logging.getLogger(__name__) @@ -123,35 +123,25 @@ class OutputLayer: """ Defines the output layer of Sockeye decoders. Supports weight tying and weight normalization. - :param num_hidden: Number of hidden units in layer input (decoder representation). - :param num_embed: Target embedding size. + :param hidden_size: Decoder hidden size. :param vocab_size: Target vocabulary size. - :param weight_tying: Whether to use target embedding parameters as output layer parameters. - :param embed_weight: Optional embedding matrix. Required if weight_tying == True. :param weight_normalization: Whether to apply weight normalization. :param prefix: Prefix used for naming. """ def __init__(self, - num_hidden: int, - num_embed: int, + hidden_size: int, vocab_size: int, - weight_tying: bool, - embed_weight: Optional[mx.sym.Symbol], + weight: Optional[mx.sym.Symbol], weight_normalization: bool, - prefix: str = '') -> None: + prefix: str = C.DEFAULT_OUTPUT_LAYER_PREFIX) -> None: self.vocab_size = vocab_size + self.prefix = prefix - if weight_tying: - utils.check_condition(num_hidden == num_embed, - "Weight tying requires target embedding size and decoder hidden size " + - "to be equal: %d vs. %d" % (num_embed, num_hidden)) - - logger.info("Tying the target embeddings and prediction matrix.") - assert embed_weight is not None, "Must provide embed_weight if weight_tying == True" - self.w = embed_weight + if weight is None: + self.w = mx.sym.Variable("%sweight" % self.prefix, shape=(vocab_size, hidden_size)) else: - self.w = mx.sym.Variable("%scls_weight" % prefix, shape=(vocab_size, num_hidden)) + self.w = weight self.weight_normalization = weight_normalization if weight_normalization: @@ -159,10 +149,10 @@ def __init__(self, self.weight_norm = WeightNormalization(self.w, num_hidden=vocab_size, ndim=2, - prefix="%scls_" % prefix) + prefix=self.prefix) self.w = self.weight_norm() - self.b = mx.sym.Variable("%scls_bias" % prefix) + self.b = mx.sym.Variable("%sbias" % self.prefix) def __call__(self, hidden: Union[mx.sym.Symbol, mx.nd.NDArray], @@ -175,6 +165,7 @@ def __call__(self, :return: Logits. Shape(n, self.vocab_size). """ if isinstance(hidden, mx.sym.Symbol): + # TODO dropout? return mx.sym.FullyConnected(data=hidden, num_hidden=self.vocab_size, weight=self.w, @@ -186,6 +177,7 @@ def __call__(self, assert isinstance(hidden, mx.nd.NDArray) utils.check_condition(weight is not None and bias is not None, "OutputLayer NDArray implementation requires passing weight and bias NDArrays.") + return mx.nd.FullyConnected(data=hidden, num_hidden=bias.shape[0], weight=weight, diff --git a/sockeye/model.py b/sockeye/model.py index a6602b4f6..b7512b9b9 100644 --- a/sockeye/model.py +++ b/sockeye/model.py @@ -14,7 +14,7 @@ import copy import logging import os -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import mxnet as mx @@ -24,7 +24,7 @@ from . import data_io from . import decoder from . import encoder -from . import lexicon +from . import layers from . import loss from . import utils @@ -42,11 +42,11 @@ class ModelConfig(Config): :param max_seq_len_target: Maximum target sequence length to unroll during training. :param vocab_source_size: Source vocabulary size. :param vocab_target_size: Target vocabulary size. + :param config_embed_source: Embedding config for source. + :param config_embed_target: Embedding config for target. :param config_encoder: Encoder configuration. :param config_decoder: Decoder configuration. :param config_loss: Loss configuration. - :param lexical_bias: Use lexical biases. - :param learn_lexical_bias: Learn lexical biases during training. :param weight_tying: Enables weight tying if True. :param weight_tying_type: Determines which weights get tied. Must be set if weight_tying is enabled. """ @@ -56,33 +56,42 @@ def __init__(self, max_seq_len_target: int, vocab_source_size: int, vocab_target_size: int, + config_embed_source: Config, + config_embed_target: Config, config_encoder: Config, config_decoder: Config, config_loss: loss.LossConfig, - lexical_bias: bool = False, - learn_lexical_bias: bool = False, weight_tying: bool = False, - weight_tying_type: Optional[str] = C.WEIGHT_TYING_TRG_SOFTMAX) -> None: + weight_tying_type: Optional[str] = C.WEIGHT_TYING_TRG_SOFTMAX, + weight_normalization: bool = False) -> None: super().__init__() self.config_data = config_data self.max_seq_len_source = max_seq_len_source self.max_seq_len_target = max_seq_len_target self.vocab_source_size = vocab_source_size self.vocab_target_size = vocab_target_size + self.config_embed_source = config_embed_source + self.config_embed_target = config_embed_target self.config_encoder = config_encoder self.config_decoder = config_decoder self.config_loss = config_loss - self.lexical_bias = lexical_bias - self.learn_lexical_bias = learn_lexical_bias self.weight_tying = weight_tying + self.weight_tying_type = weight_tying_type + self.weight_normalization = weight_normalization if weight_tying and weight_tying_type is None: raise RuntimeError("weight_tying_type must be specified when using weight_tying.") - self.weight_tying_type = weight_tying_type class SockeyeModel: """ SockeyeModel shares components needed for both training and inference. + The main components of a Sockeye model are + 1) Source embedding + 2) Target embedding + 3) Encoder + 4) Decoder + 5) Output Layer + ModelConfig contains parameters and their values that are fixed at training time and must be re-used at inference time. @@ -93,10 +102,12 @@ def __init__(self, config: ModelConfig) -> None: self.config = copy.deepcopy(config) self.config.freeze() logger.info("%s", self.config) + self.embedding_source = None # type: Optional[encoder.Embedding] self.encoder = None # type: Optional[encoder.Encoder] + self.embedding_target = None # type: Optional[encoder.Embedding] self.decoder = None # type: Optional[decoder.Decoder] - self.rnn_cells = [] # type: List[mx.rnn.RNNCell] - self.built = False + self.output_layer = None # type: Optional[layers.OutputLayer] + self._is_built = False self.params = None # type: Optional[Dict] def save_config(self, folder: str): @@ -127,12 +138,8 @@ def save_params_to_file(self, fname: str): :param fname: Path to save parameters to. """ - assert self.built - params = self.params.copy() - # unpack rnn cell weights - for cell in self.rnn_cells: - params = cell.unpack_weights(params) - utils.save_params(params, fname) + assert self._is_built + utils.save_params(self.params.copy(), fname) logging.info('Saved params to "%s"', fname) def load_params_from_file(self, fname: str): @@ -141,15 +148,11 @@ def load_params_from_file(self, fname: str): :param fname: Path to load parameters from. """ - assert self.built + assert self._is_built utils.check_condition(os.path.exists(fname), "No model parameter file found under %s. " "This is either not a model directory or the first training " "checkpoint has not happened yet." % fname) - self.params, _ = utils.load_params(fname) - # pack rnn cell weights - for cell in self.rnn_cells: - self.params = cell.pack_weights(self.params) logger.info('Loaded params from "%s"', fname) @staticmethod @@ -163,30 +166,64 @@ def save_version(folder: str): with open(fname, "w") as out: out.write(__version__) - def _build_model_components(self, fused_encoder: bool): + def _get_embed_weights(self) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, mx.sym.Symbol]: """ - Instantiates model components. + Returns embedding parameters for source and target. - :param fused_encoder: Use FusedRNNCells in encoder. + :return: Tuple of source and target parameter symbols. """ - # we tie the source and target embeddings if both appear in the type - if self.config.weight_tying and C.WEIGHT_TYING_SRC in self.config.weight_tying_type \ - and C.WEIGHT_TYING_TRG in self.config.weight_tying_type: - logger.info("Tying the source and target embeddings.") - embed_weight = encoder.Embedding.get_embed_weight(vocab_size=self.config.vocab_source_size, - embed_size=0, # will get inferred - prefix=C.SHARED_EMBEDDING_PREFIX) - else: - embed_weight = None - - self.encoder = encoder.get_encoder(self.config.config_encoder, fused_encoder, embed_weight) - - self.lexicon = lexicon.Lexicon(self.config.vocab_source_size, - self.config.vocab_target_size, - self.config.learn_lexical_bias) if self.config.lexical_bias else None - - self.decoder = decoder.get_decoder(self.config.config_decoder, self.lexicon, embed_weight) - - self.rnn_cells = self.encoder.get_rnn_cells() + self.decoder.get_rnn_cells() - - self.built = True + assert isinstance(self.config.config_embed_source, encoder.EmbeddingConfig) + assert isinstance(self.config.config_embed_target, encoder.EmbeddingConfig) + w_embed_source = mx.sym.Variable(C.SOURCE_EMBEDDING_PREFIX + "weight", + shape=(self.config.config_embed_source.vocab_size, + self.config.config_embed_source.num_embed)) + w_embed_target = mx.sym.Variable(C.TARGET_EMBEDDING_PREFIX + "weight", + shape=(self.config.config_embed_target.vocab_size, + self.config.config_embed_target.num_embed)) + w_out_target = mx.sym.Variable("target_output_weight", + shape=(self.config.vocab_target_size, self.decoder.get_num_hidden())) + + if self.config.weight_tying: + if C.WEIGHT_TYING_SRC in self.config.weight_tying_type \ + and C.WEIGHT_TYING_TRG in self.config.weight_tying_type: + logger.info("Tying the source and target embeddings.") + w_embed_source = w_embed_target = mx.sym.Variable(C.SHARED_EMBEDDING_PREFIX + "weight", + shape=(self.config.config_embed_source.vocab_size, + self.config.config_embed_source.num_embed)) + + if C.WEIGHT_TYING_SOFTMAX in self.config.weight_tying_type: + logger.info("Tying the target embeddings and output layer parameters.") + utils.check_condition(self.config.config_embed_target.num_embed == self.decoder.get_num_hidden(), + "Weight tying requires target embedding size and decoder hidden size " + + "to be equal: %d vs. %d" % (self.config.config_embed_target.num_embed, + self.decoder.get_num_hidden())) + w_out_target = w_embed_target + + return w_embed_source, w_embed_target, w_out_target + + def _build_model_components(self): + """ + Instantiates model components. + """ + # encoder & decoder first (to know the decoder depth) + self.encoder = encoder.get_encoder(self.config.config_encoder) + self.decoder = decoder.get_decoder(self.config.config_decoder) + + # source & target embeddings + embed_weight_source, embed_weight_target, out_weight_target = self._get_embed_weights() + assert isinstance(self.config.config_embed_source, encoder.EmbeddingConfig) + assert isinstance(self.config.config_embed_target, encoder.EmbeddingConfig) + self.embedding_source = encoder.Embedding(self.config.config_embed_source, + prefix=C.SOURCE_EMBEDDING_PREFIX, + embed_weight=embed_weight_source) + self.embedding_target = encoder.Embedding(self.config.config_embed_target, + prefix=C.TARGET_EMBEDDING_PREFIX, + embed_weight=embed_weight_target) + + # output layer + self.output_layer = layers.OutputLayer(hidden_size=self.decoder.get_num_hidden(), + vocab_size=self.config.vocab_target_size, + weight=out_weight_target, + weight_normalization=self.config.weight_normalization) + + self._is_built = True diff --git a/sockeye/train.py b/sockeye/train.py index c7879c798..6422be1d3 100644 --- a/sockeye/train.py +++ b/sockeye/train.py @@ -36,7 +36,6 @@ from . import decoder from . import encoder from . import initializer -from . import lexicon from . import loss from . import lr_scheduler from . import model @@ -89,12 +88,19 @@ def check_arg_compatibility(args: argparse.Namespace): :param args: Arguments as returned by argparse. """ - if args.use_fused_rnn: - check_condition(not args.use_cpu, "GPU required for FusedRNN cells") - check_condition(args.optimized_metric == C.BLEU or args.optimized_metric in args.metrics, "Must optimize either BLEU or one of tracked metrics (--metrics)") + if args.encoder == C.TRANSFORMER_TYPE: + check_condition(args.transformer_model_size == args.num_embed[0], + "Source embedding size must match transformer model size: %s vs. %s" + % (args.transformer_model_size, args.num_embed[0])) + if args.decoder == C.TRANSFORMER_TYPE: + check_condition(args.transformer_model_size == args.num_embed[1], + "Target embedding size must match transformer model size: %s vs. %s" + % (args.transformer_model_size, args.num_embed[1])) + + def check_resume(args: argparse.Namespace, output_folder: str) -> Tuple[bool, str]: """ @@ -268,20 +274,18 @@ def create_lr_scheduler(args: argparse.Namespace, resume_training: bool, return lr_scheduler_instance -def create_encoder_config(args: argparse.Namespace, vocab_source_size: int, +def create_encoder_config(args: argparse.Namespace, config_conv: Optional[encoder.ConvolutionalEmbeddingConfig]) -> Tuple[Config, int]: """ Create the encoder config. :param args: Arguments as returned by argparse. - :param vocab_source_size: The source vocabulary. :param config_conv: The config for the convolutional encoder (optional). :return: The encoder config and the number of hidden units of the encoder. """ encoder_num_layers, _ = args.num_layers max_seq_len_source, max_seq_len_target = args.max_seq_len num_embed_source, _ = args.num_embed - encoder_embed_dropout, _ = args.embed_dropout config_encoder = None # type: Optional[Config] if args.encoder in (C.TRANSFORMER_TYPE, C.TRANSFORMER_WITH_CONV_EMBED_TYPE): @@ -292,13 +296,9 @@ def create_encoder_config(args: argparse.Namespace, vocab_source_size: int, attention_heads=args.transformer_attention_heads, feed_forward_num_hidden=args.transformer_feed_forward_num_hidden, num_layers=encoder_num_layers, - vocab_size=vocab_source_size, - dropout_embed=encoder_embed_dropout, dropout_attention=args.transformer_dropout_attention, dropout_relu=args.transformer_dropout_relu, dropout_prepost=args.transformer_dropout_prepost, - weight_tying=args.weight_tying and C.WEIGHT_TYING_SRC in args.weight_tying_type, - weight_normalization=False, positional_embedding_type=args.transformer_positional_embedding_type, preprocess_sequence=encoder_transformer_preprocess, postprocess_sequence=encoder_transformer_postprocess, @@ -312,9 +312,7 @@ def create_encoder_config(args: argparse.Namespace, vocab_source_size: int, num_hidden=args.cnn_num_hidden, act_type=args.cnn_activation_type, weight_normalization=args.weight_normalization) - config_encoder = encoder.ConvolutionalEncoderConfig(vocab_size=vocab_source_size, - num_embed=num_embed_source, - embed_dropout=encoder_embed_dropout, + config_encoder = encoder.ConvolutionalEncoderConfig(num_embed=num_embed_source, max_seq_len_source=max_seq_len_source, cnn_config=cnn_config, num_layers=encoder_num_layers, @@ -326,9 +324,6 @@ def create_encoder_config(args: argparse.Namespace, vocab_source_size: int, encoder_rnn_dropout_states, _ = args.rnn_dropout_states encoder_rnn_dropout_recurrent, _ = args.rnn_dropout_recurrent config_encoder = encoder.RecurrentEncoderConfig( - vocab_size=vocab_source_size, - num_embed=num_embed_source, - embed_dropout=encoder_embed_dropout, rnn_config=rnn.RNNConfig(cell_type=args.rnn_cell_type, num_hidden=args.rnn_num_hidden, num_layers=encoder_num_layers, @@ -345,21 +340,16 @@ def create_encoder_config(args: argparse.Namespace, vocab_source_size: int, return config_encoder, encoder_num_hidden -def create_decoder_config(args: argparse.Namespace, vocab_target_size: int, encoder_num_hidden: int) -> Config: +def create_decoder_config(args: argparse.Namespace, encoder_num_hidden: int) -> Config: """ Create the config for the decoder. :param args: Arguments as returned by argparse. - :param vocab_target_size: The size of the target vocabulary. :return: The config for the decoder. """ _, decoder_num_layers = args.num_layers max_seq_len_source, max_seq_len_target = args.max_seq_len _, num_embed_target = args.num_embed - _, decoder_embed_dropout = args.embed_dropout - - decoder_weight_tying = args.weight_tying and C.WEIGHT_TYING_TRG in args.weight_tying_type \ - and C.WEIGHT_TYING_SOFTMAX in args.weight_tying_type config_decoder = None # type: Optional[Config] @@ -371,13 +361,9 @@ def create_decoder_config(args: argparse.Namespace, vocab_target_size: int, enco attention_heads=args.transformer_attention_heads, feed_forward_num_hidden=args.transformer_feed_forward_num_hidden, num_layers=decoder_num_layers, - vocab_size=vocab_target_size, - dropout_embed=decoder_embed_dropout, dropout_attention=args.transformer_dropout_attention, dropout_relu=args.transformer_dropout_relu, dropout_prepost=args.transformer_dropout_prepost, - weight_tying=decoder_weight_tying, - weight_normalization=args.weight_normalization, positional_embedding_type=args.transformer_positional_embedding_type, preprocess_sequence=decoder_transformer_preprocess, postprocess_sequence=decoder_transformer_postprocess, @@ -392,15 +378,11 @@ def create_decoder_config(args: argparse.Namespace, vocab_target_size: int, enco act_type=args.cnn_activation_type, weight_normalization=args.weight_normalization) config_decoder = decoder.ConvolutionalDecoderConfig(cnn_config=convolution_config, - vocab_size=vocab_target_size, max_seq_len_target=max_seq_len_target, num_embed=num_embed_target, encoder_num_hidden=encoder_num_hidden, num_layers=decoder_num_layers, positional_embedding_type=args.cnn_positional_embedding_type, - weight_tying=decoder_weight_tying, - embed_dropout=decoder_embed_dropout, - weight_normalization=args.weight_normalization, hidden_dropout=args.cnn_hidden_dropout) else: @@ -424,9 +406,7 @@ def create_decoder_config(args: argparse.Namespace, vocab_target_size: int, enco _, decoder_rnn_dropout_recurrent = args.rnn_dropout_recurrent config_decoder = decoder.RecurrentDecoderConfig( - vocab_size=vocab_target_size, max_seq_len_source=max_seq_len_source, - num_embed=num_embed_target, rnn_config=rnn.RNNConfig(cell_type=args.rnn_cell_type, num_hidden=args.rnn_num_hidden, num_layers=decoder_num_layers, @@ -437,14 +417,11 @@ def create_decoder_config(args: argparse.Namespace, vocab_target_size: int, enco first_residual_layer=args.rnn_first_residual_layer, forget_bias=args.rnn_forget_bias), attention_config=config_attention, - embed_dropout=decoder_embed_dropout, hidden_dropout=args.rnn_decoder_hidden_dropout, - weight_tying=decoder_weight_tying, state_init=args.rnn_decoder_state_init, context_gating=args.rnn_context_gating, layer_normalization=args.layer_normalization, - attention_in_upper_layers=args.rnn_attention_in_upper_layers, - weight_normalization=args.weight_normalization) + attention_in_upper_layers=args.rnn_attention_in_upper_layers) return config_decoder @@ -483,7 +460,8 @@ def create_model_config(args: argparse.Namespace, :return: The model configuration. """ max_seq_len_source, max_seq_len_target = args.max_seq_len - num_embed_source, _ = args.num_embed + num_embed_source, num_embed_target = args.num_embed + embed_dropout_source, embed_dropout_target = args.embed_dropout check_encoder_decoder_args(args) @@ -496,8 +474,15 @@ def create_model_config(args: argparse.Namespace, num_highway_layers=args.conv_embed_num_highway_layers, dropout=args.conv_embed_dropout) - config_encoder, encoder_num_hidden = create_encoder_config(args, vocab_source_size, config_conv) - config_decoder = create_decoder_config(args, vocab_target_size, encoder_num_hidden) + config_encoder, encoder_num_hidden = create_encoder_config(args, config_conv) + config_decoder = create_decoder_config(args, encoder_num_hidden) + + config_embed_source = encoder.EmbeddingConfig(vocab_size=vocab_source_size, + num_embed=num_embed_source, + dropout=embed_dropout_source) + config_embed_target = encoder.EmbeddingConfig(vocab_size=vocab_target_size, + num_embed=num_embed_target, + dropout=embed_dropout_target) config_loss = loss.LossConfig(name=args.loss, vocab_size=vocab_target_size, @@ -509,13 +494,14 @@ def create_model_config(args: argparse.Namespace, max_seq_len_target=max_seq_len_target, vocab_source_size=vocab_source_size, vocab_target_size=vocab_target_size, + config_embed_source=config_embed_source, + config_embed_target=config_embed_target, config_encoder=config_encoder, config_decoder=config_decoder, config_loss=config_loss, - lexical_bias=args.lexical_bias, - learn_lexical_bias=args.learn_lexical_bias, weight_tying=args.weight_tying, - weight_tying_type=args.weight_tying_type if args.weight_tying else None,) + weight_tying_type=args.weight_tying_type if args.weight_tying else None, + weight_normalization=args.weight_normalization) return model_config @@ -541,7 +527,6 @@ def create_training_model(model_config: model.ModelConfig, training_model = training.TrainingModel(config=model_config, context=context, train_iter=train_iter, - fused=args.use_fused_rnn, bucketing=not args.no_bucketing, lr_scheduler=lr_scheduler_instance) @@ -626,17 +611,13 @@ def main(): context, train_iter, lr_scheduler_instance, resume_training, training_state_dir) - lexicon_array = lexicon.initialize_lexicon(args.lexical_bias, - vocab_source, vocab_target) if args.lexical_bias else None - weight_initializer = initializer.get_initializer( default_init_type=args.weight_init, default_init_scale=args.weight_init_scale, default_init_xavier_factor_type=args.weight_init_xavier_factor_type, embed_init_type=args.embed_weight_init, embed_init_sigma=vocab_source_size ** -0.5, # TODO - rnn_init_type=args.rnn_h2h_init, - lexicon=lexicon_array) + rnn_init_type=args.rnn_h2h_init) optimizer, optimizer_params, kvstore = define_optimizer(args, lr_scheduler_instance) diff --git a/sockeye/training.py b/sockeye/training.py index 25e26db04..f36dfcacb 100644 --- a/sockeye/training.py +++ b/sockeye/training.py @@ -68,8 +68,6 @@ class TrainingModel(model.SockeyeModel): :param config: Configuration object holding details about the model. :param context: The context(s) that MXNet will be run in (GPU(s)/CPU) :param train_iter: The iterator over the training data. - :param fused: If True fused RNN cells will be used (should be slightly more efficient, but is only available - on GPUs). :param bucketing: If True bucketing will be used, if False the computation graph will always be unrolled to the full length. :param lr_scheduler: The scheduler that lowers the learning rate during training. @@ -79,14 +77,13 @@ def __init__(self, config: model.ModelConfig, context: List[mx.context.Context], train_iter: data_io.ParallelBucketSentenceIter, - fused: bool, bucketing: bool, lr_scheduler) -> None: super().__init__(config) self.context = context self.lr_scheduler = lr_scheduler self.bucketing = bucketing - self._build_model_components(fused) + self._build_model_components() self.module = self._build_module(train_iter) self.training_monitor = None # type: Optional[callback.TrainingMonitor] @@ -113,18 +110,39 @@ def sym_gen(seq_lens): """ source_seq_len, target_seq_len = seq_lens + # source embedding + (source_embed, + source_embed_length, + source_embed_seq_len) = self.embedding_source.encode(source, source_length, source_seq_len) + + # target embedding + (target_embed, + target_embed_length, + target_embed_seq_len) = self.embedding_target.encode(target, target_length, target_seq_len) + + # encoder + # source_encoded: (source_encoded_length, batch_size, encoder_depth) (source_encoded, source_encoded_length, - source_encoded_seq_len) = self.encoder.encode(source, source_length, seq_len=source_seq_len) + source_encoded_seq_len) = self.encoder.encode(source_embed, + source_embed_length, + source_embed_seq_len) + + # decoder + # target_decoded: (batch-size, target_len, decoder_depth) + target_decoded = self.decoder.decode_sequence(source_encoded, source_encoded_length, source_encoded_seq_len, + target_embed, target_embed_length, target_embed_seq_len) - source_lexicon = self.lexicon.lookup(source) if self.lexicon else None + # target_decoded: (batch_size * target_seq_len, rnn_num_hidden) + target_decoded = mx.sym.reshape(data=target_decoded, shape=(-3, 0)) - logits = self.decoder.decode_sequence(source_encoded, source_encoded_length, source_encoded_seq_len, target, - target_length, target_seq_len, source_lexicon) + # output layer + # logits: (batch_size * target_seq_len, target_vocab_size) + logits = self.output_layer(target_decoded) - outputs = model_loss.get_loss(logits, labels) + probs = model_loss.get_loss(logits, labels) - return mx.sym.Group(outputs), data_names, label_names + return mx.sym.Group(probs), data_names, label_names if self.bucketing: logger.info("Using bucketing. Default max_seq_len=%s", train_iter.default_bucket_key) diff --git a/sockeye/transformer.py b/sockeye/transformer.py index 5b3e5dd45..125cd55d1 100644 --- a/sockeye/transformer.py +++ b/sockeye/transformer.py @@ -26,13 +26,9 @@ def __init__(self, attention_heads: int, feed_forward_num_hidden: int, num_layers: int, - vocab_size: int, - dropout_embed: int, dropout_attention: float, dropout_relu: float, dropout_prepost: float, - weight_tying: bool, - weight_normalization: bool, positional_embedding_type: str, preprocess_sequence: str, postprocess_sequence: str, @@ -44,13 +40,9 @@ def __init__(self, self.attention_heads = attention_heads self.feed_forward_num_hidden = feed_forward_num_hidden self.num_layers = num_layers - self.vocab_size = vocab_size - self.dropout_embed = dropout_embed self.dropout_attention = dropout_attention self.dropout_relu = dropout_relu self.dropout_prepost = dropout_prepost - self.weight_tying = weight_tying - self.weight_normalization = weight_normalization self.positional_embedding_type = positional_embedding_type self.preprocess_sequence = preprocess_sequence self.postprocess_sequence = postprocess_sequence diff --git a/test/integration/test_seq_copy_int.py b/test/integration/test_seq_copy_int.py index a76191d08..53e8984c6 100644 --- a/test/integration/test_seq_copy_int.py +++ b/test/integration/test_seq_copy_int.py @@ -47,8 +47,8 @@ "--beam-size 2", False), # Transformer encoder, GRU decoder, mhdot attention - ("--encoder transformer --num-layers 2:1 --rnn-cell-type gru --rnn-num-hidden 16 --num-embed 8" - " --transformer-attention-heads 2 --transformer-model-size 16" + ("--encoder transformer --num-layers 2:1 --rnn-cell-type gru --rnn-num-hidden 16 --num-embed 8:16" + " --transformer-attention-heads 2 --transformer-model-size 8" " --transformer-feed-forward-num-hidden 32" " --rnn-attention-type mhdot --rnn-attention-mhdot-heads 4 --rnn-attention-num-hidden 16 --batch-size 8 " " --max-updates 10 --checkpoint-frequency 10 --optimizer adam --initial-learning-rate 0.01" @@ -56,7 +56,7 @@ "--beam-size 2", False), # LSTM encoder, Transformer decoder - ("--encoder rnn --num-layers 2:2 --rnn-cell-type lstm --rnn-num-hidden 16 --num-embed 8" + ("--encoder rnn --decoder transformer --num-layers 2:2 --rnn-cell-type lstm --rnn-num-hidden 16 --num-embed 16" " --transformer-attention-heads 2 --transformer-model-size 16" " --transformer-feed-forward-num-hidden 32" " --batch-size 8 --max-updates 10" @@ -65,7 +65,7 @@ False), # Full transformer ("--encoder transformer --decoder transformer" - " --num-layers 3 --transformer-attention-heads 2 --transformer-model-size 16" + " --num-layers 3 --transformer-attention-heads 2 --transformer-model-size 16 --num-embed 16" " --transformer-feed-forward-num-hidden 32" " --transformer-dropout-prepost 0.1 --transformer-preprocess n --transformer-postprocess dr" " --weight-tying --weight-tying-type src_trg_softmax" diff --git a/test/system/test_seq_copy_sys.py b/test/system/test_seq_copy_sys.py index 860a1da7d..f4e33219b 100644 --- a/test/system/test_seq_copy_sys.py +++ b/test/system/test_seq_copy_sys.py @@ -55,7 +55,7 @@ "--encoder transformer --num-layers 2:1 --rnn-cell-type lstm --rnn-num-hidden 64 --num-embed 32" " --rnn-attention-type mhdot --rnn-attention-num-hidden 32 --batch-size 16 --rnn-attention-mhdot-heads 1" " --loss cross-entropy --optimized-metric perplexity --max-updates 6000" - " --transformer-attention-heads 4 --transformer-model-size 64" + " --transformer-attention-heads 4 --transformer-model-size 32" " --transformer-feed-forward-num-hidden 64" " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 5", @@ -75,9 +75,8 @@ "--encoder transformer --decoder transformer" " --batch-size 16 --max-updates 4000" " --num-layers 2 --transformer-attention-heads 4 --transformer-model-size 32" - " --transformer-feed-forward-num-hidden 64" - " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001" - " --layer-normalization", + " --transformer-feed-forward-num-hidden 64 --num-embed 32" + " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 1", 1.01, 0.999), @@ -88,7 +87,7 @@ " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 1", 1.01, - 0.999) + 0.99) ]) def test_seq_copy(name, train_params, translate_params, perplexity_thresh, bleu_thresh): """Task: copy short sequences of digits""" @@ -126,37 +125,36 @@ def test_seq_copy(name, train_params, translate_params, perplexity_thresh, bleu_ " --optimized-metric perplexity --max-updates 5000 --checkpoint-frequency 1000 --optimizer adam " " --initial-learning-rate 0.001 --rnn-dropout-states 0.0:0.1 --embed-dropout 0.1:0.0", "--beam-size 5", - 1.04, - 0.98), + 1.01, + 0.99), ("Sort:transformer:lstm", "--encoder transformer --num-layers 1 --rnn-cell-type lstm --rnn-num-hidden 64 --num-embed 32" " --rnn-attention-type mhdot --rnn-attention-num-hidden 32 --batch-size 16 --rnn-attention-mhdot-heads 2" " --loss cross-entropy --optimized-metric perplexity --max-updates 5000" - " --transformer-attention-heads 4 --transformer-model-size 64" + " --transformer-attention-heads 4 --transformer-model-size 32" " --transformer-feed-forward-num-hidden 64" " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 5", - 1.03, - 0.98), + 1.02, + 0.99), ("Sort:lstm:transformer", "--encoder rnn --num-layers 1:2 --rnn-cell-type lstm --rnn-num-hidden 64 --num-embed 32" - " --decoder transformer --batch-size 16" + " --decoder transformer --batch-size 16 --transformer-model-size 32" " --loss cross-entropy --optimized-metric perplexity --max-updates 7000" - " --transformer-attention-heads 4 --transformer-model-size 64" + " --transformer-attention-heads 4" " --transformer-feed-forward-num-hidden 64" " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 5", - 1.03, - 0.98), + 1.02, + 0.99), ("Sort:transformer", "--encoder transformer --decoder transformer" " --batch-size 16 --max-updates 5000" - " --num-layers 2 --transformer-attention-heads 4 --transformer-model-size 32" + " --num-layers 2 --transformer-attention-heads 4 --transformer-model-size 32 --num-embed 32" " --transformer-feed-forward-num-hidden 64" - " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001" - " --layer-normalization", + " --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001", "--beam-size 1", - 1.03, + 1.01, 0.99), ("Sort:cnn", "--encoder cnn --decoder cnn " diff --git a/test/unit/test_arguments.py b/test/unit/test_arguments.py index eef99ff4b..9f5457aa1 100644 --- a/test/unit/test_arguments.py +++ b/test/unit/test_arguments.py @@ -76,8 +76,6 @@ def test_device_args(test_params, expected_params): rnn_attention_num_hidden=None, rnn_attention_coverage_type='count', rnn_attention_coverage_num_hidden=1, - lexical_bias=None, - learn_lexical_bias=False, weight_tying=False, weight_tying_type="trg_softmax", max_seq_len=(100, 100), @@ -151,7 +149,6 @@ def test_model_parameters(test_params, expected_params): learning_rate_schedule=None, learning_rate_decay_param_reset=False, learning_rate_decay_optimizer_states_reset='off', - use_fused_rnn=False, weight_init='xavier', weight_init_scale=2.34, weight_init_xavier_factor_type='in', diff --git a/test/unit/test_decoder.py b/test/unit/test_decoder.py index 8a4f19fbe..977c2735a 100644 --- a/test/unit/test_decoder.py +++ b/test/unit/test_decoder.py @@ -69,9 +69,7 @@ def test_step(cell_type, context_gating, residual=False, forget_bias=0.) - config_decoder = sockeye.decoder.RecurrentDecoderConfig(vocab_size=vocab_size, - max_seq_len_source=source_seq_len, - num_embed=num_embed, + config_decoder = sockeye.decoder.RecurrentDecoderConfig(max_seq_len_source=source_seq_len, rnn_config=config_rnn, attention_config=config_attention, context_gating=context_gating) diff --git a/test/unit/test_encoder.py b/test/unit/test_encoder.py index d056d704f..9cc210402 100644 --- a/test/unit/test_encoder.py +++ b/test/unit/test_encoder.py @@ -75,7 +75,10 @@ def test_sincos_positional_embeddings(): # Test that .encode() and .encode_positions() return the same values: data = mx.sym.Variable("data") positions = mx.sym.Variable("positions") - pos_encoder = sockeye.encoder.AddSinCosPositionalEmbeddings(num_embed=_NUM_EMBED, prefix="test") + pos_encoder = sockeye.encoder.AddSinCosPositionalEmbeddings(num_embed=_NUM_EMBED, + scale_up_input=False, + scale_down_positions=False, + prefix="test") encoded, _, __ = pos_encoder.encode(data, None, _SEQ_LEN) nd_encoded = encoded.eval(data=mx.nd.zeros((_BATCH_SIZE, _SEQ_LEN, _NUM_EMBED)))[0] # Take the first element in the batch to get (seq_len, num_embed)