diff --git a/outlines/fsm/types.py b/outlines/fsm/types.py index 93b59dd28..3e337542f 100644 --- a/outlines/fsm/types.py +++ b/outlines/fsm/types.py @@ -1,5 +1,5 @@ import datetime -from typing import Any +from typing import Any, Callable, Tuple INTEGER = r"[+-]?(0|[1-9][0-9]*)" BOOLEAN = "(True|False)" @@ -9,19 +9,27 @@ DATETIME = rf"({DATE})(\s)({TIME})" -def python_types_to_regex(python_type: Any) -> str: +def python_types_to_regex(python_type: Any) -> Tuple[str, Callable[[str], Any]]: if python_type == float: - return FLOAT + float_format_fn = lambda x: float(x) + return FLOAT, float_format_fn elif python_type == int: - return INTEGER + int_format_fn = lambda x: int(x) + return INTEGER, int_format_fn elif python_type == bool: - return BOOLEAN + bool_format_fn = lambda x: bool(x) + return BOOLEAN, bool_format_fn elif python_type == datetime.date: - return DATE + date_format_fn = lambda s: datetime.datetime.strptime(s, "%Y-%m-%d").date() + return DATE, date_format_fn elif python_type == datetime.time: - return TIME + time_format_fn = lambda s: datetime.datetime.strptime(s, "%H:%M:%S").time() + return TIME, time_format_fn elif python_type == datetime.datetime: - return DATETIME + datetime_format_fn = lambda s: datetime.datetime.strptime( + s, "%Y-%m-%d %H:%M:%S" + ) + return DATETIME, datetime_format_fn else: raise NotImplementedError( f"The Python type {python_type} is not supported. Please open an issue." diff --git a/outlines/generate/format.py b/outlines/generate/format.py index 5afc65bb7..d87a3fe70 100644 --- a/outlines/generate/format.py +++ b/outlines/generate/format.py @@ -10,8 +10,31 @@ @singledispatch def format(model, python_type, sampler: Sampler = multinomial()) -> SequenceGenerator: - regex_str = python_types_to_regex(python_type) - return regex(model, regex_str, sampler) + """Generate structured data that can be parsed as a Python type. + + Parameters + ---------- + model: + An instance of `Transformer` that represents a model from the + `transformers` library. + python_type: + A Python type. The output of the generator must be parseable into + this type. + sampler: + The sampling algorithm to use to generate token ids from the logits + distribution. + + Returns + ------- + A `SequenceGenerator` instance that generates text constrained by the Python type + and translates this text into the corresponding type. + + """ + regex_str, format_fn = python_types_to_regex(python_type) + generator = regex(model, regex_str, sampler) + generator.format_sequence = format_fn + + return generator @format.register(OpenAI) diff --git a/outlines/generate/json.py b/outlines/generate/json.py index b81c438a3..cf5866340 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -30,8 +30,6 @@ def json( schema_object: The JSON Schema to generate data for. Can be a JSON string, a Pydantic model, or a callable that returns a JSON schema. - max_tokens: - The maximum number of tokens to generate. sampler: The sampling algorithm to use to generate token ids from the logits distribution. @@ -43,6 +41,7 @@ def json( ------- A `SequenceGenerator` instance that generates text constrained by the schema_object and transforms the result if BaseModel is used. + """ if isinstance(schema_object, type(BaseModel)): schema = pyjson.dumps(schema_object.model_json_schema()) diff --git a/outlines/generate/regex.py b/outlines/generate/regex.py index edb8f807e..9d0b9ee87 100644 --- a/outlines/generate/regex.py +++ b/outlines/generate/regex.py @@ -8,6 +8,25 @@ @singledispatch def regex(model, regex_str: str, sampler: Sampler = multinomial()): + """Generate structured text in the language of a regular expression. + + Parameters + ---------- + model: + An instance of `Transformer` that represents a model from the + `transformers` library. + regex_str: + The regular expression that the output must follow. + sampler: + The sampling algorithm to use to generate token ids from the logits + distribution. + + Returns + ------- + A `SequenceGenerator` instance that generates text constrained by the + regular expression. + + """ fsm = RegexFSM(regex_str, model.tokenizer) device = model.device diff --git a/tests/fsm/test_types.py b/tests/fsm/test_types.py index cee586fa9..d5450434c 100644 --- a/tests/fsm/test_types.py +++ b/tests/fsm/test_types.py @@ -25,5 +25,5 @@ ], ) def test_python_types(python_type, regex): - test_regex = python_types_to_regex(python_type) + test_regex, _ = python_types_to_regex(python_type) assert regex == test_regex diff --git a/tests/generate/test_integration_transfomers.py b/tests/generate/test_integration_transfomers.py index 35019b238..f2de27c39 100644 --- a/tests/generate/test_integration_transfomers.py +++ b/tests/generate/test_integration_transfomers.py @@ -226,8 +226,7 @@ def test_transformers_integration_integer(): prompt = "Write a short sentence" sequence = generate.format(model, int)(prompt, max_tokens=10, rng=rng) - assert sequence != "" - int(sequence) + assert isinstance(sequence, int) def test_transformers_integration_integer_array(): @@ -240,8 +239,8 @@ def test_transformers_integration_integer_array(): sequence = generate.format(model, int)(prompts, max_tokens=10, rng=rng) assert isinstance(sequence, list) assert len(sequence) == 2 - int(sequence[0]) - int(sequence[1]) + assert isinstance(sequence[0], int) + assert isinstance(sequence[1], int) def test_transformers_integration_float(): @@ -254,7 +253,7 @@ def test_transformers_integration_float(): sequence = generate.format(model, float)(prompt, max_tokens=10, rng=rng) assert sequence != "" - float(sequence) + assert isinstance(sequence, float) def test_transformers_integration_bool(): @@ -267,7 +266,7 @@ def test_transformers_integration_bool(): sequence = generate.format(model, bool)(prompt, max_tokens=10, rng=rng) assert sequence != "" - bool(sequence) + assert isinstance(sequence, bool) def test_transformers_integration_date(): @@ -280,7 +279,7 @@ def test_transformers_integration_date(): sequence = generate.format(model, datetime.date)(prompt, max_tokens=10, rng=rng) assert sequence != "" - datetime.datetime.strptime(sequence, "%Y-%m-%d") + assert isinstance(sequence, datetime.date) def test_transformers_integration_time(): @@ -293,7 +292,7 @@ def test_transformers_integration_time(): sequence = generate.format(model, datetime.time)(prompt, max_tokens=10, rng=rng) assert sequence != "" - datetime.datetime.strptime(sequence, "%H:%M:%S") + assert isinstance(sequence, datetime.time) def test_transformers_integration_datetime(): @@ -306,7 +305,7 @@ def test_transformers_integration_datetime(): sequence = generate.format(model, datetime.datetime)(prompt, max_tokens=20, rng=rng) assert sequence != 0 - datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S") + assert isinstance(sequence, datetime.datetime) def test_transformers_integration_choice():