Skip to content

Commit

Permalink
No longer relying on external translation id. (#508)
Browse files Browse the repository at this point in the history
* No longer relying on external translation id.

* Making mypy happy.

* removing chunk_id from input tests.

* fix branch without source eos

* update docstring

* Sentence ids can now be strings

* addressed comment
  • Loading branch information
tdomhan authored and fhieber committed Aug 17, 2018
1 parent bf628b8 commit dcbe1bc
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 79 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.47]
### Changed
- translate CLI: no longer rely on external, user-given input id for sorting translations. Also allow string ids for sentences.

## [1.18.46]
### Fixed
- Fixed issue with `--num-words 0:0` in image captioning and another issue related to loading all features to memory with variable length.
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.46'
__version__ = '1.18.47'
149 changes: 85 additions & 64 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def get_max_output_length(input_length: int):

BeamHistory = Dict[str, List]
Tokens = List[str]
SentenceId = Union[int, str]


class TranslatorInput:
Expand All @@ -574,28 +575,25 @@ class TranslatorInput:
:param tokens: List of input tokens.
:param factors: Optional list of additional factor sequences.
:param constraints: Optional list of target-side constraints.
:param chunk_id: Chunk id. Defaults to -1.
"""

__slots__ = ('sentence_id', 'tokens', 'factors', 'constraints', 'avoid_list', 'chunk_id')
__slots__ = ('sentence_id', 'tokens', 'factors', 'constraints', 'avoid_list')

def __init__(self,
sentence_id: int,
sentence_id: SentenceId,
tokens: Tokens,
factors: Optional[List[Tokens]] = None,
constraints: Optional[List[Tokens]] = None,
avoid_list: Optional[List[Tokens]] = None,
chunk_id: int = -1) -> None:
avoid_list: Optional[List[Tokens]] = None) -> None:
self.sentence_id = sentence_id
self.chunk_id = chunk_id
self.tokens = tokens
self.factors = factors
self.constraints = constraints
self.avoid_list = avoid_list

def __str__(self):
return 'TranslatorInput(%d, %s, factors=%s, constraints=%s, avoid=%s, chunk_id=%d)' \
% (self.sentence_id, self.tokens, self.factors, self.constraints, self.avoid_list, self.chunk_id)
return 'TranslatorInput(%s, %s, factors=%s, constraints=%s, avoid=%s)' \
% (self.sentence_id, self.tokens, self.factors, self.constraints, self.avoid_list)

def __len__(self):
return len(self.tokens)
Expand All @@ -617,7 +615,7 @@ def chunks(self, chunk_size: int) -> Generator['TranslatorInput', None, None]:

if len(self.tokens) > chunk_size and self.constraints is not None:
logger.warning(
'Input %d has length (%d) that exceeds max input length (%d), '
'Input %s has length (%d) that exceeds max input length (%d), '
'triggering internal splitting. Placing all target-side constraints '
'with the first chunk, which is probably wrong.',
self.sentence_id, len(self.tokens), chunk_size)
Expand All @@ -631,8 +629,7 @@ def chunks(self, chunk_size: int) -> Generator['TranslatorInput', None, None]:
tokens=self.tokens[i:i + chunk_size],
factors=factors,
constraints=constraints,
avoid_list=self.avoid_list,
chunk_id=chunk_id)
avoid_list=self.avoid_list)

def with_eos(self) -> 'TranslatorInput':
"""
Expand All @@ -643,37 +640,36 @@ def with_eos(self) -> 'TranslatorInput':
factors=[factor + [C.EOS_SYMBOL] for factor in
self.factors] if self.factors is not None else None,
constraints=self.constraints,
avoid_list=self.avoid_list,
chunk_id=self.chunk_id)
avoid_list=self.avoid_list)


class BadTranslatorInput(TranslatorInput):

def __init__(self, sentence_id, tokens):
super().__init__(sentence_id=sentence_id, tokens=tokens, chunk_id=-1, factors=None)
def __init__(self, sentence_id: SentenceId, tokens: Tokens) -> None:
super().__init__(sentence_id=sentence_id, tokens=tokens, factors=None)


