From b81236f62442ada84059e45d07f9d7ff9dbc2c53 Mon Sep 17 00:00:00 2001 From: saattrupdan Date: Wed, 28 Feb 2024 17:52:47 +0100 Subject: [PATCH] Add integration for `transformers` via logits processors --- examples/transformers_integration.py | 25 +++ examples/vllm_integration.py | 20 +- outlines/fsm/guide.py | 7 +- outlines/fsm/types.py | 43 +++- outlines/generate/api.py | 18 +- outlines/generate/regex.py | 11 +- outlines/integrations/__init__.py | 1 + outlines/integrations/llamacpp.py | 205 ++++++++++++++++++++ outlines/integrations/transformers.py | 159 +++++++++++++++ outlines/integrations/utils.py | 99 ++++++++++ outlines/integrations/vllm.py | 153 +++++++++++++++ outlines/models/llamacpp.py | 150 +++----------- outlines/models/tokenizer.py | 9 +- outlines/serve/serve.py | 2 +- outlines/serve/vllm.py | 147 +------------- tests/generate/test_integration_llamacpp.py | 33 ++-- 16 files changed, 763 insertions(+), 319 deletions(-) create mode 100644 examples/transformers_integration.py create mode 100644 outlines/integrations/__init__.py create mode 100644 outlines/integrations/llamacpp.py create mode 100644 outlines/integrations/transformers.py create mode 100644 outlines/integrations/utils.py create mode 100644 outlines/integrations/vllm.py diff --git a/examples/transformers_integration.py b/examples/transformers_integration.py new file mode 100644 index 000000000..16e2b8ec3 --- /dev/null +++ b/examples/transformers_integration.py @@ -0,0 +1,25 @@ +"""Example of integrating `outlines` with `transformers`.""" + +from pydantic import BaseModel +from transformers import pipeline + +from outlines.integrations.transformers import JSONPrefixAllowedTokens + + +class Person(BaseModel): + first_name: str + surname: str + + +pipe = pipeline("text-generation", model="mistralai/Mistral-7B-v0.1") +prefix_allowed_tokens_fn = JSONPrefixAllowedTokens( + schema=Person, tokenizer_or_pipe=pipe, whitespace_pattern=r" ?" +) +results = pipe( + ["He is Tom Jones", "She saw Linda Smith"], + return_full_text=False, + do_sample=False, + max_new_tokens=50, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, +) +print(results) diff --git a/examples/vllm_integration.py b/examples/vllm_integration.py index cf5934774..5f833d19f 100644 --- a/examples/vllm_integration.py +++ b/examples/vllm_integration.py @@ -1,20 +1,24 @@ +"""Example of integrating `outlines` with `vllm`.""" + import vllm from pydantic import BaseModel -from outlines.serve.vllm import JSONLogitsProcessor +from outlines.integrations.vllm import JSONLogitsProcessor -class User(BaseModel): - id: int - name: str +class Person(BaseModel): + first_name: str + surname: str -llm = vllm.LLM(model="openai-community/gpt2") -logits_processor = JSONLogitsProcessor(schema=User, llm=llm) +llm = vllm.LLM(model="mistralai/Mistral-7B-v0.1", max_model_len=512) +logits_processor = JSONLogitsProcessor(schema=Person, llm=llm, whitespace_pattern=r" ?") result = llm.generate( - ["A prompt", "Another prompt"], + ["He is Tom Jones", "She saw Linda Smith"], sampling_params=vllm.SamplingParams( - max_tokens=100, logits_processors=[logits_processor] + temperature=0.0, + max_tokens=50, + logits_processors=[logits_processor], ), ) print(result) diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index 66b4388d0..0690bdf6e 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -62,6 +62,9 @@ def get_next_state(self, state: int, token_id: int) -> int: def is_final_state(self, state: int) -> bool: ... + def copy(self) -> "Guide": + ... + class StopAtEOSGuide(Guide): """Guide to generate tokens until the EOS token has been generated.""" @@ -189,7 +192,9 @@ def get_next_state(self, state: int, token_id: int) -> int: """ if token_id == self.eos_token_id: return -1 - elif state in self.final_states: + elif ( + state in self.final_states + ): # Necessary because we keep generating EOS tokens when finished return state last_token_to_end_state = self.states_to_token_maps[state] diff --git a/outlines/fsm/types.py b/outlines/fsm/types.py index 3e337542f..bcf091854 100644 --- a/outlines/fsm/types.py +++ b/outlines/fsm/types.py @@ -1,5 +1,5 @@ import datetime -from typing import Any, Callable, Tuple +from typing import Protocol, Tuple, Type, Union INTEGER = r"[+-]?(0|[1-9][0-9]*)" BOOLEAN = "(True|False)" @@ -9,26 +9,49 @@ DATETIME = rf"({DATE})(\s)({TIME})" -def python_types_to_regex(python_type: Any) -> Tuple[str, Callable[[str], Any]]: +class FormatFunction(Protocol): + def __call__( + self, sequence: str + ) -> Union[int, float, bool, datetime.date, datetime.time, datetime.datetime]: + ... + + +def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]: if python_type == float: - float_format_fn = lambda x: float(x) + + def float_format_fn(sequence: str) -> float: + return float(sequence) + return FLOAT, float_format_fn elif python_type == int: - int_format_fn = lambda x: int(x) + + def int_format_fn(sequence: str) -> int: + return int(sequence) + return INTEGER, int_format_fn elif python_type == bool: - bool_format_fn = lambda x: bool(x) + + def bool_format_fn(sequence: str) -> bool: + return bool(sequence) + return BOOLEAN, bool_format_fn elif python_type == datetime.date: - date_format_fn = lambda s: datetime.datetime.strptime(s, "%Y-%m-%d").date() + + def date_format_fn(sequence: str) -> datetime.date: + return datetime.datetime.strptime(sequence, "%Y-%m-%d").date() + return DATE, date_format_fn elif python_type == datetime.time: - time_format_fn = lambda s: datetime.datetime.strptime(s, "%H:%M:%S").time() + + def time_format_fn(sequence: str) -> datetime.time: + return datetime.datetime.strptime(sequence, "%H:%M:%S").time() + return TIME, time_format_fn elif python_type == datetime.datetime: - datetime_format_fn = lambda s: datetime.datetime.strptime( - s, "%Y-%m-%d %H:%M:%S" - ) + + def datetime_format_fn(sequence: str) -> datetime.datetime: + return datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S") + return DATETIME, datetime_format_fn else: raise NotImplementedError( diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 97b9a981b..4713bdc5b 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -1,9 +1,14 @@ +import datetime from typing import Iterator, List, Optional, Union import torch from outlines.generate.generator import sequence_generator +FormattedOutput = Union[ + str, int, float, bool, datetime.date, datetime.time, datetime.datetime +] + class SequenceGenerator: def __init__( @@ -100,7 +105,7 @@ def strip_stop_sequences( return sequence - def format_sequence(self, sequence: str) -> str: + def format_sequence(self, sequence: str) -> FormattedOutput: """Translate the generated sequence to another type. This method is for instance overridden when generating JSON to either @@ -124,7 +129,7 @@ def __call__( max_tokens: Optional[int] = None, stop_at: Optional[Union[str, List[str]]] = None, rng: Optional[torch.Generator] = None, - ) -> Union[str, List[str], List[List[str]]]: + ) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]: """Generate the full text sequence. Since `SequenceGenerator.stream` calls the tokenizer at every step this @@ -148,8 +153,7 @@ def __call__( Returns ------- - A string or list of strings that contain the generated text. - + The generation(s), potentially cast to another type. """ if isinstance(prompts, str): @@ -222,7 +226,7 @@ def __call__( formatted = [self.format_sequence(sequence) for sequence in stripped] # We reshape the output to (batch_size, sample_size) - output = [] + output: List[List[FormattedOutput]] = list() for i in range(batch_size): output.append(formatted[i : i + num_samples]) @@ -242,7 +246,7 @@ def stream( max_tokens: Optional[int] = None, stop_at: Optional[Union[str, List[str]]] = None, rng: Optional[torch.Generator] = None, - ) -> Iterator[Union[List[str], List[List[str]], str]]: + ) -> Iterator[Union[List[str], str, List[List[str]]]]: """Generate the text sequence one token at a time. Since `Tokenizer.decode` strips the whitespaces from the tokens we have no @@ -352,7 +356,7 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: ] # We reshape the output to (batch_size, sample_size) - output = [] + output: List[List[str]] = list() for i in range(batch_size): output.append(next_tokens[i : i + num_samples]) diff --git a/outlines/generate/regex.py b/outlines/generate/regex.py index d53d9d3d7..e56f5de78 100644 --- a/outlines/generate/regex.py +++ b/outlines/generate/regex.py @@ -2,12 +2,9 @@ from outlines.fsm.guide import RegexGuide from outlines.generate.api import SequenceGenerator +from outlines.integrations.llamacpp import RegexLogitsProcessor from outlines.models import OpenAI -from outlines.models.llamacpp import ( - LlamaCpp, - LlamaSequenceGenerator, - RegexLogitsProcessor, -) +from outlines.models.llamacpp import LlamaCpp, LlamaSequenceGenerator from outlines.samplers import Sampler, multinomial @@ -52,8 +49,8 @@ def regex_llamacpp( + "than the multinomial sampler." ) - logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) - generator = LlamaSequenceGenerator(logits_processor, model) + logits_processor = RegexLogitsProcessor(regex_str, llm=model.model) + generator = LlamaSequenceGenerator(logits_processor=logits_processor, model=model) return generator diff --git a/outlines/integrations/__init__.py b/outlines/integrations/__init__.py new file mode 100644 index 000000000..b0a90d5ea --- /dev/null +++ b/outlines/integrations/__init__.py @@ -0,0 +1 @@ +"""Utility functions and classes used to integrate `outlines` into other packages.""" diff --git a/outlines/integrations/llamacpp.py b/outlines/integrations/llamacpp.py new file mode 100644 index 000000000..1ab5e3091 --- /dev/null +++ b/outlines/integrations/llamacpp.py @@ -0,0 +1,205 @@ +"""Make LlamaCpp compatible with Outlines' structured generation. + + _______________________________ +/ Don't want to self-host? \ +\\ Try .json at http://dottxt.co / + ------------------------------- + \\ ^__^ + \\ (oo)\\_______ + (__)\\ )\\/\ + ||----w | + || || + +Copyright 2024- the Outlines developers + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import math +from typing import TYPE_CHECKING, Dict, Optional, Set, Type, Union + +import numpy as np +import torch +from numpy.typing import NDArray +from pydantic import BaseModel + +from outlines.fsm.guide import CFGGuide, Guide, RegexGuide +from outlines.fsm.json_schema import build_regex_from_schema +from outlines.integrations.utils import convert_json_schema_to_str + +if TYPE_CHECKING: + from llama_cpp import Llama + + +class LlamaCppTokenizer: + def __init__(self, model: "Llama"): + self.eos_token_id = model.token_eos() + self.pad_token_id = self.eos_token_id + self.special_tokens: Set[int] = set() + + self.vocabulary: Dict[str, int] = dict() + for t in range(model.n_vocab()): + token_piece = model.tokenizer().decode([t]) + self.vocabulary[token_piece] = t + + def convert_token_to_string(self, token: str) -> str: + return token + + +class LogitsProcessor: + """Bias LlamaCpp generation using a finite state machine. + + Attributes + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__(self, tokenizer: LlamaCppTokenizer, fsm: Guide): + """A FSM-based logits processor. + + Parameters + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + fsm + The finite state machine which is used to bias the logits. + """ + self.tokenizer = tokenizer + self._fsm_state = 0 + self.fsm: Guide = fsm + self._is_first_token = True + + def __call__( + self, input_ids: NDArray[np.int64], scores: NDArray[np.float32] + ) -> NDArray[np.float32]: + """Use the FSM to bias the logits before sampling the next token. + + Parameters + ---------- + input_ids + The input token ids. + scores + The logits. + + Returns + ------- + NDArray[np.float32] + The biased logits. + """ + if self._is_first_token: + self._is_first_token = False + else: + last_token = input_ids[-1] + self._fsm_state = self.fsm.get_next_state(self._fsm_state, last_token) + + allowed_tokens = self.fsm.get_next_instruction(self._fsm_state).tokens + + mask = torch.full((scores.shape[-1],), -math.inf, device="cpu").numpy() + mask[allowed_tokens] = 0 + biased_scores = scores + mask + + return biased_scores + + def copy(self) -> "LogitsProcessor": + """Return a copy of the logits processor.""" + return LogitsProcessor(tokenizer=self.tokenizer, fsm=self.fsm.copy()) + + +class RegexLogitsProcessor(LogitsProcessor): + """Bias LlamaCpp generation based on a regular expression. + + Attributes + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__(self, regex_string: str, llm: "Llama"): + """Compile the FSM that drives the regex-guided generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression + llm + The Llama model. + """ + tokenizer = LlamaCppTokenizer(model=llm) + fsm = RegexGuide(regex_string, tokenizer) + super().__init__(tokenizer=tokenizer, fsm=fsm) + + +class JSONLogitsProcessor(RegexLogitsProcessor): + """Bias LlamaCpp generation based on a JSON schema. + + Attributes + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__( + self, + schema: Union[dict, Type[BaseModel], str], + llm: "Llama", + whitespace_pattern: Optional[str] = None, + ): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A JSON schema that encodes the structure we want the model to generate. + llm + The Llama model. + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string + literals). For example, to allow only a single space or newline with + `whitespace_pattern=r"[\n ]?"` + """ + schema_str = convert_json_schema_to_str(json_schema=schema) + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string=regex_string, llm=llm) + + +class CFGLogitsProcessor(LogitsProcessor): + """Bias LlamaCpp generation based on a context-free grammar. + + Attributes + ---------- + llm + The Llama model. + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__(self, cfg_str: str, llm: "Llama"): + """Compile the FSM that drives the CFG-guided generation. + + Parameters + ---------- + cfg_str + A string that represents a grammar + llm + The Llama model. + """ + tokenizer = LlamaCppTokenizer(model=llm) + fsm = CFGGuide(cfg_string=cfg_str, tokenizer=tokenizer) + super().__init__(tokenizer=tokenizer, fsm=fsm) diff --git a/outlines/integrations/transformers.py b/outlines/integrations/transformers.py new file mode 100644 index 000000000..f8f1af945 --- /dev/null +++ b/outlines/integrations/transformers.py @@ -0,0 +1,159 @@ +"""Make Hugging Face transformers compatible with Outlines' structured generation. + + _______________________________ +/ Don't want to self-host? \ +\\ Try .json at http://dottxt.co / + ------------------------------- + \\ ^__^ + \\ (oo)\\_______ + (__)\\ )\\/\ + ||----w | + || || + +Copyright 2024- the Outlines developers + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from collections import defaultdict +from typing import DefaultDict, List, Optional, Type, Union + +import torch +from pydantic import BaseModel +from transformers import Pipeline, PreTrainedTokenizerBase + +from outlines.fsm.guide import RegexGuide +from outlines.fsm.json_schema import build_regex_from_schema +from outlines.integrations.utils import adapt_tokenizer, convert_json_schema_to_str + + +class RegexPrefixAllowedTokens: + """Bias transformers generation based on a regular expression. + + Attributes + ---------- + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__( + self, + regex_string: str, + tokenizer_or_pipe: Union[PreTrainedTokenizerBase, Pipeline], + ): + """Compile the FSM that drives the regex-structured generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression. + tokenizer_or_pipe + The tokenizer of the model, or the pipeline object. + + Raises + ------ + ValueError + If the `tokenizer_or_pipe` parameter is not a tokenizer or a pipeline. + """ + if isinstance(tokenizer_or_pipe, Pipeline): + tokenizer = tokenizer_or_pipe.tokenizer + elif isinstance(tokenizer_or_pipe, PreTrainedTokenizerBase): + tokenizer = tokenizer_or_pipe + else: + raise ValueError( + "The tokenizer_or_pipe parameter must be a tokenizer or a pipeline." + ) + assert isinstance(tokenizer, PreTrainedTokenizerBase) + tokenizer = adapt_tokenizer(tokenizer=tokenizer) + self.fsm = RegexGuide(regex_string=regex_string, tokenizer=tokenizer) + self._fsm_state: DefaultDict[int, int] = defaultdict(int) + + # The generated text with `transformers` include the input token IDs as well, + # so we use this attribute to keep track of the input token IDs. This allows us + # to reset the FSM state when the input token IDs change, as well as to only + # apply the FSM to the generated tokens. + self._prefix = [-1] + + def __call__(self, batch_id: int, sent: torch.Tensor) -> List[int]: + """Use the FSM to bias the logits before sampling the next token. + + Parameters + ---------- + batch_id + The index of the current batch. + sent + The tokens of the current sentence. + + Returns + ------- + List[int] + The indices of the tokens that are allowed to be sampled next. + """ + input_ids = sent.tolist() + + # If the prefix token IDs have changed we assume that we are dealing with a new + # sample and reset the FSM state + if input_ids[: len(self._prefix)] != self._prefix: + self._fsm_state = defaultdict(int) + self._prefix = input_ids + seq_id = hash(tuple([])) + + else: + # Remove the prefix token IDs from the input token IDs, as the FSM should + # only be applied to the generated tokens + input_ids = input_ids[len(self._prefix) :] + + last_token = input_ids[-1] + last_seq_id = hash(tuple(input_ids[:-1])) + seq_id = hash(tuple(input_ids)) + self._fsm_state[seq_id] = self.fsm.get_next_state( + state=self._fsm_state[last_seq_id], token_id=last_token + ) + + allowed_tokens = self.fsm.get_next_instruction( + state=self._fsm_state[seq_id] + ).tokens + return allowed_tokens + + +class JSONPrefixAllowedTokens(RegexPrefixAllowedTokens): + """Bias transformers generation based on a JSON schema. + + Attributes + ---------- + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__( + self, + schema: Union[dict, Type[BaseModel], str], + tokenizer_or_pipe: Union[PreTrainedTokenizerBase, Pipeline], + whitespace_pattern: Optional[str] = None, + ): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A schema that encodes the structure we want the model to generate. + tokenizer_or_pipe + The tokenizer of the model, or the pipeline object. + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string + literals). For example, to allow only a single space or newline with + `whitespace_pattern=r"[\n ]?"` + """ + schema_str = convert_json_schema_to_str(json_schema=schema) + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string=regex_string, tokenizer_or_pipe=tokenizer_or_pipe) diff --git a/outlines/integrations/utils.py b/outlines/integrations/utils.py new file mode 100644 index 000000000..3c92428c8 --- /dev/null +++ b/outlines/integrations/utils.py @@ -0,0 +1,99 @@ +"""Utility functions used in integrations with other packages. + + _______________________________ +/ Don't want to self-host? \ +\\ Try .json at http://dottxt.co / + ------------------------------- + \\ ^__^ + \\ (oo)\\_______ + (__)\\ )\\/\ + ||----w | + || || + +Copyright 2024- the Outlines developers + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +from typing import Type, Union + +from pydantic import BaseModel +from transformers import SPIECE_UNDERLINE, PreTrainedTokenizerBase + + +def adapt_tokenizer(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase: + """Adapt a tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of `transformers`. In + addition we need to handle the missing spaces to Llama's tokenizer to be able to + compile FSMs for this model. + + Parameters + ---------- + tokenizer + The tokenizer of the model. + + Returns + ------- + PreTrainedTokenizerBase + The adapted tokenizer. + """ + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + + return tokenizer + + +def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str: + """Convert a JSON schema to a string. + + Parameters + ---------- + json_schema + The JSON schema. + + Returns + ------- + str + The JSON schema converted to a string. + + Raises + ------ + ValueError + If the schema is not a dictionary, a string or a Pydantic class. + """ + if isinstance(json_schema, dict): + schema_str = json.dumps(json_schema) + elif isinstance(json_schema, str): + schema_str = json_schema + elif issubclass(json_schema, BaseModel): + schema_str = json.dumps(json_schema.model_json_schema()) + else: + raise ValueError( + f"Cannot parse schema {json_schema}. The schema must be either " + + "a Pydantic class, a dictionary or a string that contains the JSON " + + "schema specification" + ) + return schema_str diff --git a/outlines/integrations/vllm.py b/outlines/integrations/vllm.py new file mode 100644 index 000000000..a70f0e3e8 --- /dev/null +++ b/outlines/integrations/vllm.py @@ -0,0 +1,153 @@ +"""Make vLLM compatible with Outlines' structured generation. + + _______________________________ +/ Don't want to self-host? \ +\\ Try .json at http://dottxt.co / + ------------------------------- + \\ ^__^ + \\ (oo)\\_______ + (__)\\ )\\/\ + ||----w | + || || + +Copyright 2024- the Outlines developers + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import math +from collections import defaultdict +from typing import TYPE_CHECKING, DefaultDict, List, Optional, Type, Union + +import torch +from pydantic import BaseModel + +from outlines.fsm.guide import RegexGuide +from outlines.fsm.json_schema import build_regex_from_schema +from outlines.integrations.utils import adapt_tokenizer, convert_json_schema_to_str + +if TYPE_CHECKING: + from vllm import LLM + + +class RegexLogitsProcessor: + """Bias vLLM generation based on a regular expression. + + Attributes + ---------- + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__(self, regex_string: str, llm: "LLM"): + """Compile the FSM that drives the regex-structured generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression. + llm + The vLLM model. + + Raises + ------ + ValueError + If the provided LLM instance in `RegexLogitsProcessor` neither has a + `tokenizer` attribute or a `get_tokenizer` method. + """ + if hasattr(llm, "get_tokenizer"): + tokenizer = llm.get_tokenizer() + elif hasattr(llm, "tokenizer"): + if hasattr(llm.tokenizer, "tokenizer"): + tokenizer = llm.tokenizer.tokenizer + else: + tokenizer = llm.tokenizer + else: + raise ValueError( + "The provided LLM instance in `RegexLogitsProcessor` neither has a " + "`tokenizer` attribute or a `get_tokenizer` method." + ) + tokenizer = adapt_tokenizer(tokenizer=tokenizer) + self.fsm = RegexGuide(regex_string, tokenizer) + self._fsm_state: DefaultDict[int, int] = defaultdict(int) + + def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: + """Use the FSM to bias the logits before sampling the next token. + + Parameters + ---------- + input_ids + The tokens of the current sentence. + scores + The logits of the current sentence. + + Returns + ------- + torch.Tensor + The biased logits. + """ + seq_id = hash(tuple(input_ids)) + + # Initialize the FSM state dictionary if the input_ids are empty, as this means + # that the input_ids are the first tokens of the sequence. + if len(input_ids) == 0: + self._fsm_state = defaultdict(int) + else: + last_token = input_ids[-1] + last_seq_id = hash(tuple(input_ids[:-1])) + self._fsm_state[seq_id] = self.fsm.get_next_state( + state=self._fsm_state[last_seq_id], token_id=last_token + ) + + allowed_tokens = self.fsm.get_next_instruction( + state=self._fsm_state[seq_id] + ).tokens + + mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) + mask[allowed_tokens] = 0 + biased_scores = scores + mask + + return biased_scores + + +class JSONLogitsProcessor(RegexLogitsProcessor): + """Bias vLLM generation based on a JSON schema. + + Attributes + ---------- + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__( + self, + schema: Union[dict, Type[BaseModel], str], + llm: "LLM", + whitespace_pattern: Optional[str] = None, + ): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A JSON schema that encodes the structure we want the model to generate. + llm + The vLLM model. + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string + literals). For example, to allow only a single space or newline with + `whitespace_pattern=r"[\n ]?"` + """ + schema_str = convert_json_schema_to_str(json_schema=schema) + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string=regex_string, llm=llm) diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index bcffefd8e..010aaafbf 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -1,20 +1,39 @@ -import math from typing import TYPE_CHECKING, List, Optional, Union -import numpy as np import torch -from numpy.typing import NDArray -from outlines.fsm.guide import CFGGuide, Guide, RegexGuide +from outlines.integrations.llamacpp import ( # noqa: F401 + CFGLogitsProcessor, + JSONLogitsProcessor, + LlamaCppTokenizer, + LogitsProcessor, + RegexLogitsProcessor, +) if TYPE_CHECKING: from llama_cpp import Llama +class LlamaCpp: + """Represents a `llama_cpp` model.""" + + def __init__(self, model: "Llama"): + self.model = model + self.tokenizer = LlamaCppTokenizer(model=model) + + +def llamacpp(model_path: str, device: Optional[str] = None, **model_kwargs) -> LlamaCpp: + from llama_cpp import Llama + + if device == "cuda": + model_kwargs["n_gpu_layers"].setdefault(-1) + + model = Llama(model_path, **model_kwargs) + return LlamaCpp(model=model) + + class LlamaSequenceGenerator: - def __init__( - self, logits_processor: Optional["LogitsProcessor"], model: "LlamaCpp" - ): + def __init__(self, logits_processor: Optional[LogitsProcessor], model: LlamaCpp): self.model = model.model self.logits_processor = logits_processor @@ -41,14 +60,16 @@ def __call__( if self.logits_processor is not None: processors = [self.logits_processor.copy()] - result = self.model.create_completion( + completion = self.model.create_completion( prompt, max_tokens=max_tokens, stop=stop_at, seed=rng.initial_seed(), logits_processor=LogitsProcessorList(processors), **model_kwargs, - )["choices"][0]["text"] + ) + assert isinstance(completion, dict) + result = completion["choices"][0]["text"] results.append(result) self.model.reset() @@ -71,7 +92,6 @@ def format_sequence(self, sequence: str) -> str: Returns ------- The formatted sequence. - """ return sequence @@ -85,113 +105,3 @@ def stream( raise NotImplementedError( "Streaming is not implemented for the `llama.cpp` integration." ) - - -class LlamaCpp: - """Represents a `llama_cpp` model.""" - - def __init__(self, model: "Llama", **kwargs): - self.model = model - self.tokenizer = LlamaCppTokenizer(model) - - -class LlamaCppTokenizer: - def __init__(self, model, **kwargs): - self.eos_token_id = model.token_eos() - self.pad_token_id = self.eos_token_id - self.special_tokens = {} - - self.vocabulary = {} - for t in range(model.n_vocab()): - token_piece = model.tokenizer().decode([t]) - self.vocabulary[token_piece] = t - - def convert_token_to_string(self, token: str) -> str: - return token - - -def llamacpp( - model_path: str, - device: Optional[str] = None, - **model_kwargs, -): - from llama_cpp import Llama - - if device == "cuda": - model_kwargs["n_gpu_layers"].setdefault(-1) - - model = Llama(model_path, **model_kwargs) - - return LlamaCpp(model) - - -class LogitsProcessor: - def __init__(self, tokenizer: LlamaCppTokenizer, fsm: Guide): - """A FSM-based logits processor. - - Parameters - ---------- - tokenizer - An instance of `Tokenizer` - fsm - An instance of `FSM` - - """ - self.tokenizer = tokenizer - self.fsm_state = 0 - self.fsm: Guide = fsm - self.is_first_token = True - - def __call__( - self, input_ids: NDArray[np.int64], scores: NDArray[np.float32] - ) -> NDArray[np.float32]: - """Use the FSM to bias the logits before sampling the next token.""" - - if self.is_first_token: - self.is_first_token = False - else: - last_token = input_ids[-1] - self.fsm_state = self.fsm.get_next_state(self.fsm_state, last_token) - - allowed_tokens = self.fsm.get_next_instruction(self.fsm_state).tokens - - mask = torch.full((scores.shape[-1],), -math.inf, device="cpu").numpy() - mask[allowed_tokens] = 0 - biased_scores = scores + mask - - return biased_scores - - def copy(self): - return LogitsProcessor(self.tokenizer, self.fsm.copy()) - - -class RegexLogitsProcessor(LogitsProcessor): - def __init__(self, regex_string: str, tokenizer: LlamaCppTokenizer): - """Compile the FSM that drives the regex-guided generation. - - Parameters - ---------- - regex_string - A string that represents a regular expression - tokenizer - An instance of `Tokenizer` - - """ - fsm = RegexGuide(regex_string, tokenizer) - super().__init__(tokenizer, fsm) - - -class CFGLogitsProcessor(LogitsProcessor): - def __init__(self, cfg_str: str, tokenizer: LlamaCppTokenizer): - """Compile the FSM that drives the CFG-guided generation. - - Parameters - ---------- - cfg_str - A string that represents a grammar - tokenizer - An instance of `Tokenizer` - - """ - fsm = CFGGuide(cfg_str, tokenizer) - super().__init__(tokenizer, fsm) diff --git a/outlines/models/tokenizer.py b/outlines/models/tokenizer.py index 72bdae0fe..949414a44 100644 --- a/outlines/models/tokenizer.py +++ b/outlines/models/tokenizer.py @@ -1,36 +1,31 @@ -from abc import abstractmethod from typing import Dict, Hashable, List, Protocol, Set, Tuple, Union import numpy as np from numpy.typing import NDArray -class Tokenizer(Protocol, Hashable): +class Tokenizer(Hashable, Protocol): eos_token: str eos_token_id: int pad_token_id: int vocabulary: Dict[str, int] special_tokens: Set[int] - @abstractmethod def encode( self, prompt: Union[str, List[str]] ) -> Tuple[NDArray[np.int64], NDArray[np.int64]]: - """Translate the input prompts into NumPy arrays of token ids and attention mask.""" + """Translate the input prompts into arrays of token ids and attention mask.""" ... - @abstractmethod def decode(self, token_ids: NDArray[np.int64]) -> List[str]: """Translate an array of token ids to a string or list of strings.""" ... - @abstractmethod def convert_token_to_string(self, token: str) -> str: """Convert a token to its equivalent string. This is for instance useful for BPE tokenizers where whitespaces are represented by the special characted `Ġ`. This prevents matching a raw token that includes `Ġ` with a string. - """ ... diff --git a/outlines/serve/serve.py b/outlines/serve/serve.py index 135169cc1..fb8c80139 100644 --- a/outlines/serve/serve.py +++ b/outlines/serve/serve.py @@ -35,7 +35,7 @@ from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid -from .vllm import JSONLogitsProcessor, RegexLogitsProcessor +from outlines.integrations.vllm import JSONLogitsProcessor, RegexLogitsProcessor TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds. diff --git a/outlines/serve/vllm.py b/outlines/serve/vllm.py index f95a04ad2..ddc50b47d 100644 --- a/outlines/serve/vllm.py +++ b/outlines/serve/vllm.py @@ -1,143 +1,4 @@ -# Make vLLM compatible with Outlines' structured generation. -# -# _______________________________ -# / Don't want to self-host? \ -# \ Try .json at http://dottxt.co / -# ------------------------------- -# \ ^__^ -# \ (oo)\_______ -# (__)\ )\/\ -# ||----w | -# || || -# -# Copyright 2024- the Outlines developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import math -from collections import defaultdict -from typing import DefaultDict, Dict, List, Optional - -import torch -from pydantic import BaseModel - -from outlines.fsm.guide import RegexGuide -from outlines.fsm.json_schema import build_regex_from_schema - - -class RegexLogitsProcessor: - def __init__(self, regex_string, llm): - """Compile the FSM that drives the regex-structured generation. - - Parameters - ---------- - regex_string - A string that represents a regular expression - llm - An instance of `vllm.LLM` - - """ - if hasattr(llm, "get_tokenizer"): - tokenizer = llm.get_tokenizer() - elif hasattr(llm, "tokenizer"): - if hasattr(llm.tokenizer, "tokenizer"): - tokenizer = llm.tokenizer.tokenizer - else: - tokenizer = llm.tokenizer - else: - raise ValueError( - "The provided LLM instance in `RegexLogitsProcessor` neither has a " - "`tokenizer` attribute or a `get_tokenizer` method." - ) - tokenizer = self.adapt_tokenizer(tokenizer=tokenizer) - - fsm = RegexGuide(regex_string, tokenizer) - self.fsm = fsm - - def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: - """Use the FSM to bias the logits before sampling the next token.""" - - seq_id = hash(tuple(input_ids)) - - if len(input_ids) == 0: # Initialize the fsm states - self.fsm_state: DefaultDict[int, int] = defaultdict(int) - else: - last_token = input_ids[-1] - last_seq_id = hash(tuple(input_ids[:-1])) - self.fsm_state[seq_id] = self.fsm.next_state( - self.fsm_state[last_seq_id], last_token - ) - - allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) - - mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) - mask[allowed_tokens] = 0 - biased_scores = scores + mask - - return biased_scores - - def adapt_tokenizer(self, tokenizer): - """Adapt vLLM's tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. In addition we need to handle the missing spaces to - Llama's tokenizer to be able to compile FSMs for this model. - - """ - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) - - def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = tokenizer.convert_tokens_to_string([token]) - - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - tokenizer.convert_token_to_string = convert_token_to_string - - return tokenizer - - -class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__(self, schema: Dict, llm, whitespace_pattern: Optional[str] = None): - """Compile the FSM that drives the JSON-guided generation. - - Parameters - ---------- - schema - A JSON schema that encodes the structure we want the model to generate - llm - An instance of `vllm.LLM` - whitespace_pattern - Pattern to use for JSON syntactic whitespace (doesn't impact string literals) - Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` - """ - if isinstance(schema, type(BaseModel)): - schema_str = json.dumps(schema.model_json_schema()) - elif isinstance(schema, Dict): - schema_str = json.dumps(schema) - elif isinstance(schema, str): - schema_str = schema - else: - raise ValueError( - f"Cannot parse schema {schema}. The schema must be either " - + "a Pydantic object, a dictionary or a string that contains the JSON " - + "Schema specification" - ) - regex_string = build_regex_from_schema(schema_str, whitespace_pattern) - super().__init__(regex_string, llm) +from outlines.integrations.vllm import ( # noqa[F401] + JSONLogitsProcessor, + RegexLogitsProcessor, +) diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index 908bf2ad2..090a18ced 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -23,9 +23,7 @@ def model(tmp_path_factory): local_dir_use_symlinks="auto", filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", ) - - model = llamacpp(TEST_MODEL, "cpu") - return model + return llamacpp(model_path=TEST_MODEL, device="cpu") def test_llamacpp_integration_text(model): @@ -58,6 +56,7 @@ def test_llamacpp_integration_text_stop(model): "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n" ) sequence = generate.text(model)(prompt, stop_at="a") + assert isinstance(sequence, str) assert sequence[len(prompt) :].find("a") == -1 @@ -70,8 +69,9 @@ def test_llamacpp_various_regexes(model): generator = generate.regex(model, regex_str) # One prompt - sequence = generator(prompt) - assert re.fullmatch(regex_str, sequence) is not None + sequence = generator(prompts=prompt) + assert isinstance(sequence, str) + assert re.fullmatch(pattern=regex_str, string=sequence) is not None def test_llamacpp_various_regexes_prompt_list(model): @@ -83,9 +83,12 @@ def test_llamacpp_various_regexes_prompt_list(model): generator = generate.regex(model, regex_str) # Two prompts - sequence = generator([prompt, prompt]) - assert re.fullmatch(regex_str, sequence[0]) is not None - assert re.fullmatch(regex_str, sequence[1]) is not None + sequence = generator(prompts=[prompt, prompt]) + assert isinstance(sequence, list) + assert len(sequence) == 2 + for s in sequence: + assert isinstance(s, str) + assert re.fullmatch(pattern=regex_str, string=s) is not None def test_llamacpp_integration_integer(model): @@ -94,7 +97,7 @@ def test_llamacpp_integration_integer(model): "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n" ) sequence = generate.format(model, int)(prompt, max_tokens=10) - + assert isinstance(sequence, int) assert sequence != "" int(sequence) @@ -103,10 +106,12 @@ def test_llamacpp_integration_integer_array(model): model.model.reset() prompts = ["Give me a number", "And another one"] sequence = generate.format(model, int)(prompts, max_tokens=10) + assert isinstance(sequence, list) assert len(sequence) == 2 - int(sequence[0]) - int(sequence[1]) + for s in sequence: + assert isinstance(s, int) + int(s) def test_llamacpp_integration_float(model): @@ -115,6 +120,7 @@ def test_llamacpp_integration_float(model): "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n" ) sequence = generate.format(model, float)(prompt, max_tokens=10) + assert isinstance(sequence, float) assert sequence != "" float(sequence) @@ -126,6 +132,7 @@ def test_llamacpp_integration_bool(model): "<|im_start|>user\nIs this True or False?<|im_end|>\n<|im_start|>assistant\n" ) sequence = generate.format(model, bool)(prompt, max_tokens=10) + assert isinstance(sequence, bool) assert sequence != "" bool(sequence) @@ -137,7 +144,6 @@ def test_llamacpp_integration_date(model): "<|im_start|>user\nWhat day is it today?<|im_end|>\n<|im_start|>assistant\n" ) sequence = generate.format(model, datetime.date)(prompt, max_tokens=10) - assert isinstance(sequence, datetime.date) @@ -145,7 +151,6 @@ def test_llamacpp_integration_time(model): model.model.reset() prompt = "<|im_start|>user\nWhat time is it?<|im_end|>\n<|im_start|>assistant\n" sequence = generate.format(model, datetime.time)(prompt, max_tokens=10) - assert isinstance(sequence, datetime.time) @@ -153,7 +158,6 @@ def test_llamacpp_integration_datetime(model): model.model.reset() prompt = "<|im_start|>user\nWhat time is it?<|im_end|>\n<|im_start|>assistant\n" sequence = generate.format(model, datetime.datetime)(prompt, max_tokens=20) - assert isinstance(sequence, datetime.datetime) @@ -163,7 +167,6 @@ def test_llamacpp_integration_choice(model): "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n" ) sequence = generate.choice(model, ["test", "choice"])(prompt) - assert sequence == "test" or sequence == "choice"