From dcbe1bca389884261a079b0ab81bc527c9d87ea2 Mon Sep 17 00:00:00 2001 From: Tobias Domhan Date: Fri, 17 Aug 2018 13:52:36 +0200 Subject: [PATCH] No longer relying on external translation id. (#508) * No longer relying on external translation id. * Making mypy happy. * removing chunk_id from input tests. * fix branch without source eos * update docstring * Sentence ids can now be strings * addressed comment --- CHANGELOG.md | 4 + sockeye/__init__.py | 2 +- sockeye/inference.py | 149 ++++++++++++++++++------------- sockeye/output_handler.py | 6 +- test/unit/test_inference.py | 3 - test/unit/test_output_handler.py | 18 ++-- 6 files changed, 103 insertions(+), 79 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e8e1c3e53..1ae814a1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [1.18.47] +### Changed +- translate CLI: no longer rely on external, user-given input id for sorting translations. Also allow string ids for sentences. + ## [1.18.46] ### Fixed - Fixed issue with `--num-words 0:0` in image captioning and another issue related to loading all features to memory with variable length. diff --git a/sockeye/__init__.py b/sockeye/__init__.py index b0ab29576..8ea2065b1 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '1.18.46' +__version__ = '1.18.47' diff --git a/sockeye/inference.py b/sockeye/inference.py index bdec9f468..328fec33c 100644 --- a/sockeye/inference.py +++ b/sockeye/inference.py @@ -564,6 +564,7 @@ def get_max_output_length(input_length: int): BeamHistory = Dict[str, List] Tokens = List[str] +SentenceId = Union[int, str] class TranslatorInput: @@ -574,28 +575,25 @@ class TranslatorInput: :param tokens: List of input tokens. :param factors: Optional list of additional factor sequences. :param constraints: Optional list of target-side constraints. - :param chunk_id: Chunk id. Defaults to -1. """ - __slots__ = ('sentence_id', 'tokens', 'factors', 'constraints', 'avoid_list', 'chunk_id') + __slots__ = ('sentence_id', 'tokens', 'factors', 'constraints', 'avoid_list') def __init__(self, - sentence_id: int, + sentence_id: SentenceId, tokens: Tokens, factors: Optional[List[Tokens]] = None, constraints: Optional[List[Tokens]] = None, - avoid_list: Optional[List[Tokens]] = None, - chunk_id: int = -1) -> None: + avoid_list: Optional[List[Tokens]] = None) -> None: self.sentence_id = sentence_id - self.chunk_id = chunk_id self.tokens = tokens self.factors = factors self.constraints = constraints self.avoid_list = avoid_list def __str__(self): - return 'TranslatorInput(%d, %s, factors=%s, constraints=%s, avoid=%s, chunk_id=%d)' \ - % (self.sentence_id, self.tokens, self.factors, self.constraints, self.avoid_list, self.chunk_id) + return 'TranslatorInput(%s, %s, factors=%s, constraints=%s, avoid=%s)' \ + % (self.sentence_id, self.tokens, self.factors, self.constraints, self.avoid_list) def __len__(self): return len(self.tokens) @@ -617,7 +615,7 @@ def chunks(self, chunk_size: int) -> Generator['TranslatorInput', None, None]: if len(self.tokens) > chunk_size and self.constraints is not None: logger.warning( - 'Input %d has length (%d) that exceeds max input length (%d), ' + 'Input %s has length (%d) that exceeds max input length (%d), ' 'triggering internal splitting. Placing all target-side constraints ' 'with the first chunk, which is probably wrong.', self.sentence_id, len(self.tokens), chunk_size) @@ -631,8 +629,7 @@ def chunks(self, chunk_size: int) -> Generator['TranslatorInput', None, None]: tokens=self.tokens[i:i + chunk_size], factors=factors, constraints=constraints, - avoid_list=self.avoid_list, - chunk_id=chunk_id) + avoid_list=self.avoid_list) def with_eos(self) -> 'TranslatorInput': """ @@ -643,37 +640,36 @@ def with_eos(self) -> 'TranslatorInput': factors=[factor + [C.EOS_SYMBOL] for factor in self.factors] if self.factors is not None else None, constraints=self.constraints, - avoid_list=self.avoid_list, - chunk_id=self.chunk_id) + avoid_list=self.avoid_list) class BadTranslatorInput(TranslatorInput): - def __init__(self, sentence_id, tokens): - super().__init__(sentence_id=sentence_id, tokens=tokens, chunk_id=-1, factors=None) + def __init__(self, sentence_id: SentenceId, tokens: Tokens) -> None: + super().__init__(sentence_id=sentence_id, tokens=tokens, factors=None) -def _bad_input(sentence_id: int, reason: str = '') -> BadTranslatorInput: - logger.warning("Bad input (%d): '%s'. Will return empty output.", sentence_id, reason.strip()) +def _bad_input(sentence_id: SentenceId, reason: str = '') -> BadTranslatorInput: + logger.warning("Bad input (%s): '%s'. Will return empty output.", sentence_id, reason.strip()) return BadTranslatorInput(sentence_id=sentence_id, tokens=[]) -def make_input_from_plain_string(sentence_id: int, string: str) -> TranslatorInput: +def make_input_from_plain_string(sentence_id: SentenceId, string: str) -> TranslatorInput: """ Returns a TranslatorInput object from a plain string. - :param sentence_id: An integer id. + :param sentence_id: Sentence id. :param string: An input string. :return: A TranslatorInput. """ return TranslatorInput(sentence_id, tokens=list(data_io.get_tokens(string)), factors=None) -def make_input_from_json_string(sentence_id: int, json_string: str) -> TranslatorInput: +def make_input_from_json_string(sentence_id: SentenceId, json_string: str) -> TranslatorInput: """ Returns a TranslatorInput object from a JSON object, serialized as a string. - :param sentence_id: An integer id. + :param sentence_id: Sentence id. :param json_string: A JSON object serialized as a string that must contain a key "text", mapping to the input text, and optionally a key "factors" that maps to a list of strings, each of which representing a factor sequence for the input text. @@ -718,7 +714,7 @@ def make_input_from_json_string(sentence_id: int, json_string: str) -> Translato return _bad_input(sentence_id, reason=json_string) -def make_input_from_factored_string(sentence_id: int, +def make_input_from_factored_string(sentence_id: SentenceId, factored_string: str, translator: 'Translator', delimiter: str = C.DEFAULT_FACTOR_DELIMITER) -> TranslatorInput: @@ -726,7 +722,7 @@ def make_input_from_factored_string(sentence_id: int, Returns a TranslatorInput object from a string with factor annotations on a token level, separated by delimiter. If translator does not require any source factors, the string is parsed as a plain token string. - :param sentence_id: An integer id. + :param sentence_id: Sentence id. :param factored_string: An input string with additional factors per token, separated by delimiter. :param translator: A translator object. :param delimiter: A factor delimiter. Default: '|'. @@ -758,12 +754,12 @@ def make_input_from_factored_string(sentence_id: int, return TranslatorInput(sentence_id=sentence_id, tokens=tokens, factors=factors) -def make_input_from_multiple_strings(sentence_id: int, strings: List[str]) -> TranslatorInput: +def make_input_from_multiple_strings(sentence_id: SentenceId, strings: List[str]) -> TranslatorInput: """ Returns a TranslatorInput object from multiple strings, where the first element corresponds to the surface tokens and the remaining elements to additional factors. All strings must parse into token sequences of the same length. - :param sentence_id: An integer id. + :param sentence_id: Sentence id. :param strings: A list of strings representing a factored input sequence. :return: A TranslatorInput. """ @@ -782,7 +778,7 @@ class TranslatorOutput: """ Output structure from Translator. - :param id: Id of input sentence. + :param sentence_id: Sentence id. :param translation: Translation string without sentence boundary tokens. :param tokens: List of translated tokens. :param attention_matrix: Attention matrix. Shape: (target_length, source_length). @@ -790,16 +786,16 @@ class TranslatorOutput: :param beam_histories: List of beam histories. The list will contain more than one history if it was split due to exceeding max_length. """ - __slots__ = ('id', 'translation', 'tokens', 'attention_matrix', 'score', 'beam_histories') + __slots__ = ('sentence_id', 'translation', 'tokens', 'attention_matrix', 'score', 'beam_histories') def __init__(self, - id: int, + sentence_id: SentenceId, translation: str, tokens: List[str], attention_matrix: np.ndarray, score: float, beam_histories: Optional[List[BeamHistory]] = None) -> None: - self.id = id + self.sentence_id = sentence_id self.translation = translation self.tokens = tokens self.attention_matrix = attention_matrix @@ -828,16 +824,30 @@ def empty_translation() -> Translation: return Translation(target_ids=[], attention_matrix=np.asarray([[0]]), score=-np.inf) -TranslatedChunk = NamedTuple('TranslatedChunk', [ - ('id', int), - ('chunk_id', int), +IndexedTranslatorInput = NamedTuple('IndexedTranslatorInput', [ + ('input_idx', int), + ('chunk_idx', int), + ('translator_input', TranslatorInput) +]) +""" +Translation of a chunk of a sentence. + +:param input_idx: Internal index of translation requests to keep track of the correct order of translations. +:param chunk_idx: The index of the chunk. Used when TranslatorInputs get split across multiple chunks. +:param input: The translator input. +""" + + +IndexedTranslation = NamedTuple('IndexedTranslation', [ + ('input_idx', int), + ('chunk_idx', int), ('translation', Translation), ]) """ Translation of a chunk of a sentence. -:param id: Id of the sentence. -:param chunk_id: Id of the chunk. +:param input_idx: Internal index of translation requests to keep track of the correct order of translations. +:param chunk_idx: The index of the chunk. Used when TranslatorInputs get split across multiple chunks. :param translation: The translation of the input chunk. """ @@ -1092,18 +1102,19 @@ def translate(self, trans_inputs: List[TranslatorInput]) -> List[TranslatorOutpu :param trans_inputs: List of TranslatorInputs as returned by make_input(). :return: List of translation results. """ - translated_chunks = [] # type: List[TranslatedChunk] + translated_chunks = [] # type: List[IndexedTranslation] # split into chunks - input_chunks = [] # type: List[TranslatorInput] - for trans_input in trans_inputs: + input_chunks = [] # type: List[IndexedTranslatorInput] + for trans_input_idx, trans_input in enumerate(trans_inputs): # bad input if isinstance(trans_input, BadTranslatorInput): - translated_chunks.append(TranslatedChunk(id=trans_input.sentence_id, chunk_id=0, translation=empty_translation())) - + translated_chunks.append(IndexedTranslation(input_idx=trans_input_idx, chunk_idx=0, + translation=empty_translation())) # empty input elif len(trans_input.tokens) == 0: - translated_chunks.append(TranslatedChunk(id=trans_input.sentence_id, chunk_id=0, translation=empty_translation())) + translated_chunks.append(IndexedTranslation(input_idx=trans_input_idx, chunk_idx=0, + translation=empty_translation())) else: # TODO(tdomhan): Remove branch without EOS with next major version bump, as future models will always be trained with source side EOS symbols if self.source_with_eos: @@ -1111,39 +1122,46 @@ def translate(self, trans_inputs: List[TranslatorInput]) -> List[TranslatorOutpu # oversized input if len(trans_input.tokens) > max_input_length_without_eos: logger.debug( - "Input %d has length (%d) that exceeds max input length (%d). " + "Input %s has length (%d) that exceeds max input length (%d). " "Splitting into chunks of size %d.", trans_input.sentence_id, len(trans_input.tokens), self.buckets_source[-1], max_input_length_without_eos) - input_chunks.extend([trans_input_chunk.with_eos() - for trans_input_chunk in - trans_input.chunks(max_input_length_without_eos)]) + chunks = [trans_input_chunk.with_eos() + for trans_input_chunk in trans_input.chunks(max_input_length_without_eos)] + input_chunks.extend([IndexedTranslatorInput(trans_input_idx, chunk_idx, chunk_input) + for chunk_idx, chunk_input in enumerate(chunks)]) # regular input else: - input_chunks.append(trans_input.with_eos()) + input_chunks.append(IndexedTranslatorInput(trans_input_idx, + chunk_idx=0, + translator_input=trans_input.with_eos())) else: - # oversized input if len(trans_input.tokens) > self.max_input_length: + # oversized input logger.debug( - "Input %d has length (%d) that exceeds max input length (%d). " + "Input %s has length (%d) that exceeds max input length (%d). " "Splitting into chunks of size %d.", trans_input.sentence_id, len(trans_input.tokens), self.buckets_source[-1], self.max_input_length) - input_chunks.extend([trans_input_chunk - for trans_input_chunk in - trans_input.chunks(self.max_input_length)]) - # regular input + chunks = [trans_input_chunk + for trans_input_chunk in + trans_input.chunks(self.max_input_length)] + input_chunks.extend([IndexedTranslatorInput(trans_input_idx, chunk_idx, chunk_input) + for chunk_idx, chunk_input in enumerate(chunks)]) else: - input_chunks.append(trans_input) + # regular input + input_chunks.append(IndexedTranslatorInput(trans_input_idx, + chunk_idx=0, + translator_input=trans_input)) if trans_input.constraints is not None: - logger.info("Input %d has %d %s: %s", trans_input.sentence_id, + logger.info("Input %s has %d %s: %s", trans_input.sentence_id, len(trans_input.constraints), "constraint" if len(trans_input.constraints) == 1 else "constraints", ", ".join(" ".join(x) for x in trans_input.constraints)) # Sort longest to shortest (to rather fill batches of shorter than longer sequences) - input_chunks = sorted(input_chunks, key=lambda chunk: len(chunk.tokens), reverse=True) + input_chunks = sorted(input_chunks, key=lambda chunk: len(chunk.translator_input.tokens), reverse=True) # translate in batch-sized blocks over input chunks for batch_id, batch in enumerate(utils.grouper(input_chunks, self.batch_size)): @@ -1153,24 +1171,26 @@ def translate(self, trans_inputs: List[TranslatorInput]) -> List[TranslatorOutpu if rest > 0: logger.debug("Extending the last batch to the full batch size (%d)", self.batch_size) batch = batch + [batch[0]] * rest - batch_translations = self._translate_nd(*self._get_inference_input(batch)) + translator_inputs = [indexed_translator_input.translator_input for indexed_translator_input in batch] + batch_translations = self._translate_nd(*self._get_inference_input(translator_inputs)) # truncate to remove filler translations if rest > 0: batch_translations = batch_translations[:-rest] for chunk, translation in zip(batch, batch_translations): - translated_chunks.append(TranslatedChunk(chunk.sentence_id, chunk.chunk_id, translation)) + translated_chunks.append(IndexedTranslation(chunk.input_idx, chunk.chunk_idx, translation)) # Sort by input idx and then chunk id translated_chunks = sorted(translated_chunks) # Concatenate results results = [] # type: List[TranslatorOutput] - chunks_by_input_idx = itertools.groupby(translated_chunks, key=lambda translation: translation.id) - for trans_input, (input_idx, chunks) in zip(trans_inputs, chunks_by_input_idx): - chunks = list(chunks) # type: ignore - if len(chunks) == 1: # type: ignore - translation = chunks[0].translation # type: ignore + chunks_by_input_idx = itertools.groupby(translated_chunks, key=lambda translation: translation.input_idx) + for trans_input, (input_idx, translations_for_input_idx) in zip(trans_inputs, chunks_by_input_idx): + translations_for_input_idx = list(translations_for_input_idx) # type: ignore + if len(translations_for_input_idx) == 1: # type: ignore + translation = translations_for_input_idx[0].translation # type: ignore else: - translations_to_concat = [translated_chunk.translation for translated_chunk in chunks] + translations_to_concat = [translated_chunk.translation + for translated_chunk in translations_for_input_idx] translation = self._concat_translations(translations_to_concat) results.append(self._make_result(trans_input, translation)) @@ -1224,7 +1244,8 @@ def _get_inference_input(self, raw_avoid_list[j] = [data_io.tokens2ids(phrase, self.vocab_target) for phrase in trans_input.avoid_list] if any(self.unk_id in phrase for phrase in raw_avoid_list[j]): - logger.warning("Sentence %d: %s was found in the list of phrases to avoid; this may indicate improper preprocessing.", trans_input.sentence_id, C.UNK_SYMBOL) + logger.warning("Sentence %s: %s was found in the list of phrases to avoid; " + "this may indicate improper preprocessing.", trans_input.sentence_id, C.UNK_SYMBOL) return source, bucket_key, raw_constraints, raw_avoid_list, mx.nd.array(max_output_lengths, ctx=self.context, dtype='int32') @@ -1248,7 +1269,7 @@ def _make_result(self, tok for target_id, tok in zip(target_ids, target_tokens) if target_id not in self.strip_ids) attention_matrix = attention_matrix[:, :len(trans_input.tokens)] - return TranslatorOutput(id=trans_input.sentence_id, + return TranslatorOutput(sentence_id=trans_input.sentence_id, translation=target_string, tokens=target_tokens, attention_matrix=attention_matrix, diff --git a/sockeye/output_handler.py b/sockeye/output_handler.py index adf7e7168..4b877b28d 100644 --- a/sockeye/output_handler.py +++ b/sockeye/output_handler.py @@ -182,7 +182,7 @@ def handle(self, :param t_output: Translator output. :param t_walltime: Total wall-clock time for translation. """ - line = "{sent_id:d} ||| {target} ||| {score:f} ||| {source} ||| {source_len:d} ||| {target_len:d}\n" + line = "{sent_id} ||| {target} ||| {score:f} ||| {source} ||| {source_len:d} ||| {target_len:d}\n" self.stream.write(line.format(sent_id=t_input.sentence_id, target=" ".join(t_output.tokens), score=t_output.score, @@ -244,7 +244,7 @@ def handle(self, plot_attention(t_output.attention_matrix, t_input.tokens, t_output.tokens, - "%s_%d.png" % (self.plot_prefix, t_input.sentence_id)) + "%s_%s.png" % (self.plot_prefix, t_input.sentence_id)) class AlignTextHandler(OutputHandler): @@ -297,6 +297,6 @@ def handle(self, # Add the number of steps in each beam h["number_steps"] = len(h["predicted_tokens"]) # type: ignore # Some outputs can have more than one beam, add the id for bookkeeping - h["id"] = t_output.id # type: ignore + h["id"] = t_output.sentence_id # type: ignore self.stream.write("%s\n" % json.dumps(h, sort_keys=True)) self.stream.flush() diff --git a/test/unit/test_inference.py b/test/unit/test_inference.py index b616d1af5..0f54b1d36 100644 --- a/test/unit/test_inference.py +++ b/test/unit/test_inference.py @@ -143,13 +143,11 @@ def test_translator_input(sentence_id, sentence, factors, chunk_size): if factors is not None: for factor in trans_input.factors: assert len(factor) == len(tokens) - assert trans_input.chunk_id == -1 chunked_inputs = list(trans_input.chunks(chunk_size)) assert len(chunked_inputs) == ceil(len(tokens) / chunk_size) for chunk_id, chunk_input in enumerate(chunked_inputs): assert chunk_input.sentence_id == sentence_id - assert chunk_input.chunk_id == chunk_id assert chunk_input.tokens == trans_input.tokens[chunk_id * chunk_size: (chunk_id + 1) * chunk_size] if factors: assert len(chunk_input.factors) == len(factors) @@ -228,7 +226,6 @@ def test_make_input_from_factored_string(sentence, num_expected_factors, delimit translator=translator, delimiter=delimiter) assert isinstance(inp, sockeye.inference.TranslatorInput) assert inp.sentence_id == sentence_id - assert inp.chunk_id == -1 assert inp.tokens == expected_tokens assert inp.factors == expected_factors if num_expected_factors > 1: diff --git a/test/unit/test_output_handler.py b/test/unit/test_output_handler.py index 62cae71ac..8554b6341 100644 --- a/test/unit/test_output_handler.py +++ b/test/unit/test_output_handler.py @@ -12,28 +12,30 @@ # permissions and limitations under the License. import io -import pytest + import numpy as np -from sockeye.inference import TranslatorInput, TranslatorOutput +import pytest + import sockeye.output_handler +from sockeye.inference import TranslatorInput, TranslatorOutput stream_handler_tests = [(sockeye.output_handler.StringOutputHandler(io.StringIO()), TranslatorInput(sentence_id=0, tokens=[], factors=[], constraints=[]), - TranslatorOutput(id=0, translation="ein Test", tokens=None, + TranslatorOutput(sentence_id=0, translation="ein Test", tokens=None, attention_matrix=None, score=0.), 0., "ein Test\n"), (sockeye.output_handler.StringOutputHandler(io.StringIO()), TranslatorInput(sentence_id=0, tokens=[], factors=[]), - TranslatorOutput(id=0, translation="", tokens=None, + TranslatorOutput(sentence_id=0, translation="", tokens=None, attention_matrix=None, score=0.), 0., "\n"), (sockeye.output_handler.StringWithAlignmentsOutputHandler(io.StringIO(), threshold=0.5), TranslatorInput(sentence_id=0, tokens="a test".split(), factors=[]), - TranslatorOutput(id=0, translation="ein Test", tokens=None, + TranslatorOutput(sentence_id=0, translation="ein Test", tokens=None, attention_matrix=np.asarray([[1, 0], [0, 1]]), score=0.), @@ -41,7 +43,7 @@ "ein Test\t0-0 1-1\n"), (sockeye.output_handler.StringWithAlignmentsOutputHandler(io.StringIO(), threshold=0.5), TranslatorInput(sentence_id=0, tokens="a test".split(), factors=[]), - TranslatorOutput(id=0, translation="ein Test !", tokens=None, + TranslatorOutput(sentence_id=0, translation="ein Test !", tokens=None, attention_matrix=np.asarray([[0.4, 0.6], [0.8, 0.2], [0.5, 0.5]]), @@ -50,14 +52,14 @@ "ein Test !\t0-1 1-0\n"), (sockeye.output_handler.BenchmarkOutputHandler(io.StringIO()), TranslatorInput(sentence_id=0, tokens=["a", "test"], factors=[]), - TranslatorOutput(id=0, translation="ein Test", tokens=["ein", "Test"], + TranslatorOutput(sentence_id=0, translation="ein Test", tokens=["ein", "Test"], attention_matrix=None, score=0.), 0.5, "input=a test\toutput=ein Test\tinput_tokens=2\toutput_tokens=2\ttranslation_time=0.5000\n"), (sockeye.output_handler.BeamStoringHandler(io.StringIO()), TranslatorInput(sentence_id=0, tokens=["What"]), - TranslatorOutput(id=0, translation="Was", tokens=["Was"], + TranslatorOutput(sentence_id=0, translation="Was", tokens=["Was"], attention_matrix=None, score=0., beam_histories=[ {"predicted_ids": [[258, 137, 31],