Skip to content

Commit

Permalink
Make generate.format return the corresponding object
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Feb 10, 2024
1 parent 9c74d7c commit e36065c
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 22 deletions.
24 changes: 16 additions & 8 deletions outlines/fsm/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import Any
from typing import Any, Callable, Tuple

INTEGER = r"[+-]?(0|[1-9][0-9]*)"
BOOLEAN = "(True|False)"
Expand All @@ -9,19 +9,27 @@
DATETIME = rf"({DATE})(\s)({TIME})"


def python_types_to_regex(python_type: Any) -> str:
def python_types_to_regex(python_type: Any) -> Tuple[str, Callable[[str], Any]]:
if python_type == float:
return FLOAT
float_format_fn = lambda x: float(x)
return FLOAT, float_format_fn
elif python_type == int:
return INTEGER
int_format_fn = lambda x: int(x)
return INTEGER, int_format_fn
elif python_type == bool:
return BOOLEAN
bool_format_fn = lambda x: bool(x)
return BOOLEAN, bool_format_fn
elif python_type == datetime.date:
return DATE
date_format_fn = lambda s: datetime.datetime.strptime(s, "%Y-%m-%d").date()
return DATE, date_format_fn
elif python_type == datetime.time:
return TIME
time_format_fn = lambda s: datetime.datetime.strptime(s, "%H:%M:%S").time()
return TIME, time_format_fn
elif python_type == datetime.datetime:
return DATETIME
datetime_format_fn = lambda s: datetime.datetime.strptime(
s, "%Y-%m-%d %H:%M:%S"
)
return DATETIME, datetime_format_fn
else:
raise NotImplementedError(
f"The Python type {python_type} is not supported. Please open an issue."
Expand Down
27 changes: 25 additions & 2 deletions outlines/generate/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,31 @@

@singledispatch
def format(model, python_type, sampler: Sampler = multinomial()) -> SequenceGenerator:
regex_str = python_types_to_regex(python_type)
return regex(model, regex_str, sampler)
"""Generate structured data that can be parsed as a Python type.
Parameters
----------
model:
An instance of `Transformer` that represents a model from the
`transformers` library.
python_type:
A Python type. The output of the generator must be parseable into
this type.
sampler:
The sampling algorithm to use to generate token ids from the logits
distribution.
Returns
-------
A `SequenceGenerator` instance that generates text constrained by the Python type
and translates this text into the corresponding type.
"""
regex_str, format_fn = python_types_to_regex(python_type)
generator = regex(model, regex_str, sampler)
generator.format_sequence = format_fn

return generator


@format.register(OpenAI)
Expand Down
3 changes: 1 addition & 2 deletions outlines/generate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def json(
schema_object:
The JSON Schema to generate data for. Can be a JSON string, a Pydantic model, or a callable
that returns a JSON schema.
max_tokens:
The maximum number of tokens to generate.
sampler:
The sampling algorithm to use to generate token ids from the logits
distribution.
Expand All @@ -43,6 +41,7 @@ def json(
-------
A `SequenceGenerator` instance that generates text constrained by the schema_object and
transforms the result if BaseModel is used.
"""
if isinstance(schema_object, type(BaseModel)):
schema = pyjson.dumps(schema_object.model_json_schema())
Expand Down
19 changes: 19 additions & 0 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,25 @@

@singledispatch
def regex(model, regex_str: str, sampler: Sampler = multinomial()):
"""Generate structured text in the language of a regular expression.
Parameters
----------
model:
An instance of `Transformer` that represents a model from the
`transformers` library.
regex_str:
The regular expression that the output must follow.
sampler:
The sampling algorithm to use to generate token ids from the logits
distribution.
Returns
-------
A `SequenceGenerator` instance that generates text constrained by the
regular expression.
"""
fsm = RegexFSM(regex_str, model.tokenizer)

device = model.device
Expand Down
2 changes: 1 addition & 1 deletion tests/fsm/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
],
)
def test_python_types(python_type, regex):
test_regex = python_types_to_regex(python_type)
test_regex, _ = python_types_to_regex(python_type)
assert regex == test_regex
17 changes: 8 additions & 9 deletions tests/generate/test_integration_transfomers.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def test_transformers_integration_integer():
prompt = "Write a short sentence"
sequence = generate.format(model, int)(prompt, max_tokens=10, rng=rng)

assert sequence != ""
int(sequence)
assert isinstance(sequence, int)


def test_transformers_integration_integer_array():
Expand All @@ -240,8 +239,8 @@ def test_transformers_integration_integer_array():
sequence = generate.format(model, int)(prompts, max_tokens=10, rng=rng)
assert isinstance(sequence, list)
assert len(sequence) == 2
int(sequence[0])
int(sequence[1])
assert isinstance(sequence[0], int)
assert isinstance(sequence[1], int)


def test_transformers_integration_float():
Expand All @@ -254,7 +253,7 @@ def test_transformers_integration_float():
sequence = generate.format(model, float)(prompt, max_tokens=10, rng=rng)

assert sequence != ""
float(sequence)
assert isinstance(sequence, float)


def test_transformers_integration_bool():
Expand All @@ -267,7 +266,7 @@ def test_transformers_integration_bool():
sequence = generate.format(model, bool)(prompt, max_tokens=10, rng=rng)

assert sequence != ""
bool(sequence)
assert isinstance(sequence, bool)


def test_transformers_integration_date():
Expand All @@ -280,7 +279,7 @@ def test_transformers_integration_date():
sequence = generate.format(model, datetime.date)(prompt, max_tokens=10, rng=rng)

assert sequence != ""
datetime.datetime.strptime(sequence, "%Y-%m-%d")
assert isinstance(sequence, datetime.date)


def test_transformers_integration_time():
Expand All @@ -293,7 +292,7 @@ def test_transformers_integration_time():
sequence = generate.format(model, datetime.time)(prompt, max_tokens=10, rng=rng)

assert sequence != ""
datetime.datetime.strptime(sequence, "%H:%M:%S")
assert isinstance(sequence, datetime.time)


def test_transformers_integration_datetime():
Expand All @@ -306,7 +305,7 @@ def test_transformers_integration_datetime():
sequence = generate.format(model, datetime.datetime)(prompt, max_tokens=20, rng=rng)

assert sequence != 0
datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S")
assert isinstance(sequence, datetime.datetime)


def test_transformers_integration_choice():
Expand Down

0 comments on commit e36065c

Please sign in to comment.