From ba7affd92883aaa5ceacfba7ef5abd5a27b2eb26 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 21 May 2024 18:36:17 -0500 Subject: [PATCH] Use a persistent Tokenizer hash for create_states_mapping cache --- outlines/fsm/guide.py | 56 +++++++-------- outlines/models/transformers.py | 11 ++- pyproject.toml | 4 +- .../generate/test_integration_transformers.py | 68 ++++++++++++++++++- tests/models/test_transformers.py | 14 +++- 5 files changed, 117 insertions(+), 36 deletions(-) diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index 2833fce1a..d247db62b 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -105,44 +105,44 @@ def copy(self): return self +@cache() +def create_states_mapping( + regex_string: str, tokenizer: "Tokenizer" +) -> Tuple[dict, set, set]: + """Create the variables related to the mapping between states and tokens + The parameters of the function are used for caching purpose + """ + regex_pattern = interegular.parse_pattern(regex_string) + byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True) + regex_fsm, _ = make_deterministic_fsm(byte_fsm) + states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( + regex_fsm, tokenizer + ) + + # We make sure that it is possible to generate strings in the language + # of the regular expression with the tokens present in the model's + # vocabulary. + if not any( + regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values() + ): + raise ValueError( + "The vocabulary does not allow us to build a sequence that matches the input regex" + ) + + return states_to_token_maps, empty_token_ids, regex_fsm.finals + + class RegexGuide(Guide): """Guide to generate text in the language of a regular expression.""" initial_state = 0 def __init__(self, regex_string: str, tokenizer): - @cache() - def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]: - """Create the variables related to the mapping between states and tokens - The parameters of the function are used for caching purpose - """ - regex_pattern = interegular.parse_pattern(regex_string) - byte_fsm = make_byte_level_fsm( - regex_pattern.to_fsm().reduce(), keep_utf8=True - ) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) - - # We make sure that it is possible to generate strings in the language - # of the regular expression with the tokens present in the model's - # vocabulary. - if not any( - regex_fsm.finals.intersection(v.values()) - for v in states_to_token_maps.values() - ): - raise ValueError( - "The vocabulary does not allow us to build a sequence that matches the input regex" - ) - - return states_to_token_maps, empty_token_ids, regex_fsm.finals - ( self.states_to_token_maps, self.empty_token_ids, fsm_finals, - ) = create_states_mapping(regex_string) + ) = create_states_mapping(regex_string, tokenizer) self.eos_token_id = tokenizer.eos_token_id self.final_states = fsm_finals | {-1} diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 3bc59412e..fae9b8e74 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -1,5 +1,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from datasets.fingerprint import Hasher + from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: @@ -109,10 +111,15 @@ def __eq__(self, other): return NotImplemented def __hash__(self): - from datasets.fingerprint import Hasher - return hash(Hasher.hash(self.tokenizer)) + def __getstate__(self): + state = {"tokenizer": self.tokenizer} + return state + + def __setstate__(self, state): + self.__init__(state["tokenizer"]) + class Transformers: """Represents a `transformers` model.""" diff --git a/pyproject.toml b/pyproject.toml index b18036ffc..0b310c44b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,8 @@ dependencies = [ "referencing", "jsonschema", "requests", - "tqdm" + "tqdm", + "datasets", ] dynamic = ["version"] @@ -50,7 +51,6 @@ test = [ "diff-cover", "accelerate", "beartype<0.16.0", - "datasets", "responses", "llama-cpp-python", "huggingface_hub", diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index 38525a076..cee3ca312 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -1,6 +1,7 @@ import datetime import re from enum import Enum +from importlib import reload from typing import List, Union import pytest @@ -11,7 +12,28 @@ import outlines.models as models from outlines.fsm.regex import reduced_vocabulary from outlines.models.transformers import Transformers, TransformerTokenizer -from outlines.samplers import beam_search, multinomial +from outlines.samplers import beam_search, greedy, multinomial + + +@pytest.fixture +def temp_cache_dir(): + import os + import tempfile + + import outlines.caching + import outlines.fsm.guide + + with tempfile.TemporaryDirectory() as tempdir: + os.environ["OUTLINES_CACHE_DIR"] = tempdir + outlines.caching.get_cache.cache_clear() + reload(outlines) + reload(outlines.fsm.guide) + cache_status = outlines.caching._caching_enabled + try: + outlines.caching._caching_enabled = True + yield + finally: + outlines.caching._caching_enabled = cache_status def test_transformers_integration_text(): @@ -632,3 +654,47 @@ def test_transformers_use_existing_model_and_tokenizer(): model = Transformers(hf_model, hf_tokenizer) sequence = generate.text(model)("Write a short sentence ", rng=rng) assert isinstance(sequence, str) + + +def test_RegexGuide_caching(temp_cache_dir): + import outlines.caching + from outlines.fsm.guide import create_states_mapping + + assert outlines.caching._caching_enabled + + regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + prompt = "What is the IP address of the Google DNS servers? " + + cache = outlines.caching.get_cache() + + # Returns (hits, misses) + _ = cache.stats(enable=True) + assert cache.statistics + + assert create_states_mapping.__memory__ is cache + + model = models.transformers( + "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM" + ) + generator = generate.regex(model, regex, sampler=greedy()) + assert cache.stats() == (0, 1) + + model_2 = models.transformers("hf-internal-testing/tiny-random-GPTJForCausalLM") + generator_2 = generate.regex(model_2, regex, sampler=greedy()) + assert cache.stats() == (0, 2) + + # These two different models and tokenizers should not have the same state + # mapping results + assert generator.fsm.states_to_token_maps != generator_2.fsm.states_to_token_maps + + generator_3 = generate.regex(model_2, regex, sampler=greedy()) + assert cache.stats() == (1, 2) + assert generator_2.fsm.states_to_token_maps == generator_3.fsm.states_to_token_maps + + # Just for fun... + structured = generator(prompt, max_tokens=30) + structured_2 = generator_2(prompt, max_tokens=30) + + assert re.fullmatch(regex, structured) + assert re.fullmatch(regex, structured_2) + assert structured != structured_2 diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index b4e410096..f4596a2df 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -107,6 +107,14 @@ def test_tokenizer_eq_hash(): tokenizer_hf = AutoTokenizer.from_pretrained("gpt2") tokenizer = TransformerTokenizer(tokenizer_hf) - tokenizer2 = TransformerTokenizer(tokenizer_hf) - assert tokenizer == tokenizer2 - assert hash(tokenizer) == hash(tokenizer2) + tokenizer_2 = TransformerTokenizer(tokenizer_hf) + + assert tokenizer == tokenizer_2 + assert hash(tokenizer) == hash(tokenizer_2) + + tokenizer_hf_2 = AutoTokenizer.from_pretrained("gpt2") + tokenizer_hf_2.add_tokens(["test_token"]) + + tokenizer_3 = TransformerTokenizer(tokenizer_hf_2) + assert tokenizer != tokenizer_3 + assert hash(tokenizer) != hash(tokenizer_3)