Skip to content

Commit

Permalink
Implement text.generate.format constrained generation
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 13, 2023
1 parent 96773e1 commit 4594704
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 80 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,27 @@ print(parsed)

The method works with union types, optional types, arrays, nested schemas, etc. Some field constraints are [not supported yet](https://github.com/outlines-dev/outlines/issues/215), but everything else should work.

### Open functions

Outlines can infer the structure of the output from the signature of a function. The result is a dictionary, and can be passed directly to the function using the usual dictionary expansion syntax `**`:

```python
from outlines import models
from outlines import text

def concat(a: int, b: int):
return a + b

model = models.transformers("mistralai/Mistral-7B")
generator = text.generate.json(model, add)
result = generator("Return two integers named a and b respectively. a is odd and b even.")

print(add(**result))
# 3
```

A great advantage of passing functions directly to specify the structure is that the structure of the LLM will change with the function's definition. No need to change the code at several places!

## Prompting

Writing prompts by concatenating strings in pure Python quickly becomes
Expand Down
2 changes: 1 addition & 1 deletion outlines/text/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .continuation import continuation
from .regex import choice, float, integer, json, regex
from .regex import choice, format, json, regex
51 changes: 7 additions & 44 deletions outlines/text/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from outlines.text.fsm import create_fsm_index_tokenizer, make_deterministic_fsm
from outlines.text.generate.continuation import Continuation
from outlines.text.json_schema import build_regex_from_object, get_schema_from_signature
from outlines.text.types import python_types_to_regex

if TYPE_CHECKING:
from outlines.text.generate.sample import Sampler
Expand Down Expand Up @@ -266,8 +267,9 @@ def regex(
)


def integer(
def format(
model,
python_type,
max_tokens: Optional[int] = None,
*,
sampler: Optional["Sampler"] = None,
Expand All @@ -288,6 +290,8 @@ def integer(
----------
model
The language model to use to compute the next-token logits.
python_type
The format in which the output is expected, defined as a Python type.
max_tokens
The maximum number of tokens to generate.
sampler
Expand All @@ -299,51 +303,10 @@ def integer(
Allow sampling of tokens corresponding to empty strings.
"""
regex_str = python_types_to_regex(python_type)
return Regex(
model,
r"[-+]?\d+",
max_tokens,
sampler=sampler,
allow_empty_tokens=allow_empty_tokens,
)


def float(
model,
max_tokens: Optional[int] = None,
*,
sampler: Optional["Sampler"] = None,
allow_empty_tokens: bool = True,
):
"""Generate floating-point numbers.
The regex used to constrain the generation optionally matches plus or minus
signs, and forbids leading zeros (even if the `float` function in Python
allows them).
.. note:
Reuse instances of these guided generators whenever possible,
because constructing them has more overhead than generating
token sequences from them. See the docstring for `Regex`.
Parameters
----------
model
The language model to use to compute the next-token logits.
max_tokens
The maximum number of tokens to generate.
sampler
The function used to draw samples. Defaults to
`outlines.text.generate.sample.multinomial`. See
`outlines.text.generate.sample.Sampler` for the expected form of
such functions.
allow_empty_tokens
Allow sampling of tokens corresponding to empty strings.
"""
return Regex(
model,
r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))",
regex_str,
max_tokens,
sampler=sampler,
allow_empty_tokens=allow_empty_tokens,
Expand Down
28 changes: 28 additions & 0 deletions outlines/text/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import datetime
from typing import Any

INTEGER = r"[+-]?(0|[1-9][0-9]*)"
BOOLEAN = "(True|False)"
FLOAT = rf"{INTEGER}(\.[0-9]+)?([eE][+-][0-9]+)?"
DATE = r"(\d{4})-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])"
TIME = r"([0-1][1-9]|2[0-4]):([0-5][0-9]):([0-5][0-9])"
DATETIME = rf"({DATE})(\s)({TIME})"


def python_types_to_regex(python_type: Any) -> str:
if python_type == float:
return FLOAT
elif python_type == int:
return INTEGER
elif python_type == bool:
return BOOLEAN
elif python_type == datetime.date:
return DATE
elif python_type == datetime.time:
return TIME
elif python_type == datetime.datetime:
return DATETIME
else:
raise NotImplementedError(
f"The Python type {python_type} is not supported. Please open an issue."
)
59 changes: 56 additions & 3 deletions tests/text/generate/test_integration_transfomers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import re
from enum import Enum
from typing import List, Union
Expand Down Expand Up @@ -76,7 +77,7 @@ def test_transformers_integration_integer():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name)
prompt = "Write a short sentence"
sequence = generate.integer(model, max_tokens=10)(prompt, rng=rng)
sequence = generate.format(model, int, max_tokens=10)(prompt, rng=rng)

assert sequence[0] != 0
int(sequence)
Expand All @@ -89,7 +90,7 @@ def test_transformers_integration_integer_array():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name)
prompts = ["Give me a number", "And another one"]
sequence = generate.integer(model, max_tokens=10)(prompts, rng=rng)
sequence = generate.format(model, int, max_tokens=10)(prompts, rng=rng)
assert isinstance(sequence, list)
assert len(sequence) == 2
int(sequence[0])
Expand All @@ -103,12 +104,64 @@ def test_transformers_integration_float():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name)
prompt = "Write a short sentence"
sequence = generate.float(model, max_tokens=10)(prompt, rng=rng)
sequence = generate.format(model, float, max_tokens=10)(prompt, rng=rng)

assert sequence[0] != 0
float(sequence)


def test_transformers_integration_bool():
rng = torch.Generator()
rng.manual_seed(0)

model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name)
prompt = "Is this True or False?"
sequence = generate.format(model, bool, max_tokens=10)(prompt, rng=rng)

assert sequence[0] != 0
bool(sequence)


def test_transformers_integration_date():
rng = torch.Generator()
rng.manual_seed(0)

model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name)
prompt = "What day is it today?"
sequence = generate.format(model, datetime.date, max_tokens=10)(prompt, rng=rng)

assert sequence[0] != 0
datetime.datetime.strptime(sequence, "%Y-%m-%d")


def test_transformers_integration_time():
rng = torch.Generator()
rng.manual_seed(0)

model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name)
prompt = "What time is it?"
sequence = generate.format(model, datetime.time, max_tokens=10)(prompt, rng=rng)

assert sequence[0] != 0
datetime.datetime.strptime(sequence, "%H:%M:%S")


def test_transformers_integration_datetime():
rng = torch.Generator()
rng.manual_seed(0)

model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name)
prompt = "What time is it?"
sequence = generate.format(model, datetime.datetime, max_tokens=20)(prompt, rng=rng)

assert sequence[0] != 0
datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S")


def test_transformers_integration_choice():
rng = torch.Generator()
rng.manual_seed(0)
Expand Down
39 changes: 7 additions & 32 deletions tests/text/generate/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_regex_no_valid_transition():
)
def test_integer_proposal(input_ids, proposal):
model = Model()
generator = generate.integer(model)
generator = generate.format(model, int)

logits = torch.ones(len(model.tokenizer.vocabulary))
result = generator.create_proposal(torch.tensor(input_ids), logits)
Expand Down Expand Up @@ -155,45 +155,20 @@ def test_choice_proposal():
)


@pytest.mark.parametrize(
"input_ids, proposal",
[
([[]], [[-math.inf, 1.0, 1.0, 1.0, 1.0, -math.inf, -math.inf]]),
([[3]], [[1.0, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf]]),
],
)
def test_float_proposal(input_ids, proposal):
model = Model()
generator = generate.float(model)

logits = torch.ones(len(model.tokenizer.vocabulary))
result = generator.create_proposal(torch.tensor(input_ids), logits)
assert torch.equal(
result,
torch.tensor(proposal),
)


@pytest.mark.parametrize(
"input_ids, proposal, with_empty",
[
([[]], [[-math.inf, 1.0, 1.0, 1.0, 1.0, -math.inf, -math.inf, 1]], True),
(
[[]],
[[-math.inf, 1.0, 1.0, 1.0, 1.0, -math.inf, -math.inf, -math.inf]],
False,
),
([[3]], [[1.0, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf, 1]], True),
([[]], [[-math.inf, 1.0, 1.0, 1.0, 1.0, -math.inf, -math.inf]], False),
(
[[3]],
[[1.0, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf, -math.inf]],
False,
[[-math.inf, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf]],
True,
),
],
)
def test_empty_strings(input_ids, proposal, with_empty):
model = ModelWithEmpty()
generator = generate.float(model, allow_empty_tokens=with_empty)
def test_float_proposal(input_ids, proposal, with_empty):
model = Model()
generator = generate.format(model, float, allow_empty_tokens=with_empty)

logits = torch.ones(len(model.tokenizer.vocabulary))
result = generator.create_proposal(torch.tensor(input_ids), logits)
Expand Down
29 changes: 29 additions & 0 deletions tests/text/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import datetime

import pytest

from outlines.text.types import (
BOOLEAN,
DATE,
DATETIME,
FLOAT,
INTEGER,
TIME,
python_types_to_regex,
)


@pytest.mark.parametrize(
"python_type,regex",
[
(int, INTEGER),
(float, FLOAT),
(bool, BOOLEAN),
(datetime.date, DATE),
(datetime.time, TIME),
(datetime.datetime, DATETIME),
],
)
def test_python_types(python_type, regex):
test_regex = python_types_to_regex(python_type)
assert regex == test_regex

0 comments on commit 4594704

Please sign in to comment.