From 1f27c2af5d9a5af90a7f91ec26c4c2b6c1e0dd52 Mon Sep 17 00:00:00 2001 From: Felix Hieber Date: Mon, 28 Jan 2019 17:43:50 +0100 Subject: [PATCH] Replaced expand_dims with reshape ops to avoid data copies (#630) --- CHANGELOG.md | 6 ++++++ sockeye/__init__.py | 2 +- sockeye/convolution.py | 2 +- sockeye/coverage.py | 21 ++++++++++----------- sockeye/decoder.py | 12 ++++++------ sockeye/encoder.py | 18 +++++++++--------- sockeye/inference.py | 2 +- sockeye/layers.py | 11 +++++++++-- sockeye/rnn_attention.py | 14 +++++++------- 9 files changed, 50 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f055c6e1c..9551d771f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,13 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [1.18.72] +### Changed +- Removed use of `expand_dims` in favor of `reshape` to save memory. + + ## [1.18.71] +### Fixed - Fixed default setting of source factor combination to be 'concat' for backwards compatibility. ## [1.18.70] diff --git a/sockeye/__init__.py b/sockeye/__init__.py index db1b372a5..52df5f13c 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '1.18.71' +__version__ = '1.18.72' diff --git a/sockeye/convolution.py b/sockeye/convolution.py index 050da1a3b..44dd58fbc 100644 --- a/sockeye/convolution.py +++ b/sockeye/convolution.py @@ -155,7 +155,7 @@ def step(self, data): bias=self.conv_bias, num_hidden=num_hidden) # (batch_size, num_hidden, 1) - data_conv = mx.sym.expand_dims(data_conv, axis=2) + data_conv = mx.sym.reshape(data_conv, shape=(-2, 1)) return self._post_convolution(data_conv) def _post_convolution(self, data_conv): diff --git a/sockeye/coverage.py b/sockeye/coverage.py index 83fba81ce..a367d90b8 100644 --- a/sockeye/coverage.py +++ b/sockeye/coverage.py @@ -129,7 +129,7 @@ def update_coverage(prev_hidden: mx.sym.Symbol, :param prev_coverage: Shape: (batch_size, source_seq_len, coverage_num_hidden). :return: Updated coverage matrix . Shape: (batch_size, source_seq_len, coverage_num_hidden). """ - return prev_coverage + mx.sym.expand_dims(attention_prob_scores, axis=2) + return prev_coverage + mx.sym.reshape(attention_prob_scores, shape=(-2, 1)) return update_coverage @@ -183,9 +183,8 @@ def update_coverage(prev_hidden: mx.sym.Symbol, """ # (batch_size, source_seq_len, 1) - expanded_att_scores = mx.sym.expand_dims(data=attention_prob_scores, - axis=2, - name="%sexpand_attention_scores" % self.prefix) + expanded_att_scores = mx.sym.reshape(attention_prob_scores, shape=(-2, 1), + name="%sexpand_attention_scores" % self.prefix) # (batch_size, source_seq_len, 1) new_coverage = scaled_fertility * expanded_att_scores @@ -237,13 +236,13 @@ def update_coverage(prev_hidden: mx.sym.Symbol, # (batch_size, source_seq_len, decoder_num_hidden) expanded_decoder = mx.sym.broadcast_axis( - data=mx.sym.expand_dims(data=prev_hidden, axis=1, name="%sexpand_decoder" % self.prefix), + data=mx.sym.reshape(data=prev_hidden, shape=(0, 1, -1), name="%sexpand_decoder" % self.prefix), axis=1, size=source_seq_len, name="%sbroadcast_decoder" % self.prefix) # (batch_size, source_seq_len, 1) - expanded_att_scores = mx.sym.expand_dims(data=attention_prob_scores, - axis=2, - name="%sexpand_attention_scores" % self.prefix) + expanded_att_scores = mx.sym.reshape(data=attention_prob_scores, + shape=(-2, 1), + name="%sexpand_attention_scores" % self.prefix) # (batch_size, source_seq_len, encoder_num_hidden + decoder_num_hidden + 1) # +1 for the attention_prob_score for the source word @@ -332,7 +331,7 @@ def update_coverage(prev_hidden: mx.sym.Symbol, name="%sprevious_hidden_fc" % self.prefix) # (batch_size, source_seq_len, 1) - attention_prob_scores = mx.sym.expand_dims(attention_prob_scores, axis=2) + attention_prob_scores = mx.sym.reshape(attention_prob_scores, shape=(-2, 1)) # (batch_size, source_seq_len, coverage_num_hidden) attention_hidden = mx.sym.FullyConnected(data=attention_prob_scores, @@ -347,8 +346,8 @@ def update_coverage(prev_hidden: mx.sym.Symbol, num_hidden=self.num_hidden, name="%sdecoder_hidden") # (batch_size, 1, coverage_num_hidden) - prev_hidden = mx.sym.expand_dims(data=prev_hidden, axis=1, - name="%sinput_decoder_hidden_expanded" % self.prefix) + prev_hidden = mx.sym.reshape(data=prev_hidden, shape=(0, 1, -1), + name="%sinput_decoder_hidden_expanded" % self.prefix) # (batch_size, source_seq_len, coverage_num_hidden) intermediate = mx.sym.broadcast_add(lhs=source_hidden, rhs=prev_hidden, diff --git a/sockeye/decoder.py b/sockeye/decoder.py index 5c43e1782..aa1799509 100644 --- a/sockeye/decoder.py +++ b/sockeye/decoder.py @@ -257,7 +257,7 @@ def decode_sequence(self, fold_heads=True, name="%ssource_bias" % self.prefix) # (batch_size * heads, 1, max_length) - source_bias = mx.sym.expand_dims(source_bias, axis=1) + source_bias = mx.sym.reshape(source_bias, shape=(0, 1, -1)) # (1, target_max_length, target_max_length) target_bias = transformer.get_autoregressive_bias(target_embed_max_length, name="%starget_bias" % self.prefix) @@ -302,7 +302,7 @@ def decode_step(self, # (batch_size, num_embed) target_embed_prev = self.pos_embedding.encode_positions(indices, target_embed_prev) # (batch_size, 1, num_embed) - target = mx.sym.expand_dims(target_embed_prev, axis=1) + target = mx.sym.reshape(target_embed_prev, shape=(0, 1, -1)) # (batch_size * heads, max_length) source_bias = transformer.get_variable_length_bias(lengths=source_encoded_lengths, @@ -311,7 +311,7 @@ def decode_step(self, fold_heads=True, name="%ssource_bias" % self.prefix) # (batch_size * heads, 1, max_length) - source_bias = mx.sym.expand_dims(source_bias, axis=1) + source_bias = mx.sym.reshape(source_bias, shape=(0, 1, -1)) # auto-regressive bias for last position in sequence # (1, target_max_length, target_max_length) @@ -779,7 +779,7 @@ def get_initial_state(self, # we derive the shape of hidden and layer_states from some input to enable # shape inference for the batch dimension during inference. # (batch_size, 1) - zeros = mx.sym.expand_dims(mx.sym.zeros_like(source_encoded_length), axis=1) + zeros = mx.sym.reshape(mx.sym.zeros_like(source_encoded_length), shape=(-1, 1)) # last encoder state: (batch, num_hidden) source_encoded_last = mx.sym.SequenceLast(data=source_encoded, axis=1, @@ -807,7 +807,7 @@ def get_initial_state(self, elif self.config.state_init == C.RNN_DEC_INIT_AVG: # (batch_size, encoder_num_hidden) init = mx.sym.broadcast_div(mx.sym.sum(source_masked, axis=1, keepdims=False), - mx.sym.expand_dims(source_encoded_length, axis=1)) + mx.sym.reshape(source_encoded_length, shape=(-1, 1))) else: raise ValueError("Unknown decoder state init type '%s'" % self.config.state_init) @@ -1139,7 +1139,7 @@ def decode_step(self, weight=self.i2h_weight) # re-arrange outcoming layer to the dimensions of the output # (batch_size, 1, num_hidden) - target_hidden_step = mx.sym.expand_dims(target_hidden_step, axis=1) + target_hidden_step = mx.sym.reshape(target_hidden_step, shape=(0, 1, -1)) # (batch_size, kernel_width, num_hidden) target_hidden = mx.sym.concat(embed_layer_state, target_hidden_step, dim=1) diff --git a/sockeye/encoder.py b/sockeye/encoder.py index fd33a2504..c196c670e 100644 --- a/sockeye/encoder.py +++ b/sockeye/encoder.py @@ -548,11 +548,11 @@ def encode_positions(self, :return: (batch_size, num_embed) """ # (batch_size, 1) - positions = mx.sym.expand_dims(positions, axis=1) + positions = mx.sym.reshape(positions, shape=(-1, 1)) # (num_embed,) channels = mx.sym.arange(0, self.num_embed // 2) - # (1, num_embed,) - scaling = mx.sym.expand_dims(1. / mx.sym.pow(10000, (2 * channels) / self.num_embed), axis=0) + # (1, num_embed) + scaling = mx.sym.reshape(1. / mx.sym.pow(10000, (2 * channels) / self.num_embed), shape=(1, -1)) # (batch_size, num_embed/2) scaled_positions = mx.sym.dot(positions, scaling) @@ -614,7 +614,7 @@ def encode(self, """ # (1, source_seq_len) - positions = mx.sym.expand_dims(data=mx.sym.arange(start=0, stop=seq_len, step=1), axis=0) + positions = mx.sym.reshape(data=mx.sym.arange(start=0, stop=seq_len, step=1), shape=(1, -1)) # (1, source_seq_len, num_embed) pos_embedding = mx.sym.Embedding(data=positions, @@ -1043,11 +1043,11 @@ def encode(self, data = mx.sym.Dropout(data=data, p=self.config.dropout_prepost) # (batch_size * heads, 1, max_length) - bias = mx.sym.expand_dims(transformer.get_variable_length_bias(lengths=data_length, - max_length=seq_len, - num_heads=self.config.attention_heads, - fold_heads=True, - name="%sbias" % self.prefix), axis=1) + bias = mx.sym.reshape(transformer.get_variable_length_bias(lengths=data_length, + max_length=seq_len, + num_heads=self.config.attention_heads, + fold_heads=True, + name="%sbias" % self.prefix), shape=(0, 1, -1)) bias = utils.cast_conditionally(bias, self.dtype) for i, layer in enumerate(self.layers): # (batch_size, seq_len, config.model_size) diff --git a/sockeye/inference.py b/sockeye/inference.py index 769c8d8a9..68e159fbf 100644 --- a/sockeye/inference.py +++ b/sockeye/inference.py @@ -2265,7 +2265,7 @@ def hybrid_forward(self, F, best_word_indices, max_output_lengths, finished, sco # Update lengths of all items, except those that were already finished. This updates # the lengths for inactive items, too, but that doesn't matter since they are ignored anyway. - lengths = lengths + F.cast(1 - F.expand_dims(finished, axis=1), dtype='float32') + lengths = lengths + F.cast(1 - F.reshape(finished, shape=(-1, 1)), dtype='float32') # Now, recompute finished. Hypotheses are finished if they are # - extended with , or diff --git a/sockeye/layers.py b/sockeye/layers.py index 096bfb58a..771c39b96 100644 --- a/sockeye/layers.py +++ b/sockeye/layers.py @@ -272,8 +272,15 @@ def broadcast_to_heads(x: mx.sym.Symbol, num_heads: int, ndim: int, fold_heads: Shape: (batch * heads, d1 ... dn-1) if fold_heads == True, (batch, heads, d1 ... dn-1) else. """ dims = [0] * (ndim - 1) - # x: (batch, 1) - x = mx.sym.expand_dims(x, axis=1) + if ndim == 1: + # x: (batch, 1) + x = mx.sym.reshape(x, shape=(-1, 1)) + elif ndim == 2: + # x: (batch, 1, d1) + x = mx.sym.reshape(x, shape=(0, 1, -1)) + else: + # x: (batch, 1, d1 ... dn - 1) + x = mx.sym.reshape(x, shape=(0, 1, -2)) # x: (batch, heads, dims...) x = mx.sym.broadcast_to(x, shape=[0, num_heads] + dims) if fold_heads: diff --git a/sockeye/rnn_attention.py b/sockeye/rnn_attention.py index 88c6e6f20..ec07d3b53 100644 --- a/sockeye/rnn_attention.py +++ b/sockeye/rnn_attention.py @@ -269,7 +269,7 @@ def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionSta :return: Updated attention state. """ # (batch_size, decoder_num_hidden, 1) - query = mx.sym.expand_dims(att_input.query, axis=2) + query = mx.sym.reshape(att_input.query, shape=(-2, 1)) # in: (batch_size, source_seq_len, self.num_hidden) X (batch_size, self.num_hidden, 1) # out: (batch_size, source_seq_len, 1). @@ -368,7 +368,7 @@ def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionSta query = query * self.scale # (batch_size, decoder_num_hidden, 1) - expanded_decoder_state = mx.sym.expand_dims(query, axis=2) + expanded_decoder_state = mx.sym.reshape(query, shape=(-2, 1)) # batch_dot: (batch, M, K) X (batch, K, N) –> (batch, M, N). # (batch_size, seq_len, 1) @@ -479,7 +479,7 @@ def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionSta # combine heads # (batch*heads, 1, num_hidden/head) - context = mx.sym.expand_dims(context, axis=1) + context = mx.sym.reshape(context, shape=(0, 1, -1)) # (batch, 1, num_hidden) context = layers.combine_heads(context, self.num_hidden_per_head, heads=self.heads) # (batch, num_hidden) @@ -585,7 +585,7 @@ def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionSta end=source_seq_len) # attention_scores: (batch_size, seq_len, 1) - attention_scores = mx.sym.expand_dims(data=attention_scores, axis=2) + attention_scores = mx.sym.reshape(data=attention_scores, shape=(-2, 1)) context, attention_probs = get_context_and_attention_probs(source, source_length, attention_scores, self.dtype) @@ -687,9 +687,9 @@ def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionSta name="%squery_hidden" % self.prefix) # (batch_size, 1, attention_num_hidden) - query_hidden = mx.sym.expand_dims(data=query_hidden, - axis=1, - name="%squery_hidden_expanded" % self.prefix) + query_hidden = mx.sym.reshape(data=query_hidden, + shape=(0, 1, -1), + name="%squery_hidden_expanded" % self.prefix) attention_hidden_lhs = source_hidden if self.coverage: