diff --git a/outlines/models/exllamav2.py b/outlines/models/exllamav2.py index b06e5e60a..22a52f8cb 100644 --- a/outlines/models/exllamav2.py +++ b/outlines/models/exllamav2.py @@ -1,3 +1,4 @@ +import os from typing import TYPE_CHECKING, Optional import torch @@ -76,43 +77,117 @@ def __call__(self, input_ids: torch.LongTensor, *_) -> torch.FloatTensor: def exl2( model_path: str, - device: Optional[str] = None, - model_kwargs: dict = {}, + device: Optional[torch.device] = None, + max_seq_len: Optional[int] = None, + scale_pos_emb: Optional[float] = None, + scale_alpha_value: Optional[float] = None, + no_flash_attn: Optional[bool] = None, + num_experts_per_token: Optional[int] = None, + cache_8bit: bool = False, + cache_q4: bool = False, tokenizer_kwargs: dict = {}, -): + gpu_split: Optional[str] = None, + low_mem: Optional[bool] = None, + verbose: Optional[bool] = None, +) -> ExLlamaV2Model: + """ + Load an ExLlamaV2 model. + + Args: + model_path (str): Path to the model directory. + device (Optional[torch.device], optional): Device to load the model on. Defaults to None. + max_seq_len (Optional[int], optional): Maximum sequence length. Defaults to None. + scale_pos_emb (Optional[float], optional): Scale factor for positional embeddings. Defaults to None. + scale_alpha_value (Optional[float], optional): Scale alpha value. Defaults to None. + no_flash_attn (Optional[bool], optional): Disable flash attention. Defaults to None. + num_experts_per_token (Optional[int], optional): Number of experts per token. Defaults to None. + cache_8bit (bool, optional): Use 8-bit cache. Defaults to False. + cache_q4 (bool, optional): Use Q4 cache. Defaults to False. + tokenizer_kwargs (dict, optional): Additional keyword arguments for the tokenizer. Defaults to {}. + gpu_split (str): \"auto\", or VRAM allocation per GPU in GB. Auto will use exllama's autosplit feature + low_mem (bool, optional): Enable VRAM optimizations, potentially trading off speed + verbose (bool, optional): Enable if you want debugging statements + + Returns: + ExLlamaV2Model: Loaded ExLlamaV2 model. + + Raises: + ImportError: If the `exllamav2` library is not installed. + """ + try: - from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config + from exllamav2 import ( # , ExLlamaV2Cache_Q4 + ExLlamaV2, + ExLlamaV2Cache, + ExLlamaV2Cache_8bit, + ExLlamaV2Config, + ) from transformers import AutoTokenizer except ImportError: raise ImportError( "The `exllamav2` library needs to be installed in order to use `exllamav2` models." ) + if os.name != "nt": + use_fasttensors = True + else: + use_fasttensors = False + + # Create config + config = ExLlamaV2Config() config.model_dir = model_path + config.fasttensors = use_fasttensors config.prepare() - config.max_seq_len = model_kwargs.pop("max_seq_len", config.max_seq_len) - config.scale_pos_emb = model_kwargs.pop("scale_pos_emb", config.scale_pos_emb) - config.scale_alpha_value = model_kwargs.pop( - "scale_alpha_value", config.scale_alpha_value - ) - config.no_flash_attn = model_kwargs.pop("no_flash_attn", config.no_flash_attn) - config.num_experts_per_token = int( - model_kwargs.pop("num_experts_per_token", config.num_experts_per_token) - ) + # Set config options + + config.max_seq_len = max_seq_len + config.scale_pos_emb = scale_pos_emb + config.scale_alpha_value = scale_alpha_value + config.no_flash_attn = no_flash_attn + if num_experts_per_token: + config.num_experts_per_token = num_experts_per_token + if low_mem: + config.set_low_mem() + + # Load the model model = ExLlamaV2(config) split = None - if "gpu_split" in model_kwargs.keys(): - split = [float(alloc) for alloc in model_kwargs["gpu_split"].split(",")] + if gpu_split and gpu_split != "auto": + split = [float(alloc) for alloc in gpu_split.split(",")] - model.load(split) + if gpu_split != "auto": + if not verbose: + print(" -- Loading model...") + model.load(split) + + # Load tokenizer + + if not verbose: + print(" -- Loading tokenizer...") + + # tokenizer = ExLlamaV2Tokenizer(config) tokenizer_kwargs.setdefault("padding_side", "left") tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs) + # tokenizer = TransformerTokenizer(model_path, **tokenizer_kwargs) + + # Create cache + + if cache_8bit: + cache = ExLlamaV2Cache_8bit(model, lazy=not model.loaded) + # elif cache_q4: + # cache = ExLlamaV2Cache_Q4(model, lazy = not model.loaded) + else: + cache = ExLlamaV2Cache(model, lazy=not model.loaded) + + # Load model now if auto split enabled - cache = ExLlamaV2Cache(model) + if not model.loaded: + print(" -- Loading model...") + model.load_autosplit(cache) return ExLlamaV2Model(model, tokenizer, device, cache)