Skip to content

Commit

Permalink
improve beam search variance in predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
pavanchhatpar committed May 10, 2020
1 parent 7a076ee commit 8b4ae66
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions copynet_tf/search/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,26 +85,43 @@ def search(self,
# shape: (batch_size*beam_width, beam_width) both
topkpreds, topkidx = tf.math.top_k(
predicted_proba, self.beam_width)
# shape: (batch_size*beam_width, 1)
was_end = tf.expand_dims(last_predictions == self._end_index, 1)
# shape: (batch_size, beam_width*beam_width)
has_no_end = tf.reshape(
tf.repeat(~was_end, self.beam_width, axis=-1),
(batch_size, self.beam_width*self.beam_width))
# shape: (batch_size, beam_width*beam_width)
has_no_end = tf.cast(has_no_end, tf.float32)

# shape: (batch_size*beam_width, beam_width)
beam_mask = tf.repeat(~was_end, self.beam_width, axis=-1)
beam_mask = tf.concat(
[was_end | ~was_end, beam_mask[:, 1:]], axis=-1)

# shape: (batch_size, beam_width*beam_width)
beam_mask = tf.cast(tf.reshape(
beam_mask, (batch_size, self.beam_width*self.beam_width)),
tf.float32)

# shape: (batch_size, beam_width*beam_width)
topkidx = tf.reshape(
tf.cast(topkidx, tf.int32),
(batch_size, self.beam_width*self.beam_width))
# shape: (batch_size, beam_width*beam_width)
topkpreds = tf.reshape(
topkpreds, (batch_size, self.beam_width*self.beam_width))
# shape: (batch_size*beam_width, )
has_no_end = ~tf.reduce_any(
predictions[:, :timestep] == self._end_index, axis=-1)
# shape: (batch_size, beam_width*beam_width)
has_no_end = tf.reshape(tf.cast(
tf.repeat(has_no_end, self.beam_width, axis=0), tf.float32),
(batch_size, self.beam_width*self.beam_width))
# shape: (batch_size, beam_width*beam_width)
step_log_probs = tf.reshape(
tf.repeat(log_probabilities, self.beam_width, axis=0),
(batch_size, self.beam_width*self.beam_width))
step_log_probs = step_log_probs + topkpreds*has_no_end

# we don't want to repeatedly select beams from end token which
# have all same probability. So we make all but one of such beams'
# log probability highly negative to never get it in top k beams
step_log_probs += tf.math.log(beam_mask + 1e-35)

# shape: (batch_size, beam_width) both
topsteplog, topsteplogidx = tf.math.top_k(
step_log_probs, self.beam_width)
Expand Down

0 comments on commit 8b4ae66

Please sign in to comment.