From a62ff007299a1634ff2ff0ecd1f6aa250d9a6b55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 4 Mar 2024 15:41:28 +0100 Subject: [PATCH] Simplify the `transformers` and `llamacpp` interfaces --- docs/reference/models/llamacpp.md | 37 +++++++++++++++++-- docs/reference/models/transformers.md | 29 +++++++++++++++ outlines/models/__init__.py | 4 +- outlines/models/exllamav2.py | 12 ++++-- outlines/models/llamacpp.py | 32 ++++++++++------ outlines/models/mamba.py | 7 +++- outlines/models/transformers.py | 34 ++++++----------- tests/benchmark/conftest.py | 4 +- tests/fsm/test_regex.py | 7 +++- .../generate/test_integration_transformers.py | 12 +++--- tests/models/test_transformers.py | 16 +++++--- 11 files changed, 134 insertions(+), 60 deletions(-) create mode 100644 docs/reference/models/transformers.md diff --git a/docs/reference/models/llamacpp.md b/docs/reference/models/llamacpp.md index c6a3838aa..4a82584a7 100644 --- a/docs/reference/models/llamacpp.md +++ b/docs/reference/models/llamacpp.md @@ -4,12 +4,41 @@ You need to install the `llama-cpp-python` library to be able to use these models in Outlines. -Outlines provides an integration with [Llama.cpp](https://github.com/ggerganov/llama.cpp) using the [llama-cpp-python library](https://github.com/abetlen/llama-cpp-python). Llamacpp allows to run quantized models on machines with limited compute. +Outlines provides an integration with [Llama.cpp](https://github.com/ggerganov/llama.cpp) using the [llama-cpp-python library][llamacpp]. Llamacpp allows to run quantized models on machines with limited compute. -Assuming [Phi2's weights](https://huggingface.co/TheBloke/phi-2-GGUF) are in the current directory: +You can initialize the model by pasing the path to the weights on your machine. Assuming [Phi2's weights](https://huggingface.co/TheBloke/phi-2-GGUF) are in the current directory: ```python -from outlines import models, generate +from outlines import models -model = models.llamacpp("./phi-2.Q4_K_M.gguf") +model = models.llamacpp("./phi-2.Q4_K_M.gguf", device="cuda") ``` + +If you need more control, you can pass the same keyword arguments to the model as you would pass in the [llama-ccp-library][llamacpp]: + +```python +from outlines import models + +model = models.llamacpp( + "./phi-2.Q4_K_M.gguf", + n_gpu_layers=-1, # to use GPU acceleration + seed=1337, # to set a specific seed +) +``` + +Please see the [llama-cpp-python documentation](https://llama-cpp-python.readthedocs.io/) for a list of available keyword arguments. Finally, if for some reason you would like to initialize `llama_cpp.Llama` separately, you can convert it to an Outlines model using: + +```python +from llama_cpp import Llama +from outlines import models + +llm = Llama.from_pretrained( + repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", + filename="*q8_0.gguf", + verbose=False +) +model = models.LlamaCpp(llm) +``` + + +[llamacpp]: https://github.com/abetlen/llama-cpp-python diff --git a/docs/reference/models/transformers.md b/docs/reference/models/transformers.md new file mode 100644 index 000000000..2d9880a6b --- /dev/null +++ b/docs/reference/models/transformers.md @@ -0,0 +1,29 @@ +# transformers + + +!!! Installation + + You need to install the `transformer` and `datasets` libraries to be able to use these models in Outlines. + + +Outlines provides an integration with the `torch` implementation of causal models in the [transformers][transformers] library. You can initialize the model by passing its name: + +```python +from outlines import models + +model = models.transformers("mistralai/Mistral-7B-v0.1", device="cuda") +``` + +If you need more fine-grained control you can also initialize the model and tokenizer separately: + + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +from outlines import models + +llm = AutoModelForCausalLM.from_pretrained("gpt2", output_attentions=True) +tokenizer = AutoTokenizer.from_pretrained("gpt2") +model = models.Transformers(llm, tokenizer) +``` + +[transformers]: https://github.com/huggingface/transformers diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index d8282c28e..ca3335d08 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -13,6 +13,6 @@ from .mamba import Mamba, mamba from .openai import OpenAI, openai from .openai_compatible import OpenAICompatibleAPI, openai_compatible_api -from .transformers import Transformer, transformers +from .transformers import Transformers, transformers -LogitsGenerator = Union[Transformer, LlamaCpp, ExLlamaV2Model, Mamba] +LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, Mamba] diff --git a/outlines/models/exllamav2.py b/outlines/models/exllamav2.py index cf8b40c3d..b06e5e60a 100644 --- a/outlines/models/exllamav2.py +++ b/outlines/models/exllamav2.py @@ -21,7 +21,7 @@ def __init__( ): self.device = device self.model = model - self.tokenizer = tokenizer + self.tokenizer = TransformerTokenizer(tokenizer) self.cache = cache self.past_seq = None @@ -75,20 +75,21 @@ def __call__(self, input_ids: torch.LongTensor, *_) -> torch.FloatTensor: def exl2( - model_name: str, + model_path: str, device: Optional[str] = None, model_kwargs: dict = {}, tokenizer_kwargs: dict = {}, ): try: from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config + from transformers import AutoTokenizer except ImportError: raise ImportError( "The `exllamav2` library needs to be installed in order to use `exllamav2` models." ) config = ExLlamaV2Config() - config.model_dir = model_name + config.model_dir = model_path config.prepare() config.max_seq_len = model_kwargs.pop("max_seq_len", config.max_seq_len) @@ -108,7 +109,10 @@ def exl2( split = [float(alloc) for alloc in model_kwargs["gpu_split"].split(",")] model.load(split) - tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs) + + tokenizer_kwargs.setdefault("padding_side", "left") + tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs) + cache = ExLlamaV2Cache(model) return ExLlamaV2Model(model, tokenizer, device, cache) diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 24e6cb6bc..09146af6c 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -1,5 +1,5 @@ import math -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union import numpy as np import torch @@ -7,6 +7,9 @@ from outlines.fsm.fsm import CFGFSM, FSM, FSMState, RegexFSM +if TYPE_CHECKING: + from llama_cpp import Llama + class LlamaSequenceGenerator: def __init__( @@ -87,22 +90,20 @@ def stream( class LlamaCpp: """Represents a `llama_cpp` model.""" - def __init__(self, model_path, **kwargs): - from llama_cpp import Llama - - self.model = Llama(model_path, **kwargs) - self.tokenizer = LlamaCppTokenizer(self) + def __init__(self, model: "Llama", **kwargs): + self.model = model + self.tokenizer = LlamaCppTokenizer(model) class LlamaCppTokenizer: def __init__(self, model, **kwargs): - self.eos_token_id = model.model.token_eos() + self.eos_token_id = model.token_eos() self.pad_token_id = self.eos_token_id self.special_tokens = {} self.vocabulary = {} - for t in range(model.model.n_vocab()): - token_piece = model.model.tokenizer().decode([t]) + for t in range(model.n_vocab()): + token_piece = model.tokenizer().decode([t]) self.vocabulary[token_piece] = t def convert_token_to_string(self, token: str) -> str: @@ -110,11 +111,18 @@ def convert_token_to_string(self, token: str) -> str: def llamacpp( - model_name: str, + model_path: str, device: Optional[str] = None, - model_kwargs: dict = {}, + **model_kwargs, ): - return LlamaCpp(model_name, **model_kwargs) + from llama_cpp import Llama + + if device == "cuda": + model_kwargs["n_gpu_layers"].setdefault(-1) + + model = Llama(model_path, **model_kwargs) + + return LlamaCpp(model) class LogitsProcessor: diff --git a/outlines/models/mamba.py b/outlines/models/mamba.py index ea0fcc15f..1375a3811 100644 --- a/outlines/models/mamba.py +++ b/outlines/models/mamba.py @@ -20,7 +20,7 @@ def __init__( ): self.device = device self.model = model - self.tokenizer = tokenizer + self.tokenizer = TransformerTokenizer(tokenizer) def forward(self, input_ids: torch.LongTensor, *_): """Compute a forward pass through the mamba model.""" @@ -41,6 +41,7 @@ def mamba( ): try: from mamba_ssm import MambaLMHeadModel + from transformers import AutoTokenizer except ImportError: raise ImportError( "The `mamba_ssm` library needs to be installed in order to use Mamba people." @@ -53,6 +54,8 @@ def mamba( device = "cuda" model = MambaLMHeadModel.from_pretrained(model_name, device=device) - tokenizer = TransformerTokenizer(TOKENIZER_MODEL, **tokenizer_kwargs) + + tokenizer_kwargs.setdefault("padding_side", "left") + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL, **tokenizer_kwargs) return Mamba(model, tokenizer, device) diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index da02642d2..1b29ee2f4 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -5,7 +5,7 @@ from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: - from transformers import PreTrainedModel, PreTrainedTokenizerBase + from transformers import PreTrainedModel, PreTrainedTokenizer __all__ = ["transformers"] @@ -58,22 +58,8 @@ class CodeLlamaTokenizerFast: # type: ignore class TransformerTokenizer(Tokenizer): """Represents a tokenizer for models in the `transformers` library.""" - def __init__( - self, tokenizer_or_model_name: Union["PreTrainedTokenizerBase", str], **kwargs - ): - if isinstance(tokenizer_or_model_name, str): - from transformers import AutoTokenizer - - kwargs.setdefault("padding_side", "left") - self.model_name = tokenizer_or_model_name - # TODO: Do something to make this hashable? - self.kwargs = kwargs - self.tokenizer = AutoTokenizer.from_pretrained( - tokenizer_or_model_name, **kwargs - ) - else: - self.tokenizer = tokenizer_or_model_name - + def __init__(self, tokenizer: "PreTrainedTokenizer", **kwargs): + self.tokenizer = tokenizer self.eos_token_id = self.tokenizer.eos_token_id self.eos_token = self.tokenizer.eos_token @@ -129,17 +115,17 @@ def __hash__(self): return hash(Hasher.hash(self.tokenizer)) -class Transformer: +class Transformers: """Represents a `transformers` model.""" def __init__( self, model: "PreTrainedModel", - tokenizer: TransformerTokenizer, + tokenizer: "PreTrainedTokenizer", ): self.device = model.device self.model = model - self.tokenizer = tokenizer + self.tokenizer = TransformerTokenizer(tokenizer) @torch.inference_mode def forward( @@ -221,7 +207,7 @@ def transformers( """ try: - from transformers import AutoModelForCausalLM + from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError: raise ImportError( "The `transformers` library needs to be installed in order to use `transformers` models." @@ -231,6 +217,8 @@ def transformers( model_kwargs["device_map"] = device model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) - tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs) - return Transformer(model, tokenizer) + tokenizer_kwargs.setdefault("padding_side", "left") + tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) + + return Transformers(model, tokenizer) diff --git a/tests/benchmark/conftest.py b/tests/benchmark/conftest.py index 67f673007..edf2ff614 100644 --- a/tests/benchmark/conftest.py +++ b/tests/benchmark/conftest.py @@ -1,4 +1,5 @@ import pytest +from transformers import AutoTokenizer from outlines.fsm.fsm import RegexFSM from outlines.models.transformers import TransformerTokenizer @@ -6,7 +7,8 @@ @pytest.fixture def tokenizer(): - return TransformerTokenizer("gpt2") + tokenizer = AutoTokenizer.from_pretrained("gpt2") + return TransformerTokenizer(tokenizer) @pytest.fixture diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index 2d2dbd993..cb3846c5b 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -1,6 +1,7 @@ import interegular import numba import pytest +from transformers import AutoTokenizer from outlines.fsm.regex import ( _walk_fsm, @@ -272,7 +273,8 @@ def test_create_fsm_index_tokenizer(): num_fsm_states = len(regex_fsm.states) assert num_fsm_states == 220 - tokenizer = TransformerTokenizer("gpt2") + tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer = TransformerTokenizer(tokenizer) states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer( regex_fsm, tokenizer @@ -295,7 +297,8 @@ def test_regex_index_performance(): num_fsm_states = len(regex_fsm.states) assert num_fsm_states == 220 - tokenizer = TransformerTokenizer("gpt2") + tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer = TransformerTokenizer(tokenizer) # Pre-compile Numba functions res, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer) diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index 32bef74dd..38525a076 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -10,7 +10,7 @@ import outlines.generate as generate import outlines.models as models from outlines.fsm.regex import reduced_vocabulary -from outlines.models.transformers import Transformer, TransformerTokenizer +from outlines.models.transformers import Transformers, TransformerTokenizer from outlines.samplers import beam_search, multinomial @@ -567,8 +567,11 @@ def test_transformers_json_custom_ws(): def test_transformers_reduced_vocabulary_caching(): - tokenizer = TransformerTokenizer("gpt2") - tokenizer2 = TransformerTokenizer("gpt2") + from transformers import AutoTokenizer + + hf_tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer = TransformerTokenizer(hf_tokenizer) + tokenizer2 = TransformerTokenizer(hf_tokenizer) # TODO: We might actually want only one copy of a given tokenizer. assert tokenizer is not tokenizer2 @@ -626,7 +629,6 @@ def test_transformers_use_existing_model_and_tokenizer(): model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" hf_tokenizer = AutoTokenizer.from_pretrained(model_name) hf_model = AutoModelForCausalLM.from_pretrained(model_name) - tokenizer = TransformerTokenizer(hf_tokenizer) - model = Transformer(hf_model, tokenizer) + model = Transformers(hf_model, hf_tokenizer) sequence = generate.text(model)("Write a short sentence ", rng=rng) assert isinstance(sequence, str) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 2687d6a82..b4e410096 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -1,5 +1,6 @@ import pytest import torch +from transformers import AutoTokenizer from transformers.models.gpt2 import GPT2TokenizerFast from outlines.models.transformers import TransformerTokenizer, transformers @@ -8,7 +9,8 @@ def test_tokenizer(): - tokenizer = TransformerTokenizer(TEST_MODEL) + tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL, padding_side="left") + tokenizer = TransformerTokenizer(tokenizer) assert tokenizer.eos_token_id == 0 assert tokenizer.pad_token_id == 0 assert isinstance(tokenizer.tokenizer, GPT2TokenizerFast) @@ -37,15 +39,17 @@ def test_tokenizer(): isinstance(text[0], str) isinstance(text[1], str) - tokenizer = TransformerTokenizer( + tokenizer = AutoTokenizer.from_pretrained( TEST_MODEL, additional_special_tokens=["", ""] ) + tokenizer = TransformerTokenizer(tokenizer) assert "" in tokenizer.special_tokens assert "" in tokenizer.special_tokens def test_llama_tokenizer(): - tokenizer = TransformerTokenizer("hf-internal-testing/llama-tokenizer") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + tokenizer = TransformerTokenizer(tokenizer) # Broken assert tokenizer.tokenizer.convert_tokens_to_string(["▁baz"]) == "baz" @@ -100,7 +104,9 @@ def test_model(): def test_tokenizer_eq_hash(): - tokenizer = TransformerTokenizer("gpt2") - tokenizer2 = TransformerTokenizer("gpt2") + tokenizer_hf = AutoTokenizer.from_pretrained("gpt2") + + tokenizer = TransformerTokenizer(tokenizer_hf) + tokenizer2 = TransformerTokenizer(tokenizer_hf) assert tokenizer == tokenizer2 assert hash(tokenizer) == hash(tokenizer2)