From 842ef19a6ffbf8ce2b0d26a4c591f2772fac1572 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Fri, 3 May 2024 15:09:41 -0400 Subject: [PATCH 01/35] Fix typo in docs (#860) --- docs/reference/models/vllm.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/reference/models/vllm.md b/docs/reference/models/vllm.md index 7fc29f00c..b5221e582 100644 --- a/docs/reference/models/vllm.md +++ b/docs/reference/models/vllm.md @@ -123,7 +123,7 @@ from outlines import models, generate model = models.vllm("mistralai/Mistral-7b-v0.1") generator = generate.text(model) -params = SamplingParams(n=2, frequence_penalty=1., min_tokens=2) +params = SamplingParams(n=2, frequency_penalty=1., min_tokens=2) answer = generator("A prompt", sampling_params=params) ``` From a101c1c8066f7ce92cdbcbd46a269d06163eded9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 3 May 2024 22:01:07 +0200 Subject: [PATCH 02/35] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4c3443639..e6356e2f3 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ First time here? Go to our [setup guide](https://outlines-dev.github.io/outlines - [x] 🚀 [Serve with vLLM](https://outlines-dev.github.io/outlines/reference/vllm), with official Docker image, [`outlinesdev/outlines`](https://hub.docker.com/r/outlinesdev/outlines)! -Outlines 〰 has new releases and features coming every week. Make sure to ⭐ star and 👀 watch this repository, follow [@dottxtai][twitter] to stay up to date! +Outlines 〰 has new releases and features coming every week. Make sure to ⭐ star and 👀 watch this repository, follow [@dottxtai][dottxt-twitter] to stay up to date! ## Why should I use structured generation? From b3b29fe7f61be8b603fe9da9511bea73a138aeb9 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Sun, 5 May 2024 05:50:56 -0400 Subject: [PATCH 03/35] Fix code rendering (#864) Make the code render as Python. --- docs/reference/models/openai.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/reference/models/openai.md b/docs/reference/models/openai.md index 5ddd4a457..07357a360 100644 --- a/docs/reference/models/openai.md +++ b/docs/reference/models/openai.md @@ -19,7 +19,7 @@ print(type(model)) Outlines also supports Azure OpenAI models: -``` +```python from outlines import models model = models.azure_openai( @@ -30,7 +30,7 @@ model = models.azure_openai( More generally, you can use any API client compatible with the OpenAI interface by passing an instance of the client, a configuration, and optionally the corresponding tokenizer (if you want to be able to use `outlines.generate.choice`): -``` +```python from openai import AsyncOpenAI import tiktoken From db83c089987b7fb5ddb9aeb4b834f903301dae47 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Mon, 6 May 2024 01:47:42 +0800 Subject: [PATCH 04/35] ignore import warnings from huggingface_hub & pyairports --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 0b6ae6352..41c306b14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,9 @@ filterwarnings = [ "ignore::numba.core.errors.NumbaPendingDeprecationWarning", "ignore::pydantic.warnings.PydanticDeprecatedSince20", "ignore::FutureWarning:transformers.*", + "ignore::FutureWarning:huggingface_hub.*", "ignore::UserWarning", + "ignore::DeprecationWarning:pyairports.*", ] [tool.mypy] From b7d876e2378bcb3bf2023adfb684f4be4bce9a7c Mon Sep 17 00:00:00 2001 From: Sherlock113 Date: Mon, 6 May 2024 10:45:45 +0800 Subject: [PATCH 05/35] Fix format in the BentoML doc Signed-off-by: Sherlock113 --- docs/cookbook/deploy-using-bentoml.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/cookbook/deploy-using-bentoml.md b/docs/cookbook/deploy-using-bentoml.md index f3770d07d..6bee77441 100644 --- a/docs/cookbook/deploy-using-bentoml.md +++ b/docs/cookbook/deploy-using-bentoml.md @@ -50,13 +50,14 @@ $ bentoml models list Tag Module Size Creation Time mistralai--mistral-7b-v0.1:m7lmf5ac2cmubnnz 13.49 GiB 2024-04-25 06:52:39 +``` + ## Define a BentoML Service As the model is ready, we can define a [BentoML Service](https://docs.bentoml.com/en/latest/guides/services.html) to wrap the capabilities of the model. We will run the JSON-structured generation example [in the README](https://github.com/outlines-dev/outlines?tab=readme-ov-file#efficient-json-generation-following-a-json-schema), with the following schema: - ```python DEFAULT_SCHEMA = """{ "title": "Character", @@ -206,6 +207,8 @@ Expected output: "weapon": "sword", "strength": 20 } +``` + ## Deploy to BentoCloud After the Service is ready, you can deploy it to [BentoCloud](https://docs.bentoml.com/en/latest/bentocloud/get-started.html) for better management and scalability. [Sign up](https://cloud.bentoml.com/signup) if you haven't got a BentoCloud account. From 353cebb276c4b90f1300f7ce514bc68a3f5ad6c9 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Mon, 6 May 2024 00:52:21 +0800 Subject: [PATCH 06/35] fix CFG Generation --- outlines/generate/generator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index edda617d0..e506aa035 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -86,8 +86,9 @@ def sequence_generator( token_ids = update_token_ids(token_ids, next_token_ids, ancestors) attention_masks = update_attention_masks(attention_masks, ancestors) kv_cache = reorder_kv_cache(kv_cache, ancestors) - fsms = reorder_fsms(fsms, ancestors) - fsm_states = reorder_fsm_states(fsm_states, ancestors) + if len(ancestors) > 1: + fsms = reorder_fsms(fsms, ancestors) + fsm_states = reorder_fsm_states(fsm_states, ancestors) fsm_states = get_next_fsm_states(fsms, fsm_states, next_token_ids) is_finished = is_generation_finished(fsms, fsm_states) From 4f8433d8d6633b0780c3a6c27981f9adffbe49f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 6 May 2024 10:02:45 +0200 Subject: [PATCH 07/35] Localize types --- docs/reference/types.md | 28 ++++++++++++++++++++++------ outlines/types/__init__.py | 3 +-- outlines/types/locales.py | 21 +++++++++++++++++++++ outlines/types/phone_numbers.py | 2 +- outlines/types/zip_codes.py | 2 +- tests/test_types.py | 33 +++++++++++++++++++++++++++------ 6 files changed, 73 insertions(+), 16 deletions(-) create mode 100644 outlines/types/locales.py diff --git a/docs/reference/types.md b/docs/reference/types.md index 645249263..6fa093fef 100644 --- a/docs/reference/types.md +++ b/docs/reference/types.md @@ -2,11 +2,8 @@ Outlines provides custom Pydantic types so you can focus on your use case rather than on writing regular expressions: - | Category | Type | Import | Description | |:--------:|:----:|:-------|:------------| -| Zip code | US | `outlines.types.ZipCode` | Generate US Zip(+4) codes | -| Phone number | US | `outlines.types.PhoneNumber` | Generate valid US phone numbers | | ISBN | 10 & 13 | `outlines.types.ISBN` | There is no guarantee that the [check digit][wiki-isbn] will be correct | | Airport | IATA | `outlines.types.airports.IATA` | Valid [airport IATA codes][wiki-airport-iata] | | Country | alpha-2 code | `outlines.types.airports.Alpha2` | Valid [country alpha-2 codes][wiki-country-alpha-2] | @@ -15,6 +12,23 @@ Outlines provides custom Pydantic types so you can focus on your use case rather | | name | `outlines.types.countries.Name` | Valid country names | | | flag | `outlines.types.countries.Flag` | Valid flag emojis | +Some types require localization. We currently only support US types, but please don't hesitate to create localized versions of the different types and open a Pull Request. Localized types are specified using `types.locale` in the following way: + +```python +from outlines import types + +types.locale("us").ZipCode +types.locale("us").PhoneNumber +``` + +Here are the localized types that are currently available: + +| Category | Locale | Import | Description | +|:--------:|:----:|:-------|:------------| +| Zip code | US | `ZipCode` | Generate US Zip(+4) codes | +| Phone number | US | `PhoneNumber` | Generate valid US phone numbers | + + You can use these types in Pydantic schemas for JSON-structured generation: ```python @@ -22,11 +36,13 @@ from pydantic import BaseModel from outlines import models, generate, types +# Specify the locale for types +locale = types.locale("us") class Client(BaseModel): name: str - phone_number: types.PhoneNumber - zip_code: types.ZipCode + phone_number: locale.PhoneNumber + zip_code: locale.ZipCode model = models.transformers("mistralai/Mistral-7B-v0.1") @@ -47,7 +63,7 @@ from outlines import models, generate, types model = models.transformers("mistralai/Mistral-7B-v0.1") -generator = generate.format(model, types.PhoneNumber) +generator = generate.format(model, types.locale("us").PhoneNumber) result = generator( "Return a US Phone number: " ) diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index 266d3a68e..bf6eceabc 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -1,4 +1,3 @@ from . import airports, countries from .isbn import ISBN -from .phone_numbers import PhoneNumber -from .zip_codes import ZipCode +from .locales import locale diff --git a/outlines/types/locales.py b/outlines/types/locales.py new file mode 100644 index 000000000..c5d251bae --- /dev/null +++ b/outlines/types/locales.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from outlines.types.phone_numbers import USPhoneNumber +from outlines.types.zip_codes import USZipCode + + +@dataclass +class US: + ZipCode = USZipCode + PhoneNumber = USPhoneNumber + + +def locale(locale_str: str): + locales = {"us": US} + + if locale_str not in locales: + raise NotImplementedError( + f"The locale {locale_str} is not supported yet. Please don't hesitate to create custom types for you locale and open a Pull Request." + ) + + return locales[locale_str] diff --git a/outlines/types/phone_numbers.py b/outlines/types/phone_numbers.py index 0b27c7890..618687e75 100644 --- a/outlines/types/phone_numbers.py +++ b/outlines/types/phone_numbers.py @@ -10,7 +10,7 @@ US_PHONE_NUMBER = r"(\([0-9]{3}\) |[0-9]{3}-)[0-9]{3}-[0-9]{4}" -PhoneNumber = Annotated[ +USPhoneNumber = Annotated[ str, WithJsonSchema({"type": "string", "pattern": US_PHONE_NUMBER}), ] diff --git a/outlines/types/zip_codes.py b/outlines/types/zip_codes.py index 981efdd03..67d994d5c 100644 --- a/outlines/types/zip_codes.py +++ b/outlines/types/zip_codes.py @@ -10,4 +10,4 @@ US_ZIP_CODE = r"\d{5}(?:-\d{4})?" -ZipCode = Annotated[str, WithJsonSchema({"type": "string", "pattern": US_ZIP_CODE})] +USZipCode = Annotated[str, WithJsonSchema({"type": "string", "pattern": US_ZIP_CODE})] diff --git a/tests/test_types.py b/tests/test_types.py index 2391ccc18..b4232c84a 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -10,12 +10,12 @@ @pytest.mark.parametrize( "custom_type,test_string,should_match", [ - (types.PhoneNumber, "12", False), - (types.PhoneNumber, "(123) 123-1234", True), - (types.PhoneNumber, "123-123-1234", True), - (types.ZipCode, "12", False), - (types.ZipCode, "12345", True), - (types.ZipCode, "12345-1234", True), + (types.phone_numbers.USPhoneNumber, "12", False), + (types.phone_numbers.USPhoneNumber, "(123) 123-1234", True), + (types.phone_numbers.USPhoneNumber, "123-123-1234", True), + (types.zip_codes.USZipCode, "12", False), + (types.zip_codes.USZipCode, "12345", True), + (types.zip_codes.USZipCode, "12345-1234", True), (types.ISBN, "ISBN 0-1-2-3-4-5", False), (types.ISBN, "ISBN 978-0-596-52068-7", True), # (types.ISBN, "ISBN 978-0-596-52068-1", True), wrong check digit @@ -42,6 +42,27 @@ class Model(BaseModel): assert does_match is should_match +def test_locale_not_implemented(): + with pytest.raises(NotImplementedError): + types.locale("fr") + + +@pytest.mark.parametrize( + "locale_str,base_types,locale_types", + [ + ( + "us", + ["ZipCode", "PhoneNumber"], + [types.zip_codes.USZipCode, types.phone_numbers.USPhoneNumber], + ) + ], +) +def test_locale(locale_str, base_types, locale_types): + for base_type, locale_type in zip(base_types, locale_types): + type = getattr(types.locale(locale_str), base_type) + assert type == locale_type + + @pytest.mark.parametrize( "custom_type,test_string,should_match", [ From a84d78ce446aff80962e58adcd48a15e1519e1f2 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Wed, 8 May 2024 08:02:27 -0400 Subject: [PATCH 08/35] Add `Email` type (#870) --- docs/reference/types.md | 1 + outlines/types/__init__.py | 1 + outlines/types/email.py | 11 +++++++++++ tests/test_types.py | 6 ++++++ 4 files changed, 19 insertions(+) create mode 100644 outlines/types/email.py diff --git a/docs/reference/types.md b/docs/reference/types.md index 6fa093fef..2e02f45f1 100644 --- a/docs/reference/types.md +++ b/docs/reference/types.md @@ -11,6 +11,7 @@ Outlines provides custom Pydantic types so you can focus on your use case rather | | numeric code | `outlines.types.countries.Numeric` | Valid [country numeric codes][wiki-country-numeric] | | | name | `outlines.types.countries.Name` | Valid country names | | | flag | `outlines.types.countries.Flag` | Valid flag emojis | +| | email | `outlines.types.Email` | Valid email address | Some types require localization. We currently only support US types, but please don't hesitate to create localized versions of the different types and open a Pull Request. Localized types are specified using `types.locale` in the following way: diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index bf6eceabc..f4d2b8cd3 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -1,3 +1,4 @@ from . import airports, countries +from .email import Email from .isbn import ISBN from .locales import locale diff --git a/outlines/types/email.py b/outlines/types/email.py new file mode 100644 index 000000000..45f8c4b2c --- /dev/null +++ b/outlines/types/email.py @@ -0,0 +1,11 @@ +"""Email Address types.""" +from pydantic import WithJsonSchema +from typing_extensions import Annotated + +# Taken from StackOverflow +# https://stackoverflow.com/a/201378/14773537 +EMAIL_REGEX = r"""(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?|\[(?:(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])""" +Email = Annotated[ + str, + WithJsonSchema({"type": "string", "pattern": EMAIL_REGEX}), +] diff --git a/tests/test_types.py b/tests/test_types.py index b4232c84a..5e60348b2 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -24,6 +24,12 @@ (types.ISBN, "9780596520687", True), (types.ISBN, "ISBN-10: 0-596-52068-9", True), (types.ISBN, "0-596-52068-9", True), + (types.Email, "eitan@gmail.com", True), + (types.Email, "99@yahoo.com", True), + (types.Email, "eitan@.gmail.com", False), + (types.Email, "myemail", False), + (types.Email, "eitan@gmail", False), + (types.Email, "eitan@my.custom.domain", True), ], ) def test_type_regex(custom_type, test_string, should_match): From 99e684efa9a3f5cc350f994146d0195f59810944 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Wed, 8 May 2024 08:03:56 -0400 Subject: [PATCH 09/35] Fix installation instructions (#877) --- docs/community/contribute.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/community/contribute.md b/docs/community/contribute.md index 1df15084a..fb67576e4 100644 --- a/docs/community/contribute.md +++ b/docs/community/contribute.md @@ -39,7 +39,7 @@ source .venv/bin/activate Then install the dependencies in editable mode, and install the pre-commit hooks: ```python -pip install -e .[test] +pip install -e ".[test]" pre-commit install ``` From 97ec37d9038750101152582e5df3d7315b2759b5 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Wed, 8 May 2024 11:07:27 -0400 Subject: [PATCH 10/35] Extract function name in `get_schema_from_signature` (#878) --- outlines/fsm/json_schema.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 647c95a22..d96597d4c 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -1,6 +1,7 @@ import inspect import json import re +import warnings from typing import Callable, Optional from jsonschema.protocols import Validator @@ -375,6 +376,14 @@ def get_schema_from_signature(fn: Callable) -> str: else: arguments[name] = (arg.annotation, ...) - model = create_model("Arguments", **arguments) + try: + fn_name = fn.__name__ + except Exception as e: + fn_name = "Arguments" + warnings.warn( + f"The function name could not be determined. Using default name 'Arguments' instead. For debugging, here is exact error:\n{e}", + category=UserWarning, + ) + model = create_model(fn_name, **arguments) return model.model_json_schema() From 78852b0169e7c4c6f3eaf6b2b2e6209e41edf98c Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Sat, 11 May 2024 22:14:12 +0200 Subject: [PATCH 11/35] Remove broken final state loop (#874) Fixes #856 The code this PR removes introduces an artificial and erroneous loop transition in every final state that is always traversed, regardless of the generation. The comment doesn't make sense in my opinion, as the `if` above just handles exactly this case. Removing this piece of code fixes the bug that surfaced in the upgrade of outlines in the vLLM integration. --- outlines/fsm/guide.py | 6 +----- tests/fsm/test_fsm.py | 4 +--- tests/fsm/test_guide.py | 4 +--- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index 5c7b56326..2833fce1a 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -193,12 +193,8 @@ def get_next_state(self, state: int, token_id: int) -> int: The new state of the guide. """ - if token_id == self.eos_token_id: + if token_id == self.eos_token_id or state not in self.states_to_token_maps: return -1 - elif ( - state in self.final_states - ): # Necessary because we keep generating EOS tokens when finished - return state last_token_to_end_state = self.states_to_token_maps[state] next_state = last_token_to_end_state.get(token_id) diff --git a/tests/fsm/test_fsm.py b/tests/fsm/test_fsm.py index 30047b83d..97d120f20 100644 --- a/tests/fsm/test_fsm.py +++ b/tests/fsm/test_fsm.py @@ -82,9 +82,7 @@ def convert_token_to_string(self, token): assert fsm.is_final_state(state) state = fsm.next_state(state=5, token_id=103) - assert state == 5 - - assert fsm.is_final_state(-1) + assert fsm.is_final_state(state) def test_cfg(): diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 28645f012..68f7e33ee 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -180,9 +180,7 @@ def convert_token_to_string(self, token): assert fsm.is_final_state(state) state = fsm.get_next_state(state=5, token_id=103) - assert state == 5 - - assert fsm.is_final_state(-1) + assert fsm.is_final_state(state) def test_cfg(): From 159d1ec6d6ee34515755424019abf996e061f993 Mon Sep 17 00:00:00 2001 From: Isamu Isozaki Date: Fri, 17 May 2024 18:05:17 +0900 Subject: [PATCH 12/35] Fixing stream stopping at wrong location (#898) Fixes https://github.com/outlines-dev/outlines/issues/896 --- outlines/generate/api.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 3f4f182d2..51a995664 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -340,15 +340,6 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: return generated_token_ids = sequence.token_ids[:, -num_generated:] generated_sequences = self.tokenizer.decode(generated_token_ids) - next_tokens = [ - token[len(sequence) :] if not stop else "" - for token, sequence, stop in zip( - generated_sequences, - previously_generated_sequences, - is_stop_at_reached, - ) - ] - previously_generated_sequences = generated_sequences if stop_sequences: is_stop_at_reached = [ stop @@ -360,6 +351,25 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: ) ] + generated_sequences = [ + self.format_sequence( + self.strip_stop_sequences(sequence, stop_sequences) + ) + if stop + else sequence + for sequence, stop in zip( + generated_sequences, is_stop_at_reached + ) + ] + next_tokens = [ + token[len(sequence) :] + for token, sequence, stop in zip( + generated_sequences, + previously_generated_sequences, + is_stop_at_reached, + ) + ] + previously_generated_sequences = generated_sequences # We reshape the output to (batch_size, sample_size) output: List[List[str]] = list() for i in range(batch_size): From 499d19dd3078e5e21cf68c7916a162d5e8ce0990 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 17 May 2024 09:09:03 +0000 Subject: [PATCH 13/35] Prevent Illegal Look-Around for OneOf in JSONSchema (#897) Fixes #823 This comment details the issues error: https://github.com/outlines-dev/outlines/issues/823#issuecomment-2116490949 The reproduction code provided results in a json schema with `OneOf[pets]`: ``` class Model(BaseModel): pet: Union[Cat, Dog] = Field(..., discriminator='pet_type') ``` Before this PR: `OneOf` uses negative lookaheads to assert that only one schema member is included. This is illegal in `interegular`, more details available here: https://github.com/outlines-dev/outlines/issues/456 After `OneOf` uses or-joined non-capturing groups which don't have the same issues with `interegular`. --- outlines/fsm/json_schema.py | 8 +------- tests/fsm/test_json_schema.py | 31 ++++++++++++++++++++++++++++--- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index d96597d4c..2c53fd240 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -195,13 +195,7 @@ def to_regex( to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"] ] - xor_patterns = [] - # json schema validation ensured there is no overlapping schemas in oneOf - for subregex in subregexes: - other_subregexes = filter(lambda r: r != subregex, subregexes) - other_subregexes_str = "|".join([f"{s}" for s in other_subregexes]) - negative_lookahead = f"(?!.*({other_subregexes_str}))" - xor_patterns.append(f"({subregex}){negative_lookahead}") + xor_patterns = [f"(?:{subregex})" for subregex in subregexes] return rf"({'|'.join(xor_patterns)})" diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 5b3ad9e39..b992f7aa5 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -1,9 +1,10 @@ import json import re -from typing import List +from typing import List, Literal, Union +import interegular import pytest -from pydantic import BaseModel, constr +from pydantic import BaseModel, Field, constr from outlines.fsm.json_schema import ( BOOLEAN, @@ -321,7 +322,7 @@ def test_match_number(pattern, does_match): "title": "Foo", "oneOf": [{"type": "string"}, {"type": "number"}, {"type": "boolean"}], }, - rf"(({STRING})(?!.*({NUMBER}|{BOOLEAN}))|({NUMBER})(?!.*({STRING}|{BOOLEAN}))|({BOOLEAN})(?!.*({STRING}|{NUMBER})))", + rf'((?:"{STRING_INNER}*")|(?:{NUMBER})|(?:{BOOLEAN}))', [ ("12.3", True), ("true", True), @@ -750,3 +751,27 @@ class MockModel(BaseModel): assert match_default_ws is None assert re.fullmatch(pattern, mock_result_maybe_ws) + + +def test_one_of_doesnt_produce_illegal_lookaround(): + """Reproduces failure in https://github.com/outlines-dev/outlines/issues/823""" + + class Cat(BaseModel): + pet_type: Literal["cat"] + meows: int + + class Dog(BaseModel): + pet_type: Literal["dog"] + barks: float + + class Model(BaseModel): + pet: Union[Cat, Dog] = Field(..., discriminator="pet_type") + n: int + + json_schema = Model.schema_json() + + json_schema = Model.schema_json() + pattern = build_regex_from_schema(json_schema, whitespace_pattern=None) + + # check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm() + interegular.parse_pattern(pattern).to_fsm() From 315d531b9b8cf4c1a87179531b479001e08a4d8e Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 17 May 2024 09:33:45 -0500 Subject: [PATCH 14/35] circumvent broken llama.cpp pre-tokenizer --- outlines/integrations/llamacpp.py | 16 ++++++++--- outlines/models/llamacpp.py | 11 +++++++ tests/generate/test_integration_llamacpp.py | 32 +++++++++++++++++++++ 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/outlines/integrations/llamacpp.py b/outlines/integrations/llamacpp.py index 8c000a6e5..4041c54fb 100644 --- a/outlines/integrations/llamacpp.py +++ b/outlines/integrations/llamacpp.py @@ -49,11 +49,19 @@ def __init__(self, model: "Llama"): self.special_tokens: Set[int] = set() self.vocabulary: Dict[str, int] = dict() - for t in range(model.n_vocab()): - token_piece = model.tokenizer().decode([t]) - self.vocabulary[token_piece] = t - self.decode = model.tokenizer().decode + tokenizer = model.tokenizer() + + self.decode = tokenizer.decode + + # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved + try: + self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab() + except AttributeError: + # ### + for t in range(model.n_vocab()): + token_piece = model.tokenizer().decode([t]) + self.vocabulary[token_piece] = t def convert_token_to_string(self, token: str) -> str: return token diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 8a6a53a27..5920f08d6 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -1,4 +1,5 @@ import dataclasses +import warnings from typing import TYPE_CHECKING, Iterator, List, Optional, TypedDict, Union from typing_extensions import Unpack @@ -288,6 +289,16 @@ def llamacpp( if "verbose" not in llamacpp_model_params: llamacpp_model_params["verbose"] = False + # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved + if "tokenizer" not in llamacpp_model_params: + warnings.warn( + "The pre-tokenizer in `llama.cpp` handles unicode improperly " + + "(https://github.com/ggerganov/llama.cpp/pull/5613)\n" + + "Outlines may raise a `RuntimeError` when building the regex index.\n" + + "To circumvent this error when using `models.llamacpp()` you may pass the argument" + + "`tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained()`\n" + ) + model = Llama.from_pretrained(repo_id, filename, **llamacpp_model_params) return LlamaCpp(model) diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index d036b560f..75d0e4cef 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -247,3 +247,35 @@ def test_llamacpp_cfg(model): prompt = "<|im_start|>user\nOutput a short and valid JSON object with two keys.<|im_end|>\n><|im_start|>assistant\n" result = generate.cfg(model, grammars.arithmetic)(prompt, seed=11) assert isinstance(result, str) + + +@pytest.mark.parametrize( + "repo,model_path,hf_tokenizer_uri", + [ + ("Qwen/Qwen1.5-0.5B-Chat-GGUF", "*q2*.gguf", "Qwen/Qwen1.5-0.5B-Chat"), + ("TheBloke/phi-2-GGUF", "*Q2*.gguf", "microsoft/phi-2"), + ], +) +def test_byte_tokenizer_regression(repo, model_path, hf_tokenizer_uri): + """Reproduce https://github.com/outlines-dev/outlines/issues/820""" + import llama_cpp + + model = models.llamacpp( + repo, + model_path, + tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( + hf_tokenizer_uri + ), + ) + generator = generate.choice(model, ["skirt", "dress", "pen", "jacket"]) + generator("Pick the odd word out: skirt, dress, pen, jacket") + + +def test_llama_cpp_pre_tokenizer_remains_broken(): + """If fails, llama.cpp pre-tokenizer is fixed -> revert #892, remove `with pytest.raises`""" + repo = "Qwen/Qwen1.5-0.5B-Chat-GGUF" + model_path = "*q2*.gguf" + + model = models.llamacpp(repo, model_path) + with pytest.raises(RuntimeError): + generate.choice(model, ["skirt", "dress", "pen", "jacket"]) From 3e291b1357326d1664326b9bb9780ea71ee2d236 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Fri, 17 May 2024 17:22:58 -0400 Subject: [PATCH 15/35] Add args to Jinja filters (#902) In the outlines docs, we have the example ```python import outlines def my_tool(arg1: str, arg2: int): """Tool description. The rest of the docstring """ pass @outlines.prompt def tool_prompt(question, tool): """{{ question }} COMMANDS 1. {{ tool | name }}: {{ tool | description }}, args: {{ tool | args }} {{ tool | source }} """ prompt = tool_prompt("Can you do something?", my_tool) print(prompt) ``` However, when I tried running this code, it did not work because the `args` filter used in `{{ tool | args }}` was not implemented. I implemented the `args` filter so now this example works. Now the args filter will output all of the arguments with the type annotations and default values (if they are provided). Example: ```python from typing import List def foo(x, y: str, z: List[int]=[1, 2, 3]): pass @outlines.prompt def tool_prompt(fn): """My args: {{ fn | args }}""" prompt = tool_prompt(foo) print(prompt) ``` which outputs ```python My args: x, y: str, z: List[int] = [1, 2, 3] ``` --- docs/reference/prompting.md | 2 +- outlines/prompts.py | 13 ++++++++ tests/test_prompts.py | 62 ++++++++++++++++++++++++++++++++++++- 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/docs/reference/prompting.md b/docs/reference/prompting.md index a7731ba0f..34860fce0 100644 --- a/docs/reference/prompting.md +++ b/docs/reference/prompting.md @@ -223,7 +223,7 @@ Several projects (e.g.[Toolformer](https://arxiv.org/abs/2302.04761), [ViperGPT] Can you do something? COMMANDS - 1. my_tool: Tool description, args: arg1:str, arg2:int + 1. my_tool: Tool description., args: arg1: str, arg2: int def my_tool(arg1: str, arg2: int): """Tool description. diff --git a/outlines/prompts.py b/outlines/prompts.py index b4e7288bb..01e900c96 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -207,6 +207,7 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str: env.filters["source"] = get_fn_source env.filters["signature"] = get_fn_signature env.filters["schema"] = get_schema + env.filters["args"] = get_fn_args jinja_template = env.from_string(cleaned_template) @@ -226,6 +227,18 @@ def get_fn_name(fn: Callable): return name +def get_fn_args(fn: Callable): + """Returns the arguments of a function with annotations and default values if provided.""" + if not callable(fn): + raise TypeError("The `args` filter only applies to callables.") + + arg_str_list = [] + signature = inspect.signature(fn) + arg_str_list = [str(param) for param in signature.parameters.values()] + arg_str = ", ".join(arg_str_list) + return arg_str + + def get_fn_description(fn: Callable): """Returns the first line of a callable's docstring.""" if not callable(fn): diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 65eeb2022..a0433c0e5 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Dict, List import pytest from pydantic import BaseModel, Field @@ -252,3 +252,63 @@ def source_ppt(model): prompt = source_ppt(response) assert prompt == '{\n "one": "a description",\n "two": ""\n}' + + +def test_prompt_args(): + def no_args(): + pass + + def with_args(x, y, z): + pass + + def with_annotations(x: bool, y: str, z: Dict[int, List[str]]): + pass + + def with_defaults(x=True, y="Hi", z={4: ["I", "love", "outlines"]}): + pass + + def with_annotations_and_defaults( + x: bool = True, + y: str = "Hi", + z: Dict[int, List[str]] = {4: ["I", "love", "outlines"]}, + ): + pass + + def with_all( + x1, + y1, + z1, + x2: bool, + y2: str, + z2: Dict[int, List[str]], + x3=True, + y3="Hi", + z3={4: ["I", "love", "outlines"]}, + x4: bool = True, + y4: str = "Hi", + z4: Dict[int, List[str]] = {4: ["I", "love", "outlines"]}, + ): + pass + + @outlines.prompt + def args_prompt(fn): + """args: {{ fn | args }}""" + + assert args_prompt(no_args) == "args: " + assert args_prompt(with_args) == "args: x, y, z" + assert ( + args_prompt(with_annotations) + == "args: x: bool, y: str, z: Dict[int, List[str]]" + ) + assert ( + args_prompt(with_defaults) + == "args: x=True, y='Hi', z={4: ['I', 'love', 'outlines']}" + ) + assert ( + args_prompt(with_annotations_and_defaults) + == "args: x: bool = True, y: str = 'Hi', z: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}" + ) + assert ( + args_prompt(with_all) + == "args: x1, y1, z1, x2: bool, y2: str, z2: Dict[int, List[str]], x3=True, y3='Hi', z3={4: ['I', 'love', 'outlines']}, x4: bool = True, y4: str = 'Hi', z4: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}" + ) From 7863f8e8bbaeb71c9d2434636a2d63bfe6dd7d39 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sat, 18 May 2024 02:50:25 -0500 Subject: [PATCH 16/35] Allow Parenthesis in `STRING_INNER` (#899) Fix #838 https://github.com/outlines-dev/outlines/commit/06d565496966f0dbe184dd619b62ea276035f562 erroneously disallowed parenthesis in strings. This PR allows parenthesis in strings. --- outlines/fsm/json_schema.py | 2 +- tests/fsm/test_json_schema.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 2c53fd240..c57cea7cd 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -10,7 +10,7 @@ from referencing._core import Resolver from referencing.jsonschema import DRAFT202012 -STRING_INNER = r'([^("\\\x00-\x1f\x7f-\x9f)]|\\\\)' +STRING_INNER = r'([^"\\\x00-\x1f\x7f-\x9f]|\\\\)' STRING = f'"{STRING_INNER}*"' INTEGER = r"(-)?(0|[1-9][0-9]*)" NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?" diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index b992f7aa5..edc061bec 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -119,6 +119,8 @@ def test_match_number(pattern, does_match): STRING, [ ("unquotedstring", False), + ('"(parenthesized_string)"', True), + ('"malformed) parenthesis (((() string"', True), ('"quoted_string"', True), (r'"escape_\character"', False), (r'"double_\\escape"', True), From 6f655ca8f00d6ea42c72eb699d97c78a10a826ab Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 21 May 2024 00:25:33 -0500 Subject: [PATCH 17/35] Allow objects in json schemas without additionalProperties set --- outlines/fsm/json_schema.py | 24 +++++++++++++++++++++++- tests/fsm/test_json_schema.py | 13 +++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index c57cea7cd..dbd2baa40 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -312,8 +312,30 @@ def to_regex( allow_empty = "?" if int(instance.get("minProperties", 0)) == 0 else "" + additional_properties = instance.get("additionalProperties") + + if additional_properties is None or additional_properties is True: + # JSON Schema behavior: If the additionalProperties of an object is + # unset or True, it is unconstrained object. + # We handle this by setting additionalProperties to anyOf: {all types} + + legal_values = [ + {"type": "string"}, + {"type": "number"}, + {"type": "boolean"}, + {"type": "null"} + # { "type": "array" }, # TODO: enable arrays within object-types + ] + + # We set the object depth to 2 to keep the expression finite, but the "depth" + # key is not a true component of the JSON Schema specification. + depth = instance.get("depth", 2) + if depth > 0: + legal_values.append({"type": "object", "depth": depth - 1}) + additional_properties = {"anyOf": legal_values} + value_pattern = to_regex( - resolver, instance["additionalProperties"], whitespace_pattern + resolver, additional_properties, whitespace_pattern ) key_value_pattern = ( f"{STRING}{whitespace_pattern}:{whitespace_pattern}{value_pattern}" diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index edc061bec..b12f9576e 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -710,6 +710,19 @@ def test_format(schema, regex, examples): ('{"time":20:20:39Z}', False), # missing quotes for value ], ), + # Unconstrained Object + ( + { + "title": "Foo", + "type": "object", + }, + [ + ("{}", True), + ('{"a": 1, "b": null}', True), + ('{"a": {"z": {"g": 4}}, "b": null}', True), + ("1234", False), # not an object + ], + ), ], ) def test_format_without_regex(schema, examples): From ffab2ac4fa273b1b63702d63d3e42035be023017 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 23 May 2024 01:43:46 -0500 Subject: [PATCH 18/35] Use TQDM to track index compilation progress --- .pre-commit-config.yaml | 1 + outlines/fsm/regex.py | 13 ++++++++++++- pyproject.toml | 1 + 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8fe83f89a..b528f0e8e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,3 +30,4 @@ repos: - id: mypy args: [--allow-redefinition] exclude: ^examples/ + additional_dependencies: [types-tqdm] diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index 0941bbb9f..b68e31897 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -26,6 +26,7 @@ anything_else, ) from numba.typed.typedobjectutils import _nonoptional +from tqdm import tqdm if TYPE_CHECKING: from outlines.models.tokenizer import Tokenizer @@ -692,6 +693,12 @@ def create_fsm_index_end_to_end( seen: Set[int] = set() next_states = {fsm_info.initial} + pbar = tqdm( + total=len(set(fsm_info.transitions.values())) + + 1, # all transitions plus initial + desc="Compiling FSM index for all state transitions", + ) + while next_states: start_state = next_states.pop() @@ -713,7 +720,11 @@ def create_fsm_index_end_to_end( if end_state not in seen: next_states.add(end_state) - seen.add(start_state) + if start_state not in seen: + pbar.update(1) + seen.add(start_state) + + pbar.close() return states_to_token_subsets diff --git a/pyproject.toml b/pyproject.toml index 41c306b14..b18036ffc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "referencing", "jsonschema", "requests", + "tqdm" ] dynamic = ["version"] From d7c970766e0f670ba2f55b5a126b2327d2aabde7 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 22 May 2024 14:54:33 -0500 Subject: [PATCH 19/35] Refactor cache decorator These changes make the `cache` decorator operate more like `diskcache`'s existing `memoize` method. They also remove the use of hash value as store keys. --- outlines/caching.py | 129 ++++++++++++++++++++++++++++---------------- 1 file changed, 84 insertions(+), 45 deletions(-) diff --git a/outlines/caching.py b/outlines/caching.py index 68207a0e4..52d66af74 100644 --- a/outlines/caching.py +++ b/outlines/caching.py @@ -1,15 +1,40 @@ import asyncio import functools -import hashlib import os from typing import Callable, Optional import cloudpickle -from diskcache import Cache +from diskcache import Cache, Disk +from diskcache.core import ENOVAL, UNKNOWN, args_to_key, full_name _caching_enabled = True +class CloudpickleDisk(Disk): + def __init__(self, directory, compress_level=1, **kwargs): + self.compress_level = compress_level + super().__init__(directory, **kwargs) + + def put(self, key): + data = cloudpickle.dumps(key) + return super().put(data) + + def get(self, key, raw): + data = super().get(key, raw) + return cloudpickle.loads(data) + + def store(self, value, read, key=UNKNOWN): + if not read: + value = cloudpickle.dumps(value) + return super().store(value, read, key=key) + + def fetch(self, mode, filename, value, read): + data = super().fetch(mode, filename, value, read) + if not read: + data = cloudpickle.loads(data) + return data + + @functools.lru_cache(1) def get_cache(): """Get the context object that contains previously-computed return values. @@ -26,7 +51,12 @@ def get_cache(): home_dir = os.path.expanduser("~") cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines") - memory = Cache(cache_dir, eviction_policy="none", cull_limit=0) + memory = Cache( + cache_dir, + eviction_policy="none", + cull_limit=0, + disk=CloudpickleDisk, + ) # ensure if version upgrade occurs, old cache is pruned if outlines_version != memory.get("__version__"): @@ -36,63 +66,72 @@ def get_cache(): return memory -def hash_arguments(*args, **kwargs) -> str: - """Create a hash out of the args and kwargs provided""" - result = hashlib.md5() - for item in list(args) + sorted(kwargs.items()): - result.update(cloudpickle.dumps(item)) - return result.hexdigest() - - -def cache(key_function: Optional[Callable] = None): +def cache(expire: Optional[float] = None, typed=False, ignore=()): """Caching decorator for memoizing function calls. + The cache key is created based on the values returned by the key_function callable if provided or based on the arguments of the decorated function directly otherwise + + This is based on `diskcache`'s `memoize`. + Parameters ---------- - key_function - A callable function used to generate a unique key for each function call. It's - called with the arguments of the decorated function as arguments + expire + Seconds until arguments expire. + typed + Cache different types separately. + ignore + Positional or keyword arguments to ignore. + Returns ------- - A decorator function that can be applied to other functions. + A decorator function that can be applied to other functions. """ def decorator(cached_function: Callable): memory = get_cache() - def wrapper(*args, **kwargs): - if not _caching_enabled: - return cached_function(*args, **kwargs) - if key_function: - key_args = key_function(*args, **kwargs) - cache_key = hash_arguments(*key_args) - else: - cache_key = hash_arguments(*args, **kwargs) - if cache_key in memory: - return memory[cache_key] - result = cached_function(*args, **kwargs) - memory[cache_key] = result - return result - - async def async_wrapper(*args, **kwargs): - if not _caching_enabled: - return await cached_function(*args, **kwargs) - if key_function: - key_args = key_function(*args, **kwargs) - cache_key = hash_arguments(*key_args) - else: - cache_key = hash_arguments(*args, **kwargs) - if cache_key in memory: - return memory[cache_key] - result = await cached_function(*args, **kwargs) - memory[cache_key] = result - return result + base = (full_name(cached_function),) if asyncio.iscoroutinefunction(cached_function): - return async_wrapper + + async def wrapper(*args, **kwargs): + if not _caching_enabled: + return await cached_function(*args, **kwargs) + + cache_key = wrapper.__cache_key__(*args, **kwargs) + result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True) + + if result is ENOVAL: + result = await cached_function(*args, **kwargs) + wrapper.__memory__.set(cache_key, result, expire, retry=True) + + return result + else: - return wrapper + + def wrapper(*args, **kwargs): + if not _caching_enabled: + return cached_function(*args, **kwargs) + + cache_key = wrapper.__cache_key__(*args, **kwargs) + result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True) + + if result is ENOVAL: + result = cached_function(*args, **kwargs) + wrapper.__memory__.set(cache_key, result, expire, retry=True) + + return result + + def __cache_key__(*args, **kwargs): + """Make key for cache given function arguments.""" + return args_to_key(base, args, kwargs, typed, ignore) + + wrapper.__cache_key__ = __cache_key__ # type: ignore + wrapper.__memory__ = memory # type: ignore + wrapper.__wrapped__ = cached_function # type: ignore + + return wrapper return decorator From ba7affd92883aaa5ceacfba7ef5abd5a27b2eb26 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 21 May 2024 18:36:17 -0500 Subject: [PATCH 20/35] Use a persistent Tokenizer hash for create_states_mapping cache --- outlines/fsm/guide.py | 56 +++++++-------- outlines/models/transformers.py | 11 ++- pyproject.toml | 4 +- .../generate/test_integration_transformers.py | 68 ++++++++++++++++++- tests/models/test_transformers.py | 14 +++- 5 files changed, 117 insertions(+), 36 deletions(-) diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index 2833fce1a..d247db62b 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -105,44 +105,44 @@ def copy(self): return self +@cache() +def create_states_mapping( + regex_string: str, tokenizer: "Tokenizer" +) -> Tuple[dict, set, set]: + """Create the variables related to the mapping between states and tokens + The parameters of the function are used for caching purpose + """ + regex_pattern = interegular.parse_pattern(regex_string) + byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True) + regex_fsm, _ = make_deterministic_fsm(byte_fsm) + states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( + regex_fsm, tokenizer + ) + + # We make sure that it is possible to generate strings in the language + # of the regular expression with the tokens present in the model's + # vocabulary. + if not any( + regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values() + ): + raise ValueError( + "The vocabulary does not allow us to build a sequence that matches the input regex" + ) + + return states_to_token_maps, empty_token_ids, regex_fsm.finals + + class RegexGuide(Guide): """Guide to generate text in the language of a regular expression.""" initial_state = 0 def __init__(self, regex_string: str, tokenizer): - @cache() - def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]: - """Create the variables related to the mapping between states and tokens - The parameters of the function are used for caching purpose - """ - regex_pattern = interegular.parse_pattern(regex_string) - byte_fsm = make_byte_level_fsm( - regex_pattern.to_fsm().reduce(), keep_utf8=True - ) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) - - # We make sure that it is possible to generate strings in the language - # of the regular expression with the tokens present in the model's - # vocabulary. - if not any( - regex_fsm.finals.intersection(v.values()) - for v in states_to_token_maps.values() - ): - raise ValueError( - "The vocabulary does not allow us to build a sequence that matches the input regex" - ) - - return states_to_token_maps, empty_token_ids, regex_fsm.finals - ( self.states_to_token_maps, self.empty_token_ids, fsm_finals, - ) = create_states_mapping(regex_string) + ) = create_states_mapping(regex_string, tokenizer) self.eos_token_id = tokenizer.eos_token_id self.final_states = fsm_finals | {-1} diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 3bc59412e..fae9b8e74 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -1,5 +1,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from datasets.fingerprint import Hasher + from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: @@ -109,10 +111,15 @@ def __eq__(self, other): return NotImplemented def __hash__(self): - from datasets.fingerprint import Hasher - return hash(Hasher.hash(self.tokenizer)) + def __getstate__(self): + state = {"tokenizer": self.tokenizer} + return state + + def __setstate__(self, state): + self.__init__(state["tokenizer"]) + class Transformers: """Represents a `transformers` model.""" diff --git a/pyproject.toml b/pyproject.toml index b18036ffc..0b310c44b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,8 @@ dependencies = [ "referencing", "jsonschema", "requests", - "tqdm" + "tqdm", + "datasets", ] dynamic = ["version"] @@ -50,7 +51,6 @@ test = [ "diff-cover", "accelerate", "beartype<0.16.0", - "datasets", "responses", "llama-cpp-python", "huggingface_hub", diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index 38525a076..cee3ca312 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -1,6 +1,7 @@ import datetime import re from enum import Enum +from importlib import reload from typing import List, Union import pytest @@ -11,7 +12,28 @@ import outlines.models as models from outlines.fsm.regex import reduced_vocabulary from outlines.models.transformers import Transformers, TransformerTokenizer -from outlines.samplers import beam_search, multinomial +from outlines.samplers import beam_search, greedy, multinomial + + +@pytest.fixture +def temp_cache_dir(): + import os + import tempfile + + import outlines.caching + import outlines.fsm.guide + + with tempfile.TemporaryDirectory() as tempdir: + os.environ["OUTLINES_CACHE_DIR"] = tempdir + outlines.caching.get_cache.cache_clear() + reload(outlines) + reload(outlines.fsm.guide) + cache_status = outlines.caching._caching_enabled + try: + outlines.caching._caching_enabled = True + yield + finally: + outlines.caching._caching_enabled = cache_status def test_transformers_integration_text(): @@ -632,3 +654,47 @@ def test_transformers_use_existing_model_and_tokenizer(): model = Transformers(hf_model, hf_tokenizer) sequence = generate.text(model)("Write a short sentence ", rng=rng) assert isinstance(sequence, str) + + +def test_RegexGuide_caching(temp_cache_dir): + import outlines.caching + from outlines.fsm.guide import create_states_mapping + + assert outlines.caching._caching_enabled + + regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + prompt = "What is the IP address of the Google DNS servers? " + + cache = outlines.caching.get_cache() + + # Returns (hits, misses) + _ = cache.stats(enable=True) + assert cache.statistics + + assert create_states_mapping.__memory__ is cache + + model = models.transformers( + "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM" + ) + generator = generate.regex(model, regex, sampler=greedy()) + assert cache.stats() == (0, 1) + + model_2 = models.transformers("hf-internal-testing/tiny-random-GPTJForCausalLM") + generator_2 = generate.regex(model_2, regex, sampler=greedy()) + assert cache.stats() == (0, 2) + + # These two different models and tokenizers should not have the same state + # mapping results + assert generator.fsm.states_to_token_maps != generator_2.fsm.states_to_token_maps + + generator_3 = generate.regex(model_2, regex, sampler=greedy()) + assert cache.stats() == (1, 2) + assert generator_2.fsm.states_to_token_maps == generator_3.fsm.states_to_token_maps + + # Just for fun... + structured = generator(prompt, max_tokens=30) + structured_2 = generator_2(prompt, max_tokens=30) + + assert re.fullmatch(regex, structured) + assert re.fullmatch(regex, structured_2) + assert structured != structured_2 diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index b4e410096..f4596a2df 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -107,6 +107,14 @@ def test_tokenizer_eq_hash(): tokenizer_hf = AutoTokenizer.from_pretrained("gpt2") tokenizer = TransformerTokenizer(tokenizer_hf) - tokenizer2 = TransformerTokenizer(tokenizer_hf) - assert tokenizer == tokenizer2 - assert hash(tokenizer) == hash(tokenizer2) + tokenizer_2 = TransformerTokenizer(tokenizer_hf) + + assert tokenizer == tokenizer_2 + assert hash(tokenizer) == hash(tokenizer_2) + + tokenizer_hf_2 = AutoTokenizer.from_pretrained("gpt2") + tokenizer_hf_2.add_tokens(["test_token"]) + + tokenizer_3 = TransformerTokenizer(tokenizer_hf_2) + assert tokenizer != tokenizer_3 + assert hash(tokenizer) != hash(tokenizer_3) From 411eaaf2d9426a56f7fdb99a1d7073dadd463806 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 24 May 2024 01:12:31 -0500 Subject: [PATCH 21/35] Use less problematic whitespace token (#916) Fixes #839 #908 #690 #450 ## Problem A major problem, especially with smaller language models, is the repetition problem. For example, let's say a model is generating json and must provide 12 space tokens for indentation in json output. Often a language model will assign a high probability to a 13th space token, and do the same for a 14th space, and then enter an infinite space generation loop. This is a problem with NLG that has been known for half a decade, but only has mitigations (mirostat, repetition penalty, using hundreds of billions of weights, etc), no absolute solutions (except for **structured generation**) ## Solution For structured json generation, we set a sane default whitespace pattern of `r"[ ]?"`. This removes all newlines and indentation. It disallows any syntactic whitespace beyond a single space separator. Users can still set the argument `whitespace_pattern=` if they want different behavior --- docs/reference/json.md | 4 ++-- outlines/fsm/json_schema.py | 2 +- tests/fsm/test_json_schema.py | 45 ++++++++++++----------------------- 3 files changed, 18 insertions(+), 33 deletions(-) diff --git a/docs/reference/json.md b/docs/reference/json.md index 3b5976f19..85e1a846a 100644 --- a/docs/reference/json.md +++ b/docs/reference/json.md @@ -36,10 +36,10 @@ print(result) !!! Note "JSON and whitespaces" - By default Outlines lets model choose the number of linebreaks and white spaces used to structure the JSON. Small models tend to struggle with this, in which case we recommend to set the value of the parameter `whitespace_pattern` to the empty string: + By default Outlines prevents the model from generating json with syntactic newlines, tabs, or multiple spaces. The default `whitespace_pattern` is `r"[ ]?"`. Small models tend to enter an infinite repetition loop if the `whitespace_pattern` allows infinite spacing. If you would like to allow the model to generate multiple tabs, newlines, and spaces, you can set the whitespace pattern as follows: ```python - generator = generate.json(model, User, whitespace_pattern="") + generator = generate.json(model, User, whitespace_pattern=r"[\n\t ]*") ``` !!! Note "Performance" diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index dbd2baa40..0e0d25bfc 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -16,7 +16,7 @@ NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?" BOOLEAN = r"(true|false)" NULL = r"null" -WHITESPACE = r"[\n ]*" +WHITESPACE = r"[ ]?" type_to_regex = { "string": STRING, diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index b12f9576e..bc836ac8b 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -215,8 +215,8 @@ def test_match_number(pattern, does_match): "properties": {"count": {"title": "Count", "type": "integer"}}, "required": ["count"], }, - '\\{[\\n ]*"count"[\\n ]*:[\\n ]*(-)?(0|[1-9][0-9]*)[\\n ]*\\}', - [('{\n "count": 100\n}', True)], + '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?\\}', + [('{ "count": 100 }', True)], ), # array ( @@ -277,7 +277,7 @@ def test_match_number(pattern, does_match): rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", [ ("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True), - ("""{ "test_dict":{"foo":"bar"\n}}""", True), + ("""{ "test_dict":{"foo":"bar" }}""", True), ("""{ "test_dict":{}}""", True), ("""{ "WRONG_KEY":{}}""", False), ("""{ "test_dict":{"wrong_type" 1}}""", False), @@ -369,8 +369,8 @@ def test_match_number(pattern, does_match): }, "required": ["fuzz"], }, - f'\\{{[\\n ]*"fuzz"[\\n ]*:[\\n ]*\\{{[\\n ]*"spam"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*\\}}[\\n ]*\\}}', - [('{\n "fuzz": {\n "spam": 100\n }\n}', True)], + f'\\{{[ ]?"fuzz"[ ]?:[ ]?\\{{[ ]?"spam"[ ]?:[ ]?{INTEGER}[ ]?\\}}[ ]?\\}}', + [('{ "fuzz": { "spam": 100 }}', True)], ), # Schema with a reference ( @@ -384,7 +384,7 @@ def test_match_number(pattern, does_match): }, "required": ["user_id", "name", "a"], }, - f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"a"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"a"[ ]?:[ ]?{STRING}[ ]?\\}}', [('{"user_id": 100, "name": "John", "a": "Marc"}', True)], ), ( @@ -399,7 +399,7 @@ def test_match_number(pattern, does_match): }, "required": ["user_id", "name", "name2"], }, - f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"name2"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"name2"[ ]?:[ ]?{STRING}[ ]?\\}}', [('{"user_id": 100, "name": "John", "name2": "Marc"}', True)], ), ( @@ -441,7 +441,7 @@ def test_match_number(pattern, does_match): } }, }, - f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"last_name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"address"[\\n ]*:[\\n ]*\\{{[\\n ]*"city"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}[\\n ]*\\}}', + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"last_name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"address"[ ]?:[ ]?\\{{[ ]?"city"[ ]?:[ ]?{STRING}[ ]?\\}}[ ]?\\}}', [ ( '{"name": "John", "last_name": "Doe", "address": {"city": "Paris"}}', @@ -462,7 +462,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}([\\n ]*,[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null))?([\\n ]*,[\\n ]*"weapon"[\\n ]*:[\\n ]*({STRING}|null))?[\\n ]*\\}}', + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"weapon"[ ]?:[ ]?({STRING}|null))?[ ]?\\}}', [ ('{ "name" : "Player" }', True), ('{ "name" : "Player", "weapon" : "sword" }', True), @@ -482,7 +482,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,([\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"weapon"[\\n ]*:[\\n ]*{STRING}([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?[\\n ]*\\}}', + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}', [ ('{ "name" : "Player" , "weapon" : "sword" }', True), ( @@ -506,7 +506,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?[\\n ]*"age"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"armor"[\\n ]*:[\\n ]*{STRING}[\\n ]*,([\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"weapon"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"armor"[ ]?:[ ]?{STRING}[ ]?,([ ]?"strength"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}[ ]?\\}}', [ ( '{ "name" : "Player", "age" : 10, "armor" : "plate", "strength" : 11, "weapon" : "sword" }', @@ -530,7 +530,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)([\\n ]*,[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null))?([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?|([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?|([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?([\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?[\\n ]*\\}}', + f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?({INTEGER}|null)([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}', [ ('{ "name" : "Player" }', True), ('{ "name" : "Player", "age" : 10, "strength" : 10 }', True), @@ -710,19 +710,6 @@ def test_format(schema, regex, examples): ('{"time":20:20:39Z}', False), # missing quotes for value ], ), - # Unconstrained Object - ( - { - "title": "Foo", - "type": "object", - }, - [ - ("{}", True), - ('{"a": 1, "b": null}', True), - ('{"a": {"z": {"g": 4}}, "b": null}', True), - ("1234", False), # not an object - ], - ), ], ) def test_format_without_regex(schema, examples): @@ -737,7 +724,7 @@ def test_format_without_regex(schema, examples): assert match is None -@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]?", "abc"]) +@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]*", "abc"]) def test_json_schema_custom_whitespace_pattern(whitespace_pattern): """assert whitespace_pattern setting respected""" @@ -759,13 +746,11 @@ class MockModel(BaseModel): ) mock_result_maybe_ws = """{"foo" : 4 ,"bar":"baz baz baz bar"}""" - match_default_ws = re.fullmatch(pattern, mock_result_mult_ws) + match_default_ws = re.fullmatch(pattern, mock_result_maybe_ws) if whitespace_pattern is None: assert match_default_ws else: - assert match_default_ws is None - - assert re.fullmatch(pattern, mock_result_maybe_ws) + assert re.fullmatch(pattern, mock_result_mult_ws) def test_one_of_doesnt_produce_illegal_lookaround(): From e5c39e25f7b77dbb8c4140e23135650bc4d769e8 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 22 May 2024 17:38:25 -0500 Subject: [PATCH 22/35] Enable Tuples / prefixItems in build_regex_from_schema() --- outlines/fsm/json_schema.py | 9 +++++++++ tests/fsm/test_json_schema.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 0e0d25bfc..37f04f610 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -199,6 +199,15 @@ def to_regex( return rf"({'|'.join(xor_patterns)})" + # Create pattern for Tuples, per JSON Schema spec, `prefixItems` determines types at each idx + elif "prefixItems" in instance: + element_patterns = [ + to_regex(resolver, t, whitespace_pattern) for t in instance["prefixItems"] + ] + comma_split_pattern = rf"{whitespace_pattern},{whitespace_pattern}" + tuple_inner = comma_split_pattern.join(element_patterns) + return rf"\[{whitespace_pattern}{tuple_inner}{whitespace_pattern}\]" + # The enum keyword is used to restrict a value to a fixed set of values. It # must be an array with at least one element, where each element is unique. elif "enum" in instance: diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index bc836ac8b..2a5a311d2 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -354,6 +354,15 @@ def test_match_number(pattern, does_match): rf"({STRING}{INTEGER})", [('"a"1', True), ('"a"', False), ('"1"', False)], ), + # Tuple / prefixItems + ( + { + "title": "Foo", + "prefixItems": [{"type": "string"}, {"type": "integer"}], + }, + rf"\[{WHITESPACE}{STRING}{WHITESPACE},{WHITESPACE}{INTEGER}{WHITESPACE}\]", + [('["a", 1]', True), ('["a", 1, 1]', False), ("[]", False)], + ), # Nested schema ( { From 538f77a3363e092f06c407b68ae51dfcca2a79f0 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sat, 25 May 2024 01:17:56 -0500 Subject: [PATCH 23/35] Fix invalid regex in unconstrained arrays for json_schema.py --- outlines/fsm/json_schema.py | 23 +++++++++++++------- tests/fsm/test_json_schema.py | 40 +++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 8 deletions(-) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 37f04f610..ac819bc07 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -297,15 +297,22 @@ def to_regex( # Here we need to make the choice to exclude generating list of objects # if the specification of the object is not given, even though a JSON # object that contains an object here would be valid under the specification. - types = [ + legal_types = [ {"type": "boolean"}, {"type": "null"}, {"type": "number"}, {"type": "integer"}, {"type": "string"}, ] - regexes = [to_regex(resolver, t, whitespace_pattern) for t in types] - return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}){allow_empty}{whitespace_pattern}\]" + depth = instance.get("depth", 2) + if depth > 0: + legal_types.append({"type": "object", "depth": depth - 1}) + legal_types.append({"type": "array", "depth": depth - 1}) + + regexes = [ + to_regex(resolver, t, whitespace_pattern) for t in legal_types + ] + return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}{allow_empty}{whitespace_pattern}\]" elif instance_type == "object": # pattern for json object with values defined by instance["additionalProperties"] @@ -328,20 +335,20 @@ def to_regex( # unset or True, it is unconstrained object. # We handle this by setting additionalProperties to anyOf: {all types} - legal_values = [ + legal_types = [ {"type": "string"}, {"type": "number"}, {"type": "boolean"}, - {"type": "null"} - # { "type": "array" }, # TODO: enable arrays within object-types + {"type": "null"}, ] # We set the object depth to 2 to keep the expression finite, but the "depth" # key is not a true component of the JSON Schema specification. depth = instance.get("depth", 2) if depth > 0: - legal_values.append({"type": "object", "depth": depth - 1}) - additional_properties = {"anyOf": legal_values} + legal_types.append({"type": "object", "depth": depth - 1}) + legal_types.append({"type": "array", "depth": depth - 1}) + additional_properties = {"anyOf": legal_types} value_pattern = to_regex( resolver, additional_properties, whitespace_pattern diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 2a5a311d2..efc5c054f 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -719,6 +719,46 @@ def test_format(schema, regex, examples): ('{"time":20:20:39Z}', False), # missing quotes for value ], ), + # Unconstrained Object + ( + { + "title": "Foo", + "type": "object", + }, + [ + ("{}", True), + ('{"a": 1, "b": null}', True), + ('{"a": {"z": {"g": 4}}, "b": null}', True), + ("1234", False), # not an object + ('["a", "a"]', False), # not an array + ], + ), + # Unconstrained Array + ( + { + "type": "array", + }, + [ + ("[1, {}, false]", True), + ("[{}]", True), + ('[{"a": {"z": "q"}, "b": null}]', True), + ('[{"a": [1, 2, true], "b": null}]', True), + ('[{"a": [1, 2, true], "b": {"a": "b"}}, 1, true, [1, [2]]]', True), + # too deep, default unconstrained depth limit = 2 + ( + '[{"a": [1, 2, true], "b": {"a": "b"}}, 1, true, [1, [2, [3]]]]', + False, + ), + ('[{"a": {"z": {"g": 4}}, "b": null}]', False), + ("[[[[1]]]]", False), + # not an array + ("{}", False), + ('{"a": 1, "b": null}', False), + ('{"a": {"z": {"g": 4}}, "b": null}', False), + ("1234", False), # not an array + ('{"a": "a"}', False), # not an array + ], + ), ], ) def test_format_without_regex(schema, examples): From 7723ce8d091230249db8840463c5cdefdb408b8e Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 30 May 2024 10:14:11 -0500 Subject: [PATCH 24/35] allow json schema of {}, resulting in unconstrained json value --- outlines/fsm/json_schema.py | 17 ++++++++++++++++- tests/fsm/test_json_schema.py | 14 ++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index ac819bc07..3bd4816a9 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -126,7 +126,22 @@ def to_regex( if whitespace_pattern is None: whitespace_pattern = WHITESPACE - if "properties" in instance: + if instance == {}: + # JSON Schema Spec: Empty object means unconstrained, any json type is legal + types = [ + {"type": "boolean"}, + {"type": "null"}, + {"type": "number"}, + {"type": "integer"}, + {"type": "string"}, + {"type": "array"}, + {"type": "object"}, + ] + regexes = [to_regex(resolver, t, whitespace_pattern) for t in types] + regexes = [rf"({r})" for r in regexes] + return rf"{'|'.join(regexes)}" + + elif "properties" in instance: regex = "" regex += r"\{" properties = instance["properties"] diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index efc5c054f..e691db374 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -759,6 +759,20 @@ def test_format(schema, regex, examples): ('{"a": "a"}', False), # not an array ], ), + # No schema / unconstrained value + ( + {}, + [ + ('"aaabbuecuh"', True), # string + ("5.554", True), # number + ("true", True), # boolean + ("null", True), # null + ("5999", True), # integer + ('["a", "b"]', True), # array + ('{"key": {"k2": "value"}}', True), # nested object + ("this isnt valid json", False), + ], + ), ], ) def test_format_without_regex(schema, examples): From 6696cb564e81d85a7f049dbc983e6ffe52661135 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 29 May 2024 12:31:47 -0500 Subject: [PATCH 25/35] Fix llamacpp caching by making LlamaCppTokenizer pickleable --- outlines/integrations/llamacpp.py | 10 ++++++++++ tests/generate/test_integration_llamacpp.py | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/outlines/integrations/llamacpp.py b/outlines/integrations/llamacpp.py index 4041c54fb..74498726d 100644 --- a/outlines/integrations/llamacpp.py +++ b/outlines/integrations/llamacpp.py @@ -66,6 +66,16 @@ def __init__(self, model: "Llama"): def convert_token_to_string(self, token: str) -> str: return token + def __getstate__(self): + """Allow tokenizer to be used as hash key by excluding self.decode""" + return ( + self.vocabulary.items(), + self.eos_token_id, + self.eos_token, + self.pad_token_id, + sorted(self.special_tokens), + ) + class LogitsProcessor: """Bias LlamaCpp generation using a finite state machine. diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index 75d0e4cef..531bf8fb9 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -279,3 +279,11 @@ def test_llama_cpp_pre_tokenizer_remains_broken(): model = models.llamacpp(repo, model_path) with pytest.raises(RuntimeError): generate.choice(model, ["skirt", "dress", "pen", "jacket"]) + + +def test_create_states_mapping_llamacpp_tokenizer_regression(model): + """Minimal reproducer for #922, error passing llamacpp tokenizer to create_states_mapping""" + from outlines.fsm.guide import create_states_mapping + from outlines.integrations.llamacpp import LlamaCppTokenizer + + create_states_mapping("a", LlamaCppTokenizer(model.model)) From 3a7d83b89afcf6a3ecd53b134bf226c5041d674d Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 31 May 2024 11:28:44 -0500 Subject: [PATCH 26/35] make LlamaCppTokenizer an outlines Tokenizer --- outlines/integrations/llamacpp.py | 39 +------- outlines/models/llamacpp.py | 89 ++++++++++++++++++- tests/generate/conftest.py | 24 +++++ tests/generate/test_integration_llamacpp.py | 55 +++++++++++- .../generate/test_integration_transformers.py | 22 ----- 5 files changed, 165 insertions(+), 64 deletions(-) create mode 100644 tests/generate/conftest.py diff --git a/outlines/integrations/llamacpp.py b/outlines/integrations/llamacpp.py index 74498726d..8e18c33e7 100644 --- a/outlines/integrations/llamacpp.py +++ b/outlines/integrations/llamacpp.py @@ -26,7 +26,7 @@ """ import math -from typing import TYPE_CHECKING, Dict, Optional, Set, Type, Union +from typing import TYPE_CHECKING, Optional, Type, Union import numpy as np import torch @@ -36,47 +36,12 @@ from outlines.fsm.guide import CFGGuide, Guide, RegexGuide from outlines.fsm.json_schema import build_regex_from_schema from outlines.integrations.utils import convert_json_schema_to_str +from outlines.models.llamacpp import LlamaCppTokenizer if TYPE_CHECKING: from llama_cpp import Llama -class LlamaCppTokenizer: - def __init__(self, model: "Llama"): - self.eos_token_id = model.token_eos() - self.eos_token = model.tokenizer().decode([self.eos_token_id]) - self.pad_token_id = self.eos_token_id - self.special_tokens: Set[int] = set() - - self.vocabulary: Dict[str, int] = dict() - - tokenizer = model.tokenizer() - - self.decode = tokenizer.decode - - # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved - try: - self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab() - except AttributeError: - # ### - for t in range(model.n_vocab()): - token_piece = model.tokenizer().decode([t]) - self.vocabulary[token_piece] = t - - def convert_token_to_string(self, token: str) -> str: - return token - - def __getstate__(self): - """Allow tokenizer to be used as hash key by excluding self.decode""" - return ( - self.vocabulary.items(), - self.eos_token_id, - self.eos_token, - self.pad_token_id, - sorted(self.special_tokens), - ) - - class LogitsProcessor: """Bias LlamaCpp generation using a finite state machine. diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 5920f08d6..840e1364f 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -1,15 +1,102 @@ import dataclasses +import pickle import warnings -from typing import TYPE_CHECKING, Iterator, List, Optional, TypedDict, Union +from typing import ( + TYPE_CHECKING, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + TypedDict, + Union, +) from typing_extensions import Unpack from outlines.generate.api import GenerationParameters, SamplingParameters +from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: from llama_cpp import Llama, LogitsProcessorList +class LlamaCppTokenizer(Tokenizer): + def __init__(self, model: "Llama"): + self.eos_token_id = model.token_eos() + self.eos_token = model.tokenizer().decode([self.eos_token_id]) + self.pad_token_id = self.eos_token_id + self.special_tokens: Set[int] = set() + + self.vocabulary: Dict[str, int] = dict() + + self.tokenizer = model.tokenizer() + + # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved + try: + self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab() + except AttributeError: + # ### + for t in range(model.n_vocab()): + token_piece = model.tokenizer().decode([t]) + self.vocabulary[token_piece] = t + + # ensure stable ordering of vocabulary + self.vocabulary = { + tok: tok_id + for tok, tok_id in sorted(self.vocabulary.items(), key=lambda x: x[1]) + } + + self._hash = None + + def decode(self, token_ids: List[int]) -> List[str]: + decoded_bytes = self.tokenizer.detokenize(token_ids) + return [decoded_bytes.decode("utf-8", errors="ignore")] + + def encode( + self, prompt: Union[str, List[str]], add_bos: bool = True, special: bool = True + ) -> Tuple[List[int], List[int]]: + if isinstance(prompt, list): + raise NotImplementedError( + "llama-cpp-python tokenizer doesn't support batch tokenization" + ) + token_ids = self.tokenizer.tokenize( + prompt.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special + ) + # generate attention mask, missing from llama-cpp-python + attention_mask = [ + 1 if token_id != self.pad_token_id else 0 for token_id in token_ids + ] + return token_ids, attention_mask + + def convert_token_to_string(self, token: str) -> str: + return token + + def __eq__(self, other): + if not isinstance(other, LlamaCppTokenizer): + return False + return self.__getstate__() == other.__getstate__() + + def __hash__(self): + if self._hash is None: + self._hash = hash(pickle.dumps(self)) + return self._hash + + def __getstate__(self): + """Create a stable representation for outlines.caching""" + return ( + self.vocabulary, + self.eos_token_id, + self.eos_token, + self.pad_token_id, + sorted(self.special_tokens), + ) + + def __setstate__(self, state): + raise NotImplementedError("Cannot load a pickled llamacpp tokenizer") + + class LlamaCppParams(TypedDict, total=False): suffix: Optional[str] temperature: float diff --git a/tests/generate/conftest.py b/tests/generate/conftest.py new file mode 100644 index 000000000..ef7e40eed --- /dev/null +++ b/tests/generate/conftest.py @@ -0,0 +1,24 @@ +from importlib import reload + +import pytest + + +@pytest.fixture +def temp_cache_dir(): + import os + import tempfile + + import outlines.caching + import outlines.fsm.guide + + with tempfile.TemporaryDirectory() as tempdir: + os.environ["OUTLINES_CACHE_DIR"] = tempdir + outlines.caching.get_cache.cache_clear() + reload(outlines) + reload(outlines.fsm.guide) + cache_status = outlines.caching._caching_enabled + try: + outlines.caching._caching_enabled = True + yield + finally: + outlines.caching._caching_enabled = cache_status diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index 531bf8fb9..b7eb8b3cb 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -281,9 +281,56 @@ def test_llama_cpp_pre_tokenizer_remains_broken(): generate.choice(model, ["skirt", "dress", "pen", "jacket"]) -def test_create_states_mapping_llamacpp_tokenizer_regression(model): - """Minimal reproducer for #922, error passing llamacpp tokenizer to create_states_mapping""" +def test_RegexGuide_caching(model, temp_cache_dir): + import llama_cpp + + import outlines.caching from outlines.fsm.guide import create_states_mapping - from outlines.integrations.llamacpp import LlamaCppTokenizer - create_states_mapping("a", LlamaCppTokenizer(model.model)) + assert outlines.caching._caching_enabled + + regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + prompt = "What is the IP address of the Google DNS servers? " + + cache = outlines.caching.get_cache() + + # Returns (hits, misses) + _ = cache.stats(enable=True) + assert cache.statistics + + assert create_states_mapping.__memory__ is cache + + generator = generate.regex(model, regex, sampler=samplers.greedy()) + assert cache.stats() == (0, 1) + + model_2 = models.llamacpp( + "Qwen/Qwen1.5-0.5B-Chat-GGUF", + "*q2*.gguf", + tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( + "Qwen/Qwen1.5-0.5B-Chat" + ), + ) + generator_2 = generate.regex(model_2, regex, sampler=samplers.greedy()) + assert cache.stats() == (0, 2) + + # These two different models and tokenizers should not have the same state + # mapping results + assert ( + generator.logits_processor.fsm.states_to_token_maps + != generator_2.logits_processor.fsm.states_to_token_maps + ) + + generator_3 = generate.regex(model_2, regex, sampler=samplers.greedy()) + assert cache.stats() == (1, 2) + assert ( + generator_2.logits_processor.fsm.states_to_token_maps + == generator_3.logits_processor.fsm.states_to_token_maps + ) + + # Just for fun... + structured = generator(prompt, max_tokens=30) + structured_2 = generator_2(prompt, max_tokens=30) + + assert re.fullmatch(regex, structured) + assert re.fullmatch(regex, structured_2) + assert structured != structured_2 diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index cee3ca312..da08bed71 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -1,7 +1,6 @@ import datetime import re from enum import Enum -from importlib import reload from typing import List, Union import pytest @@ -15,27 +14,6 @@ from outlines.samplers import beam_search, greedy, multinomial -@pytest.fixture -def temp_cache_dir(): - import os - import tempfile - - import outlines.caching - import outlines.fsm.guide - - with tempfile.TemporaryDirectory() as tempdir: - os.environ["OUTLINES_CACHE_DIR"] = tempdir - outlines.caching.get_cache.cache_clear() - reload(outlines) - reload(outlines.fsm.guide) - cache_status = outlines.caching._caching_enabled - try: - outlines.caching._caching_enabled = True - yield - finally: - outlines.caching._caching_enabled = cache_status - - def test_transformers_integration_text(): rng = torch.Generator() rng.manual_seed(10000) # Choosen so is generated From 7e5822620d295965e0756162036f8930d96e5ac6 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sat, 1 Jun 2024 15:45:02 -0500 Subject: [PATCH 27/35] Verify Wheel Build and SDist Test in PR Tests Workflow --- .github/scripts/build_sdist_and_wheel.sh | 22 +++++++++++++++++++++ .github/workflows/release_pypi.yaml | 25 ++---------------------- .github/workflows/tests.yml | 8 ++++++++ 3 files changed, 32 insertions(+), 23 deletions(-) create mode 100755 .github/scripts/build_sdist_and_wheel.sh diff --git a/.github/scripts/build_sdist_and_wheel.sh b/.github/scripts/build_sdist_and_wheel.sh new file mode 100755 index 000000000..ca770f5b7 --- /dev/null +++ b/.github/scripts/build_sdist_and_wheel.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# Build sdist and wheel +python -m pip install -U pip +python -m pip install build +python -m build + +# Check sdist install and imports +mkdir -p test-sdist +cd test-sdist +python -m venv venv-sdist +venv-sdist/bin/python -m pip install ../dist/outlines-*.tar.gz +venv-sdist/bin/python -c "import outlines" +cd .. + +# Check wheel install and imports +mkdir -p test-wheel +cd test-wheel +python -m venv venv-wheel +venv-wheel/bin/python -m pip install ../dist/outlines-*.whl +venv-wheel/bin/python -c "import outlines" +cd .. diff --git a/.github/workflows/release_pypi.yaml b/.github/workflows/release_pypi.yaml index 0006e74f2..95dc1891a 100644 --- a/.github/workflows/release_pypi.yaml +++ b/.github/workflows/release_pypi.yaml @@ -11,32 +11,11 @@ jobs: steps: - name: Checkout uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: "3.10" - - name: Build sdist and wheel - run: | - python -m pip install -U pip - python -m pip install build - python -m build + - name: Build SDist and Wheel + run: ./.github/scripts/build_sdist_and_wheel.sh - name: Check that the package version matches the Release name run: | grep -Rq "^Version: ${GITHUB_REF:10}$" outlines.egg-info/PKG-INFO - - name: Check sdist install and imports - run: | - mkdir -p test-sdist - cd test-sdist - python -m venv venv-sdist - venv-sdist/bin/python -m pip install ../dist/outlines-*.tar.gz - venv-sdist/bin/python -c "import outlines" - - name: Check wheel install and imports - run: | - mkdir -p test-wheel - cd test-wheel - python -m venv venv-wheel - venv-wheel/bin/python -m pip install ../dist/outlines-*.whl - venv-wheel/bin/python -c "import outlines" - name: Publish to PyPi uses: pypa/gh-action-pypi-publish@v1.4.2 with: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b8d4208af..10879c78f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -88,3 +88,11 @@ jobs: name: html-report path: htmlcov if: ${{ failure() }} + + build-wheel: + name: Build Wheel and Test SDist + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Build SDist and Wheel + run: ./.github/scripts/build_sdist_and_wheel.sh From 81354c8a2302bdaa49b460714584542f9724ef5d Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sat, 1 Jun 2024 15:45:08 -0500 Subject: [PATCH 28/35] add pycountry and pyairports to required dependencies --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0b310c44b..c6f72d3e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,8 @@ dependencies = [ "requests", "tqdm", "datasets", + "pycountry", + "pyairports", ] dynamic = ["version"] @@ -58,8 +60,6 @@ test = [ "vllm", "torch", "transformers", - "pycountry", - "pyairports", ] serve = [ "vllm>=0.3.0", From 95f108e0824b8135c270087d4d09e25290efe619 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sat, 1 Jun 2024 15:56:31 -0500 Subject: [PATCH 29/35] Set up Python step in PyPi Release workflow --- .github/workflows/release_pypi.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/release_pypi.yaml b/.github/workflows/release_pypi.yaml index 95dc1891a..9f78cfc43 100644 --- a/.github/workflows/release_pypi.yaml +++ b/.github/workflows/release_pypi.yaml @@ -11,6 +11,10 @@ jobs: steps: - name: Checkout uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: "3.10" - name: Build SDist and Wheel run: ./.github/scripts/build_sdist_and_wheel.sh - name: Check that the package version matches the Release name From 3029b289c151ce32aa868e28c162827959ca8b00 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 31 May 2024 17:04:34 -0500 Subject: [PATCH 30/35] ASV PR bench workflow, pytest-bench -> ASV, add peakmem tests --- .github/workflows/asv_benchmark_pr.yml | 52 +++++++++++++++++++ .gitignore | 1 + benchmarks/__init__.py | 0 benchmarks/asv.conf.json | 20 +++++++ .../bench_json_schema.py | 43 +++++++-------- benchmarks/bench_numba_compile.py | 37 +++++++++++++ .../bench_regex_guide.py | 39 +++++++++----- .../conftest.py => benchmarks/common.py | 10 ++-- docs/community/contribute.md | 34 ++++++++++-- .../benchmark/test_benchmark_numba_compile.py | 33 ------------ 10 files changed, 191 insertions(+), 78 deletions(-) create mode 100644 .github/workflows/asv_benchmark_pr.yml create mode 100644 benchmarks/__init__.py create mode 100644 benchmarks/asv.conf.json rename tests/benchmark/test_benchmark_json_schema.py => benchmarks/bench_json_schema.py (70%) create mode 100644 benchmarks/bench_numba_compile.py rename tests/benchmark/test_benchmark_regex_fsm.py => benchmarks/bench_regex_guide.py (68%) rename tests/benchmark/conftest.py => benchmarks/common.py (74%) delete mode 100644 tests/benchmark/test_benchmark_numba_compile.py diff --git a/.github/workflows/asv_benchmark_pr.yml b/.github/workflows/asv_benchmark_pr.yml new file mode 100644 index 000000000..09786b72c --- /dev/null +++ b/.github/workflows/asv_benchmark_pr.yml @@ -0,0 +1,52 @@ +name: Benchmark PR + +on: + pull_request: + branches: [main] + workflow_dispatch: +env: + PYTHON_VERSION: "3.10" + WORKING_DIR: ${{ github.workspace }}/benchmarks + BENCHMARKS_OUTPUT: ${{ github.workspace }}/benchmarks_output + +jobs: + benchmark-pr: + runs-on: ubuntu-latest + if: contains(github.event.pull_request.labels.*.name, 'run_benchmarks') || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_run' + + defaults: + run: + working-directory: ${{ env.WORKING_DIR }} + + steps: + + - name: Checkout repository + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install asv virtualenv lf-asv-formatter + + - name: Create ASV machine config file + run: asv machine --machine gh-runner --yes + + - name: Run Benchmarks - `PR HEAD` vs `main` + run: | + # prepare main branch for comparison + git remote add upstream https://github.com/${{ github.repository }}.git + git fetch upstream main + + # Run benchmarks, allow errors, they will be caught in the next step + asv continuous upstream/main HEAD \ + --no-stats --interleave-rounds -a repeat=3 || true + + - name: BENCHMARK RESULTS + run: asv compare --factor=1.1 --no-stats --split upstream/main HEAD diff --git a/.gitignore b/.gitignore index 9e95a8732..9add6d8c4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ docs/build .idea/ *.gguf .venv +benchmarks/results diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json new file mode 100644 index 000000000..287bff98f --- /dev/null +++ b/benchmarks/asv.conf.json @@ -0,0 +1,20 @@ +{ + "version": 1, + "project": "Outlines", + "project_url": "https://outlines-dev.github.io/outlines/", + "repo": "..", + "branches": [ + "HEAD" + ], + "build_command": [ + "python -mpip install .[test]", + "PIP_NO_BUILD_ISOLATION=false python -mpip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}", + ], + "environment_type": "virtualenv", + "show_commit_url": "https://github.com/lapp0/outlines/commit/", + "benchmark_dir": ".", + "env_dir": "env", + "results_dir": "results", + "html_dir": "html", + "build_cache_size": 8 +} diff --git a/tests/benchmark/test_benchmark_json_schema.py b/benchmarks/bench_json_schema.py similarity index 70% rename from tests/benchmark/test_benchmark_json_schema.py rename to benchmarks/bench_json_schema.py index 33f3f5b16..daa77510b 100644 --- a/tests/benchmark/test_benchmark_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -1,5 +1,3 @@ -import pytest - import outlines outlines.disable_cache() @@ -7,6 +5,12 @@ from outlines.fsm.guide import RegexGuide # noqa: E402 from outlines.fsm.json_schema import build_regex_from_schema # noqa: E402 +from .common import ( # noqa: E402 + clear_outlines_cache, + ensure_numba_compiled, + setup_tokenizer, +) + simple_schema = """{ "$defs": { "Armor": { @@ -63,30 +67,21 @@ "required": ["id", "work", "recording_artists"] }""" - schemas = dict(simple_schema=simple_schema, complex_schema=complex_schema) -@pytest.mark.parametrize("schema_name", schemas.keys()) -def test_benchmark_json_schema_to_regex(benchmark, ensure_numba_compiled, schema_name): - """Benchmark convert json schema to regex""" - schema = schemas[schema_name] - benchmark.pedantic( - build_regex_from_schema, - args=(schema,), - rounds=8, - ) +class JsonSchemaBenchmark: + params = schemas.keys() + + def setup(self, schema_name): + clear_outlines_cache() + self.tokenizer = setup_tokenizer() + self.schema = schemas[schema_name] + ensure_numba_compiled(self.tokenizer) + def time_json_schema_to_regex(self, schema_name): + build_regex_from_schema(self.schema) -@pytest.mark.parametrize("schema_name", schemas.keys()) -def test_benchmark_json_schema_to_fsm( - benchmark, tokenizer, ensure_numba_compiled, schema_name -): - """Benchmark compile json schema as FSM""" - schema = schemas[schema_name] - regex = build_regex_from_schema(schema) - benchmark.pedantic( - RegexGuide, - args=(regex, tokenizer), - rounds=8, - ) + def time_json_schema_to_fsm(self, schema_name): + regex = build_regex_from_schema(self.schema) + RegexGuide(regex, self.tokenizer) diff --git a/benchmarks/bench_numba_compile.py b/benchmarks/bench_numba_compile.py new file mode 100644 index 000000000..c0e9d87c4 --- /dev/null +++ b/benchmarks/bench_numba_compile.py @@ -0,0 +1,37 @@ +import importlib + +import interegular +import numba + +import outlines + +from .common import clear_outlines_cache, setup_tokenizer + +outlines.disable_cache() + + +class NumbaCompileBenchmark: + def setup(self): + clear_outlines_cache() + from outlines.fsm import regex + + self.tokenizer = setup_tokenizer() + self.regex = regex + original_njit = numba.njit + + def mock_njit(*args, **kwargs): + kwargs["cache"] = False + return original_njit(*args, **kwargs) + + self.original_njit = original_njit + numba.njit = mock_njit + importlib.reload(self.regex) + self.regex_pattern, _ = self.regex.make_deterministic_fsm( + interegular.parse_pattern("a").to_fsm().reduce() + ) + + def teardown(self): + numba.njit = self.original_njit + + def time_compile_numba(self): + self.regex.create_fsm_index_tokenizer(self.regex_pattern, self.tokenizer) diff --git a/tests/benchmark/test_benchmark_regex_fsm.py b/benchmarks/bench_regex_guide.py similarity index 68% rename from tests/benchmark/test_benchmark_regex_fsm.py rename to benchmarks/bench_regex_guide.py index e9e45052a..efaea9e1f 100644 --- a/tests/benchmark/test_benchmark_regex_fsm.py +++ b/benchmarks/bench_regex_guide.py @@ -1,7 +1,7 @@ -import pytest - import outlines +from .common import clear_outlines_cache, ensure_numba_compiled, setup_tokenizer + outlines.disable_cache() from outlines.fsm.guide import RegexGuide # noqa: E402 @@ -19,14 +19,27 @@ } -@pytest.mark.parametrize("regex_name", regex_samples.keys()) -def test_benchmark_regex_to_fsm( - benchmark, tokenizer, ensure_numba_compiled, regex_name -): - """Benchmark converting regex to FSM""" - regex_str = regex_samples[regex_name] - benchmark.pedantic( - RegexGuide, - args=(regex_str, tokenizer), - rounds=8, - ) +class RegexGuideBenchmark: + params = regex_samples.keys() + + def setup(self, pattern_name): + clear_outlines_cache() + self.tokenizer = setup_tokenizer() + ensure_numba_compiled(self.tokenizer) + self.pattern = regex_samples[pattern_name] + + def time_regex_to_guide(self, pattern_name): + RegexGuide(self.pattern, self.tokenizer) + + +class MemoryRegexGuideBenchmark: + params = ["simple_phone", "complex_span_constrained_relation_extraction"] + + def setup(self, pattern_name): + clear_outlines_cache() + self.tokenizer = setup_tokenizer() + ensure_numba_compiled(self.tokenizer) + self.pattern = regex_samples[pattern_name] + + def peakmem_regex_to_guide(self, pattern_name): + RegexGuide(self.pattern, self.tokenizer) diff --git a/tests/benchmark/conftest.py b/benchmarks/common.py similarity index 74% rename from tests/benchmark/conftest.py rename to benchmarks/common.py index 902d5d6eb..e0fe36f14 100644 --- a/tests/benchmark/conftest.py +++ b/benchmarks/common.py @@ -1,17 +1,19 @@ -import pytest from transformers import AutoTokenizer +import outlines.caching from outlines.fsm.guide import RegexGuide from outlines.models.transformers import TransformerTokenizer -@pytest.fixture -def tokenizer(): +def clear_outlines_cache(): + outlines.caching.clear_cache() + + +def setup_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("gpt2") return TransformerTokenizer(tokenizer) -@pytest.fixture def ensure_numba_compiled(tokenizer): RegexGuide("a", tokenizer) return True diff --git a/docs/community/contribute.md b/docs/community/contribute.md index fb67576e4..b336eacad 100644 --- a/docs/community/contribute.md +++ b/docs/community/contribute.md @@ -57,12 +57,38 @@ And run the code style checks: pre-commit run --all-files ``` -When modifying the code related to the index compilation, we kindly ask you to -post benchmarks before and after your changes. You can run benchmarks using: +### Benchmarking -```python -pytest --benchmark-only +Outlines uses [asv](https://asv.readthedocs.io) for automated benchmark testing. Benchmarks are run automatically before pull requests are merged to prevent performance degredation. + +You can run the benchmark test suite locally with the following command: +``` +asv run --config benchmarks/asv.conf.json +``` + +Run a specific test: +``` +asv run --config benchmarks/asv.conf.json -b bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_fsm +``` + +Profile a specific test: ``` +asv run --config benchmarks/asv.conf.json --profile -b bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_fsm +``` + +Compare to `origin/main` +``` +get fetch origin +asv continuous origin/main HEAD --config benchmarks/asv.conf.json +``` + +#### ASV PR Behavior + +- **View ASV Benchmark Results:** Open the workflow, view `BENCHMARK RESULTS` section. +- Merging is blocked unless benchmarks are run for the latest commit. +- Benchmarks fail if performance degrades by more than 10% for any individual benchmark. +- The "Benchmark PR" workflow runs when its manually dispatched, or if the `run_benchmarks` label is added to the PR they run for every commit. + ### Contribute to the documentation diff --git a/tests/benchmark/test_benchmark_numba_compile.py b/tests/benchmark/test_benchmark_numba_compile.py deleted file mode 100644 index 827d561bd..000000000 --- a/tests/benchmark/test_benchmark_numba_compile.py +++ /dev/null @@ -1,33 +0,0 @@ -import importlib - -import interegular -import numba - -import outlines - -outlines.disable_cache() - - -def test_benchmark_compile_numba(benchmark, tokenizer, mocker): - """Compile a basic regex to benchmark the numba compilation time""" - - def setup(): - from outlines.fsm import regex - - original_njit = numba.njit - - def mock_njit(*args, **kwargs): - kwargs["cache"] = False - return original_njit(*args, **kwargs) - - mocker.patch("numba.njit", new=mock_njit) - importlib.reload(regex) - - regex_pattern, _ = regex.make_deterministic_fsm( - interegular.parse_pattern("a").to_fsm().reduce() - ) - return (regex, regex_pattern, tokenizer), {} - - benchmark.pedantic( - lambda r, *args: r.create_fsm_index_tokenizer(*args), rounds=2, setup=setup - ) From 6fe76d5a9f4a003dbf735218d9f277f9b4b87422 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 31 May 2024 19:41:25 -0500 Subject: [PATCH 31/35] ensure workflow fails if benchmark degredation >10% --- .github/workflows/asv_benchmark_pr.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/asv_benchmark_pr.yml b/.github/workflows/asv_benchmark_pr.yml index 09786b72c..90fb47423 100644 --- a/.github/workflows/asv_benchmark_pr.yml +++ b/.github/workflows/asv_benchmark_pr.yml @@ -49,4 +49,9 @@ jobs: --no-stats --interleave-rounds -a repeat=3 || true - name: BENCHMARK RESULTS - run: asv compare --factor=1.1 --no-stats --split upstream/main HEAD + run: | + asv compare --factor=1.1 --no-stats --split upstream/main HEAD | tee ${{ env.BENCHMARKS_OUTPUT }} + if grep -q "Benchmarks that have got worse" "${{ env.BENCHMARKS_OUTPUT }}"; then + echo "Performance degradation detected!" + exit 1 + fi From 9b2b9233062802166b6f18c034cc98e218270325 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 4 Jun 2024 01:32:32 -0500 Subject: [PATCH 32/35] disable outlines cache localized to the benchmarks scope --- benchmarks/bench_json_schema.py | 18 ++++++------------ benchmarks/bench_numba_compile.py | 11 ++++------- benchmarks/bench_regex_guide.py | 13 +++++-------- benchmarks/common.py | 5 ----- outlines/caching.py | 13 +++++++++++++ tests/test_cache.py | 31 +++++++++++++++++++++++++++++++ 6 files changed, 59 insertions(+), 32 deletions(-) diff --git a/benchmarks/bench_json_schema.py b/benchmarks/bench_json_schema.py index daa77510b..8d1ceeb24 100644 --- a/benchmarks/bench_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -1,15 +1,8 @@ -import outlines +from outlines.caching import cache_disabled +from outlines.fsm.guide import RegexGuide +from outlines.fsm.json_schema import build_regex_from_schema -outlines.disable_cache() - -from outlines.fsm.guide import RegexGuide # noqa: E402 -from outlines.fsm.json_schema import build_regex_from_schema # noqa: E402 - -from .common import ( # noqa: E402 - clear_outlines_cache, - ensure_numba_compiled, - setup_tokenizer, -) +from .common import ensure_numba_compiled, setup_tokenizer # noqa: E402 simple_schema = """{ "$defs": { @@ -74,14 +67,15 @@ class JsonSchemaBenchmark: params = schemas.keys() def setup(self, schema_name): - clear_outlines_cache() self.tokenizer = setup_tokenizer() self.schema = schemas[schema_name] ensure_numba_compiled(self.tokenizer) + @cache_disabled() def time_json_schema_to_regex(self, schema_name): build_regex_from_schema(self.schema) + @cache_disabled() def time_json_schema_to_fsm(self, schema_name): regex = build_regex_from_schema(self.schema) RegexGuide(regex, self.tokenizer) diff --git a/benchmarks/bench_numba_compile.py b/benchmarks/bench_numba_compile.py index c0e9d87c4..2713707e5 100644 --- a/benchmarks/bench_numba_compile.py +++ b/benchmarks/bench_numba_compile.py @@ -3,18 +3,14 @@ import interegular import numba -import outlines +from outlines.caching import cache_disabled +from outlines.fsm import regex -from .common import clear_outlines_cache, setup_tokenizer - -outlines.disable_cache() +from .common import setup_tokenizer class NumbaCompileBenchmark: def setup(self): - clear_outlines_cache() - from outlines.fsm import regex - self.tokenizer = setup_tokenizer() self.regex = regex original_njit = numba.njit @@ -33,5 +29,6 @@ def mock_njit(*args, **kwargs): def teardown(self): numba.njit = self.original_njit + @cache_disabled() def time_compile_numba(self): self.regex.create_fsm_index_tokenizer(self.regex_pattern, self.tokenizer) diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index efaea9e1f..099f94df2 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -1,10 +1,7 @@ -import outlines +from outlines.caching import cache_disabled +from outlines.fsm.guide import RegexGuide -from .common import clear_outlines_cache, ensure_numba_compiled, setup_tokenizer - -outlines.disable_cache() - -from outlines.fsm.guide import RegexGuide # noqa: E402 +from .common import ensure_numba_compiled, setup_tokenizer regex_samples = { "email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?", @@ -23,11 +20,11 @@ class RegexGuideBenchmark: params = regex_samples.keys() def setup(self, pattern_name): - clear_outlines_cache() self.tokenizer = setup_tokenizer() ensure_numba_compiled(self.tokenizer) self.pattern = regex_samples[pattern_name] + @cache_disabled() def time_regex_to_guide(self, pattern_name): RegexGuide(self.pattern, self.tokenizer) @@ -36,10 +33,10 @@ class MemoryRegexGuideBenchmark: params = ["simple_phone", "complex_span_constrained_relation_extraction"] def setup(self, pattern_name): - clear_outlines_cache() self.tokenizer = setup_tokenizer() ensure_numba_compiled(self.tokenizer) self.pattern = regex_samples[pattern_name] + @cache_disabled() def peakmem_regex_to_guide(self, pattern_name): RegexGuide(self.pattern, self.tokenizer) diff --git a/benchmarks/common.py b/benchmarks/common.py index e0fe36f14..7d999ea9b 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -1,14 +1,9 @@ from transformers import AutoTokenizer -import outlines.caching from outlines.fsm.guide import RegexGuide from outlines.models.transformers import TransformerTokenizer -def clear_outlines_cache(): - outlines.caching.clear_cache() - - def setup_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("gpt2") return TransformerTokenizer(tokenizer) diff --git a/outlines/caching.py b/outlines/caching.py index 52d66af74..95392c7e8 100644 --- a/outlines/caching.py +++ b/outlines/caching.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import functools import os from typing import Callable, Optional @@ -164,3 +165,15 @@ def clear_cache(): """Erase the cache completely.""" memory = get_cache() memory.clear() + + +@contextlib.contextmanager +def cache_disabled(): + # outlines.caching._caching_enabled + global _caching_enabled + original_state = _caching_enabled + _caching_enabled = False + try: + yield + finally: + _caching_enabled = original_state diff --git a/tests/test_cache.py b/tests/test_cache.py index 5a2de778e..eb4ec406e 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,5 +1,6 @@ import os import tempfile +import unittest import diskcache import pytest @@ -157,3 +158,33 @@ def foo(): # assert with version upgrade, old cache is invalidated and new cache is used a, b = foo() + + +def test_cache_disabled_decorator(test_cache): + """Ensure cache can be disabled in a local scope""" + + from outlines.caching import cache_disabled + + mock = unittest.mock.MagicMock() + + @test_cache + def fn(): + mock() + return 1 + + # first call isn't cached + fn() + assert mock.call_count == 1 + + # second call doesn't run fn, uses cache + fn() + assert mock.call_count == 1 + + # cache_disabled decorator disables cache within scope + with cache_disabled(): + fn() + assert mock.call_count == 2 # called once in cache_disabled scope + + # scope has exited, cache is enabled again + fn() + assert mock.call_count == 2 From b5a2073fcb3a2d2bb5c159c51bd27307c0fa4f62 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 4 Jun 2024 01:33:09 -0500 Subject: [PATCH 33/35] use outlines-dev/outlines for asv.conf.json show_commit_url --- benchmarks/asv.conf.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index 287bff98f..f57db9a0b 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -11,7 +11,7 @@ "PIP_NO_BUILD_ISOLATION=false python -mpip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}", ], "environment_type": "virtualenv", - "show_commit_url": "https://github.com/lapp0/outlines/commit/", + "show_commit_url": "https://github.com/outlines-dev/outlines/commit/", "benchmark_dir": ".", "env_dir": "env", "results_dir": "results", From 9780992c1faccc5f2731a4734b0dd273d456f059 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 3 Jun 2024 07:04:03 -0500 Subject: [PATCH 34/35] Add Documentation on Outlines Versioning and Releases --- docs/versioning.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 docs/versioning.md diff --git a/docs/versioning.md b/docs/versioning.md new file mode 100644 index 000000000..84aab63a5 --- /dev/null +++ b/docs/versioning.md @@ -0,0 +1,19 @@ +The Outlines project follows a structured versioning scheme designed to provide clarity and minimize risk for downstream dependents. + +Each part of the version number (`major.minor.patch`) conveys information about the nature and impact of the changes included in the release. + +- **Major Releases** includes compatibility-breaking changes to core interfaces, such as `LogitsProcessor`s and `Guides`. +- **Minor Releases** introduce changes of substance to internal or unexposed functionality. These changes are well tested and intended to maintain compatability with existing use of core interfaces. +- **Patch Releases** address bug fixes and incorporate low-risk changes to improve stability and performance. + +## Releases + +Releases along with release notes can be found on the [Outlines Releases GitHub Page](https://github.com/outlines-dev/outlines/releases). + +## Version Pinning Recommendations + +Here are our recommendations for managing dependencies on the Outlines package: + +**Small, Risk-Tolerant Projects:** Pin to a specific major version. + +**Large, Conservative Projects:** Pin to a specific minor version. From 0b4d12b0b9998a26e9dbde3bd558e695c51b75be Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 4 Jun 2024 15:14:57 -0500 Subject: [PATCH 35/35] add versioning guide to community tab --- docs/{ => community}/versioning.md | 7 +++++++ mkdocs.yml | 1 + 2 files changed, 8 insertions(+) rename docs/{ => community}/versioning.md (95%) diff --git a/docs/versioning.md b/docs/community/versioning.md similarity index 95% rename from docs/versioning.md rename to docs/community/versioning.md index 84aab63a5..d64a56e7f 100644 --- a/docs/versioning.md +++ b/docs/community/versioning.md @@ -1,3 +1,10 @@ +--- +title: Versioning Guide +--- + +# Versioning Guide + + The Outlines project follows a structured versioning scheme designed to provide clarity and minimize risk for downstream dependents. Each part of the version number (`major.minor.patch`) conveys information about the nature and impact of the changes included in the release. diff --git a/mkdocs.yml b/mkdocs.yml index c4a9edcbc..01e8506ab 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -143,4 +143,5 @@ nav: - Chat with us ☕: https://discord.com/invite/R9DSu34mGd - How to contribute 🏗️: community/contribute.md - Your projects 👏: community/examples.md + - Versioning Guide 📌: community/versioning.md - Blog: blog/index.md