Skip to content

Commit

Permalink
Enable --beam-search-stop first with --batch-size > 1 (#474)
Browse files Browse the repository at this point in the history
  • Loading branch information
fhieber authored Jul 12, 2018
1 parent aa4c736 commit 7629fae
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 13 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.34]
### Added
- Added support for `--beam-search-stop first` for decoding jobs with `--batch-size > 1`.

## [1.18.33]
### Added
- Now supports negative constraints, which are phrases that must *not* appear in the output.
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.33'
__version__ = '1.18.34'
2 changes: 1 addition & 1 deletion sockeye/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def __init__(self,
max_seq_len_source: int,
rnn_config: rnn.RNNConfig,
attention_config: rnn_attention.AttentionConfig,
hidden_dropout: float = .0, # TODO: move this dropout functionality to OutputLayer
hidden_dropout: float = .0,
state_init: str = C.RNN_DEC_INIT_LAST,
state_init_lhuc: bool = False,
context_gating: bool = False,
Expand Down
8 changes: 4 additions & 4 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,12 +1587,12 @@ def _beam_search(self,
beam_histories[sent]["normalized_scores"].append(
normalized_scores[rows].asnumpy().flatten().tolist())

if self.beam_search_stop == C.BEAM_SEARCH_STOP_FIRST and self.batch_size == 1:
# TODO: extend to work with batch_size > 1 (i.e., one stopped for each sentence)
if mx.nd.sum(finished).asscalar() > 0:
if self.beam_search_stop == C.BEAM_SEARCH_STOP_FIRST:
at_least_one_finished = finished.reshape((self.batch_size, self.beam_size)).sum(axis=1) > 0
if at_least_one_finished.sum().asscalar() == self.batch_size:
break
else:
if mx.nd.sum(finished).asscalar() == self.batch_size * self.beam_size: # all finished
if finished.sum().asscalar() == self.batch_size * self.beam_size: # all finished
break

# (8) update models' state with winning hypotheses (ascending)
Expand Down
4 changes: 0 additions & 4 deletions sockeye/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ def run_translate(args: argparse.Namespace):
if args.checkpoints is not None:
check_condition(len(args.checkpoints) == len(args.models), "must provide checkpoints for each model")

if args.beam_search_stop == C.BEAM_SEARCH_STOP_FIRST:
check_condition(args.batch_size == 1,
"Early stopping (--beam-search-stop %s) not supported with batching" % (C.BEAM_SEARCH_STOP_FIRST))

log_basic_info(args)

output_handler = get_output_handler(args.output_type,
Expand Down
4 changes: 2 additions & 2 deletions test/integration/test_seq_copy_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
" --rnn-decoder-state-init avg --rnn-encoder-reverse-input --rnn-dropout-recurrent 0.1:0.0"
" --rnn-h2h-init orthogonal_stacked --batch-type sentence --decode-and-evaluate 0"
" --learning-rate-decay-param-reset --weight-normalization --source-factors-num-embed 5",
"--beam-size 2",
"--beam-size 2 --beam-search-stop first",
False, True, True, False),
# Convolutional embedding encoder + LSTM encoder-decoder with attention
("--encoder rnn-with-conv-embed --decoder rnn --conv-embed-max-filter-width 3 --conv-embed-num-filters 4:4:8"
Expand Down Expand Up @@ -119,7 +119,7 @@
" --weight-init-scale=3.0 --weight-init-xavier-factor-type=avg --embed-weight-init=normal"
" --batch-size 2 --max-updates 2 --batch-type sentence --decode-and-evaluate 0"
" --checkpoint-frequency 2 --optimizer adam --initial-learning-rate 0.01 --lhuc all",
"--beam-size 2",
"--beam-size 2 --beam-prune 1",
True, False, False, False)]


Expand Down
2 changes: 1 addition & 1 deletion test/system/test_seq_copy_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
" --batch-size 80 --batch-type word "
" --max-updates 5000 "
" --rnn-dropout-states 0.0:0.1 --embed-dropout 0.1:0.0 --layer-normalization" + COMMON_TRAINING_PARAMS,
"--beam-size 5 --batch-size 2 --beam-prune 1",
"--beam-size 5 --batch-size 2 --beam-prune 1 --beam-search-stop first",
True,
1.01,
0.99),
Expand Down

0 comments on commit 7629fae

Please sign in to comment.