Skip to content

Commit

Permalink
Disabled dynamic batch sizing by default in Translator.translate due … (
Browse files Browse the repository at this point in the history
#662)

* Disabled dynamic batch sizing by default in Translator.translate due to increased memory usage
  • Loading branch information
fhieber authored Mar 14, 2019
1 parent 2e39633 commit c1b1da8
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa
Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.85]
### Fixed
### Changed
- Disabled dynamic batching for `Translator.translate()` by default due to increased memory usage. The default is to
fill-up batches to `Translator.max_batch_size`.
Dynamic batching can still be enabled if `fill_up_batches` is set to False.
### Added
- Added parameter to force training to stop after a given number of checkpoints. Useful when forced to share limited GPU resources.

## [1.18.84]
Expand Down
4 changes: 2 additions & 2 deletions sockeye/image_captioning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def translate(self, trans_inputs: List[TranslatorInput]) -> List[TranslatorOutpu
:param trans_inputs: List of TranslatorInputs as returned by make_input().
:return: List of translation results.
"""
batch_size = len(trans_inputs)
batch_size = self.max_batch_size
# translate in batch-sized blocks over input chunks
translations = []
for batch_id, batch in enumerate(utils.grouper(trans_inputs, batch_size)):
logger.debug("Translating batch %d", batch_id)
# underfilled batch will be filled to a full batch size with copies of the 1st input
rest = batch_size - len(batch)
if rest > 0:
logger.debug("Extending the last batch to the full batch size (%d)", self.batch_size)
logger.debug("Extending the last batch to the full batch size (%d)", batch_size)
batch = batch + [batch[0]] * rest
batch_translations = self._translate_nd(*self._get_inference_input(batch))
# truncate to remove filler translations
Expand Down
19 changes: 17 additions & 2 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,19 +1369,21 @@ def _log_linear_interpolation(predictions):
# pylint: disable=invalid-unary-operand-type
return -log_probs.log_softmax()

def translate(self, trans_inputs: List[TranslatorInput]) -> List[TranslatorOutput]:
def translate(self, trans_inputs: List[TranslatorInput], fill_up_batches: bool = True) -> List[TranslatorOutput]:
"""
Batch-translates a list of TranslatorInputs, returns a list of TranslatorOutputs.
Empty or bad inputs are skipped.
Splits inputs longer than Translator.max_input_length into segments of size max_input_length,
and then groups segments into batches of at most Translator.max_batch_size.
Too-long segments that were split are reassembled into a single output after translation.
If fill_up_batches is set to True, underfilled batches are padded to Translator.max_batch_size, otherwise
dynamic batch sizing is used, which comes at increased memory usage.
:param trans_inputs: List of TranslatorInputs as returned by make_input().
:param fill_up_batches: If True, underfilled batches are padded to Translator.max_batch_size.
:return: List of translation results.
"""
num_inputs = len(trans_inputs)
batch_size = min(num_inputs, self.max_batch_size)
translated_chunks = [] # type: List[IndexedTranslation]

# split into chunks
Expand Down Expand Up @@ -1446,11 +1448,24 @@ def translate(self, trans_inputs: List[TranslatorInput]) -> List[TranslatorOutpu
input_chunks = sorted(input_chunks, key=lambda chunk: len(chunk.translator_input.tokens), reverse=True)

# translate in batch-sized blocks over input chunks
batch_size = self.max_batch_size if fill_up_batches else min(len(input_chunks), self.max_batch_size)

num_batches = 0
for batch_id, batch in enumerate(utils.grouper(input_chunks, batch_size)):
logger.debug("Translating batch %d", batch_id)

rest = batch_size - len(batch)
if fill_up_batches and rest > 0:
logger.debug("Padding batch of size %d to full batch size (%d)", len(batch), batch_size)
batch = batch + [batch[0]] * rest

translator_inputs = [indexed_translator_input.translator_input for indexed_translator_input in batch]
batch_translations = self._translate_nd(*self._get_inference_input(translator_inputs))

# truncate to remove filler translations
if fill_up_batches and rest > 0:
batch_translations = batch_translations[:-rest]

for chunk, translation in zip(batch, batch_translations):
translated_chunks.append(IndexedTranslation(chunk.input_idx, chunk.chunk_idx, translation))
num_batches += 1
Expand Down

0 comments on commit c1b1da8

Please sign in to comment.