diff --git a/outlines/text/__init__.py b/outlines/text/__init__.py index 4b187905e..8870c7a1f 100644 --- a/outlines/text/__init__.py +++ b/outlines/text/__init__.py @@ -1,2 +1,3 @@ from .functions import function +from .generate import continuation from .prompts import prompt, render diff --git a/outlines/text/generate/__init__.py b/outlines/text/generate/__init__.py new file mode 100644 index 000000000..3176b9b4a --- /dev/null +++ b/outlines/text/generate/__init__.py @@ -0,0 +1 @@ +from .continuation import continuation diff --git a/outlines/text/generate/continuation.py b/outlines/text/generate/continuation.py new file mode 100644 index 000000000..e616d3f36 --- /dev/null +++ b/outlines/text/generate/continuation.py @@ -0,0 +1,52 @@ +from typing import List, Optional + +import numpy as np +from numpy.typing import NDArray + +from outlines.text.generate.sequence import Sequence + + +class Continuation(Sequence): + """Represents a completion generation model. + + `Completion` instances are unconstrained generation models that stop when an EOS token + has been found or when the maximum number of tokens has been reached. + + >> import outlines.text as text + >> sequence = text.sequence(model)("Say something") + + """ + + def __init__(self, model, max_tokens: Optional[int]): + super().__init__(model, max_tokens) + + def is_finished(self, token_ids: NDArray[np.int64]) -> NDArray[np.bool_]: + """Determine whether the sequences reached maximum length of end with + and EOS token. + + In practice, `Sequence`'s `__call__` methods only passed the `token_ids` + of the sequences that haven't been marked as finished already, which is + why we only need to look for the EOS token in the last element rather + than in the whole sequence. + + Parameters + ---------- + token_ids + The input sequences. + + """ + is_finished = np.zeros((token_ids.shape[0],), dtype=np.bool_) + is_finished[token_ids[:, -1] == self.model.tokenizer.eos_token_id] = True + + return is_finished + + def postprocess_completions(self, completions: List[str]) -> List[str]: + """Remove the EOS token from the completion.""" + return [ + completion.replace(self.model.tokenizer.eos_token, "") + for completion in completions + ] + + +def continuation(model, max_tokens: Optional[int] = None): + return Continuation(model, max_tokens) diff --git a/outlines/text/sequences/sequence.py b/outlines/text/generate/sequence.py similarity index 98% rename from outlines/text/sequences/sequence.py rename to outlines/text/generate/sequence.py index bea23de4c..614297edd 100644 --- a/outlines/text/sequences/sequence.py +++ b/outlines/text/generate/sequence.py @@ -29,6 +29,9 @@ def is_finished(self, token_ids: NDArray[np.int64]) -> NDArray[np.bool_]: "`Sequence.is_finished` must be implemented by subclasses." ) + def postprocess_completions(self, completions: List[str]) -> List[str]: + return completions + def step( self, rng: Generator, @@ -202,6 +205,7 @@ def __call__( is_finished[~is_finished] = self.is_finished(token_ids_unfinished).flatten() result = self.model.tokenizer.decode(token_ids) + result = self.postprocess_completions(result) if len(result) == 1: return result[0] diff --git a/tests/text/generate/test_continuation.py b/tests/text/generate/test_continuation.py new file mode 100644 index 000000000..aaf017491 --- /dev/null +++ b/tests/text/generate/test_continuation.py @@ -0,0 +1,42 @@ +import numpy as np +from numpy.testing import assert_array_equal + +from outlines.text.generate.continuation import Continuation, continuation + + +class Tokenizer: + eos_token = "" + eos_token_id = 0 + pad_token_ids = -1 + + +class Model: + tokenizer = Tokenizer() + + +def test_continuation_is_finished(): + model = continuation(Model(), 10) + assert isinstance(model, Continuation) + + token_ids = np.array([[3, 2]]) + result = model.is_finished(token_ids) + assert_array_equal(result, [False]) + + token_ids = np.array([[3, 2, 0]]) + result = model.is_finished(token_ids) + assert_array_equal(result, [True]) + + token_ids = np.array([[3, 2, 1], [3, 2, 0]]) + result = model.is_finished(token_ids) + assert_array_equal(result, [False, True]) + + token_ids = np.array([[3, 2, 1, 0], [3, 2, 0, -1]]) + result = model.is_finished(token_ids) + assert_array_equal(result, [True, False]) + + +def test_continuation_postprocess(): + model = continuation(Model()) + result = model.postprocess_completions(["Here"]) + assert len(result) == 1 + assert result[0] == "Here" diff --git a/tests/text/generate/test_integration_transfomers.py b/tests/text/generate/test_integration_transfomers.py new file mode 100644 index 000000000..55bbde966 --- /dev/null +++ b/tests/text/generate/test_integration_transfomers.py @@ -0,0 +1,24 @@ +import numpy as np + +import outlines.models as models +from outlines.text.generate.continuation import continuation + + +def test_transformers_integration_completion(): + rng = np.random.default_rng(0) + + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" + model = models.transformers(model_name, device="cpu") + sequence = continuation(model)("prompt", rng=rng) + assert isinstance(sequence, str) + assert model.tokenizer.eos_token not in sequence + + sequence = continuation(model, max_tokens=10)("prompt", rng=rng) + assert isinstance(sequence, str) + + +def test_transformers_integration_with_pad_token(): + model_name = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM" + model = models.transformers(model_name, device="cpu") + assert model.tokenizer.pad_token_id == 1 + assert model.tokenizer.pad_token == "" diff --git a/tests/text/sequences/test_sequence.py b/tests/text/generate/test_sequence.py similarity index 99% rename from tests/text/sequences/test_sequence.py rename to tests/text/generate/test_sequence.py index 946990102..9659e8d6a 100644 --- a/tests/text/sequences/test_sequence.py +++ b/tests/text/generate/test_sequence.py @@ -4,7 +4,7 @@ import pytest from numpy.testing import assert_array_equal -from outlines.text.sequences.sequence import Sequence, vectorized_random_choice +from outlines.text.generate.sequence import Sequence, vectorized_random_choice def test_vectorized_random_choice():