From e99d92d024dbf6f6bff10a9c3954f326cf4a0cd3 Mon Sep 17 00:00:00 2001 From: Daniel Tiarks Date: Thu, 18 Jan 2024 19:07:26 +0100 Subject: [PATCH] Integrate `llama.cpp` via a logits processor --- .gitignore | 1 + docs/reference/models/llamacpp.md | 2 +- examples/llamacpp_example.py | 4 +- examples/llamacpp_processor.py | 50 +++ outlines/fsm/fsm.py | 4 +- outlines/fsm/json_schema.py | 16 +- outlines/generate/cfg.py | 19 + outlines/generate/choice.py | 6 +- outlines/generate/json.py | 8 +- outlines/generate/regex.py | 25 +- outlines/generate/text.py | 15 +- outlines/models/llamacpp.py | 410 +++++++----------- outlines/serve/vllm.py | 23 +- pyproject.toml | 2 +- tests/benchmark/test_benchmark_json_schema.py | 6 +- tests/fsm/test_json_schema.py | 14 +- tests/generate/test_integration_llamacpp.py | 342 +++++++++++++++ tests/models/test_llama_cpp.py | 51 +-- 18 files changed, 649 insertions(+), 349 deletions(-) create mode 100644 examples/llamacpp_processor.py create mode 100644 tests/generate/test_integration_llamacpp.py diff --git a/.gitignore b/.gitignore index 66ce54fc7..66dfb8612 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ __pycache__ docs/build .coverage .idea/ +*.gguf diff --git a/docs/reference/models/llamacpp.md b/docs/reference/models/llamacpp.md index 826a2cfec..c6a3838aa 100644 --- a/docs/reference/models/llamacpp.md +++ b/docs/reference/models/llamacpp.md @@ -11,5 +11,5 @@ Assuming [Phi2's weights](https://huggingface.co/TheBloke/phi-2-GGUF) are in the ```python from outlines import models, generate -model = models.llamacpp("./phi-2.Q4_K_M.gguf", device="cpu") +model = models.llamacpp("./phi-2.Q4_K_M.gguf") ``` diff --git a/examples/llamacpp_example.py b/examples/llamacpp_example.py index 1160c6f6d..0b478217f 100644 --- a/examples/llamacpp_example.py +++ b/examples/llamacpp_example.py @@ -30,8 +30,8 @@ class Character(BaseModel): if __name__ == "__main__": - # Download model from https://huggingface.co/TheBloke/phi-2-GGUF - model = outlines.models.llamacpp("./phi-2.Q3_K_M.gguf", device="cpu") + # curl -L -o mistral-7b-instruct-v0.2.Q5_K_M.gguf https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q5_K_M.gguf + model = outlines.models.llamacpp("./mistral-7b-instruct-v0.2.Q5_K_M.gguf") # Construct structured sequence generator generator = outlines.generate.json(model, Character) diff --git a/examples/llamacpp_processor.py b/examples/llamacpp_processor.py new file mode 100644 index 000000000..909df38fc --- /dev/null +++ b/examples/llamacpp_processor.py @@ -0,0 +1,50 @@ +from enum import Enum + +from llama_cpp import Llama, LogitsProcessorList +from pydantic import BaseModel, constr + +from outlines.generate.processors import JSONLogitsProcessor +from outlines.models.llamacpp import LlamaCppTokenizer + + +class Weapon(str, Enum): + sword = "sword" + axe = "axe" + mace = "mace" + spear = "spear" + bow = "bow" + crossbow = "crossbow" + + +class Armor(str, Enum): + leather = "leather" + chainmail = "chainmail" + plate = "plate" + + +class Character(BaseModel): + name: constr(max_length=10) + age: int + armor: Armor + weapon: Weapon + strength: int + + +if __name__ == "__main__": + llama = Llama("./phi-2.Q4_K_M.gguf") + tokenizer = LlamaCppTokenizer(llama) + + prompt = "Instruct: You are a leading role play gamer. You have seen thousands of different characters and their attributes.\nPlease return a JSON object with common attributes of an RPG character. Give me a character description\nOutput:" + + logits_processor = JSONLogitsProcessor(Character, tokenizer) + + json_str = llama.create_completion( + prompt, + top_k=40, + top_p=0.95, + temperature=0.7, + max_tokens=100, + logits_processor=LogitsProcessorList([logits_processor]), + )["choices"][0]["text"] + + print(json_str) diff --git a/outlines/fsm/fsm.py b/outlines/fsm/fsm.py index 74a972d9a..6ec3ff9e4 100644 --- a/outlines/fsm/fsm.py +++ b/outlines/fsm/fsm.py @@ -91,7 +91,7 @@ def copy(self) -> "StopAtEosFSM": class RegexFSM(FSM): """FSM to generate text that is in the language of a regular expression.""" - def __init__(self, regex_string: str, tokenizer: "Tokenizer"): + def __init__(self, regex_string: str, tokenizer): @cache() def create_states_mapping( regex_string: str, cacheable_vocabulary: Tuple[Tuple[str, int]] @@ -190,7 +190,7 @@ def copy(self) -> "RegexFSM": class CFGFSM(FSM): """FSM to generate text that is in the language of a context-free grammar.""" - def __init__(self, cfg_string: str, tokenizer: "Tokenizer"): + def __init__(self, cfg_string: str, tokenizer): self.cfg_string = cfg_string self.tokenizer = tokenizer diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 2d947f372..847da6ce6 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -1,10 +1,10 @@ import inspect import json import re -from typing import Callable, Optional, Union +from typing import Callable, Optional from jsonschema.protocols import Validator -from pydantic import BaseModel, create_model +from pydantic import create_model from referencing import Registry, Resource from referencing._core import Resolver from referencing.jsonschema import DRAFT202012 @@ -38,9 +38,7 @@ } -def build_regex_from_object( - object: Union[str, Callable, BaseModel], whitespace_pattern: Optional[str] = None -): +def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None): """Turn a JSON schema into a regex that matches any JSON object that follows this schema. @@ -72,13 +70,7 @@ def build_regex_from_object( """ - if isinstance(object, type(BaseModel)): - schema = object.model_json_schema() - elif callable(object): - schema = get_schema_from_signature(object) - else: - schema = json.loads(object) - + schema = json.loads(schema) Validator.check_schema(schema) # Build reference resolver diff --git a/outlines/generate/cfg.py b/outlines/generate/cfg.py index 9bf92939f..2ecf5a3d5 100644 --- a/outlines/generate/cfg.py +++ b/outlines/generate/cfg.py @@ -3,6 +3,7 @@ from outlines.fsm.fsm import CFGFSM from outlines.generate.api import SequenceGenerator from outlines.models import OpenAI +from outlines.models.llamacpp import CFGLogitsProcessor, LlamaCpp from outlines.samplers import Sampler, multinomial @@ -31,6 +32,24 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera return generator +@cfg.register(LlamaCpp) +def cfg_llamacpp( + model: LlamaCpp, + cfg_str: str, + sampler: Sampler = multinomial(), +): + if not isinstance(sampler, multinomial): + raise NotImplementedError( + r"The llama.cpp integration does not currently support any other sampling algorithm " + + "than the multinomial sampler." + ) + + logits_processor = CFGLogitsProcessor(cfg_str, model.tokenizer) + model.logits_processor = logits_processor + + return model + + @cfg.register(OpenAI) def cfg_openai(model, cfg_str: str, sampler: Sampler = multinomial()): raise NotImplementedError( diff --git a/outlines/generate/choice.py b/outlines/generate/choice.py index a415fabe5..6718f26b2 100644 --- a/outlines/generate/choice.py +++ b/outlines/generate/choice.py @@ -13,7 +13,11 @@ def choice( model, choices: List[str], sampler: Sampler = multinomial() ) -> SequenceGenerator: regex_str = r"(" + r"|".join(choices) + r")" - return regex(model, regex_str, sampler) + + generator = regex(model, regex_str, sampler) + generator.format_sequence = lambda x: x + + return generator @choice.register(OpenAI) diff --git a/outlines/generate/json.py b/outlines/generate/json.py index cf5866340..3837f72b6 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from outlines.fsm.json_schema import build_regex_from_object, get_schema_from_signature +from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature from outlines.generate.api import SequenceGenerator from outlines.models import OpenAI from outlines.samplers import Sampler, multinomial @@ -45,17 +45,17 @@ def json( """ if isinstance(schema_object, type(BaseModel)): schema = pyjson.dumps(schema_object.model_json_schema()) - regex_str = build_regex_from_object(schema, whitespace_pattern) + regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: schema_object.parse_raw(x) elif callable(schema_object): schema = pyjson.dumps(get_schema_from_signature(schema_object)) - regex_str = build_regex_from_object(schema, whitespace_pattern) + regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: pyjson.loads(x) elif isinstance(schema_object, str): schema = schema_object - regex_str = build_regex_from_object(schema, whitespace_pattern) + regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: pyjson.loads(x) else: diff --git a/outlines/generate/regex.py b/outlines/generate/regex.py index 9d0b9ee87..d04db42b5 100644 --- a/outlines/generate/regex.py +++ b/outlines/generate/regex.py @@ -3,6 +3,7 @@ from outlines.fsm.fsm import RegexFSM from outlines.generate.api import SequenceGenerator from outlines.models import OpenAI +from outlines.models.llamacpp import LlamaCpp, RegexLogitsProcessor from outlines.samplers import Sampler, multinomial @@ -35,8 +36,30 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()): return generator +@regex.register(LlamaCpp) +def regex_llamacpp( + model: LlamaCpp, + regex_str: str, + sampler: Sampler = multinomial(), +): + if not isinstance(sampler, multinomial): + raise NotImplementedError( + r"The llama.cpp integration does not currently support any other sampling algorithm " + + "than the multinomial sampler." + ) + + logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) + model.logits_processor = logits_processor + + return model + + @regex.register(OpenAI) -def regex_openai(model, regex_str: str, sampler: Sampler = multinomial()): +def regex_openai( + model: OpenAI, + regex_str: str, + sampler: Sampler = multinomial(), +): raise NotImplementedError( "Cannot use regex-structured generation with an OpenAI model" + "due to the limitations of the OpenAI API." diff --git a/outlines/generate/text.py b/outlines/generate/text.py index 606721dd3..fcb8c96dc 100644 --- a/outlines/generate/text.py +++ b/outlines/generate/text.py @@ -2,7 +2,7 @@ from outlines.fsm.fsm import StopAtEosFSM from outlines.generate import SequenceGenerator -from outlines.models import OpenAI +from outlines.models import LlamaCpp, OpenAI from outlines.samplers import Sampler, multinomial @@ -36,12 +36,23 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator: return generator +@text.register(LlamaCpp) +def text_llamacpp(model: LlamaCpp, sampler: Sampler = multinomial()): + if not isinstance(sampler, multinomial): + raise NotImplementedError( + r"The llama.cpp API does not support any other sampling algorithm " + + "than the multinomial sampler." + ) + + return model + + @text.register(OpenAI) def text_openai(model: OpenAI, sampler: Sampler = multinomial()) -> OpenAI: if not isinstance(sampler, multinomial): raise NotImplementedError( r"The OpenAI API does not support any other sampling algorithm " - + "that the multinomial sampler." + + "than the multinomial sampler." ) return model diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index c51f600f8..511086dcf 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -1,304 +1,184 @@ -import ctypes -from typing import List, Optional, Tuple, Union +import math +from typing import List, Optional, Union import numpy as np import torch from numpy.typing import NDArray -from outlines.models.tokenizer import Tokenizer +from outlines.fsm.fsm import CFGFSM, FSM, FSMState, RegexFSM class LlamaCpp: """Represents a `llama_cpp` model.""" def __init__( - self, llama_instance, model, tokenizer, device, context_params, **kwargs + self, model_path, logits_processor: Optional["LogitsProcessor"] = None, **kwargs ): - self.device = device - self.llama_instance = llama_instance - self.tokenizer = tokenizer + from llama_cpp import Llama - # Note: the concept of padding does not exist in llama.cpp as a batched sequence is just - # a flat array of tokens that can be assigned to one or more sequences. - # To make it compatible with the transformers inspired tokenizer interface - # we need a padding token to homogenize to token_ids tensor. - self.pad_token_id = -1 + self.logits_processor = logits_processor + self.model = Llama(model_path, **kwargs) + self.tokenizer = LlamaCppTokenizer(self) - self.n_past = 0 - self.n_vocab = kwargs.pop("n_vocab") + def __call__( + self, + prompts: Union[str, List[str]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + rng: Optional[torch.Generator] = None, + **model_kwargs, + ) -> Union[str, List[str]]: + from llama_cpp import LogitsProcessorList + + if isinstance(prompts, str): + prompts = [prompts] + + if rng is None: + rng = torch.Generator(device="cpu") + rng.seed() + + results = [] + for prompt in prompts: + processors = [] + if self.logits_processor is not None: + processors = [self.logits_processor.copy()] + + result = self.model.create_completion( + prompt, + max_tokens=max_tokens, + stop=stop_at, + seed=rng.initial_seed(), + logits_processor=LogitsProcessorList(processors), + **model_kwargs, + )["choices"][0]["text"] + results.append(result) + + self.model.reset() + + formatted = [self.format_sequence(sequence) for sequence in results] + + return formatted if len(formatted) > 1 else formatted[0] + + def format_sequence(self, sequence: str) -> str: + """Translate the generated sequence to another type. + + This method is for instance overridden when generating JSON to either + return a dictionnary or a Pydantic model. + + Parameters + ---------- + sequence + A generated sequences. + + Returns + ------- + The formatted sequence. + + """ + return sequence + + def stream( + self, + prompts: Union[str, List[str]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + rng: Optional[torch.Generator] = None, + ): + raise NotImplementedError( + "Streaming is not implemented for the `llama.cpp` integration." + ) - self.ctx = llama_instance.llama_new_context_with_model(model, context_params) - def forward(self, input_ids: torch.LongTensor, *_): - """Compute a forward pass through the llama_cpp model.""" - if input_ids.ndim == 2: - seq_tensor = input_ids[:, self.n_past :] - elif input_ids.ndim == 1: - seq_tensor = input_ids.view(1, -1)[:, self.n_past :] - else: - raise Exception("Only one and two dimensional inputs allowed.") +class LlamaCppTokenizer: + def __init__(self, model, **kwargs): + self.eos_token_id = model.model.token_eos() + self.pad_token_id = self.eos_token_id + self.special_tokens = {} - tokens_total = torch.numel(seq_tensor[seq_tensor != self.pad_token_id]) - batch = self.llama_instance.llama_batch_init(tokens_total, 0, 1) + self.vocabulary = {} + for t in range(model.model.n_vocab()): + token_piece = model.model.tokenizer().decode([t]) + self.vocabulary[token_piece] = t - seq_token_ids = [] - for seq_idx, seq in enumerate(seq_tensor): - for token_pos, token_id in enumerate(seq): - if token_id == self.pad_token_id: - break - batch.token[batch.n_tokens] = token_id.item() - batch.pos[batch.n_tokens] = token_pos - batch.seq_id[batch.n_tokens][0] = seq_idx - batch.n_seq_id[batch.n_tokens] = 1 - batch.logits[batch.n_tokens] = False + def convert_token_to_string(self, token: str) -> str: + return token - batch.n_tokens += 1 - self.n_past += 1 - batch.logits[batch.n_tokens - 1] = True - seq_token_ids.append(batch.n_tokens - 1) +def llamacpp( + model_name: str, + device: Optional[str] = None, + model_kwargs: dict = {}, +): + return LlamaCpp(model_name, **model_kwargs) - if self.llama_instance.llama_decode(self.ctx, batch) != 0: - print("Error decoding") - all_logits = [] - for seq_token in seq_token_ids: - logits = self.llama_instance.llama_get_logits_ith(self.ctx, seq_token) - logits_list = (ctypes.c_float * self.n_vocab)( - *[logits[token_id] for token_id in range(self.n_vocab)] - ) - logits_tensor = torch.tensor(logits_list) - all_logits.append(logits_tensor) +class LogitsProcessor: + def __init__(self, tokenizer: LlamaCppTokenizer, fsm: FSM): + """A FSM-based logits processor. - self.llama_instance.llama_batch_free(batch) + Parameters + ---------- + tokenizer + An instance of `Tokenizer` + fsm + An instance of `FSM` - stacked_logits = torch.stack(all_logits) - return stacked_logits, None + """ + self.tokenizer = tokenizer + self.fsm_state = FSMState(0) + self.fsm: FSM = fsm + self.is_first_token = True def __call__( - self, - input_ids: torch.LongTensor, - attention_mask: torch.LongTensor, - past_key_values: Optional[Tuple] = None, - ) -> torch.FloatTensor: - logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values) - next_token_logits = logits - - return next_token_logits, kv_cache - - -class LlamaCppTokenizer(Tokenizer): - def __init__(self, llama_instance, model, model_name: str, **kwargs): - self.model_name = model_name - self.llama_instance = llama_instance - self.is_llama = False - - self.model = model - self.n_vocab = kwargs.pop("n_vocab") - - self.eos_token_id = llama_instance.llama_token_eos(model) - self.eos_token = self._get_eos_token() - self.pad_token_id = -1 - self.bos_token_id = llama_instance.llama_token_eos(model) - self.nl_token_id = llama_instance.llama_token_nl(model) - self.vocabulary = {} - self._create_vocabulary() - - self.n_past = 0 - - self.special_tokens = { - self.eos_token_id, - self.pad_token_id, - self.bos_token_id, - self.nl_token_id, - } - - def _create_vocabulary(self): - for t in range(self.n_vocab): - size = 32 - buffer = (ctypes.c_char * size)() - n = self.llama_instance.llama_token_to_piece( - self.model, self.llama_instance.llama_token(t), buffer, size - ) - - try: - token_piece = buffer[:n].decode("utf-8") - self.vocabulary[token_piece] = t - except Exception as e: - print(f"Failed to convert token ({buffer[:n]}): {e}") - continue - - def encode( - self, prompt: Union[str, List[str]] - ) -> Tuple[NDArray[np.int64], NDArray[np.int64]]: - if isinstance(prompt, list): - prompts = prompt - else: - prompts = [prompt] - - max_len = 0 - token_ids = [] - for p in prompts: - embd_inp = (self.llama_instance.llama_token * (len(p) + 1))() - - n_of_tok = self.llama_instance.llama_tokenize( - model=self.model, - text=bytes(str(p), "utf-8"), - text_len=len(embd_inp), - tokens=embd_inp, - n_max_tokens=len(embd_inp), - add_bos=self.n_past == 0, - special=False, - ) - - self.n_past += n_of_tok - - if n_of_tok > max_len: - max_len = n_of_tok - - embd_inp = embd_inp[:n_of_tok] - token_ids.append(np.array(embd_inp)) - - max_len = np.max([len(a) for a in token_ids]) - padded = np.asarray( - [ - np.pad( - a, - (0, max_len - len(a)), - "constant", - constant_values=self.pad_token_id, - ) - for a in token_ids - ] - ) - - token_ids = torch.LongTensor(padded) - return token_ids, torch.ones_like(token_ids) - - def decode(self, token_ids: NDArray[np.int64]) -> List[str]: - if isinstance(token_ids, list): - token_ids = np.array(token_ids) - if token_ids.ndim == 1: - token_ids = [token_ids] + self, input_ids: NDArray[np.int64], scores: NDArray[np.float32] + ) -> NDArray[np.float32]: + """Use the FSM to bias the logits before sampling the next token.""" - pieces = [] - for tokens in token_ids: - seq = [] - for id in tokens: - size = 32 - buffer = (ctypes.c_char * size)() - n = self.llama_instance.llama_token_to_piece( - self.model, self.llama_instance.llama_token(id), buffer, size - ) - - token_piece = buffer[:n].decode("utf-8") # type: ignore + if self.is_first_token: + self.is_first_token = False + else: + last_token = input_ids[-1] + self.fsm_state = self.fsm.next_state(self.fsm_state, last_token) - seq.append(token_piece) + allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state) - pieces.append("".join(seq)) + mask = torch.full((scores.shape[-1],), -math.inf, device="cpu").numpy() + mask[allowed_tokens] = 0 + biased_scores = scores + mask - return pieces + return biased_scores - def _get_eos_token(self): - size = 32 - buffer = (ctypes.c_char * size)() - n = self.llama_instance.llama_token_to_piece( - self.model, self.llama_instance.llama_token(self.eos_token_id), buffer, size - ) + def copy(self): + return LogitsProcessor(self.tokenizer, self.fsm.copy()) - token_piece = buffer[:n].decode("utf-8") - return token_piece +class RegexLogitsProcessor(LogitsProcessor): + def __init__(self, regex_string: str, tokenizer: LlamaCppTokenizer): + """Compile the FSM that drives the regex-guided generation. - def convert_token_to_string(self, token: str) -> str: - return token + Parameters + ---------- + regex_string + A string that represents a regular expression + tokenizer + An instance of `Tokenizer` - def __eq__(self, other): - if isinstance(other, type(self)): - return other.model_name == self.model_name and other.kwargs == self.kwargs - return NotImplemented + """ + fsm = RegexFSM(regex_string, tokenizer) + super().__init__(tokenizer, fsm) - def __hash__(self): - return hash(self.model_name) +class CFGLogitsProcessor(LogitsProcessor): + def __init__(self, cfg_str: str, tokenizer: LlamaCppTokenizer): + """Compile the FSM that drives the CFG-guided generation. -def llamacpp( - model_name: str, - device: Optional[str] = None, - model_kwargs: dict = {}, - tokenizer_kwargs: dict = {}, -): - try: - import llama_cpp - except ImportError: - raise ImportError( - "The `llama-cpp-python` library needs to be installed in order to use LlamaCpp." - ) - - if device is None: - device = "cpu" - - llama_cpp.llama_backend_init(numa=False) - - model_params = llama_cpp.llama_model_default_params() - - if "cuda" in device: - model_params.n_gpu_layers = 999 - else: - model_params.n_gpu_layers = model_kwargs.pop( - "n_gpu_layers", model_params.n_gpu_layers - ) + Parameters + ---------- + cfg_str + A string that represents a grammar + tokenizer + An instance of `Tokenizer` - if "tensor_split" in model_kwargs.keys(): - tensor_split = model_kwargs.get("tensor_split") - if isinstance(tensor_split, list): - tensor_split_arr = (ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES)( - *[t for t in tensor_split] - ) - model_params.tensor_split = tensor_split_arr - - context_params = llama_cpp.llama_context_default_params() - context_params.n_batch = model_kwargs.pop("n_batch", context_params.n_batch) - context_params.n_ctx = model_kwargs.pop("n_ctx", context_params.n_ctx) - context_params.n_threads = model_kwargs.pop("n_threads", context_params.n_threads) - context_params.n_threads_batch = model_kwargs.pop( - "n_threads_batch", context_params.n_threads_batch - ) - context_params.rope_scaling_type = model_kwargs.pop( - "rope_scaling_type", context_params.rope_scaling_type - ) - context_params.rope_freq_base = model_kwargs.pop( - "rope_freq_base", context_params.rope_freq_base - ) - context_params.rope_freq_scale = model_kwargs.pop( - "rope_freq_scale", context_params.rope_freq_scale - ) - context_params.yarn_ext_factor = model_kwargs.pop( - "yarn_ext_factor", context_params.yarn_ext_factor - ) - context_params.yarn_attn_factor = model_kwargs.pop( - "yarn_attn_factor", context_params.yarn_attn_factor - ) - context_params.yarn_beta_fast = model_kwargs.pop( - "yarn_beta_fast", context_params.yarn_beta_fast - ) - context_params.yarn_beta_slow = model_kwargs.pop( - "yarn_beta_slow", context_params.yarn_beta_slow - ) - context_params.yarn_orig_ctx = model_kwargs.pop( - "yarn_orig_ctx", context_params.yarn_orig_ctx - ) - context_params.offload_kqv = model_kwargs.pop( - "offload_kqv", context_params.offload_kqv - ) - - model = llama_cpp.llama_load_model_from_file( - model_name.encode("utf-8"), model_params - ) - - model_kwargs["n_vocab"] = llama_cpp.llama_n_vocab(model) - tokenizer_kwargs["n_vocab"] = model_kwargs.get("n_vocab") - - tokenizer = LlamaCppTokenizer(llama_cpp, model, model_name, **tokenizer_kwargs) - - return LlamaCpp(llama_cpp, model, tokenizer, "cpu", context_params, **model_kwargs) + """ + fsm = CFGFSM(cfg_str, tokenizer) + super().__init__(tokenizer, fsm) diff --git a/outlines/serve/vllm.py b/outlines/serve/vllm.py index 94bbc3c23..6843fe8de 100644 --- a/outlines/serve/vllm.py +++ b/outlines/serve/vllm.py @@ -2,12 +2,13 @@ import json import math from collections import defaultdict -from typing import DefaultDict, List, Optional +from typing import DefaultDict, Dict, List, Optional import torch +from pydantic import BaseModel from outlines.fsm.fsm import RegexFSM -from outlines.fsm.json_schema import build_regex_from_object +from outlines.fsm.json_schema import build_regex_from_schema class RegexLogitsProcessor: @@ -77,7 +78,7 @@ def convert_token_to_string(token: str) -> str: class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__(self, schema, llm, whitespace_pattern: Optional[str] = None): + def __init__(self, schema: Dict, llm, whitespace_pattern: Optional[str] = None): """Compile the FSM that drives the JSON-guided generation. Parameters @@ -90,7 +91,17 @@ def __init__(self, schema, llm, whitespace_pattern: Optional[str] = None): Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` """ - if isinstance(schema, dict): - schema = json.dumps(schema) - regex_string = build_regex_from_object(schema, whitespace_pattern) + if isinstance(schema, type(BaseModel)): + schema_str = json.dumps(schema.model_json_schema()) + elif isinstance(schema, Dict): + schema_str = json.dumps(schema) + elif isinstance(schema, str): + schema_str = schema + else: + raise ValueError( + f"Cannot parse schema {schema}. The schema must be either " + + "a Pydantic object, a dictionary or a string that contains the JSON " + + "Schema specification" + ) + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) super().__init__(regex_string, llm) diff --git a/pyproject.toml b/pyproject.toml index 0e0b2f03c..d131e8453 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ test = [ "beartype<0.16.0", "datasets", "responses", - "llama-cpp-python", + "llama-cpp-python>=0.2.42", "huggingface_hub" ] serve = [ diff --git a/tests/benchmark/test_benchmark_json_schema.py b/tests/benchmark/test_benchmark_json_schema.py index 8dc6ba0b5..3c3a4d3c3 100644 --- a/tests/benchmark/test_benchmark_json_schema.py +++ b/tests/benchmark/test_benchmark_json_schema.py @@ -5,7 +5,7 @@ outlines.disable_cache() from outlines.fsm.fsm import RegexFSM # noqa: E402 -from outlines.fsm.json_schema import build_regex_from_object # noqa: E402 +from outlines.fsm.json_schema import build_regex_from_schema # noqa: E402 simple_schema = """{ "$defs": { @@ -72,7 +72,7 @@ def test_benchmark_json_schema_to_regex(benchmark, ensure_numba_compiled, schema """Benchmark convert json schema to regex""" schema = schemas[schema_name] benchmark.pedantic( - build_regex_from_object, + build_regex_from_schema, args=(schema,), rounds=8, ) @@ -84,7 +84,7 @@ def test_benchmark_json_schema_to_fsm( ): """Benchmark compile json schema as FSM""" schema = schemas[schema_name] - regex = build_regex_from_object(schema) + regex = build_regex_from_schema(schema) benchmark.pedantic( RegexFSM, args=(regex, tokenizer), diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 5afc19a80..4cd9331b2 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -17,7 +17,7 @@ TIME, UUID, WHITESPACE, - build_regex_from_object, + build_regex_from_schema, get_schema_from_signature, to_regex, ) @@ -53,7 +53,7 @@ class User(BaseModel): is_true: bool schema = json.dumps(User.model_json_schema()) - schedule = build_regex_from_object(schema) + schedule = build_regex_from_schema(schema) assert isinstance(schedule, str) @@ -516,7 +516,7 @@ def test_match_number(pattern, does_match): ) def test_match(schema, regex, examples): schema = json.dumps(schema) - test_regex = build_regex_from_object(schema) + test_regex = build_regex_from_schema(schema) assert test_regex == regex for string, does_match in examples: @@ -590,7 +590,7 @@ def test_match(schema, regex, examples): ) def test_format(schema, regex, examples): schema = json.dumps(schema) - test_regex = build_regex_from_object(schema) + test_regex = build_regex_from_schema(schema) assert test_regex == regex for string, does_match in examples: @@ -610,12 +610,14 @@ class MockModel(BaseModel): foo: int bar: str + schema = json.dumps(MockModel.model_json_schema()) + # assert any ws pattern can be used if whitespace_pattern == "abc": - build_regex_from_object(MockModel, whitespace_pattern) + build_regex_from_schema(schema, whitespace_pattern) return - pattern = build_regex_from_object(MockModel, whitespace_pattern) + pattern = build_regex_from_schema(schema, whitespace_pattern) mock_result_mult_ws = ( """{ "foo" : 4, \n\n\n "bar": "baz baz baz bar"\n\n}""" diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py new file mode 100644 index 000000000..5468f259f --- /dev/null +++ b/tests/generate/test_integration_llamacpp.py @@ -0,0 +1,342 @@ +import datetime +import re +from enum import Enum +from typing import List, Union + +import pytest +import torch +from huggingface_hub import hf_hub_download +from pydantic import BaseModel, constr + +import outlines.generate as generate +from outlines.models import llamacpp + +TEST_MODEL = "./llama-test-model/TinyMistral-248M-v2-Instruct.Q4_K_M.gguf" + + +@pytest.fixture(scope="session") +def model(tmp_path_factory): + tmp_path_factory.mktemp("./llama-test-model") + hf_hub_download( + repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF", + local_dir="./llama-test-model", + local_dir_use_symlinks="auto", + filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", + ) + + model = llamacpp(TEST_MODEL, "cpu") + return model + + +def test_llamacpp_integration_text(model): + model.model.reset() + sequence = generate.text(model)( + "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n" + ) + assert isinstance(sequence, str) + + sequence = generate.text(model)( + "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n", + max_tokens=10, + stop_at=".", + ) + assert isinstance(sequence, str) + + prompts = [ + "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n", + "<|im_start|>user\nAnd another one<|im_end|>\n<|im_start|>assistant\n", + ] + sequence = generate.text(model)(prompts, max_tokens=10, stop_at=[".", ","]) + assert isinstance(sequence, list) + assert len(sequence) == 2 + assert isinstance(sequence[0], str) + + +def test_llamacpp_integration_text_stop(model): + model.model.reset() + prompt = ( + "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n" + ) + sequence = generate.text(model)(prompt, stop_at="a") + assert sequence[len(prompt) :].find("a") == -1 + + +def test_llamacpp_various_regexes(model): + model.model.reset() + prompt = ( + "<|im_start|>user\nWrite an email address<|im_end|>\n<|im_start|>assistant\n" + ) + regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})" + generator = generate.regex(model, regex_str) + + # One prompt + sequence = generator(prompt) + assert re.fullmatch(regex_str, sequence) is not None + + +def test_llamacpp_various_regexes_prompt_list(model): + model.model.reset() + prompt = ( + "<|im_start|>user\nWrite an email address<|im_end|>\n<|im_start|>assistant\n" + ) + regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})" + generator = generate.regex(model, regex_str) + + # Two prompts + sequence = generator([prompt, prompt]) + assert re.fullmatch(regex_str, sequence[0]) is not None + assert re.fullmatch(regex_str, sequence[1]) is not None + + +def test_llamacpp_integration_integer(model): + model.model.reset() + prompt = ( + "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n" + ) + sequence = generate.format(model, int)(prompt, max_tokens=10) + + assert sequence != "" + int(sequence) + + +def test_llamacpp_integration_integer_array(model): + model.model.reset() + prompts = ["Give me a number", "And another one"] + sequence = generate.format(model, int)(prompts, max_tokens=10) + assert isinstance(sequence, list) + assert len(sequence) == 2 + int(sequence[0]) + int(sequence[1]) + + +def test_llamacpp_integration_float(model): + model.model.reset() + prompt = ( + "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n" + ) + sequence = generate.format(model, float)(prompt, max_tokens=10) + + assert sequence != "" + float(sequence) + + +def test_llamacpp_integration_bool(model): + model.model.reset() + prompt = ( + "<|im_start|>user\nIs this True or False?<|im_end|>\n<|im_start|>assistant\n" + ) + sequence = generate.format(model, bool)(prompt, max_tokens=10) + + assert sequence != "" + bool(sequence) + + +def test_llamacpp_integration_date(model): + model.model.reset() + prompt = ( + "<|im_start|>user\nWhat day is it today?<|im_end|>\n<|im_start|>assistant\n" + ) + sequence = generate.format(model, datetime.date)(prompt, max_tokens=10) + + assert isinstance(sequence, datetime.date) + + +def test_llamacpp_integration_time(model): + model.model.reset() + prompt = "<|im_start|>user\nWhat time is it?<|im_end|>\n<|im_start|>assistant\n" + sequence = generate.format(model, datetime.time)(prompt, max_tokens=10) + + assert isinstance(sequence, datetime.time) + + +def test_llamacpp_integration_datetime(model): + model.model.reset() + prompt = "<|im_start|>user\nWhat time is it?<|im_end|>\n<|im_start|>assistant\n" + sequence = generate.format(model, datetime.datetime)(prompt, max_tokens=20) + + assert isinstance(sequence, datetime.datetime) + + +def test_llamacpp_integration_choice(model): + model.model.reset() + prompt = ( + "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n" + ) + sequence = generate.choice(model, ["test", "choice"])(prompt) + + assert sequence == "test" or sequence == "choice" + + +def test_llamacpp_json_basic(model): + model.model.reset() + prompt = "<|im_start|>user\nOutput some JSON<|im_end|>\n<|im_start|>assistant\n" + + class Spam(BaseModel): + foo: int + bar: float + spam: constr(max_length=10) + fuzz: bool + + rng = torch.Generator(device="cpu") + rng.manual_seed(0) + result = generate.json(model, Spam)( + prompt, max_tokens=1000, temperature=0.0, rng=rng + ) + assert isinstance(result, BaseModel) + assert isinstance(result.foo, int) + assert isinstance(result.bar, float) + assert isinstance(result.spam, str) + assert isinstance(result.fuzz, bool) + assert len(result.spam) <= 10 + + +def test_llamacpp_json_schema(model): + model.model.reset() + prompt = "<|im_start|>user\nOutput some JSON<|im_end|>\n<|im_start|>assistant\n" + + schema = """{ + "title": "spam", + "type": "object", + "properties": { + "foo" : {"type": "integer"}, + "bar": {"type": "string", "maxLength": 4} + }, + "required": ["foo", "bar"] + } + """ + + rng = torch.Generator(device="cpu") + rng.manual_seed(0) + result = generate.json(model, schema)( + prompt, max_tokens=500, temperature=0, rng=rng + ) + assert isinstance(result, dict) + assert isinstance(result["foo"], int) + assert isinstance(result["bar"], str) + + +def test_llamacpp_json_batch(model): + model.model.reset() + prompts = [ + "<|im_start|>user\nOutput a valid JSON object. Only use alpha numeric characters as keys.<|im_end|>\n<|im_start|>assistant\n", + "<|im_start|>user\nOutput a valid JSON object. Only use alpha numeric characters as keys.<|im_end|>\n<|im_start|>assistant\n", + ] + + class Spam(BaseModel): + foo: int + bar: float + spam: constr(max_length=10) + fuzz: bool + + rng = torch.Generator(device="cpu") + rng.manual_seed(0) + result = generate.json(model, Spam)( + prompts, max_tokens=500, temperature=0.0, rng=rng + ) + assert isinstance(result[0], BaseModel) + assert isinstance(result[1], BaseModel) + + +def test_llamacpp_json_str_enum(model): + model.model.reset() + prompt = "<|im_start|>user\nOutput a valid JSON object. Only use alpha numeric characters as keys.<|im_end|>\n<|im_start|>assistant\n" + + class Name(str, Enum): + john = "John" + marc = "Marc" + michel = "Michel" + + class User(BaseModel): + id: int + name: Name + + rng = torch.Generator(device="cpu") + rng.manual_seed(0) + result = generate.json(model, User)( + prompt, max_tokens=500, temperature=0.0, rng=rng + ) + assert isinstance(result, BaseModel) + assert isinstance(result.id, int) + assert result.name in ["John", "Marc", "Michel"] + + +def test_llamacpp_json_array(model): + model.model.reset() + prompt = "<|im_start|>user\nOutput a valid JSON object. Only use alpha numeric characters as keys.<|im_end|>\n<|im_start|>assistant\n" + + class User(BaseModel): + id: int + value: List[float] + + rng = torch.Generator(device="cpu") + rng.manual_seed(0) + result = generate.json(model, User)( + prompt, + max_tokens=500, + temperature=0.0, + rng=rng, + frequency_penalty=0.5, + ) + assert isinstance(result, BaseModel) + assert isinstance(result.id, int) + assert isinstance(result.value, list) + for value in result.value: + assert isinstance(value, float) or isinstance(value, int) + + +def test_llamacpp_json_int_enum(model): + model.model.reset() + prompt = "<|im_start|>user\nOutput a valid JSON object. Only use alpha numeric characters as keys.<|im_end|>\n<|im_start|>assistant\n" + + class Id(int, Enum): + one = 1 + two = 2 + + class User(BaseModel): + id: Id + + rng = torch.Generator(device="cpu") + rng.manual_seed(0) + result = generate.json(model, User)( + prompt, max_tokens=500, temperature=0.0, rng=rng + ) + assert isinstance(result, BaseModel) + assert isinstance(result.id, int) + assert result.id in [1, 2] + + +def test_llamacpp_json_union(model): + model.model.reset() + prompt = "<|im_start|>user\nOutput some JSON<|im_end|>\n<|im_start|>assistant\n" + + class Spam(BaseModel): + foo: int + bar: Union[constr(max_length=10), float] + + rng = torch.Generator(device="cpu") + rng.manual_seed(0) + result = generate.json(model, Spam)( + prompt, max_tokens=100, temperature=0.0, rng=rng + ) + assert isinstance(result, BaseModel) + assert ( + isinstance(result.bar, int) + or isinstance(result.bar, float) + or isinstance(result.bar, str) + ) + + +def test_llamacpp_json_function(model): + model.model.reset() + prompt = "<|im_start|>user\nOutput arguments for the function<|im_end|>\n<|im_start|>assistant\n" + + def function(foo: int, bar: List[int]): + return foo + sum(bar) + + rng = torch.Generator(device="cpu") + rng.manual_seed(0) + sequence = generate.json(model, function)( + prompt, max_tokens=100, temperature=0.0, rng=rng + ) + assert isinstance(sequence, dict) + assert isinstance(function(**sequence), int) diff --git a/tests/models/test_llama_cpp.py b/tests/models/test_llama_cpp.py index 68e998239..4bca43378 100644 --- a/tests/models/test_llama_cpp.py +++ b/tests/models/test_llama_cpp.py @@ -1,6 +1,4 @@ -import numpy as np import pytest -import torch from huggingface_hub import hf_hub_download from outlines.models.llamacpp import llamacpp @@ -19,50 +17,17 @@ def model_download(tmp_path_factory): ) -def test_tokenizer(model_download): - model = llamacpp(TEST_MODEL, "cpu") - - tokenizer = model.tokenizer - assert tokenizer.eos_token_id == 2 - assert tokenizer.pad_token_id == -1 - - token_ids, attention_mask = tokenizer.encode(["Test", "test bla hallo"]) - assert token_ids.ndim == 2 - assert token_ids.shape[0] == 2 - assert token_ids[0, -1] == -1 - assert token_ids[1, -1] != -1 - assert isinstance(token_ids, torch.LongTensor) - assert token_ids.shape == attention_mask.shape - - token_ids, attention_mask = tokenizer.encode("Test") - assert token_ids.ndim == 2 - assert token_ids.shape[0] == 1 - assert isinstance(token_ids, torch.LongTensor) - assert token_ids.shape == attention_mask.shape - - text = tokenizer.decode(np.array([0, 1, 2])) - assert isinstance(text, list) - - def test_model(model_download): model = llamacpp(TEST_MODEL) - input_ids = torch.tensor([[0, 1, 2]]) - logits, kv_cache = model(input_ids, torch.ones_like(input_ids)) - - assert logits.ndim == 2 - assert logits.shape[0] == 1 - - model.n_past = 0 - input_ids = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - logits, kv_cache = model(input_ids, torch.ones_like(input_ids)) + completion = model("Some string") - assert logits.ndim == 2 - assert logits.shape[0] == 3 + assert isinstance(completion, str) - model.n_past = 0 - input_ids = torch.tensor([[0, 1, 2], [3, -1, -1]]) - logits, kv_cache = model(input_ids, torch.ones_like(input_ids)) + model.model.reset() + completions = model(["Some string", "Other string"]) - assert logits.ndim == 2 - assert logits.shape[0] == 2 + assert isinstance(completions, list) + assert len(completions) == 2 + assert isinstance(completions[0], str) + assert isinstance(completions[1], str)