def _bad_input(sentence_id: int, reason: str = '') -> BadTranslatorInput:
logger.warning("Bad input (%d): '%s'. Will return empty output.", sentence_id, reason.strip())
def _bad_input(sentence_id: SentenceId, reason: str = '') -> BadTranslatorInput:
logger.warning("Bad input (%s): '%s'. Will return empty output.", sentence_id, reason.strip())
return BadTranslatorInput(sentence_id=sentence_id, tokens=[])


def make_input_from_plain_string(sentence_id: int, string: str) -> TranslatorInput:
def make_input_from_plain_string(sentence_id: SentenceId, string: str) -> TranslatorInput:
"""
Returns a TranslatorInput object from a plain string.
:param sentence_id: An integer id.
:param sentence_id: Sentence id.
:param string: An input string.
:return: A TranslatorInput.
"""
return TranslatorInput(sentence_id, tokens=list(data_io.get_tokens(string)), factors=None)


def make_input_from_json_string(sentence_id: int, json_string: str) -> TranslatorInput:
def make_input_from_json_string(sentence_id: SentenceId, json_string: str) -> TranslatorInput:
"""
Returns a TranslatorInput object from a JSON object, serialized as a string.
:param sentence_id: An integer id.
:param sentence_id: Sentence id.
:param json_string: A JSON object serialized as a string that must contain a key "text", mapping to the input text,
and optionally a key "factors" that maps to a list of strings, each of which representing a factor sequence
for the input text.
Expand Down Expand Up @@ -718,15 +714,15 @@ def make_input_from_json_string(sentence_id: int, json_string: str) -> Translato
return _bad_input(sentence_id, reason=json_string)


def make_input_from_factored_string(sentence_id: int,
def make_input_from_factored_string(sentence_id: SentenceId,
factored_string: str,
translator: 'Translator',
delimiter: str = C.DEFAULT_FACTOR_DELIMITER) -> TranslatorInput:
"""
Returns a TranslatorInput object from a string with factor annotations on a token level, separated by delimiter.
If translator does not require any source factors, the string is parsed as a plain token string.
:param sentence_id: An integer id.
:param sentence_id: Sentence id.
:param factored_string: An input string with additional factors per token, separated by delimiter.
:param translator: A translator object.
:param delimiter: A factor delimiter. Default: '|'.
Expand Down Expand Up @@ -758,12 +754,12 @@ def make_input_from_factored_string(sentence_id: int,
return TranslatorInput(sentence_id=sentence_id, tokens=tokens, factors=factors)


def make_input_from_multiple_strings(sentence_id: int, strings: List[str]) -> TranslatorInput:
def make_input_from_multiple_strings(sentence_id: SentenceId, strings: List[str]) -> TranslatorInput:
"""
Returns a TranslatorInput object from multiple strings, where the first element corresponds to the surface tokens
and the remaining elements to additional factors. All strings must parse into token sequences of the same length.
:param sentence_id: An integer id.
:param sentence_id: Sentence id.
:param strings: A list of strings representing a factored input sequence.
:return: A TranslatorInput.
"""
Expand All @@ -782,24 +778,24 @@ class TranslatorOutput:
"""
Output structure from Translator.
:param id: Id of input sentence.
:param sentence_id: Sentence id.
:param translation: Translation string without sentence boundary tokens.
:param tokens: List of translated tokens.
:param attention_matrix: Attention matrix. Shape: (target_length, source_length).
:param score: Negative log probability of generated translation.
:param beam_histories: List of beam histories. The list will contain more than one
history if it was split due to exceeding max_length.
"""
__slots__ = ('id', 'translation', 'tokens', 'attention_matrix', 'score', 'beam_histories')
__slots__ = ('sentence_id', 'translation', 'tokens', 'attention_matrix', 'score', 'beam_histories')

def __init__(self,
id: int,
sentence_id: SentenceId,
translation: str,
tokens: List[str],
attention_matrix: np.ndarray,
score: float,
beam_histories: Optional[List[BeamHistory]] = None) -> None:
self.id = id
self.sentence_id = sentence_id
self.translation = translation
self.tokens = tokens
self.attention_matrix = attention_matrix
Expand Down Expand Up @@ -828,16 +824,30 @@ def empty_translation() -> Translation:
return Translation(target_ids=[], attention_matrix=np.asarray([[0]]), score=-np.inf)


TranslatedChunk = NamedTuple('TranslatedChunk', [
('id', int),
('chunk_id', int),
IndexedTranslatorInput = NamedTuple('IndexedTranslatorInput', [
('input_idx', int),
('chunk_idx', int),
('translator_input', TranslatorInput)
])
"""
Translation of a chunk of a sentence.
:param input_idx: Internal index of translation requests to keep track of the correct order of translations.
:param chunk_idx: The index of the chunk. Used when TranslatorInputs get split across multiple chunks.
:param input: The translator input.
"""


IndexedTranslation = NamedTuple('IndexedTranslation', [
('input_idx', int),
('chunk_idx', int),
('translation', Translation),
])
"""
Translation of a chunk of a sentence.
:param id: Id of the sentence.
:param chunk_id: Id of the chunk.
:param input_idx: Internal index of translation requests to keep track of the correct order of translations.
:param chunk_idx: The index of the chunk. Used when TranslatorInputs get split across multiple chunks.
:param translation: The translation of the input chunk.
"""

Expand Down Expand Up @@ -1092,58 +1102,66 @@ def translate(self, trans_inputs: List[TranslatorInput]) -> List[TranslatorOutpu
:param trans_inputs: List of TranslatorInputs as returned by make_input().
:return: List of translation results.
"""
translated_chunks = [] # type: List[TranslatedChunk]
translated_chunks = [] # type: List[IndexedTranslation]

# split into chunks
input_chunks = [] # type: List[TranslatorInput]
for trans_input in trans_inputs:
input_chunks = [] # type: List[IndexedTranslatorInput]
for trans_input_idx, trans_input in enumerate(trans_inputs):
# bad input
if isinstance(trans_input, BadTranslatorInput):
translated_chunks.append(TranslatedChunk(id=trans_input.sentence_id, chunk_id=0, translation=empty_translation()))

translated_chunks.append(IndexedTranslation(input_idx=trans_input_idx, chunk_idx=0,
translation=empty_translation()))
# empty input
elif len(trans_input.tokens) == 0:
translated_chunks.append(TranslatedChunk(id=trans_input.sentence_id, chunk_id=0, translation=empty_translation()))
translated_chunks.append(IndexedTranslation(input_idx=trans_input_idx, chunk_idx=0,
translation=empty_translation()))
else:
# TODO(tdomhan): Remove branch without EOS with next major version bump, as future models will always be trained with source side EOS symbols
if self.source_with_eos:
max_input_length_without_eos = self.max_input_length - C.SPACE_FOR_XOS
# oversized input
if len(trans_input.tokens) > max_input_length_without_eos:
logger.debug(
"Input %d has length (%d) that exceeds max input length (%d). "
"Input %s has length (%d) that exceeds max input length (%d). "
"Splitting into chunks of size %d.",
trans_input.sentence_id, len(trans_input.tokens),
self.buckets_source[-1], max_input_length_without_eos)
input_chunks.extend([trans_input_chunk.with_eos()
for trans_input_chunk in
trans_input.chunks(max_input_length_without_eos)])
chunks = [trans_input_chunk.with_eos()
for trans_input_chunk in trans_input.chunks(max_input_length_without_eos)]
input_chunks.extend([IndexedTranslatorInput(trans_input_idx, chunk_idx, chunk_input)
for chunk_idx, chunk_input in enumerate(chunks)])
# regular input
else:
input_chunks.append(trans_input.with_eos())
input_chunks.append(IndexedTranslatorInput(trans_input_idx,
chunk_idx=0,
translator_input=trans_input.with_eos()))
else:
# oversized input
if len(trans_input.tokens) > self.max_input_length:
# oversized input
logger.debug(
"Input %d has length (%d) that exceeds max input length (%d). "
"Input %s has length (%d) that exceeds max input length (%d). "
"Splitting into chunks of size %d.",
trans_input.sentence_id, len(trans_input.tokens),
self.buckets_source[-1], self.max_input_length)
input_chunks.extend([trans_input_chunk
for trans_input_chunk in
trans_input.chunks(self.max_input_length)])
# regular input
chunks = [trans_input_chunk
for trans_input_chunk in
trans_input.chunks(self.max_input_length)]
input_chunks.extend([IndexedTranslatorInput(trans_input_idx, chunk_idx, chunk_input)
for chunk_idx, chunk_input in enumerate(chunks)])
else:
input_chunks.append(trans_input)
# regular input
input_chunks.append(IndexedTranslatorInput(trans_input_idx,
chunk_idx=0,
translator_input=trans_input))

if trans_input.constraints is not None:
logger.info("Input %d has %d %s: %s", trans_input.sentence_id,
logger.info("Input %s has %d %s: %s", trans_input.sentence_id,
len(trans_input.constraints),
"constraint" if len(trans_input.constraints) == 1 else "constraints",
", ".join(" ".join(x) for x in trans_input.constraints))

# Sort longest to shortest (to rather fill batches of shorter than longer sequences)
input_chunks = sorted(input_chunks, key=lambda chunk: len(chunk.tokens), reverse=True)
input_chunks = sorted(input_chunks, key=lambda chunk: len(chunk.translator_input.tokens), reverse=True)

# translate in batch-sized blocks over input chunks
for batch_id, batch in enumerate(utils.grouper(input_chunks, self.batch_size)):
Expand All @@ -1153,24 +1171,26 @@ def translate(self, trans_inputs: List[TranslatorInput]) -> List[TranslatorOutpu
if rest > 0:
logger.debug("Extending the last batch to the full batch size (%d)", self.batch_size)
batch = batch + [batch[0]] * rest
batch_translations = self._translate_nd(*self._get_inference_input(batch))
translator_inputs = [indexed_translator_input.translator_input for indexed_translator_input in batch]
batch_translations = self._translate_nd(*self._get_inference_input(translator_inputs))
# truncate to remove filler translations
if rest > 0:
batch_translations = batch_translations[:-rest]
for chunk, translation in zip(batch, batch_translations):
translated_chunks.append(TranslatedChunk(chunk.sentence_id, chunk.chunk_id, translation))
translated_chunks.append(IndexedTranslation(chunk.input_idx, chunk.chunk_idx, translation))
# Sort by input idx and then chunk id
translated_chunks = sorted(translated_chunks)

# Concatenate results
results = [] # type: List[TranslatorOutput]
chunks_by_input_idx = itertools.groupby(translated_chunks, key=lambda translation: translation.id)
for trans_input, (input_idx, chunks) in zip(trans_inputs, chunks_by_input_idx):
chunks = list(chunks) # type: ignore
if len(chunks) == 1: # type: ignore
translation = chunks[0].translation # type: ignore
chunks_by_input_idx = itertools.groupby(translated_chunks, key=lambda translation: translation.input_idx)
for trans_input, (input_idx, translations_for_input_idx) in zip(trans_inputs, chunks_by_input_idx):
translations_for_input_idx = list(translations_for_input_idx) # type: ignore
if len(translations_for_input_idx) == 1: # type: ignore
translation = translations_for_input_idx[0].translation # type: ignore
else:
translations_to_concat = [translated_chunk.translation for translated_chunk in chunks]
translations_to_concat = [translated_chunk.translation
for translated_chunk in translations_for_input_idx]
translation = self._concat_translations(translations_to_concat)

results.append(self._make_result(trans_input, translation))
Expand Down Expand Up @@ -1224,7 +1244,8 @@ def _get_inference_input(self,
raw_avoid_list[j] = [data_io.tokens2ids(phrase, self.vocab_target) for phrase in
trans_input.avoid_list]
if any(self.unk_id in phrase for phrase in raw_avoid_list[j]):
logger.warning("Sentence %d: %s was found in the list of phrases to avoid; this may indicate improper preprocessing.", trans_input.sentence_id, C.UNK_SYMBOL)
logger.warning("Sentence %s: %s was found in the list of phrases to avoid; "
"this may indicate improper preprocessing.", trans_input.sentence_id, C.UNK_SYMBOL)

return source, bucket_key, raw_constraints, raw_avoid_list, mx.nd.array(max_output_lengths, ctx=self.context, dtype='int32')

Expand All @@ -1248,7 +1269,7 @@ def _make_result(self,
tok for target_id, tok in zip(target_ids, target_tokens) if target_id not in self.strip_ids)
attention_matrix = attention_matrix[:, :len(trans_input.tokens)]

return TranslatorOutput(id=trans_input.sentence_id,
return TranslatorOutput(sentence_id=trans_input.sentence_id,
translation=target_string,
tokens=target_tokens,
attention_matrix=attention_matrix,
Expand Down
Loading

0 comments on commit dcbe1bc

Please sign in to comment.