diff --git a/outlines/generate/fsm.py b/outlines/generate/fsm.py index 03fe512b9..c27c8bd2e 100644 --- a/outlines/generate/fsm.py +++ b/outlines/generate/fsm.py @@ -4,11 +4,10 @@ from outlines.fsm.guide import RegexGuide from outlines.generate.api import ( - SequenceGenerator, SequenceGeneratorAdapter, VisionSequenceGeneratorAdapter, ) -from outlines.models import ExLlamaV2Model, TransformersVision +from outlines.models import TransformersVision from outlines.samplers import Sampler, multinomial @@ -30,13 +29,3 @@ def fsm_vision(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial() fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer) logits_processor = FSMLogitsProcessor(tokenizer=model.tokenizer, fsm=fsm) return VisionSequenceGeneratorAdapter(model, logits_processor, sampler) - - -@fsm.register(ExLlamaV2Model) -def fsm_exllamav2( - model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial() -) -> SequenceGenerator: - fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer) - device = model.device - generator = SequenceGenerator(fsm, model, sampler, device) - return generator diff --git a/outlines/generate/regex.py b/outlines/generate/regex.py index 815a8b1b9..673880e49 100644 --- a/outlines/generate/regex.py +++ b/outlines/generate/regex.py @@ -1,12 +1,10 @@ from functools import singledispatch -from outlines.fsm.guide import RegexGuide from outlines.generate.api import ( - SequenceGenerator, SequenceGeneratorAdapter, VisionSequenceGeneratorAdapter, ) -from outlines.models import ExLlamaV2Model, OpenAI, TransformersVision +from outlines.models import OpenAI, TransformersVision from outlines.samplers import Sampler, multinomial @@ -49,20 +47,6 @@ def regex_vision( return VisionSequenceGeneratorAdapter(model, logits_processor, sampler) -@regex.register(ExLlamaV2Model) -def regex_exllamav2( - model, - regex_str: str, - sampler: Sampler = multinomial(), -) -> SequenceGenerator: - fsm = RegexGuide(regex_str, model.tokenizer) - - device = model.device - generator = SequenceGenerator(fsm, model, sampler, device) - - return generator - - @regex.register(OpenAI) def regex_openai( model: OpenAI, diff --git a/outlines/generate/text.py b/outlines/generate/text.py index 3fe3dc553..32530d0c4 100644 --- a/outlines/generate/text.py +++ b/outlines/generate/text.py @@ -1,12 +1,10 @@ from functools import singledispatch -from outlines.fsm.guide import StopAtEOSGuide from outlines.generate.api import ( - SequenceGenerator, SequenceGeneratorAdapter, VisionSequenceGeneratorAdapter, ) -from outlines.models import ExLlamaV2Model, OpenAI, TransformersVision +from outlines.models import OpenAI, TransformersVision from outlines.samplers import Sampler, multinomial @@ -36,13 +34,6 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGeneratorAdapter: return SequenceGeneratorAdapter(model, None, sampler) -@text.register(ExLlamaV2Model) -def text_exllamav2(model, sampler: Sampler = multinomial()) -> SequenceGenerator: - fsm = StopAtEOSGuide(model.tokenizer) - device = model.device - return SequenceGenerator(fsm, model, sampler, device) - - @text.register(TransformersVision) def text_vision(model, sampler: Sampler = multinomial()): return VisionSequenceGeneratorAdapter(model, None, sampler) diff --git a/outlines/models/exllamav2.py b/outlines/models/exllamav2.py index 0ec6ef033..f06b7e46e 100644 --- a/outlines/models/exllamav2.py +++ b/outlines/models/exllamav2.py @@ -1,12 +1,21 @@ -import os -from typing import TYPE_CHECKING, Optional +import dataclasses +from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, TypedDict, Union + +from typing_extensions import Unpack + +from outlines.generate.api import GenerationParameters, SamplingParameters if TYPE_CHECKING: - from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Lora - from transformers import PreTrainedTokenizer - import torch + from exllamav2 import ExLlamaV2Tokenizer + from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler + -from .transformers import TransformerTokenizer +class ExllamaV2Params(TypedDict, total=False): + max_tokens: int + stop_conditions: Optional[List[Union[int, str]]] + seed: Optional[int] + gen_settings: "ExLlamaV2Sampler.Settings" + max_new_tokens: List[int] class ExLlamaV2Model: @@ -14,108 +23,218 @@ class ExLlamaV2Model: def __init__( self, - model: "ExLlamaV2", - tokenizer: "PreTrainedTokenizer", - device, - cache: "ExLlamaV2Cache", - lora: Optional["ExLlamaV2Lora"] = None, + generator: "ExLlamaV2DynamicGenerator", + tokenizer: "ExLlamaV2Tokenizer", + max_seq_len: int, ): - self.device = device - self.model = model - self.tokenizer = TransformerTokenizer(tokenizer) - self.cache = cache - self.past_seq = None - self.lora = lora - - def forward(self, input_ids: "torch.LongTensor", *_): - """Compute a forward pass through the exl2 model.""" - import torch - - # Caching with past_seq - reset = True - seq_tensor = input_ids[0] - - if self.past_seq is not None: - min_length = min(self.past_seq.shape[0], seq_tensor.shape[0]) - indices = torch.nonzero( - ~torch.eq(self.past_seq[:min_length], seq_tensor[:min_length]) - ) - if len(indices) > 0: - longest_prefix = indices[0].item() - else: - longest_prefix = min_length - - if longest_prefix > 0: - reset = False - self.cache.current_seq_len = longest_prefix - if seq_tensor.shape[0] - longest_prefix > 1: - self.model.forward( - seq_tensor[longest_prefix:-1].view(1, -1), - self.cache, - preprocess_only=True, - loras=[self.lora], - ) - elif seq_tensor.shape[0] == longest_prefix: - self.cache.current_seq_len -= 1 - - if reset: - self.cache.current_seq_len = 0 - if seq_tensor.shape[0] > 1: - self.model.forward( - seq_tensor[:-1].view(1, -1), - self.cache, - preprocess_only=True, - loras=[self.lora], + self.generator = generator + self.tokenizer = tokenizer + self.max_seq_len = max_seq_len + + def prepare_generation_parameters( + self, + prompts: Union[str, List[str]], + generation_parameters: GenerationParameters, + sampling_parameters: SamplingParameters, + structure_logits_processor, + **exllamav2_params: Unpack[ExllamaV2Params], + ) -> Tuple[ExllamaV2Params, Union[str, List[str]]]: + """Prepare the generation parameters. + + `exllamav2` uses different default values + + """ + from exllamav2.generator import ExLlamaV2Sampler + + if isinstance(prompts, str): + prompts = [prompts] + max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) + + if max_tokens is None: + max_tokens = [] + for prompt in prompts: + ids = self.generator.tokenizer.encode( + prompt, encode_special_tokens=True ) + prompt_tokens = ids.shape[-1] + max_tokens.append(self.max_seq_len - prompt_tokens) + exllamav2_params["max_new_tokens"] = max_tokens + else: + exllamav2_params["max_new_tokens"] = [ + max_tokens for _ in range(len(prompts)) + ] - self.past_seq = seq_tensor + stop_conditions = [self.generator.tokenizer.eos_token_id] + if isinstance(generation_parameters.stop_at, str): + stop_conditions.append(generation_parameters.stop_at) + elif isinstance(generation_parameters.stop_at, list): + for stop_at in generation_parameters.stop_at: + stop_conditions.append(stop_at) + exllamav2_params["stop_conditions"] = stop_conditions + exllamav2_params["seed"] = seed - return self.model.forward( - seq_tensor[-1:].view(1, -1), self.cache, loras=[self.lora] - ) + gen_settings = ExLlamaV2Sampler.Settings() + if sampling_parameters.temperature is not None: + gen_settings.temperature = sampling_parameters.temperature + if sampling_parameters.top_p is not None: + gen_settings.top_p = sampling_parameters.top_p + if sampling_parameters.top_k is not None: + gen_settings.top_k = sampling_parameters.top_k + gen_settings.logits_processor = structure_logits_processor + exllamav2_params["gen_settings"] = gen_settings + if sampling_parameters.num_samples > 1: + prompts = prompts * sampling_parameters.num_samples + exllamav2_params["max_new_tokens"] = ( + exllamav2_params["max_new_tokens"] * sampling_parameters.num_samples + ) - def __call__(self, input_ids: "torch.LongTensor", *_) -> "torch.FloatTensor": - logits = self.forward(input_ids) - next_token_logits = logits[..., -1, :] + if len(prompts) == 1: + prompts = prompts[0] - return next_token_logits, None + return exllamav2_params, prompts - def update_lora(self, lora_path: Optional[str] = None): + def reformat_output( + self, output: Union[str, List[str]], sampling_parameters: SamplingParameters + ): """ - Update and apply the LoRA to the model. + The purpose of this function is to reformat the output from exllamav2's output format to outline's output format + For exllamav2, it mainly accepts only a list or a string(they also do cfg sampling with tuples but we will ignore this for now) + The exllamav2's logic is + 1. If the prompt is a string, return a string. This is the same as outlines + 2. If a prompt is a list, return a list. This is not the same as outlines output in that if the list is only one element, the string is expected to be outputted. + 3. There is no such thing as num_samples, so the prompts had to be duplicated by num_samples times. Then, we had the function output a list of lists + """ + if isinstance(output, str): + return output + if len(output) == 1: + return output[0] + if sampling_parameters.num_samples > 1: + if len(output) == sampling_parameters.num_samples: + return output + assert len(output) % sampling_parameters.num_samples == 0 + num_items_per_sample = len(output) // sampling_parameters.num_samples + new_output = [] + for i in range(sampling_parameters.num_samples): + curr_sample = [] + for j in range(num_items_per_sample): + curr_sample.append(output[i * num_items_per_sample + j]) + new_output.append(curr_sample) + return new_output + return output - Args: - lora_path (Optional[str]): The path to the LoRA directory. If None, the LoRA will be unloaded. + def generate( + self, + prompts: Union[str, List[str]], + generation_parameters: GenerationParameters, + structure_logits_processor, + sampling_parameters: SamplingParameters, + **exllamav2_params: Unpack[ExllamaV2Params], + ) -> Union[str, List[str]]: + exllamav2_params, prompts = self.prepare_generation_parameters( + prompts, + generation_parameters, + sampling_parameters, + structure_logits_processor, + ) """ - try: - from exllamav2 import ExLlamaV2Lora - except ImportError: - raise ImportError( - "The `exllamav2` library needs to be installed in order to use `exllamav2` models." + In exllamav2, it needs the max amount of new tokens generated. + The reason exllamav2_params["max_new_tokens"] is a list is because in prepare_generation_parameters + the max amount of tokens that can be generated by the model for each prompt(by encoding with tokenizer) is calculated. + The minimum is picked because otherwise it might be possible for one of the + prompts to exceed the max sequence length. + """ + output = self.generator.generate( + prompt=prompts, + gen_settings=exllamav2_params["gen_settings"], + max_new_tokens=min(exllamav2_params["max_new_tokens"]), + completion_only=True, + encode_special_tokens=True, + stop_conditions=exllamav2_params["stop_conditions"], + add_bos=False, + seed=exllamav2_params["seed"], + ) + + return self.reformat_output(output, sampling_parameters) + + def stream( + self, + prompts: Union[str, List[str]], + generation_parameters: GenerationParameters, + structure_logits_processor, + sampling_parameters: SamplingParameters, + **exllamav2_params: Unpack[ExllamaV2Params], + ) -> Iterator[Union[str, List[str]]]: + from exllamav2.generator import ExLlamaV2DynamicJob + + exllamav2_params, prompts = self.prepare_generation_parameters( + prompts, + generation_parameters, + sampling_parameters, + structure_logits_processor, + ) + + order = {} + if isinstance(prompts, str): + prompts = [prompts] + batch_size = len(prompts) + seed = exllamav2_params["seed"] + for idx, p in enumerate(prompts): + input_ids = self.generator.tokenizer.encode( + p, encode_special_tokens=True, add_bos=False ) - if lora_path is None: - if self.lora is not None: - print(" -- Unloading LoRA...") - self.lora = None - else: - self.lora = ExLlamaV2Lora.from_directory(self.model, lora_path) - print(" -- Loading LoRA...") + + job = ExLlamaV2DynamicJob( + input_ids=input_ids, + max_new_tokens=exllamav2_params["max_new_tokens"][idx], + min_new_tokens=0, + seed=seed, + stop_conditions=exllamav2_params["stop_conditions"], + gen_settings=exllamav2_params["gen_settings"], + token_healing=False, + decode_special_tokens=False, + ) + + if seed is not None: + seed += 1 + + serial = self.generator.enqueue(job) + order[serial] = idx + + # Collect outputs until all jobs finish + + next_text = [""] * batch_size + + def token_generator() -> Iterator[str]: + while self.generator.num_remaining_jobs(): + results = self.generator.iterate() + for r in results: + idx = order[r["serial"]] + if r["stage"] == "streaming": + text = r.get("text", "") + next_text[idx] = text + if r["eos"]: + next_text[idx] = "" + yield self.reformat_output(next_text, sampling_parameters) + return + + return token_generator() + + +# Taken from https://github.com/lapp0/exllamav2/pull/1/files#diff-26f303de07c10aad998e33d3df52581643673a598162cc4b35ef051f52d7c60b +def patch_tokenizer(tokenizer): + tokenizer.vocabulary = tokenizer.piece_to_id + tokenizer.special_tokens = set(tokenizer.extended_piece_to_id) + tokenizer.convert_token_to_string = lambda t: t + return tokenizer def exl2( model_path: str, - device: str, + draft_model_path: Optional[str] = None, max_seq_len: Optional[int] = None, - scale_pos_emb: Optional[float] = None, - scale_alpha_value: Optional[float] = None, - no_flash_attn: Optional[bool] = None, - num_experts_per_token: Optional[int] = None, - cache_8bit: bool = False, cache_q4: bool = False, - tokenizer_kwargs: dict = {}, - gpu_split: Optional[str] = None, - low_mem: Optional[bool] = None, - verbose: Optional[bool] = None, + paged: bool = True, + max_chunk_size: Optional[int] = None, ) -> ExLlamaV2Model: """ Load an ExLlamaV2 model. @@ -136,8 +255,6 @@ def exl2( Disable flash attention. Defaults to None. num_experts_per_token (Optional[int], optional) Number of experts per token. Defaults to None. - cache_8bit (bool, optional) - Use 8-bit cache. Defaults to False. cache_q4 (bool, optional) Use Q4 cache. Defaults to False. tokenizer_kwargs (dict, optional) @@ -162,71 +279,62 @@ def exl2( from exllamav2 import ( ExLlamaV2, ExLlamaV2Cache, - ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4, ExLlamaV2Config, + ExLlamaV2Tokenizer, ) - from transformers import AutoTokenizer + from exllamav2.generator import ExLlamaV2DynamicGenerator + except ImportError: raise ImportError( "The `exllamav2`, `transformers` and `torch` libraries needs to be installed in order to use `exllamav2` models." ) + config = ExLlamaV2Config(model_path) + if max_chunk_size is not None: + config.max_input_len = max_chunk_size + config.max_attention_size = max_chunk_size**2 - # Load tokenizer - if not verbose: - print(" -- Loading tokenizer...") - tokenizer_kwargs.setdefault("padding_side", "left") - tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs) - # tokenizer = TransformerTokenizer(model_path, **tokenizer_kwargs) - - # Check fasttensors for config - if os.name != "nt": - use_fasttensors = True - else: - use_fasttensors = False - - # Create config - config = ExLlamaV2Config() - config.model_dir = model_path - config.fasttensors = use_fasttensors - config.prepare() - - # Set config options - if max_seq_len is not None: - config.max_seq_len = max_seq_len - if scale_pos_emb is not None: - config.scale_pos_emb = scale_pos_emb - if scale_alpha_value is not None: - config.scale_alpha_value = scale_alpha_value - if no_flash_attn is not None: - config.no_flash_attn = no_flash_attn - if num_experts_per_token is not None: - config.num_experts_per_token = num_experts_per_token - if low_mem: - config.set_low_mem() - - # Prepare the model from the config + config.arch_compat_overrides() model = ExLlamaV2(config) - - # Create cache - if cache_8bit: - cache = ExLlamaV2Cache_8bit(model, lazy=not model.loaded) - elif cache_q4: - cache = ExLlamaV2Cache_Q4(model, lazy=not model.loaded) + if max_seq_len is None: + max_seq_len = -1 + if cache_q4: + cache = ExLlamaV2Cache_Q4(model, max_seq_len=max_seq_len, lazy=True) else: - cache = ExLlamaV2Cache(model, lazy=not model.loaded) - - # Load the model - split = None - if gpu_split and gpu_split != "auto": - split = [float(alloc) for alloc in gpu_split.split(",")] - if not verbose: - print(" -- Loading model...") - model.load(split) - - # Autoload if no GPU split was provided - if not model.loaded: - print(" -- Loading model...") - model.load_autosplit(cache) - - return ExLlamaV2Model(model, tokenizer, device, cache) + cache = ExLlamaV2Cache(model, max_seq_len=max_seq_len, lazy=True) + model.load_autosplit(cache, progress=True) + + print("Loading tokenizer...") + tokenizer = ExLlamaV2Tokenizer(config) + tokenizer = patch_tokenizer(tokenizer) + max_batch_size = 4 if paged else 1 + + draft_model = None + draft_cache = None + if draft_model_path is not None: + draft_config = ExLlamaV2Config(draft_model_path) + draft_model = ExLlamaV2(draft_config) + + if cache_q4: + draft_cache = ExLlamaV2Cache_Q4( + draft_model, max_seq_len=max_seq_len, lazy=True + ) + else: + draft_cache = ExLlamaV2Cache( + draft_model, max_seq_len=max_seq_len, lazy=True + ) + + # Initialize the generator with all default parameters + generator = ExLlamaV2DynamicGenerator( + model=model, + cache=cache, + draft_model=draft_model, + draft_cache=draft_cache, + tokenizer=tokenizer, + max_batch_size=max_batch_size, + use_ngram_draft=False, + max_chunk_size=max_chunk_size, + paged=paged, + ) + max_seq_len = cache.max_seq_len + return ExLlamaV2Model(generator, tokenizer, max_seq_len) diff --git a/pyproject.toml b/pyproject.toml index 99d4f94e1..82b01c4f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ test = [ "torch", "transformers", "pillow", + "exllamav2", ] serve = [ "vllm>=0.3.0", diff --git a/tests/generate/conftest.py b/tests/generate/conftest.py index ed8830119..abd9c72a4 100644 --- a/tests/generate/conftest.py +++ b/tests/generate/conftest.py @@ -27,9 +27,11 @@ def pytest_collection_modifyitems(config, items): for item in items: if "model_fixture" in item.fixturenames: model_param = item.callspec.params.get("model_fixture", None) - if model_param.startswith( - "model_transformers_vision" - ) or model_param.startswith("model_vllm"): + if ( + model_param.startswith("model_transformers_vision") + or model_param.startswith("model_vllm") + or model_param.startswith("model_exllamav2") + ): item.add_marker(skip_marker) if not is_metal_available(): diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index 8d5daa37e..19466737a 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -20,6 +20,15 @@ def model_llamacpp(tmp_path_factory): ) +@pytest.fixture(scope="session") +def model_exllamav2(tmp_path_factory): + return models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + cache_q4=True, + paged=False, + ) + + @pytest.fixture(scope="session") def model_mlxlm(tmp_path_factory): return models.mlxlm("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit") @@ -98,6 +107,7 @@ def model_t5(tmp_path_factory): ALL_MODEL_FIXTURES = ( "model_llamacpp", + "model_exllamav2", "model_mlxlm", "model_mlxlm_phi3", "model_transformers_random", diff --git a/tests/generate/test_integration_exllamav2.py b/tests/generate/test_integration_exllamav2.py new file mode 100644 index 000000000..12c4143b3 --- /dev/null +++ b/tests/generate/test_integration_exllamav2.py @@ -0,0 +1,363 @@ +import importlib +from unittest.mock import patch + +import pytest + +import outlines.models as models +from outlines.generate.api import GenerationParameters, SamplingParameters +from outlines.models.exllamav2 import ExLlamaV2Model + + +@pytest.fixture(scope="session") +def model_exllamav2(tmp_path_factory): + return models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + cache_q4=True, + paged=False, + ) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_exl2_import_error(request, model_fixture): + with patch.dict("sys.modules", {"exllamav2": None}): + with pytest.raises(ImportError): + models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + cache_q4=True, + paged=False, + ) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_attributes(request, model_fixture): + model = request.getfixturevalue(model_fixture) + assert hasattr(model, "generator") + assert hasattr(model, "tokenizer") + assert model.tokenizer.convert_token_to_string(1) == 1 + assert hasattr(model, "max_seq_len") + assert isinstance(model.max_seq_len, int) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_generate_prompt_types(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, stop_at=None, seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 1, + 0.9, + 50, + 1.0, + ) + output = model.generate( + prompt, generation_params, structure_logits_processor, sampling_params + ) + assert isinstance(output, str) + prompt = ["test"] + output = model.generate( + prompt, generation_params, structure_logits_processor, sampling_params + ) + assert isinstance(output, str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_generate_no_max_tokens(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=None, stop_at=None, seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 1, + 0.9, + 50, + 1.0, + ) + output = model.generate( + prompt, generation_params, structure_logits_processor, sampling_params + ) + assert isinstance(output, str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_generate_test_stop_at(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, stop_at="stop", seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 1, + 0.9, + 50, + 1.0, + ) + output = model.generate( + prompt, generation_params, structure_logits_processor, sampling_params + ) + assert isinstance(output, str) + generation_params = GenerationParameters(max_tokens=10, stop_at=["stop"], seed=None) + output = model.generate( + prompt, generation_params, structure_logits_processor, sampling_params + ) + assert isinstance(output, str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_generate_multisampling(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, stop_at="stop", seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 2, + ) + output = model.generate( + prompt, generation_params, structure_logits_processor, sampling_params + ) + assert isinstance(output, list) + assert isinstance(output[0], str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_prepare_generation_parameters(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, stop_at="stop", seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 2, + ) + exllamav2_params, prompts = model.prepare_generation_parameters( + prompt, generation_params, sampling_params, structure_logits_processor + ) + assert isinstance(exllamav2_params, dict) + assert isinstance(prompts, list) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_stream_prompt_types(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, stop_at=None, seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 1, + 0.9, + 50, + 1.0, + ) + generator = model.stream( + prompt, generation_params, structure_logits_processor, sampling_params + ) + for token in generator: + assert isinstance(token, str) + prompt = ["test"] + generator = model.stream( + prompt, generation_params, structure_logits_processor, sampling_params + ) + for token in generator: + assert isinstance(token, str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_stream_no_max_tokens(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=None, stop_at=None, seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 1, + 0.9, + 50, + 1.0, + ) + generator = model.stream( + prompt, generation_params, structure_logits_processor, sampling_params + ) + for token in generator: + assert isinstance(token, str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_stream_test_stop_at(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, stop_at="stop", seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 1, + 0.9, + 50, + 1.0, + ) + generator = model.stream( + prompt, generation_params, structure_logits_processor, sampling_params + ) + for token in generator: + assert isinstance(token, str) + generation_params = GenerationParameters(max_tokens=10, stop_at=["stop"], seed=None) + generator = model.stream( + prompt, generation_params, structure_logits_processor, sampling_params + ) + for token in generator: + assert isinstance(token, str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_stream_multisampling(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, stop_at="stop", seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 2, + ) + generator = model.stream( + prompt, generation_params, structure_logits_processor, sampling_params + ) + for token in generator: + assert isinstance(token, list) + assert isinstance(token[0], str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_stream_seed(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, seed=1, stop_at=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 1, + 0.9, + 50, + 1.0, + ) + generator = model.stream( + prompt, generation_params, structure_logits_processor, sampling_params + ) + for token in generator: + assert isinstance(token, str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_reformat_output(request, model_fixture): + model = request.getfixturevalue(model_fixture) + sampling_params = SamplingParameters( + "multinomial", + 1, + ) + output = "test" + reformatted_output = model.reformat_output(output, sampling_params) + assert reformatted_output == output + output = ["test"] + reformatted_output = model.reformat_output(output, sampling_params) + assert reformatted_output == output[0] + output = ["test", "test"] + sampling_params = SamplingParameters( + "multinomial", + 1, + ) + reformatted_output = model.reformat_output(output, sampling_params) + assert len(reformatted_output) == 2 + assert reformatted_output[0] == "test" + assert reformatted_output[1] == "test" + output = ["test", "test"] + sampling_params = SamplingParameters( + "multinomial", + 2, + ) + reformatted_output = model.reformat_output(output, sampling_params) + assert len(reformatted_output) == 2 + assert reformatted_output[0] == "test" + assert reformatted_output[1] == "test" + output = ["test", "test", "test", "test"] + sampling_params = SamplingParameters( + "multinomial", + 2, + ) + reformatted_output = model.reformat_output(output, sampling_params) + assert len(reformatted_output) == 2 + assert reformatted_output[0] == ["test", "test"] + assert reformatted_output[1] == ["test", "test"] + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_exl2_max_chunk_size(request, model_fixture): + model = models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + cache_q4=True, + paged=False, + max_chunk_size=128, + ) + assert isinstance(model, ExLlamaV2Model) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_exl2_cache_default(request, model_fixture): + model = models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + paged=False, + ) + assert isinstance(model, ExLlamaV2Model) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def is_flash_attn_available(): + try: + importlib.import_module("flash_attn") + except (ImportError, AssertionError): + return False + return True + + +@pytest.mark.skipif(not is_flash_attn_available(), reason="flash-attn is not installed") +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_exl2_paged(request, model_fixture): + model = models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + cache_q4=True, + paged=True, + ) + assert isinstance(model, ExLlamaV2Model) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_exl2_draft_model(request, model_fixture): + model = models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + draft_model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + cache_q4=True, + paged=False, + ) + assert isinstance(model, ExLlamaV2Model) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_exl2_draft_model_cache_default(request, model_fixture): + model = models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + draft_model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + paged=False, + ) + assert isinstance(model, ExLlamaV2Model) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_exl2_set_max_seq_len(request, model_fixture): + model = models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + max_seq_len=2048, + paged=False, + cache_q4=True, + ) + assert isinstance(model, ExLlamaV2Model)