Skip to content

Commit

Permalink
Fixed per review
Browse files Browse the repository at this point in the history
  • Loading branch information
isamu-isozaki committed Sep 20, 2024
1 parent e6b3af6 commit 5508c92
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
43 changes: 31 additions & 12 deletions outlines/models/exllamav2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import dataclasses
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, TypedDict, Union
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, TypedDict, Union

from typing_extensions import Unpack

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.models.transformers import TransformerTokenizer

if TYPE_CHECKING:
from exllamav2 import ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler


Expand All @@ -24,7 +24,7 @@ class ExLlamaV2Model:
def __init__(
self,
generator: "ExLlamaV2DynamicGenerator",
tokenizer: TransformerTokenizer,
tokenizer: "ExLlamaV2Tokenizer",
max_seq_len: int,
):
self.generator = generator
Expand Down Expand Up @@ -54,7 +54,7 @@ def prepare_generation_parameters(
max_tokens = []
for prompt in prompts:
ids = self.generator.tokenizer.encode(
prompt, encode_special_tokens=False
prompt, encode_special_tokens=True
)
prompt_tokens = ids.shape[-1]
max_tokens.append(self.max_seq_len - prompt_tokens)
Expand Down Expand Up @@ -96,6 +96,14 @@ def prepare_generation_parameters(
def reformat_output(
self, output: Union[str, List[str]], sampling_parameters: SamplingParameters
):
"""
The purpose of this function is to reformat the output from exllamav2's output format to outline's output format
For exllamav2, it mainly accepts only a list or a string(they also do cfg sampling with tuples but we will ignore this for now)
The exllamav2's logic is
1. If the prompt is a string, return a string. This is the same as outlines
2. If a prompt is a list, return a list. This is not the same as outlines output in that if the list is only one element, the string is expected to be outputted.
3. There is no such thing as num_samples, so the prompts had to be duplicated by num_samples times. Then, we had the function output a list of lists
"""
if isinstance(output, str):
return output
if len(output) == 1:
Expand Down Expand Up @@ -128,12 +136,19 @@ def generate(
sampling_parameters,
structure_logits_processor,
)
"""
In exllamav2, it needs the max amount of new tokens generated.
The reason exllamav2_params["max_new_tokens"] is a list is because in prepare_generation_parameters
the max amount of tokens that can be generated by the model for each prompt(by encoding with tokenizer) is calculated.
The minimum is picked because otherwise it might be possible for one of the
prompts to exceed the max sequence length.
"""
output = self.generator.generate(
prompt=prompts,
gen_settings=exllamav2_params["gen_settings"],
max_new_tokens=min(exllamav2_params["max_new_tokens"]),
completion_only=True,
encode_special_tokens=False,
encode_special_tokens=True,
stop_conditions=exllamav2_params["stop_conditions"],
add_bos=False,
seed=exllamav2_params["seed"],
Expand Down Expand Up @@ -165,7 +180,7 @@ def stream(
seed = exllamav2_params["seed"]
for idx, p in enumerate(prompts):
input_ids = self.generator.tokenizer.encode(
p, encode_special_tokens=False, add_bos=False
p, encode_special_tokens=True, add_bos=False
)

job = ExLlamaV2DynamicJob(
Expand Down Expand Up @@ -205,6 +220,14 @@ def token_generator() -> Iterator[str]:
return token_generator()


# Taken from https://github.com/lapp0/exllamav2/pull/1/files#diff-26f303de07c10aad998e33d3df52581643673a598162cc4b35ef051f52d7c60b
def patch_tokenizer(tokenizer):
tokenizer.vocabulary = tokenizer.piece_to_id
tokenizer.special_tokens = set(tokenizer.extended_piece_to_id)
tokenizer.convert_token_to_string = lambda t: t
return tokenizer


def exl2(
model_path: str,
draft_model_path: Optional[str] = None,
Expand Down Expand Up @@ -261,7 +284,6 @@ def exl2(
ExLlamaV2Tokenizer,
)
from exllamav2.generator import ExLlamaV2DynamicGenerator
from transformers import AutoTokenizer

except ImportError:
raise ImportError(
Expand All @@ -284,7 +306,7 @@ def exl2(

print("Loading tokenizer...")
tokenizer = ExLlamaV2Tokenizer(config)
tokenizer.vocabulary = tokenizer.extended_piece_to_id
tokenizer = patch_tokenizer(tokenizer)
max_batch_size = 4 if paged else 1

draft_model = None
Expand Down Expand Up @@ -314,8 +336,5 @@ def exl2(
max_chunk_size=max_chunk_size,
paged=paged,
)
hf_tokenizer_kwargs: dict[str, Any] = {}
hf_tokenizer_kwargs.setdefault("padding_side", "left")
hf_tokenizer = AutoTokenizer.from_pretrained(model_path, **hf_tokenizer_kwargs)
max_seq_len = cache.max_seq_len
return ExLlamaV2Model(generator, TransformerTokenizer(hf_tokenizer), max_seq_len)
return ExLlamaV2Model(generator, tokenizer, max_seq_len)
2 changes: 0 additions & 2 deletions tests/generate/test_integration_exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import outlines.models as models
from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.models.exllamav2 import ExLlamaV2Model
from outlines.models.transformers import TransformerTokenizer


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -34,7 +33,6 @@ def test_model_attributes(request, model_fixture):
model = request.getfixturevalue(model_fixture)
assert hasattr(model, "generator")
assert hasattr(model, "tokenizer")
assert isinstance(model.tokenizer, TransformerTokenizer)
assert hasattr(model, "max_seq_len")
assert isinstance(model.max_seq_len, int)

Expand Down

0 comments on commit 5508c92

Please sign in to comment.