From 00052d52290ece908b2b514d39bd027548ddd434 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sat, 22 Jun 2024 13:47:08 -0500 Subject: [PATCH] Improve outlines.processors, add integration tests to test_generate.py --- benchmarks/bench_processors.py | 44 +++++ outlines/__init__.py | 1 + outlines/models/mlxlm.py | 6 +- outlines/processors/__init__.py | 2 +- outlines/processors/base_logits_processor.py | 145 +++++++++++----- outlines/processors/structured.py | 39 +++-- tests/generate/test_generate.py | 169 ++++++++++++++++--- 7 files changed, 324 insertions(+), 82 deletions(-) create mode 100644 benchmarks/bench_processors.py diff --git a/benchmarks/bench_processors.py b/benchmarks/bench_processors.py new file mode 100644 index 000000000..96b74b78f --- /dev/null +++ b/benchmarks/bench_processors.py @@ -0,0 +1,44 @@ +import mlx.core as mx +import numpy as np +import torch + +from outlines.processors import OutlinesLogitsProcessor + + +class HalvingLogitsProcessor(OutlinesLogitsProcessor): + """Simply halve the passed logits""" + + def process_logits(self, input_ids, logits): + return logits / 2 + + +class LogitsProcessorBenchmark: + params = ["torch", "numpy"] + if mx.metal.is_available(): + params += ["mlx"] + + def setup(self, array_library): + self.logits_processor = HalvingLogitsProcessor() + + # logits: (4, 30,000 ) dtype=float + # input_ids shape: (4, 2048) dtype=int + if array_library == "torch": + self.logits = torch.rand((4, 30000), dtype=torch.float) + self.input_ids = torch.randint( + low=0, high=30000, size=(4, 2048), dtype=torch.int + ) + elif array_library == "numpy": + self.logits = np.random.rand(4, 30000).astype(np.float32) + self.input_ids = np.random.randint(low=0, high=30000, size=(4, 2048)) + elif array_library == "mlx": + self.logits = mx.random.uniform( + low=-1e9, high=1e9, shape=(4, 30000), dtype=mx.float32 + ) + self.input_ids = mx.random.randint( + low=0, high=30000, shape=(4, 2048), dtype=mx.int32 + ) + else: + raise ValueError + + def time_logits_processor(self, array_library): + self.logits_processor(self.input_ids, self.logits) diff --git a/outlines/__init__.py b/outlines/__init__.py index 3eb6a2f94..307d2ba6f 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -2,6 +2,7 @@ import outlines.generate import outlines.grammars import outlines.models +import outlines.processors import outlines.types from outlines.base import vectorize from outlines.caching import clear_cache, disable_cache, get_cache diff --git a/outlines/models/mlxlm.py b/outlines/models/mlxlm.py index f561f269d..57aa6f596 100644 --- a/outlines/models/mlxlm.py +++ b/outlines/models/mlxlm.py @@ -9,7 +9,7 @@ from transformers import PreTrainedTokenizer from outlines.generate.api import GenerationParameters, SamplingParameters - from outlines.processors import BaseLogitsProcessor + from outlines.processors import OutlinesLogitsProcessor class MLXLM: @@ -120,7 +120,7 @@ def generate_step( temp: Optional[float], top_p: Optional[float], sampler: str, - logits_processor: "BaseLogitsProcessor", + logits_processor: "OutlinesLogitsProcessor", ) -> Generator[Tuple[int, float], None, None]: """ Adapted from @@ -135,7 +135,7 @@ def generate_step( top_p (float, optional): Nulceus sampling, higher means model considers more less likely words. sampler (str): The sampler string defined by SequenceGeneratorAdapter - logits_processor (BaseLogitsProcessor): Augment logits before sampling. + logits_processor (OutlinesLogitsProcessor): Augment logits before sampling. """ import mlx.core as mx import mlx_lm diff --git a/outlines/processors/__init__.py b/outlines/processors/__init__.py index 5c6a697ed..22c10d905 100644 --- a/outlines/processors/__init__.py +++ b/outlines/processors/__init__.py @@ -1,7 +1,7 @@ from .structured import ( - BaseLogitsProcessor, CFGLogitsProcessor, FSMLogitsProcessor, JSONLogitsProcessor, + OutlinesLogitsProcessor, RegexLogitsProcessor, ) diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index dabfd91b0..f829844a6 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -1,23 +1,30 @@ from abc import abstractmethod -from typing import List, Protocol, Union +from typing import TYPE_CHECKING, List, Protocol, Type, Union import numpy as np import torch from numpy.typing import NDArray +if TYPE_CHECKING: + import mlx.core as mx -def is_mlx_array(logits): + +Array = Union[NDArray, torch.Tensor, List, "mx.array"] + + +def is_mlx_array_type(array_type): try: import mlx.core as mx except ImportError: return False - return isinstance(logits, mx.array) + return issubclass(array_type, mx.array) -class BaseLogitsProcessor(Protocol): +class OutlinesLogitsProcessor(Protocol): """ Base class for logits processors which normalizes types of logits: - ndarray (used by llama-cpp-python), converted to torch.Tensor + - mlx.core.array (used by mlx-lm), converted to torch.Tensor - torch.Tensor (used by everything else) Normalization of types and conversion to torch.Tensor @@ -29,50 +36,100 @@ class BaseLogitsProcessor(Protocol): @abstractmethod def process_logits( - self, input_ids: List[int], logits: torch.Tensor + self, input_ids: List[List[int]], logits: torch.Tensor ) -> torch.Tensor: - ... + """ + input_ids and logits are always 2D tensors for handling a batch of sequences. + + - input_ids -> List[List[tokens]] + - logits.shape[0] -> 2D_Tensor[logits] + + Important to keep in mind when designing universal logits processors + - logits processors are only used once and never re-applied for a new sequence generator + - Some models only pass output_ids, some models such as llamacpp and transformers prefix with input_ids + - Some sampling methods, such as beam search, result in unstable sequence ordering in models like vLLM + """ + pass + @torch.no_grad() def __call__( self, - input_ids: Union[NDArray[np.int64], List[int], torch.Tensor], - logits: Union[NDArray[np.float32], torch.Tensor], - ) -> Union[NDArray[np.int64], torch.Tensor]: + input_ids: Array, + logits: Array, + ) -> Array: """ Apply logits processor - Unify type - - convert input_ids: either ndarray, List[int], or Tensor -> List[int] - - convert logits: either ndarray, mlx array, Tensor -> Tensor - Call process_logits() to perform business logic + + 1) Unify type + - convert input_ids: either ndarray, mlx array, List[int], or Tensor -> List[List[int]] + - convert logits: either ndarray, mlx array, or Tensor -> 2D float Tensor + 2) Unify shape, ensure logits and input_ids are 2D + 3) Call self.process_logits() to perform business logic + 4) Cast logits back to original array library type """ - with torch.no_grad(): - if not isinstance(input_ids, list): - input_ids = input_ids.tolist() - - if isinstance(logits, np.ndarray): - # Unify type, convert numpy array to Tensor - # from_numpy and .numpy() don't copy the data, it uses the same memory address - torch_logits = torch.from_numpy(logits) - processed_torch_logits = self.process_logits(input_ids, torch_logits) - return processed_torch_logits.detach().numpy() - - elif isinstance(logits, torch.Tensor): - return self.process_logits(input_ids, logits) - - elif is_mlx_array(logits): - # mlx -> torch -> mlx conversion docs: - # https://ml-explore.github.io/mlx/build/html/usage/numpy.html - import mlx.core as mx - - torch_logits = torch.from_dlpack(logits) - processed_torch_logits = self.process_logits(input_ids, torch_logits) - - # numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch - logits_float32_numpy = processed_torch_logits.float().numpy() - return mx.array(logits_float32_numpy) - - else: - raise TypeError( - "LogitsProcessor must be called with either np.NDArray" - ", torch.Tensor, or mlx.core.array typed logits" - ) + + # ensure logits are torch Tensors + torch_logits = self._to_torch(logits) + + assert torch_logits.shape[:-1] == self._to_torch(input_ids).shape[:-1] + + # ensure input_ids are List + if not isinstance(input_ids, list): + input_ids = input_ids.tolist() # compatible with numpy, torch, and mlx + + # Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape + if len(torch_logits.shape) == 2: + processed_logits = self.process_logits(input_ids, torch_logits) + elif len(torch_logits.shape) == 1: + processed_logits = self.process_logits( + [input_ids], torch_logits.unsqueeze(0) + ).squeeze(0) + + # return logits as passed array type + return self._from_torch(processed_logits, type(logits)) + + @staticmethod + def _to_torch(tensor_like: Array) -> torch.Tensor: + """Convert various types to torch.Tensor.""" + if isinstance(tensor_like, torch.Tensor): + return tensor_like + + elif isinstance(tensor_like, np.ndarray): + return torch.from_numpy(tensor_like) + + elif isinstance(tensor_like, list): + return torch.tensor(tensor_like) + + elif is_mlx_array_type(type(tensor_like)): + # mlx -> torch -> mlx conversion docs: + # https://ml-explore.github.io/mlx/build/html/usage/numpy.html + return torch.from_dlpack(tensor_like) + + else: + raise TypeError( + "LogitsProcessor must be called with either np.NDArray, " + "torch.Tensor, list, or mlx.core.array typed logits" + ) + + @staticmethod + def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array: + """Convert torch.Tensor to the specified target type.""" + if target_type == torch.Tensor: + return tensor + + elif target_type == np.ndarray: + return tensor.detach().numpy() + + elif target_type == list: + return tensor.detach().tolist() + + elif is_mlx_array_type(target_type): + import mlx.core as mx + + # numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch + return mx.array(tensor.float().numpy()) + + else: + raise TypeError( + f"Failed to convert torch tensors to target_type `{target_type}`" + ) diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index b8ef5b2da..d037c679f 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -24,24 +24,22 @@ limitations under the License. """ import math -from typing import TYPE_CHECKING, List, Optional, Type, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union -import numpy as np import torch -from numpy.typing import NDArray from pydantic import BaseModel from outlines.fsm.guide import CFGGuide, Guide, RegexGuide from outlines.fsm.json_schema import build_regex_from_schema from outlines.integrations.utils import convert_json_schema_to_str -from .base_logits_processor import BaseLogitsProcessor +from .base_logits_processor import OutlinesLogitsProcessor if TYPE_CHECKING: from outlines.models.tokenizer import Tokenizer -class FSMLogitsProcessor(BaseLogitsProcessor): +class FSMLogitsProcessor(OutlinesLogitsProcessor): """Bias generation using a finite state machine. Attributes @@ -63,13 +61,14 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide): The finite state machine which is used to bias the logits. """ self.tokenizer = tokenizer - self._fsm_state = 0 + self._fsm_states: Dict[int, int] = {} self.fsm: Guide = fsm self._is_first_token = True + self._seq_start_idx: Optional[int] = None def process_logits( - self, input_ids: List[int], logits: torch.Tensor - ) -> NDArray[np.float32]: + self, input_ids: List[List[int]], logits: torch.Tensor + ) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token. Parameters @@ -84,17 +83,31 @@ def process_logits( torch.Tensor The biased logits. """ + sequence_states: List[int] = [] # vector of states corresponding to `input_ids` + if self._is_first_token: self._is_first_token = False + self._seq_start_idx = len(input_ids[0]) + + self._fsm_states = {hash(tuple([])): 0} + sequence_states = [0] * len(input_ids) + else: - last_token = input_ids[-1] - self._fsm_state = self.fsm.get_next_state(self._fsm_state, last_token) + for seq_ids in input_ids: + prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1])) + prev_state = self._fsm_states[prev_state_key] - allowed_tokens = self.fsm.get_next_instruction(self._fsm_state).tokens - allowed_tokens = torch.tensor(allowed_tokens, device=logits.device) + curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :])) + curr_state = self.fsm.get_next_state(prev_state, seq_ids[-1]) + + self._fsm_states[curr_state_key] = curr_state + sequence_states.append(curr_state) mask = torch.full_like(logits, -math.inf) - mask[allowed_tokens] = logits[allowed_tokens] + for i, fsm_state in enumerate(sequence_states): + allowed_tokens = self.fsm.get_next_instruction(fsm_state).tokens + mask[i, allowed_tokens] = logits[i, allowed_tokens] + return mask def copy(self) -> "FSMLogitsProcessor": diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index 1f1a3aea2..111f8f93d 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -1,9 +1,11 @@ +import contextlib import re import pytest import outlines.generate as generate import outlines.models as models +import outlines.samplers as samplers @pytest.fixture(scope="session") @@ -20,35 +22,160 @@ def model_mlxlm(tmp_path_factory): @pytest.fixture(scope="session") -def model_transformers(tmp_path_factory): - return models.transformers("Locutusque/TinyMistral-248M-v2-Instruct", device="cpu") +def model_transformers_random(tmp_path_factory): + return models.transformers("hf-internal-testing/tiny-random-gpt2", device="cpu") -@pytest.mark.parametrize( - "model_fixture", - ("model_llamacpp", "model_mlxlm", "model_transformers"), +@pytest.fixture(scope="session") +def model_transformers_opt125m(tmp_path_factory): + return models.transformers("facebook/opt-125m", device="cpu") + + +ALL_MODEL_FIXTURES = ( + "model_llamacpp", + "model_mlxlm", + "model_transformers_random", + "model_transformers_opt125m", ) -def test_generate_text(request, model_fixture): + + +NOT_IMPLEMENTED = { + "batch": ["model_llamacpp"], + "stream": ["model_vllm"], + "beam_search": ["model_llamacpp"], + "multiple_samples": ["model_llamacpp"], +} + + +def enforce_not_implemented(model_fixture, *task_names): + """ + Per `NOT_IMPLEMENTED`, mapping, if a model hasn't implemented a task, + assert an NotImplementedError is raised. Otherwise, run normally + """ + for task_name in task_names: + if model_fixture in NOT_IMPLEMENTED.get(task_name, []): + return pytest.raises(NotImplementedError) + else: + return contextlib.nullcontext() + + +REGEX_PATTERNS = [ + "(123456789)|(abcdefghijklmnop)", + "abc*", + "\\+?[1-9][0-9]{7,14}", + r"([a-z]{10})@([a-z]{5})\.([a-z]{3})", +] + + +@pytest.mark.parametrize("sampler_name", ("greedy", "multinomial", "beam_search")) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_text(request, model_fixture, sampler_name): + model = request.getfixturevalue(model_fixture) + generator = generate.text(model, getattr(samplers, sampler_name)()) + with enforce_not_implemented(model_fixture, sampler_name): + res = generator("test", max_tokens=10) + assert isinstance(res, str) + + +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_batch_text(request, model_fixture): model = request.getfixturevalue(model_fixture) generator = generate.text(model) - res = generator("test", max_tokens=10) - assert isinstance(res, str) + with enforce_not_implemented(model_fixture, "batch"): + res = generator(["test", "test2"], max_tokens=10) + assert isinstance(res, list) + assert isinstance(res[0], str) -@pytest.mark.parametrize( - "model_fixture", - ("model_llamacpp", "model_mlxlm", "model_transformers"), -) -@pytest.mark.parametrize( - "pattern", - ( - "[0-9]", - "abc*", - "\\+?[1-9][0-9]{7,14}", - ), -) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_text_stream(request, model_fixture): + model = request.getfixturevalue(model_fixture) + generator = generate.text(model) + with enforce_not_implemented(model_fixture, "stream"): + for token in generator.stream("a b c ", max_tokens=10): + assert isinstance(token, str) + + +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) def test_generate_regex(request, model_fixture, pattern): model = request.getfixturevalue(model_fixture) generator = generate.regex(model, pattern) res = generator("foobarbaz", max_tokens=20) - assert re.match(pattern, res) is not None, res + assert re.fullmatch(pattern, res) is not None, res + + +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_regex_stream(request, model_fixture, pattern): + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern) + with enforce_not_implemented(model_fixture, "stream"): + output = "" + for token in generator.stream("output:", max_tokens=20): + output += token + assert re.fullmatch(pattern, output) is not None, output + + +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_regex_batch_stream(request, model_fixture, pattern): + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern) + with enforce_not_implemented(model_fixture, "batch", "stream"): + outputs = ["", ""] + for tokens in generator.stream(["input 0", "input 1"], max_tokens=20): + outputs[0] += tokens[0] + outputs[1] += tokens[1] + for output in outputs: + assert re.fullmatch(pattern, output) is not None, output + + +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_regex_batch(request, model_fixture, pattern): + """Ensure batch requests work and fsm order is maintained""" + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern) + with enforce_not_implemented(model_fixture, "batch"): + outputs = generator(["abc", "123", "123bce", "33aa"], max_tokens=20) + for output in outputs: + assert re.fullmatch(pattern, output) is not None, output + + +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_regex_single_multinomial(request, model_fixture, pattern): + """Ensure batch requests work and fsm order is maintained""" + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern, sampler=samplers.multinomial(4)) + with enforce_not_implemented(model_fixture, "multiple_samples"): + output_sample_groups = generator("single input", max_tokens=40) + for output in output_sample_groups: + assert re.fullmatch(pattern, output) is not None, output + + +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_regex_batch_multinomial(request, model_fixture, pattern): + """Ensure batch requests work and fsm order is maintained""" + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern, sampler=samplers.multinomial(4)) + with enforce_not_implemented(model_fixture, "batch", "multiple_samples"): + output_batch_groups = generator(["abc", "123", "123bce", "33aa"], max_tokens=40) + for output_sample_groups in output_batch_groups: + for output in output_sample_groups: + assert re.fullmatch(pattern, output) is not None, output + + +@pytest.mark.parametrize("pattern", REGEX_PATTERNS) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_regex_batch_beam_search(request, model_fixture, pattern): + """Ensure batch requests work and fsm order is maintained""" + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern, sampler=samplers.beam_search(4)) + with enforce_not_implemented(model_fixture, "batch", "multiple_samples"): + output_batch_groups = generator(["abc", "123", "123bce", "33aa"], max_tokens=40) + for output_sample_groups in output_batch_groups: + for output in output_sample_groups: + assert re.fullmatch(pattern, output) is not None, output