diff --git a/examples/transformers_integration.py b/examples/transformers_integration.py index a5f1dbf9e..05acbfa14 100644 --- a/examples/transformers_integration.py +++ b/examples/transformers_integration.py @@ -6,17 +6,20 @@ from outlines.integrations.transformers import JSONPrefixAllowedTokens -class User(BaseModel): - id: int - name: str +class Person(BaseModel): + first_name: str + surname: str pipe = pipeline("text-generation", model="mistralai/Mistral-7B-v0.1") -prefix_allowed_tokens_fn = JSONPrefixAllowedTokens(schema=User, tokenizer_or_pipe=pipe) +prefix_allowed_tokens_fn = JSONPrefixAllowedTokens( + schema=Person, tokenizer_or_pipe=pipe, whitespace_pattern=" ?" +) results = pipe( - ["Tom Jones", "Linda Smith"], + ["He is Tom Jones", "She saw Linda Smith"], return_full_text=False, do_sample=False, + max_new_tokens=50, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, ) print(results) diff --git a/examples/vllm_integration.py b/examples/vllm_integration.py index a30d4bf6f..5a9c04f01 100644 --- a/examples/vllm_integration.py +++ b/examples/vllm_integration.py @@ -6,17 +6,19 @@ from outlines.integrations.vllm import JSONLogitsProcessor -class User(BaseModel): - id: int - name: str +class Person(BaseModel): + first_name: str + surname: str -llm = vllm.LLM(model="mistralai/Mistral-7B-v0.1") -logits_processor = JSONLogitsProcessor(schema=User, llm=llm) +llm = vllm.LLM(model="mistralai/Mistral-7B-v0.1", max_model_len=512) +logits_processor = JSONLogitsProcessor(schema=Person, llm=llm, whitespace_pattern=" ?") result = llm.generate( - ["Tom Jones", "Linda Smith"], + ["He is Tom Jones", "She saw Linda Smith"], sampling_params=vllm.SamplingParams( - temperature=0.0, logits_processors=[logits_processor] + temperature=0.0, + max_tokens=50, + logits_processors=[logits_processor], ), ) print(result) diff --git a/outlines/fsm/fsm.py b/outlines/fsm/fsm.py index 5544d4925..e15f0dc7f 100644 --- a/outlines/fsm/fsm.py +++ b/outlines/fsm/fsm.py @@ -78,7 +78,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState: The new state of the FSM. """ - if token_id == self.eos_token_id: + if token_id == self.eos_token_id or state == self.final_state: return self.final_state return self.first_state @@ -172,7 +172,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState: The new state of the FSM. """ - if token_id == self.eos_token_id: + if token_id == self.eos_token_id or state == self.final_state: return self.final_state last_token_to_end_state = self.states_to_token_maps[state] @@ -354,7 +354,7 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState: ------- The new state of the FSM. """ - if token_id == self.tokenizer.eos_token_id: + if token_id == self.tokenizer.eos_token_id or state == self.final_state: return self.final_state self.generation += self.tokenizer.decode([token_id])[0] diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 847da6ce6..ce329e064 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -138,7 +138,7 @@ def to_regex( if any(is_required): last_required_pos = max([i for i, value in enumerate(is_required) if value]) for i, (name, value) in enumerate(properties.items()): - subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}' + subregex = f'{whitespace_pattern}"{re.escape(name)}"{whitespace_pattern}:{whitespace_pattern}' subregex += to_regex(resolver, value, whitespace_pattern) if i < last_required_pos: subregex = f"{subregex}{whitespace_pattern}," @@ -216,6 +216,14 @@ def to_regex( return f"({'|'.join(choices)})" + elif "const" in instance: + const = instance["const"] + if type(const) in [int, float, bool, None]: + const = re.escape(str(const)) + elif type(const) == str: + const = f'"{re.escape(const)}"' + return const + elif "$ref" in instance: path = f"{instance['$ref']}" instance = resolver.lookup(path).contents diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 38e7c0992..8d3a88b94 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -279,13 +279,9 @@ def stream( stop_sequences = stop_at num_samples = self.num_samples - if rng is None: - rng = torch.Generator(device=self.device) - rng.seed() - prompt_token_ids, attention_masks = self.tokenizer.encode(prompts) prompt_token_ids = prompt_token_ids.to(self.device) - attention_masks = attention_masks.to(self.device) + attention_masks = attention_masks.to(prompt_token_ids.device) # To draw multiple samples we repeat the prompt as many times # as there are samples. We copy the FSMs and initialize the @@ -298,9 +294,15 @@ def stream( fsm_states = [FSMState(0) for _ in range(batch_size * num_samples)] fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)] weights = torch.zeros( - (batch_size * num_samples), dtype=torch.float, device=self.device + (batch_size * num_samples), + dtype=torch.float, + device=prompt_token_ids.device, ) + if rng is None: + rng = torch.Generator(device=prompt_token_ids.device) + rng.seed() + states = sequence_generator( self.model, self.sampler, diff --git a/outlines/integrations/utils.py b/outlines/integrations/utils.py index 2c06513a6..330d9a357 100644 --- a/outlines/integrations/utils.py +++ b/outlines/integrations/utils.py @@ -77,7 +77,7 @@ def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) - Returns ------- str - The regular expression. + The JSON schema converted to a string. Raises ------ diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 9e333bba0..da02642d2 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -5,7 +5,7 @@ from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: - from transformers import PreTrainedModel, PreTrainedTokenizer + from transformers import PreTrainedModel, PreTrainedTokenizerBase __all__ = ["transformers"] @@ -55,13 +55,87 @@ class CodeLlamaTokenizerFast: # type: ignore ) +class TransformerTokenizer(Tokenizer): + """Represents a tokenizer for models in the `transformers` library.""" + + def __init__( + self, tokenizer_or_model_name: Union["PreTrainedTokenizerBase", str], **kwargs + ): + if isinstance(tokenizer_or_model_name, str): + from transformers import AutoTokenizer + + kwargs.setdefault("padding_side", "left") + self.model_name = tokenizer_or_model_name + # TODO: Do something to make this hashable? + self.kwargs = kwargs + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_or_model_name, **kwargs + ) + else: + self.tokenizer = tokenizer_or_model_name + + self.eos_token_id = self.tokenizer.eos_token_id + self.eos_token = self.tokenizer.eos_token + + if not self.tokenizer.pad_token_id: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + self.pad_token_id = self.eos_token_id + else: + self.pad_token_id = self.tokenizer.pad_token_id + self.pad_token = self.tokenizer.pad_token + + self.special_tokens = set(self.tokenizer.all_special_tokens) + + self.vocabulary = self.tokenizer.get_vocab() + self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types()) + + def encode( + self, prompt: Union[str, List[str]], **kwargs + ) -> Tuple[torch.LongTensor, torch.LongTensor]: + kwargs["padding"] = True + kwargs["return_tensors"] = "pt" + output = self.tokenizer(prompt, **kwargs) + return output["input_ids"], output["attention_mask"] + + def decode(self, token_ids: torch.LongTensor) -> List[str]: + text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) + return text + + def convert_token_to_string(self, token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = self.tokenizer.convert_tokens_to_string([token]) + + if self.is_llama: + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + def __eq__(self, other): + if isinstance(other, type(self)): + if hasattr(self, "model_name") and hasattr(self, "kwargs"): + return ( + other.model_name == self.model_name and other.kwargs == self.kwargs + ) + else: + return other.tokenizer == self.tokenizer + return NotImplemented + + def __hash__(self): + from datasets.fingerprint import Hasher + + return hash(Hasher.hash(self.tokenizer)) + + class Transformer: """Represents a `transformers` model.""" def __init__( self, model: "PreTrainedModel", - tokenizer: "PreTrainedTokenizer", + tokenizer: TransformerTokenizer, ): self.device = model.device self.model = model @@ -119,67 +193,6 @@ def __call__( return next_token_logits, kv_cache -class TransformerTokenizer(Tokenizer): - """Represents a tokenizer for models in the `transformers` library.""" - - def __init__(self, model_name: str, **kwargs): - from transformers import AutoTokenizer - - kwargs.setdefault("padding_side", "left") - self.model_name = model_name - # TODO: Do something to make this hashable? - self.kwargs = kwargs - self.tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs) - self.eos_token_id = self.tokenizer.eos_token_id - self.eos_token = self.tokenizer.eos_token - - if not self.tokenizer.pad_token_id: - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - self.pad_token_id = self.eos_token_id - else: - self.pad_token_id = self.tokenizer.pad_token_id - self.pad_token = self.tokenizer.pad_token - - self.special_tokens = set(self.tokenizer.all_special_tokens) - - self.vocabulary = self.tokenizer.get_vocab() - self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types()) - - def encode( - self, prompt: Union[str, List[str]], **kwargs - ) -> Tuple[torch.LongTensor, torch.LongTensor]: - kwargs["padding"] = True - kwargs["return_tensors"] = "pt" - output = self.tokenizer(prompt, **kwargs) - return output["input_ids"], output["attention_mask"] - - def decode(self, token_ids: torch.LongTensor) -> List[str]: - text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) - return text - - def convert_token_to_string(self, token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = self.tokenizer.convert_tokens_to_string([token]) - - if self.is_llama: - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - def __eq__(self, other): - if isinstance(other, type(self)): - return other.model_name == self.model_name and other.kwargs == self.kwargs - return NotImplemented - - def __hash__(self): - from datasets.fingerprint import Hasher - - return hash(Hasher.hash(self.tokenizer)) - - def transformers( model_name: str, device: Optional[str] = None, diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 4cd9331b2..ea0575dc0 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -163,6 +163,24 @@ def test_match_number(pattern, does_match): ("0", False), ], ), + # Const string + ( + {"title": "Foo", "const": "Marc", "type": "string"}, + '"Marc"', + [('"Marc"', True), ('"Jean"', False), ('"John"', False)], + ), + # Make sure strings are escaped + ( + {"title": "Foo", "const": ".*", "type": "string"}, + r'"\.\*"', + [('".*"', True), (r'"\s*"', False), (r'"\.\*"', False)], + ), + # Const integer + ( + {"title": "Foo", "const": 0, "type": "integer"}, + "0", + [("0", True), ("1", False), ("a", False)], + ), # Enum string ( {"title": "Foo", "enum": ["Marc", "Jean"], "type": "string"}, diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index f2de27c39..32bef74dd 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -10,7 +10,7 @@ import outlines.generate as generate import outlines.models as models from outlines.fsm.regex import reduced_vocabulary -from outlines.models.transformers import TransformerTokenizer +from outlines.models.transformers import Transformer, TransformerTokenizer from outlines.samplers import beam_search, multinomial @@ -615,3 +615,18 @@ def __call__( ) assert sequence == "c" + + +def test_transformers_use_existing_model_and_tokenizer(): + from transformers import AutoModelForCausalLM, AutoTokenizer + + rng = torch.Generator() + rng.manual_seed(10000) + + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" + hf_tokenizer = AutoTokenizer.from_pretrained(model_name) + hf_model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = TransformerTokenizer(hf_tokenizer) + model = Transformer(hf_model, tokenizer) + sequence = generate.text(model)("Write a short sentence ", rng=rng) + assert isinstance(sequence, str)