From 69ac0eb8aad783eec9581e5f76224b2e1df58b69 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 11 Nov 2024 12:04:40 -0500 Subject: [PATCH] Model: Add vision loading support Adds the ability to load vision parts of text + image models. Requires an explicit flag in config because there isn't a way to automatically determine whether the vision tower should be used. Signed-off-by: kingbri --- backends/exllamav2/model.py | 17 +++++++++++++++++ common/config_models.py | 6 ++++++ common/model.py | 20 +++++++++++++++----- config_sample.yml | 3 +++ endpoints/core/types/model.py | 1 + 5 files changed, 42 insertions(+), 5 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index c7d2069..df8cacf 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -20,6 +20,7 @@ ExLlamaV2Cache_TP, ExLlamaV2Tokenizer, ExLlamaV2Lora, + ExLlamaV2VisionTower, ) from exllamav2.generator import ( ExLlamaV2Sampler, @@ -28,6 +29,7 @@ ) from itertools import zip_longest from loguru import logger +from PIL import Image from typing import List, Optional, Union from ruamel.yaml import YAML @@ -91,6 +93,10 @@ class ExllamaV2Container: autosplit_reserve: List[float] = [96 * 1024**2] use_tp: bool = False + # Vision vars + use_vision: bool = False + vision_model: Optional[ExLlamaV2VisionTower] = None + # Load state model_is_loading: bool = False model_loaded: bool = False @@ -144,6 +150,9 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Apply a model's config overrides while respecting user settings kwargs = await self.set_model_overrides(**kwargs) + # Set vision state + self.use_vision = unwrap(kwargs.get("vision"), True) + # Prepare the draft model config if necessary draft_args = unwrap(kwargs.get("draft_model"), {}) draft_model_name = draft_args.get("draft_model_name") @@ -608,6 +617,14 @@ def progress(loaded_modules: int, total_modules: int) input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True) + # Load vision tower if it exists + if self.use_vision: + self.vision_model = ExLlamaV2VisionTower(self.config) + + for value in self.vision_model.load_gen(callback_gen=progress_callback): + if value: + yield value + self.model = ExLlamaV2(self.config) if not self.quiet: logger.info("Loading model: " + self.config.model_dir) diff --git a/common/config_models.py b/common/config_models.py index 40b4109..8102333 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -270,6 +270,12 @@ class ModelConfig(BaseConfigModel): "NOTE: Only works with chat completion message lists!" ), ) + vision: Optional[bool] = Field( + False, + description=( + "Enables vision support if the model supports it. (default: False)" + ), + ) num_experts_per_token: Optional[int] = Field( None, description=( diff --git a/common/model.py b/common/model.py index 87b06ad..d30d11b 100644 --- a/common/model.py +++ b/common/model.py @@ -33,6 +33,7 @@ class ModelType(Enum): MODEL = "model" DRAFT = "draft" EMBEDDING = "embedding" + VISION = "vision" def load_progress(module, modules): @@ -70,17 +71,26 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): # Create a new container container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs) - model_type = "draft" if container.draft_config else "model" + # Add possible types of models that can be loaded + model_type = [ModelType.MODEL] + + if container.use_vision: + model_type.insert(0, ModelType.VISION) + + if container.draft_config: + model_type.insert(0, ModelType.DRAFT) + load_status = container.load_gen(load_progress, **kwargs) progress = get_loading_progress_bar() progress.start() try: + index = 0 async for module, modules in load_status: if module == 0: loading_task = progress.add_task( - f"[cyan]Loading {model_type} modules", total=modules + f"[cyan]Loading {model_type[index].value} modules", total=modules ) else: progress.advance(loading_task) @@ -89,10 +99,10 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): if module == modules: # Switch to model progress if the draft model is loaded - if model_type == "draft": - model_type = "model" - else: + if index == len(model_type): progress.stop() + else: + index += 1 finally: progress.stop() diff --git a/config_sample.yml b/config_sample.yml index 83f2fc7..388dcf4 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -124,6 +124,9 @@ model: # NOTE: Only works with chat completion message lists! prompt_template: + # Enables vision support if the model supports it. (default: False) + vision: false + # Number of experts to use per token. # Fetched from the model's config.json if empty. # NOTE: For MoE models only. diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index f2817f0..17fa0a7 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -107,6 +107,7 @@ class ModelLoadRequest(BaseModel): cache_mode: Optional[str] = None chunk_size: Optional[int] = None prompt_template: Optional[str] = None + vision: Optional[bool] = None num_experts_per_token: Optional[int] = None # Non-config arguments