diff --git a/copynet_tf/search/beam_search.py b/copynet_tf/search/beam_search.py index df7407b..7b1e74b 100644 --- a/copynet_tf/search/beam_search.py +++ b/copynet_tf/search/beam_search.py @@ -85,6 +85,25 @@ def search(self, # shape: (batch_size*beam_width, beam_width) both topkpreds, topkidx = tf.math.top_k( predicted_proba, self.beam_width) + # shape: (batch_size*beam_width, 1) + was_end = tf.expand_dims(last_predictions == self._end_index, 1) + # shape: (batch_size, beam_width*beam_width) + has_no_end = tf.reshape( + tf.repeat(~was_end, self.beam_width, axis=-1), + (batch_size, self.beam_width*self.beam_width)) + # shape: (batch_size, beam_width*beam_width) + has_no_end = tf.cast(has_no_end, tf.float32) + + # shape: (batch_size*beam_width, beam_width) + beam_mask = tf.repeat(~was_end, self.beam_width, axis=-1) + beam_mask = tf.concat( + [was_end | ~was_end, beam_mask[:, 1:]], axis=-1) + + # shape: (batch_size, beam_width*beam_width) + beam_mask = tf.cast(tf.reshape( + beam_mask, (batch_size, self.beam_width*self.beam_width)), + tf.float32) + # shape: (batch_size, beam_width*beam_width) topkidx = tf.reshape( tf.cast(topkidx, tf.int32), @@ -92,19 +111,17 @@ def search(self, # shape: (batch_size, beam_width*beam_width) topkpreds = tf.reshape( topkpreds, (batch_size, self.beam_width*self.beam_width)) - # shape: (batch_size*beam_width, ) - has_no_end = ~tf.reduce_any( - predictions[:, :timestep] == self._end_index, axis=-1) - # shape: (batch_size, beam_width*beam_width) - has_no_end = tf.reshape(tf.cast( - tf.repeat(has_no_end, self.beam_width, axis=0), tf.float32), - (batch_size, self.beam_width*self.beam_width)) # shape: (batch_size, beam_width*beam_width) step_log_probs = tf.reshape( tf.repeat(log_probabilities, self.beam_width, axis=0), (batch_size, self.beam_width*self.beam_width)) step_log_probs = step_log_probs + topkpreds*has_no_end + # we don't want to repeatedly select beams from end token which + # have all same probability. So we make all but one of such beams' + # log probability highly negative to never get it in top k beams + step_log_probs += tf.math.log(beam_mask + 1e-35) + # shape: (batch_size, beam_width) both topsteplog, topsteplogidx = tf.math.top_k( step_log_probs, self.beam_width)