From f5ae15eaf2f7b7aeb578089006cb587b6ee8bc86 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 21 Jun 2024 09:49:00 -0500 Subject: [PATCH] Use LogitsProcessors for models.transformers -> outlines.generate.* --- README.md | 5 +- docs/reference/models/transformers.md | 53 ++- docs/reference/text.md | 5 +- examples/llamacpp_example.py | 6 +- outlines/__init__.py | 1 + outlines/generate/cfg.py | 42 +-- outlines/generate/regex.py | 6 +- outlines/generate/text.py | 5 +- outlines/models/__init__.py | 2 +- outlines/models/mlxlm.py | 6 +- outlines/models/transformers.py | 187 +++++++++- outlines/processors/__init__.py | 2 +- outlines/processors/base_logits_processor.py | 125 +++++-- outlines/processors/structured.py | 35 +- tests/generate/test_generate.py | 169 +++++++-- tests/generate/test_integration_llamacpp.py | 9 +- .../generate/test_integration_transformers.py | 320 +++++------------- tests/models/test_transformers.py | 3 - 18 files changed, 624 insertions(+), 357 deletions(-) diff --git a/README.md b/README.md index aeb1126b6..97cd1020b 100644 --- a/README.md +++ b/README.md @@ -191,10 +191,9 @@ model = outlines.models.transformers("mistralai/Mistral-7B-Instruct-v0.2") generator = outlines.generate.json(model, Character) # Draw a sample -rng = torch.Generator(device="cuda") -rng.manual_seed(789001) +seed = 789001 -character = generator("Give me a character description", rng=rng) +character = generator("Give me a character description", seed=seed) print(repr(character)) # Character(name='Anderson', age=28, armor=, weapon=, strength=8) diff --git a/docs/reference/models/transformers.md b/docs/reference/models/transformers.md index 286df4367..7c1febd02 100644 --- a/docs/reference/models/transformers.md +++ b/docs/reference/models/transformers.md @@ -15,7 +15,7 @@ Outlines provides an integration with the `torch` implementation of causal model ```python from outlines import models -model = models.transformers("mistralai/Mistral-7B-v0.1", device="cuda") +model = models.transformers("mistralai/Mistral-7B-v0.3", device="cuda") ``` If you need more fine-grained control you can also initialize the model and tokenizer separately: @@ -30,4 +30,55 @@ tokenizer = AutoTokenizer.from_pretrained("gpt2") model = models.Transformers(llm, tokenizer) ``` +# Using Logits Processors + +There are two ways to use Outlines Structured Generation with HuggingFace Transformers: +- 1) Use Outlines generation wrapper, `outlines.models.transformers` +- 2) Use `OutlinesLogitsProcessor` with `transformers.AutoModelForCausalLM` + +Outlines supports a myriad of logits processors for structured generation. In these example, we will use the `RegexLogitsProcessor` which guarantees generated text matches the specified pattern. + +## Example: `outlines.models.transformers` + +``` +import outlines + +time_regex_pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?" + +model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct", device="cuda") +generator = outlines.generate.regex(model, time_regex_pattern) + +output = generator("The the best time to visit a dentist is at ") +print(output) +# 2:30 pm +``` + +## Example: Direct `transformers` library use + +``` +import outlines +import transformers + + +model_uri = "microsoft/Phi-3-mini-4k-instruct" + +outlines_tokenizer = outlines.models.TransformerTokenizer( + transformers.AutoTokenizer.from_pretrained(model_uri) +) +phone_number_logits_processor = outlines.processors.RegexLogitsProcessor( + "\\+?[1-9][0-9]{7,14}", # phone number pattern + outlines_tokenizer, +) + +generator = transformers.pipeline('text-generation', model=model_uri) + +output = generator( + "Jenny gave me her number it's ", + logits_processor=transformers.LogitsProcessorList([phone_number_logits_processor]) +) +print(output) +# [{'generated_text': "Jenny gave me her number it's 2125550182"}] +# not quite 8675309 what we expected, but it is a valid phone number +``` + [transformers]: https://github.com/huggingface/transformers diff --git a/docs/reference/text.md b/docs/reference/text.md index f364c3d2e..423a26ba6 100644 --- a/docs/reference/text.md +++ b/docs/reference/text.md @@ -80,8 +80,7 @@ from outlines import models, generate model = models.transformers("mistralai/Mistral-7B-v0.1") -rng = torch.Generator(device="cuda") -rng.manual_seed(789001) +seed = 789001 -answer = generator("What is 2+2?", rng=rng) +answer = generator("What is 2+2?", seed=seed) ``` diff --git a/examples/llamacpp_example.py b/examples/llamacpp_example.py index 0b478217f..22d0da3ba 100644 --- a/examples/llamacpp_example.py +++ b/examples/llamacpp_example.py @@ -1,6 +1,5 @@ from enum import Enum -import torch from pydantic import BaseModel, constr import outlines @@ -37,10 +36,9 @@ class Character(BaseModel): generator = outlines.generate.json(model, Character) # Draw a sample - rng = torch.Generator(device="cpu") - rng.manual_seed(789005) + seed = 789005 prompt = "Instruct: You are a leading role play gamer. You have seen thousands of different characters and their attributes.\nPlease return a JSON object with common attributes of an RPG character. Give me a character description\nOutput:" - sequence = generator(prompt, rng=rng, max_tokens=512) + sequence = generator(prompt, seed=seed, max_tokens=512) print(sequence) diff --git a/outlines/__init__.py b/outlines/__init__.py index 3eb6a2f94..307d2ba6f 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -2,6 +2,7 @@ import outlines.generate import outlines.grammars import outlines.models +import outlines.processors import outlines.types from outlines.base import vectorize from outlines.caching import clear_cache, disable_cache, get_cache diff --git a/outlines/generate/cfg.py b/outlines/generate/cfg.py index e473c26a6..0df833067 100644 --- a/outlines/generate/cfg.py +++ b/outlines/generate/cfg.py @@ -1,16 +1,14 @@ from functools import singledispatch -from outlines.fsm.guide import CFGGuide -from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter +from outlines.generate.api import SequenceGeneratorAdapter from outlines.models import OpenAI -from outlines.models.llamacpp import LlamaCpp -from outlines.models.mlxlm import MLXLM -from outlines.models.vllm import VLLM from outlines.samplers import Sampler, multinomial @singledispatch -def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenerator: +def cfg( + model, cfg_str: str, sampler: Sampler = multinomial() +) -> SequenceGeneratorAdapter: """Generate text in the language of a Context-Free Grammar Arguments @@ -24,40 +22,16 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera Returns ------- - A `SequenceGenerator` instance that generates text. + A `SequenceGeneratorAdapter` instance that generates text. """ - fsm = CFGGuide(cfg_str, model.tokenizer) - device = model.device - generator = SequenceGenerator(fsm, model, sampler, device) - - return generator - - -@cfg.register(MLXLM) -@cfg.register(VLLM) -def cfg_unimplemented( - model, - cfg_str: str, - sampler: Sampler = multinomial(), -): raise NotImplementedError( - f"The CFG Logits processor is not available for {type(model)}." + f"The CFG Logits processor is not available for {type(model)}. " + + "Please subscribe to https://github.com/outlines-dev/outlines/issues/684" + + " for updates on the fix." ) -@cfg.register(LlamaCpp) -def cfg_llamacpp( - model: LlamaCpp, - cfg_str: str, - sampler: Sampler = multinomial(), -): - from outlines.integrations.llamacpp import CFGLogitsProcessor - - logits_processor = CFGLogitsProcessor(cfg_str, model.model) - return SequenceGeneratorAdapter(model, logits_processor, sampler) - - @cfg.register(OpenAI) def cfg_openai(model, cfg_str: str, sampler: Sampler = multinomial()): raise NotImplementedError( diff --git a/outlines/generate/regex.py b/outlines/generate/regex.py index 6b6656fe9..cdf64a21f 100644 --- a/outlines/generate/regex.py +++ b/outlines/generate/regex.py @@ -5,6 +5,7 @@ from outlines.models import OpenAI from outlines.models.llamacpp import LlamaCpp from outlines.models.mlxlm import MLXLM +from outlines.models.transformers import Transformers from outlines.models.vllm import VLLM from outlines.samplers import Sampler, multinomial @@ -39,8 +40,9 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()): @regex.register(MLXLM) -def regex_mlxlm( - model: MLXLM, +@regex.register(Transformers) +def regex_unified( + model, regex_str: str, sampler: Sampler = multinomial(), ): diff --git a/outlines/generate/text.py b/outlines/generate/text.py index 081ba0920..b8feb7659 100644 --- a/outlines/generate/text.py +++ b/outlines/generate/text.py @@ -2,7 +2,7 @@ from outlines.fsm.guide import StopAtEOSGuide from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter -from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI +from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI, Transformers from outlines.samplers import Sampler, multinomial @@ -37,7 +37,8 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator: @text.register(MLXLM) -def text_mlxlm(model: MLXLM, sampler: Sampler = multinomial()): +@text.register(Transformers) +def text_unified(model, sampler: Sampler = multinomial()): return SequenceGeneratorAdapter(model, None, sampler) diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index 4d74ebf60..fde913e2c 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -12,7 +12,7 @@ from .mamba import Mamba, mamba from .mlxlm import MLXLM, mlxlm from .openai import OpenAI, azure_openai, openai -from .transformers import Transformers, transformers +from .transformers import Transformers, TransformerTokenizer, transformers from .vllm import VLLM, vllm LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, Mamba, MLXLM, VLLM] diff --git a/outlines/models/mlxlm.py b/outlines/models/mlxlm.py index f561f269d..57aa6f596 100644 --- a/outlines/models/mlxlm.py +++ b/outlines/models/mlxlm.py @@ -9,7 +9,7 @@ from transformers import PreTrainedTokenizer from outlines.generate.api import GenerationParameters, SamplingParameters - from outlines.processors import BaseLogitsProcessor + from outlines.processors import OutlinesLogitsProcessor class MLXLM: @@ -120,7 +120,7 @@ def generate_step( temp: Optional[float], top_p: Optional[float], sampler: str, - logits_processor: "BaseLogitsProcessor", + logits_processor: "OutlinesLogitsProcessor", ) -> Generator[Tuple[int, float], None, None]: """ Adapted from @@ -135,7 +135,7 @@ def generate_step( top_p (float, optional): Nulceus sampling, higher means model considers more less likely words. sampler (str): The sampler string defined by SequenceGeneratorAdapter - logits_processor (BaseLogitsProcessor): Augment logits before sampling. + logits_processor (OutlinesLogitsProcessor): Augment logits before sampling. """ import mlx.core as mx import mlx_lm diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index fae9b8e74..491435e8a 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -1,13 +1,17 @@ -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +import dataclasses +from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union from datasets.fingerprint import Hasher +from outlines.generate.api import GenerationParameters, SamplingParameters from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: import torch from transformers import PreTrainedModel, PreTrainedTokenizer + from outlines.processors import OutlinesLogitsProcessor + __all__ = ["transformers"] @@ -129,7 +133,6 @@ def __init__( model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", ): - self.device = model.device self.model = model self.tokenizer = TransformerTokenizer(tokenizer) @@ -179,6 +182,10 @@ def forward( return output.logits, output.past_key_values + @property + def device(self): + return self.model.device + def __call__( self, input_ids: "torch.LongTensor", @@ -190,6 +197,182 @@ def __call__( return next_token_logits, kv_cache + def generate( + self, + prompts: Union[str, List[str]], + generation_parameters: GenerationParameters, + logits_processor: Optional["OutlinesLogitsProcessor"], + sampling_parameters: SamplingParameters, + ) -> Union[str, List[str], List[List[str]]]: + """Generate text using `transformers`. + + Arguments + --------- + prompts + A prompt or list of prompts. + generation_parameters + An instance of `GenerationParameters` that contains the prompt, + the maximum number of tokens, stop sequences and seed. All the + arguments to `SequenceGeneratorAdapter`'s `__cal__` method. + logits_processor + The logits processor to use when generating text. + sampling_parameters + An instance of `SamplingParameters`, a dataclass that contains + the name of the sampler to use and related parameters as available + in Outlines. + + Returns + ------- + The generated text + """ + if isinstance(prompts, str): + # convert to 2d + input_ids, attention_mask = self.tokenizer.encode([prompts]) + else: + input_ids, attention_mask = self.tokenizer.encode(prompts) + inputs = { + "input_ids": input_ids.to(self.model.device), + "attention_mask": attention_mask.to(self.model.device), + } + + generation_kwargs = self._get_generation_kwargs( + prompts, + generation_parameters, + logits_processor, + sampling_parameters, + ) + generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs) + + # if single str input and single sample per input, convert to a 1D output + if isinstance(prompts, str): + generated_ids = generated_ids.squeeze(0) + + return self._decode_generation(generated_ids) + + def stream( + self, + prompts: Union[str, List[str]], + generation_parameters: GenerationParameters, + logits_processor: Optional["OutlinesLogitsProcessor"], + sampling_parameters: SamplingParameters, + ) -> Iterator[Union[str, List[str]]]: + """ + Temporary stream stand-in which implements stream() signature + and equivalent behaviour but isn't iterable. + """ + if isinstance(prompts, str): + # convert to 2d + input_ids, attention_mask = self.tokenizer.encode([prompts]) + else: + input_ids, attention_mask = self.tokenizer.encode(prompts) + inputs = { + "input_ids": input_ids.to(self.model.device), + "attention_mask": attention_mask.to(self.model.device), + } + + generation_kwargs = self._get_generation_kwargs( + prompts, + generation_parameters, + logits_processor, + sampling_parameters, + ) + generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs) + + # if single str input and single sample per input, convert to a 1D output + if isinstance(prompts, str): + generated_ids = generated_ids.squeeze(0) + + for i in range(generated_ids.size(-1)): + output_group_ids = generated_ids.select(-1, i).unsqueeze(-1) + yield self._decode_generation(output_group_ids) + + def _get_generation_kwargs( + self, + prompts: Union[str, List[str]], + generation_parameters: GenerationParameters, + logits_processor: Optional["OutlinesLogitsProcessor"], + sampling_parameters: SamplingParameters, + ) -> dict: + """ + Conert outlines generation parameters into model.generate kwargs + """ + from transformers import GenerationConfig, LogitsProcessorList, set_seed + + max_new_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) + sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple( + sampling_parameters + ) + if max_new_tokens is None: + max_new_tokens = int(1e9) + + # global seed, not desirable + if seed is not None: + set_seed(seed) + + if logits_processor is not None: + logits_processor_list = LogitsProcessorList([logits_processor]) + else: + logits_processor_list = None + + generation_config = GenerationConfig( + max_new_tokens=max_new_tokens, + max_length=int(1e9), + stop_strings=stop_at, + num_return_sequences=(num_samples or 1), + top_p=top_p, + top_k=top_k, + temperature=temperature, + do_sample=(sampler == "multinomial"), + num_beams=(num_samples if sampler == "beam_search" else 1), + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + + return dict( + logits_processor=logits_processor_list, + generation_config=generation_config, + tokenizer=self.tokenizer.tokenizer, + ) + + def _generate_output_seq( + self, prompts, inputs, generation_config, **generation_kwargs + ): + input_ids = inputs["input_ids"] + output_ids = self.model.generate( + generation_config=generation_config, **inputs, **generation_kwargs + ) + + # encoder-decoder returns output_ids only, decoder-only returns full seq ids + if self.model.config.is_encoder_decoder: + generated_ids = output_ids + else: + generated_ids = output_ids[:, input_ids.shape[1] :] + + # if batch list inputs AND multiple samples per input, convert generated_id to 3D view + num_samples = generation_config.num_return_sequences or 1 + + if num_samples > 1 and isinstance(prompts, list): + batch_size = input_ids.size(0) + num_return_sequences = generation_config.num_return_sequences or 1 + generated_ids = generated_ids.view(batch_size, num_return_sequences, -1) + + return generated_ids + + def _decode_generation(self, generated_ids: "torch.Tensor"): + if len(generated_ids.shape) == 1: + return self.tokenizer.decode([generated_ids])[0] + elif len(generated_ids.shape) == 2: + return self.tokenizer.decode(generated_ids) + elif len(generated_ids.shape) == 3: + return [ + self.tokenizer.decode(generated_ids[i]) + for i in range(len(generated_ids)) + ] + else: + raise TypeError( + f"Generated outputs aren't 1D, 2D or 3D, but instead are {generated_ids.shape}" + ) + def transformers( model_name: str, diff --git a/outlines/processors/__init__.py b/outlines/processors/__init__.py index 5c6a697ed..22c10d905 100644 --- a/outlines/processors/__init__.py +++ b/outlines/processors/__init__.py @@ -1,7 +1,7 @@ from .structured import ( - BaseLogitsProcessor, CFGLogitsProcessor, FSMLogitsProcessor, JSONLogitsProcessor, + OutlinesLogitsProcessor, RegexLogitsProcessor, ) diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index dabfd91b0..3484106db 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import List, Protocol, Union +from typing import Any, List, Protocol, Union import numpy as np import torch @@ -14,10 +14,11 @@ def is_mlx_array(logits): return isinstance(logits, mx.array) -class BaseLogitsProcessor(Protocol): +class OutlinesLogitsProcessor(Protocol): """ Base class for logits processors which normalizes types of logits: - ndarray (used by llama-cpp-python), converted to torch.Tensor + - mlx.core.array (used by mlx-lm), converted to torch.Tensor - torch.Tensor (used by everything else) Normalization of types and conversion to torch.Tensor @@ -29,10 +30,22 @@ class BaseLogitsProcessor(Protocol): @abstractmethod def process_logits( - self, input_ids: List[int], logits: torch.Tensor + self, input_ids: List[List[int]], logits: torch.Tensor ) -> torch.Tensor: - ... + """ + input_ids and logits are always 2D tensors for handling a batch of sequences. + + - input_ids.shape[1] -> contains the sequence of int tokens + - logits.shape[0] -> Dimension 1 contains one + + Important to keep in mind when designing universal logits processors + - logits processors are only used once and never re-applied for a new sequence generator + - Some models only pass output_ids, some models such as llamacpp and transformers prefix with input_ids + - Some sampling methods, such as beam search, result in unstable sequence ordering in models like vLLM + """ + pass + @torch.no_grad() def __call__( self, input_ids: Union[NDArray[np.int64], List[int], torch.Tensor], @@ -40,39 +53,77 @@ def __call__( ) -> Union[NDArray[np.int64], torch.Tensor]: """ Apply logits processor + Unify type - - convert input_ids: either ndarray, List[int], or Tensor -> List[int] - - convert logits: either ndarray, mlx array, Tensor -> Tensor + - convert input_ids: either ndarray, List[int], or Tensor -> 2D tensor + - convert logits: either ndarray, mlx array, Tensor -> 2D Tensor + Call process_logits() to perform business logic """ - with torch.no_grad(): - if not isinstance(input_ids, list): - input_ids = input_ids.tolist() - - if isinstance(logits, np.ndarray): - # Unify type, convert numpy array to Tensor - # from_numpy and .numpy() don't copy the data, it uses the same memory address - torch_logits = torch.from_numpy(logits) - processed_torch_logits = self.process_logits(input_ids, torch_logits) - return processed_torch_logits.detach().numpy() - - elif isinstance(logits, torch.Tensor): - return self.process_logits(input_ids, logits) - - elif is_mlx_array(logits): - # mlx -> torch -> mlx conversion docs: - # https://ml-explore.github.io/mlx/build/html/usage/numpy.html - import mlx.core as mx - - torch_logits = torch.from_dlpack(logits) - processed_torch_logits = self.process_logits(input_ids, torch_logits) - - # numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch - logits_float32_numpy = processed_torch_logits.float().numpy() - return mx.array(logits_float32_numpy) - - else: - raise TypeError( - "LogitsProcessor must be called with either np.NDArray" - ", torch.Tensor, or mlx.core.array typed logits" - ) + + # ensure logits are torch Tensors + torch_logits = self._to_torch(logits) + + assert torch_logits.shape[:-1] == self._to_torch(input_ids).shape[:-1] + + # ensure input_ids are List + if not isinstance(input_ids, list): + input_ids = input_ids.tolist() # compatible with numpy, torch, and mlx + + # Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape + if len(torch_logits.shape) == 2: + processed_logits = self.process_logits(input_ids, logits) + elif len(torch_logits.shape) == 1: + processed_logits = self.process_logits( + [input_ids], torch_logits.unsqueeze(0) + ).squeeze(0) + + # return logits as passed array type + return self._from_torch(processed_logits, type(logits)) + + @staticmethod + def _to_torch(tensor_like: Any) -> torch.Tensor: + """Convert various types to torch.Tensor.""" + if isinstance(tensor_like, torch.Tensor): + return tensor_like + + elif isinstance(tensor_like, np.ndarray): + return torch.from_numpy(tensor_like) + + elif isinstance(tensor_like, list): + return torch.tensor(tensor_like) + + elif is_mlx_array(tensor_like): + # mlx -> torch -> mlx conversion docs: + # https://ml-explore.github.io/mlx/build/html/usage/numpy.html + return torch.from_dlpack(tensor_like) + + else: + raise TypeError( + "LogitsProcessor must be called with either np.NDArray" + ", torch.Tensor, list, or mlx.core.array typed logits" + ) + + @staticmethod + def _from_torch(tensor: torch.Tensor, target_type: Any) -> Any: + """Convert torch.Tensor to the specified target type.""" + if target_type == torch.Tensor: + return tensor + + elif target_type == np.ndarray: + return tensor.detach().numpy() + + elif target_type == list: + return tensor.detach().tolist() + + elif target_type == "mlx_array": + import mlx.core as mx + + # numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch + return mx.array(tensor.float().numpy()) + + else: + raise RuntimeError( + "Failed to convert torch tensors back to original dtype. {tensor}" + f"tensor={tensor}, target_type={target_type}" + ) diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index b8ef5b2da..271ea97ce 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -24,7 +24,7 @@ limitations under the License. """ import math -from typing import TYPE_CHECKING, List, Optional, Type, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union import numpy as np import torch @@ -35,13 +35,13 @@ from outlines.fsm.json_schema import build_regex_from_schema from outlines.integrations.utils import convert_json_schema_to_str -from .base_logits_processor import BaseLogitsProcessor +from .base_logits_processor import OutlinesLogitsProcessor if TYPE_CHECKING: from outlines.models.tokenizer import Tokenizer -class FSMLogitsProcessor(BaseLogitsProcessor): +class FSMLogitsProcessor(OutlinesLogitsProcessor): """Bias generation using a finite state machine. Attributes @@ -63,12 +63,13 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide): The finite state machine which is used to bias the logits. """ self.tokenizer = tokenizer - self._fsm_state = 0 + self._fsm_states: Dict[int, int] = {} self.fsm: Guide = fsm self._is_first_token = True + self._seq_start_idx: Optional[int] = None def process_logits( - self, input_ids: List[int], logits: torch.Tensor + self, input_ids: List[List[int]], logits: torch.Tensor ) -> NDArray[np.float32]: """Use the FSM to bias the logits before sampling the next token. @@ -84,17 +85,31 @@ def process_logits( torch.Tensor The biased logits. """ + sequence_states: List[int] = [] # vector of states corresponding to `input_ids` + if self._is_first_token: self._is_first_token = False + self._seq_start_idx = len(input_ids[0]) + + self._fsm_states = {hash(tuple([])): 0} + sequence_states = [0] * len(input_ids) + else: - last_token = input_ids[-1] - self._fsm_state = self.fsm.get_next_state(self._fsm_state, last_token) + for seq_ids in input_ids: + prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1])) + prev_state = self._fsm_states[prev_state_key] - allowed_tokens = self.fsm.get_next_instruction(self._fsm_state).tokens - allowed_tokens = torch.tensor(allowed_tokens, device=logits.device) + curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :])) + curr_state = self.fsm.get_next_state(prev_state, seq_ids[-1]) + + self._fsm_states[curr_state_key] = curr_state + sequence_states.append(curr_state) mask = torch.full_like(logits, -math.inf) - mask[allowed_tokens] = logits[allowed_tokens] + for i, fsm_state in enumerate(sequence_states): + allowed_tokens = self.fsm.get_next_instruction(fsm_state).tokens + mask[i, allowed_tokens] = logits[i, allowed_tokens] + return mask def copy(self) -> "FSMLogitsProcessor": diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index 1f1a3aea2..a8f59f1f9 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -1,9 +1,11 @@ +import contextlib import re import pytest import outlines.generate as generate import outlines.models as models +import outlines.samplers as samplers @pytest.fixture(scope="session") @@ -20,35 +22,160 @@ def model_mlxlm(tmp_path_factory): @pytest.fixture(scope="session") -def model_transformers(tmp_path_factory): - return models.transformers("Locutusque/TinyMistral-248M-v2-Instruct", device="cpu") +def model_transformers_random(tmp_path_factory): + return models.transformers("hf-internal-testing/tiny-random-gpt2", device="cpu") -@pytest.mark.parametrize( - "model_fixture", - ("model_llamacpp", "model_mlxlm", "model_transformers"), +@pytest.fixture(scope="session") +def model_transformers_distilgpt2(tmp_path_factory): + return models.transformers("distilbert/distilgpt2", device="cpu") + + +ALL_MODEL_FIXTURES = ( + "model_llamacpp", + "model_mlxlm", + "model_transformers_random", + "model_transformers_distilgpt2", ) -def test_generate_text(request, model_fixture): + + +NOT_IMPLEMENTED = { + "batch": ["model_llamacpp"], + "stream": ["model_vllm"], + "beam_search": ["model_llamacpp"], + "multiple_samples": ["model_llamacpp"], +} + + +def enforce_not_implemented(model_fixture, *task_names): + """ + Per `NOT_IMPLEMENTED`, mapping, if a model hasn't implemented a task, + assert an NotImplementedError is raised. Otherwise, run normally + """ + for task_name in task_names: + if model_fixture in NOT_IMPLEMENTED.get(task_name, []): + return pytest.raises(NotImplementedError) + else: + return contextlib.nullcontext() + + +REGEX_PATTERNS = [ + "(123456789)|(abcdefghijklmnop)", + "abc*", + "\\+?[1-9][0-9]{7,14}", + r"([a-z]{10})@([a-z]{5})\.([a-z]{3})", +] + + +@pytest.mark.parametrize("sampler_name", ("greedy", "multinomial", "beam_search")) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_text(request, model_fixture, sampler_name): + model = request.getfixturevalue(model_fixture) + generator = generate.text(model, getattr(samplers, sampler_name)()) + with enforce_not_implemented(model_fixture, sampler_name): + res = generator("test", max_tokens=10) + assert isinstance(res, str) + + +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_batch_text(request, model_fixture): model = request.getfixturevalue(model_fixture) generator = generate.text(model) - res = generator("test", max_tokens=10) - assert isinstance(res, str) + with enforce_not_implemented(model_fixture, "batch"): + res = generator(["test", "test2"], max_tokens=10) + assert isinstance(res, list) + assert isinstance(res[0], str) -@pytest.mark.parametrize( - "model_fixture", - ("model_llamacpp", "model_mlxlm", "model_transformers"), -) -@pytest.mark.parametrize( - "pattern", - ( - "[0-9]", - "abc*", - "\\+?[1-9][0-9]{7,14}", - ), -) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_text_stream(request, model_fixture): + model = request.getfixturevalue(model_fixture) + generator = generate.text(model) + with enforce_not_implemented(model_fixture, "stream"): + for token in generator.stream("a b c ", max_tokens=10): + assert isinstance(token, str) + + +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) def test_generate_regex(request, model_fixture, pattern): model = request.getfixturevalue(model_fixture) generator = generate.regex(model, pattern) res = generator("foobarbaz", max_tokens=20) - assert re.match(pattern, res) is not None, res + assert re.fullmatch(pattern, res) is not None, res + + +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_regex_stream(request, model_fixture, pattern): + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern) + with enforce_not_implemented(model_fixture, "stream"): + output = "" + for token in generator.stream("output:", max_tokens=20): + output += token + assert re.fullmatch(pattern, output) is not None, output + + +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_regex_batch_stream(request, model_fixture, pattern): + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern) + with enforce_not_implemented(model_fixture, "batch", "stream"): + outputs = ["", ""] + for tokens in generator.stream(["input 0", "input 1"], max_tokens=20): + outputs[0] += tokens[0] + outputs[1] += tokens[1] + for output in outputs: + assert re.fullmatch(pattern, output) is not None, output + + +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_regex_batch(request, model_fixture, pattern): + """Ensure batch requests work and fsm order is maintained""" + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern) + with enforce_not_implemented(model_fixture, "batch"): + outputs = generator(["abc", "123", "123bce", "33aa"], max_tokens=20) + for output in outputs: + assert re.fullmatch(pattern, output) is not None, output + + +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_regex_single_multinomial(request, model_fixture, pattern): + """Ensure batch requests work and fsm order is maintained""" + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern, sampler=samplers.multinomial(4)) + with enforce_not_implemented(model_fixture, "multiple_samples"): + output_sample_groups = generator("single input", max_tokens=40) + for output in output_sample_groups: + assert re.fullmatch(pattern, output) is not None, output + + +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_regex_batch_multinomial(request, model_fixture, pattern): + """Ensure batch requests work and fsm order is maintained""" + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern, sampler=samplers.multinomial(4)) + with enforce_not_implemented(model_fixture, "batch", "multiple_samples"): + output_batch_groups = generator(["abc", "123", "123bce", "33aa"], max_tokens=40) + for output_sample_groups in output_batch_groups: + for output in output_sample_groups: + assert re.fullmatch(pattern, output) is not None, output + + +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_regex_batch_beam_search(request, model_fixture, pattern): + """Ensure batch requests work and fsm order is maintained""" + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern, sampler=samplers.beam_search(4)) + with enforce_not_implemented(model_fixture, "batch", "multiple_samples"): + output_batch_groups = generator(["abc", "123", "123bce", "33aa"], max_tokens=40) + for output_sample_groups in output_batch_groups: + for output in output_sample_groups: + assert re.fullmatch(pattern, output) is not None, output diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index b7eb8b3cb..72070919a 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -25,7 +25,7 @@ def model(tmp_path_factory): ( (generate.text, []), (generate.regex, ("[0-9]",)), - (generate.cfg, (grammars.arithmetic,)), + # (generate.cfg, (grammars.arithmetic,)), # Awaiting CFG fix ), ) def test_llamacpp_generation_api(model, generator_type, params): @@ -245,8 +245,11 @@ def test_llamacpp_json_schema(model): def test_llamacpp_cfg(model): prompt = "<|im_start|>user\nOutput a short and valid JSON object with two keys.<|im_end|>\n><|im_start|>assistant\n" - result = generate.cfg(model, grammars.arithmetic)(prompt, seed=11) - assert isinstance(result, str) + + # remove this statement once cfg is implemented + with pytest.raises(NotImplementedError): + result = generate.cfg(model, grammars.arithmetic)(prompt, seed=11) + assert isinstance(result, str) @pytest.mark.parametrize( diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index da08bed71..f3fb9682e 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -14,44 +14,50 @@ from outlines.samplers import beam_search, greedy, multinomial -def test_transformers_integration_text(): - rng = torch.Generator() - rng.manual_seed(10000) # Choosen so is generated +@pytest.fixture(scope="session") +def model(tmp_path_factory): + return models.transformers( + "hf-internal-testing/tiny-random-GPTJForCausalLM", device="cpu" + ) - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") - sequence = generate.text(model)("Write a short sentence ", rng=rng) + +def test_transformers_integration_text(model): + sequence = generate.text(model)( + "Write a short sentence ", seed=10000, max_tokens=10 + ) assert isinstance(sequence, str) assert model.tokenizer.eos_token not in sequence sequence = generate.text(model)( - "Write a short sentence ", max_tokens=10, stop_at=".", rng=rng + "Write a short sentence ", + max_tokens=20, + stop_at="a", + seed=10000, ) assert isinstance(sequence, str) prompts = ["Write a short sentence ", "And another one "] - sequence = generate.text(model)(prompts, max_tokens=10, stop_at=[".", ","], rng=rng) + sequence = generate.text(model)( + prompts, max_tokens=10, stop_at=[".", ","], seed=10000 + ) assert isinstance(sequence, list) assert len(sequence) == 2 assert isinstance(sequence[0], str) -def test_transformers_integration_text_multiple_samples(): - rng = torch.Generator() - rng.manual_seed(10000) # Choosen so is generated - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_integration_text_multiple_samples(model): sampler = multinomial(2) - sequence = generate.text(model, sampler=sampler)("Write a short sentence ", rng=rng) + sequence = generate.text(model, sampler=sampler)( + "Write a short sentence ", seed=10000 + ) assert isinstance(sequence, list) assert len(sequence) == 2 assert model.tokenizer.eos_token not in sequence prompts = ["Write a short sentence ", "And another one "] sequence = generate.text(model, sampler=sampler)( - prompts, max_tokens=10, stop_at=[".", ","], rng=rng + prompts, max_tokens=10, stop_at=[".", ","], seed=10000 ) assert isinstance(sequence, list) assert len(sequence) == 2 @@ -60,14 +66,9 @@ def test_transformers_integration_text_multiple_samples(): assert isinstance(sequence[0][0], str) -def test_transformers_integration_streaming(): - rng = torch.Generator() - rng.manual_seed(10000) # Choosen so is generated - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_integration_streaming(model): sequence = generate.text(model).stream( - "Write a short sentence ", max_tokens=10, stop_at=[".", ","], rng=rng + "Write a short sentence ", max_tokens=10, stop_at=[".", ","], seed=10000 ) token = next(sequence) @@ -77,7 +78,7 @@ def test_transformers_integration_streaming(): assert isinstance(remaining, str) sequence = generate.text(model).stream( - ["Prompt1", "Prompt2"], max_tokens=10, stop_at=[".", ","], rng=rng + ["Prompt1", "Prompt2"], max_tokens=10, stop_at=[".", ","], seed=10000 ) tokens = next(sequence) assert isinstance(tokens, list) @@ -85,19 +86,11 @@ def test_transformers_integration_streaming(): assert isinstance(tokens[1], str) -def test_transformers_integration_streaming_batch_samples(): - rng = torch.Generator() - rng.manual_seed(10000) # Choosen so is generated - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_integration_streaming_batch_samples(model): sampler = multinomial(samples=2) sequence = generate.text(model, sampler=sampler).stream( - ["Prompt1", "Prompt2"], - max_tokens=10, - stop_at=[".", ","], - rng=rng, + ["Prompt1", "Prompt2"], max_tokens=10, stop_at=[".", ","], seed=10000 ) tokens = next(sequence) assert isinstance(tokens, list) @@ -108,19 +101,11 @@ def test_transformers_integration_streaming_batch_samples(): assert len(tokens[1]) == 2 -def test_transformers_integration_streaming_batch_beam_search(): - rng = torch.Generator() - rng.manual_seed(10000) # Choosen so is generated - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_integration_streaming_batch_beam_search(model): sampler = beam_search(beams=2) - sequence = generate.text(model, sampler=sampler).stream( - ["Prompt1", "Prompt2"], - max_tokens=10, - stop_at=[".", ","], - rng=rng, + sequence = generate.regex(model, r"ab[cd]e", sampler=sampler).stream( + ["Prompt1", "Prompt2"], max_tokens=10, stop_at=["c", "d"], seed=10000 ) tokens = next(sequence) assert isinstance(tokens, list) @@ -131,61 +116,41 @@ def test_transformers_integration_streaming_batch_beam_search(): assert len(tokens[1]) == 2 -def test_transformers_integration_text_stop(): - rng = torch.Generator() - rng.manual_seed(10000) # Choosen so is generated - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_integration_text_stop(model): prompt = "Write a short sentence " - sequence = generate.text(model)(prompt, stop_at="a", rng=rng) + sequence = generate.text(model)(prompt, stop_at="a", seed=10000) assert sequence[len(prompt) :].find("a") == -1 -def test_transformers_various_regexes(): - rng = torch.Generator() - rng.manual_seed(0) - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_various_regexes(model): prompt = "Write an email address" regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})" generator = generate.regex(model, regex_str) # One prompt - sequence = generator(prompt, rng=rng) + sequence = generator(prompt, seed=0) assert re.fullmatch(regex_str, sequence) is not None -def test_transformers_various_regexes_prompt_list(): - rng = torch.Generator() - rng.manual_seed(0) - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_various_regexes_prompt_list(model): prompt = "Write an email address" regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})" generator = generate.regex(model, regex_str) # Two prompts - sequence = generator([prompt, prompt], rng=rng) + sequence = generator([prompt, prompt], seed=0) assert re.fullmatch(regex_str, sequence[0]) is not None assert re.fullmatch(regex_str, sequence[1]) is not None -def test_transformers_various_regexes_prompt_list_multiple_samples(): - rng = torch.Generator() - rng.manual_seed(0) - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_various_regexes_prompt_list_multiple_samples(model): sampler = multinomial(samples=2) prompt = "Write an email address" regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})" generator = generate.regex(model, regex_str, sampler=sampler) # Two prompts - sequence = generator([prompt, prompt], rng=rng) + sequence = generator([prompt, prompt], seed=0, max_tokens=500) assert isinstance(sequence, list) assert len(sequence) == 2 assert re.fullmatch(regex_str, sequence[0][0]) is not None @@ -194,12 +159,7 @@ def test_transformers_various_regexes_prompt_list_multiple_samples(): assert re.fullmatch(regex_str, sequence[1][1]) is not None -def test_transformers_various_regexes_prompt_list_beam_search(): - rng = torch.Generator() - rng.manual_seed(0) - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_various_regexes_prompt_list_beam_search(model): sampler = beam_search(5) prompt_1 = "Write an email address" prompt_2 = "Random" @@ -207,7 +167,7 @@ def test_transformers_various_regexes_prompt_list_beam_search(): generator = generate.regex(model, regex_str, sampler=sampler) # Two prompts - sequence = generator([prompt_1, prompt_2], rng=rng) + sequence = generator([prompt_1, prompt_2], seed=0) assert isinstance(sequence, list) assert len(sequence) == 2 assert len(sequence[0]) == 5 @@ -217,105 +177,65 @@ def test_transformers_various_regexes_prompt_list_beam_search(): assert re.fullmatch(regex_str, sequence[1][1]) is not None -def test_transformers_integration_integer(): - rng = torch.Generator() - rng.manual_seed(0) - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name) +def test_transformers_integration_integer(model): prompt = "Write a short sentence" - sequence = generate.format(model, int)(prompt, max_tokens=10, rng=rng) + sequence = generate.format(model, int)(prompt, max_tokens=10, seed=0) assert isinstance(sequence, int) -def test_transformers_integration_integer_array(): - rng = torch.Generator() - rng.manual_seed(0) - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name) +def test_transformers_integration_integer_array(model): prompts = ["Give me a number", "And another one"] - sequence = generate.format(model, int)(prompts, max_tokens=10, rng=rng) + sequence = generate.format(model, int)(prompts, max_tokens=10, seed=0) assert isinstance(sequence, list) assert len(sequence) == 2 assert isinstance(sequence[0], int) assert isinstance(sequence[1], int) -def test_transformers_integration_float(): - rng = torch.Generator() - rng.manual_seed(0) - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name) +def test_transformers_integration_float(model): prompt = "Write a short sentence" - sequence = generate.format(model, float)(prompt, max_tokens=10, rng=rng) + sequence = generate.format(model, float)(prompt, max_tokens=10, seed=0) assert sequence != "" assert isinstance(sequence, float) -def test_transformers_integration_bool(): - rng = torch.Generator() - rng.manual_seed(0) - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name) +def test_transformers_integration_bool(model): prompt = "Is this True or False?" - sequence = generate.format(model, bool)(prompt, max_tokens=10, rng=rng) + sequence = generate.format(model, bool)(prompt, max_tokens=10, seed=0) assert sequence != "" assert isinstance(sequence, bool) -def test_transformers_integration_date(): - rng = torch.Generator() - rng.manual_seed(0) - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name) +def test_transformers_integration_date(model): prompt = "What day is it today?" - sequence = generate.format(model, datetime.date)(prompt, max_tokens=10, rng=rng) + sequence = generate.format(model, datetime.date)(prompt, max_tokens=10, seed=0) assert sequence != "" assert isinstance(sequence, datetime.date) -def test_transformers_integration_time(): - rng = torch.Generator() - rng.manual_seed(0) - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name) +def test_transformers_integration_time(model): prompt = "What time is it?" - sequence = generate.format(model, datetime.time)(prompt, max_tokens=10, rng=rng) + sequence = generate.format(model, datetime.time)(prompt, max_tokens=10, seed=0) assert sequence != "" assert isinstance(sequence, datetime.time) -def test_transformers_integration_datetime(): - rng = torch.Generator() - rng.manual_seed(0) - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name) +def test_transformers_integration_datetime(model): prompt = "What time is it?" - sequence = generate.format(model, datetime.datetime)(prompt, max_tokens=20, rng=rng) + sequence = generate.format(model, datetime.datetime)(prompt, max_tokens=20, seed=0) assert sequence != 0 assert isinstance(sequence, datetime.datetime) -def test_transformers_integration_choice(): - rng = torch.Generator() - rng.manual_seed(0) - - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_integration_choice(model): prompt = "Write a short sentence " - sequence = generate.choice(model, ["test", "choice"])(prompt, rng=rng) + sequence = generate.choice(model, ["test", "choice"])(prompt, seed=0) assert sequence == "test" or sequence == "choice" @@ -327,9 +247,7 @@ def test_transformers_integration_with_pad_token(): assert model.tokenizer.pad_token == "" -def test_transformers_json_basic(): - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_json_basic(model): prompt = "Output some JSON " class Spam(BaseModel): @@ -338,10 +256,7 @@ class Spam(BaseModel): spam: constr(max_length=10) fuzz: bool - rng = torch.Generator() - rng.manual_seed(0) # make sure that `bar` is not an int - - result = generate.json(model, Spam)(prompt, max_tokens=500, rng=rng) + result = generate.json(model, Spam)(prompt, max_tokens=500, seed=0) assert isinstance(result, BaseModel) assert isinstance(result.foo, int) assert isinstance(result.bar, float) @@ -350,9 +265,7 @@ class Spam(BaseModel): assert len(result.spam) <= 10 -def test_transformers_json_schema(): - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_json_schema(model): prompt = "Output some JSON " schema = """{ @@ -366,18 +279,13 @@ def test_transformers_json_schema(): } """ - rng = torch.Generator() - rng.manual_seed(0) # make sure that `bar` is not an int - - result = generate.json(model, schema)(prompt, max_tokens=500, rng=rng) + result = generate.json(model, schema)(prompt, max_tokens=500, seed=0) assert isinstance(result, dict) assert isinstance(result["foo"], int) assert isinstance(result["bar"], str) -def test_transformers_json_batch(): - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_json_batch(model): prompts = ["Output some JSON ", "Output more JSON"] class Spam(BaseModel): @@ -386,17 +294,12 @@ class Spam(BaseModel): spam: constr(max_length=10) fuzz: bool - rng = torch.Generator() - rng.manual_seed(0) # make sure that `bar` is not an int - - result = generate.json(model, Spam)(prompts, max_tokens=500, rng=rng) + result = generate.json(model, Spam)(prompts, max_tokens=500, seed=0) assert isinstance(result[0], BaseModel) assert isinstance(result[1], BaseModel) -def test_transformers_json_batch_multiple_samples(): - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_json_batch_multiple_samples(model): sampler = multinomial(samples=2) prompts = ["Output some JSON ", "Output more JSON"] @@ -406,11 +309,8 @@ class Spam(BaseModel): spam: constr(max_length=10) fuzz: bool - rng = torch.Generator() - rng.manual_seed(0) # make sure that `bar` is not an int - result = generate.json(model, Spam, sampler=sampler)( - prompts, max_tokens=500, rng=rng + prompts, max_tokens=500, seed=0 ) assert isinstance(result, list) assert len(result) == 2 @@ -420,14 +320,9 @@ class Spam(BaseModel): assert isinstance(result[1][1], BaseModel) -def test_transformers_json_str_enum(): - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_json_str_enum(model): prompt = "Output some JSON " - rng = torch.Generator() - rng.manual_seed(0) - class Name(str, Enum): john = "John" marc = "Marc" @@ -437,20 +332,15 @@ class User(BaseModel): user_id: int name: Name - result = generate.json(model, User)(prompt, rng=rng) + result = generate.json(model, User)(prompt, seed=0) assert isinstance(result, BaseModel) assert isinstance(result.user_id, int) assert result.name in ["John", "Marc", "Michel"] -def test_transformers_json_int_enum(): - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_json_int_enum(model): prompt = "Output some JSON " - rng = torch.Generator() - rng.manual_seed(0) - class Id(int, Enum): one = 1 two = 2 @@ -458,25 +348,20 @@ class Id(int, Enum): class User(BaseModel): user_id: Id - result = generate.json(model, User)(prompt, rng=rng) + result = generate.json(model, User)(prompt, seed=0) assert isinstance(result, BaseModel) assert isinstance(result.user_id, int) assert result.user_id in [1, 2] -def test_transformers_json_array(): - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_json_array(model): prompt = "Output some JSON " class User(BaseModel): user_id: int value: List[float] - rng = torch.Generator() - rng.manual_seed(0) - - result = generate.json(model, User)(prompt, rng=rng) + result = generate.json(model, User)(prompt, seed=0) assert isinstance(result, BaseModel) assert isinstance(result.user_id, int) assert isinstance(result.value, list) @@ -485,19 +370,14 @@ class User(BaseModel): @pytest.mark.xfail(reason="The implementation of `anyOf` is incorrect") -def test_transformers_json_union(): - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_json_union(model): prompt = "Output some JSON " class Spam(BaseModel): foo: int bar: Union[constr(max_length=10), float] - rng = torch.Generator() - rng.manual_seed(4) - - result = generate.json(model, Spam)(prompt, max_tokens=100, rng=rng) + result = generate.json(model, Spam)(prompt, max_tokens=100, seed=4) assert isinstance(result, BaseModel) assert ( isinstance(result.bar, int) @@ -506,26 +386,18 @@ class Spam(BaseModel): ) -def test_transformers_json_function(): - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name) +def test_transformers_json_function(model): prompt = "Output arguments for the function" def function(foo: int, bar: List[int]): return foo + sum(bar) - rng = torch.Generator() - rng.manual_seed(4) - - sequence = generate.json(model, function)(prompt, max_tokens=100, rng=rng) + sequence = generate.json(model, function)(prompt, max_tokens=100, seed=4) assert isinstance(sequence, dict) assert isinstance(function(**sequence), int) -def test_transformers_logits_vocab_size(): - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") - +def test_transformers_logits_vocab_size(model): # Artificially increase the weights/logits size relative # to the vocabulary model.model.resize_token_embeddings(pad_to_multiple_of=3) @@ -535,16 +407,11 @@ def test_transformers_logits_vocab_size(): generator = generate.choice(model, ["True", "False"]) - rng = torch.Generator() - rng.manual_seed(101) - - sequence = generator("blah", rng=rng) + sequence = generator("blah", seed=101) assert sequence == "False" -def test_transformers_json_custom_ws(): - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") +def test_transformers_json_custom_ws(model): prompt = "Output some JSON with newlines" # try to force model to use newlines schema = """{ @@ -558,12 +425,9 @@ def test_transformers_json_custom_ws(): } """ - rng = torch.Generator() - rng.manual_seed(0) - generator = generate.json(model, schema, whitespace_pattern=r"[ ]?") generator.format_sequence = lambda x: x # patch to return raw text - assert "\n" not in generator(prompt, max_tokens=500, rng=rng) + assert "\n" not in generator(prompt, max_tokens=500, seed=0) def test_transformers_reduced_vocabulary_caching(): @@ -582,11 +446,8 @@ def test_transformers_reduced_vocabulary_caching(): assert vocab2 is vocab -def test_custom_sampler(): - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - - model = models.transformers(model_name) - +@pytest.mark.skip(reason="Custom Sampler Disabled in Transformers Integration") +def test_custom_sampler(model): seen = False target_token_ids = model.tokenizer.encode(["c"])[0] @@ -623,14 +484,11 @@ def __call__( def test_transformers_use_existing_model_and_tokenizer(): from transformers import AutoModelForCausalLM, AutoTokenizer - rng = torch.Generator() - rng.manual_seed(10000) - model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" hf_tokenizer = AutoTokenizer.from_pretrained(model_name) hf_model = AutoModelForCausalLM.from_pretrained(model_name) model = Transformers(hf_model, hf_tokenizer) - sequence = generate.text(model)("Write a short sentence ", rng=rng) + sequence = generate.text(model)("Write a short sentence ", seed=10000) assert isinstance(sequence, str) @@ -652,22 +510,30 @@ def test_RegexGuide_caching(temp_cache_dir): assert create_states_mapping.__memory__ is cache model = models.transformers( - "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM" + "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM", device="cpu" ) generator = generate.regex(model, regex, sampler=greedy()) assert cache.stats() == (0, 1) - model_2 = models.transformers("hf-internal-testing/tiny-random-GPTJForCausalLM") + model_2 = models.transformers( + "hf-internal-testing/tiny-random-GPTJForCausalLM", device="cpu" + ) generator_2 = generate.regex(model_2, regex, sampler=greedy()) assert cache.stats() == (0, 2) # These two different models and tokenizers should not have the same state # mapping results - assert generator.fsm.states_to_token_maps != generator_2.fsm.states_to_token_maps + assert ( + generator.logits_processor.fsm.states_to_token_maps + != generator_2.logits_processor.fsm.states_to_token_maps + ) generator_3 = generate.regex(model_2, regex, sampler=greedy()) assert cache.stats() == (1, 2) - assert generator_2.fsm.states_to_token_maps == generator_3.fsm.states_to_token_maps + assert ( + generator_2.logits_processor.fsm.states_to_token_maps + == generator_3.logits_processor.fsm.states_to_token_maps + ) # Just for fun... structured = generator(prompt, max_tokens=30) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index f4596a2df..097858fd0 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -63,9 +63,6 @@ def test_llama_tokenizer(): def test_model(): - with pytest.raises(ValueError, match="When passing device_map as a string"): - transformers(TEST_MODEL, device="non_existent") - model = transformers(TEST_MODEL, device="cpu") assert isinstance(model.tokenizer, TransformerTokenizer) assert model.device.type == "cpu"