Skip to content

Commit

Permalink
Model: Add vision loading support
Browse files Browse the repository at this point in the history
Adds the ability to load vision parts of text + image models. Requires
an explicit flag in config because there isn't a way to automatically
determine whether the vision tower should be used.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Nov 11, 2024
1 parent cc25167 commit 69ac0eb
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 5 deletions.
17 changes: 17 additions & 0 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ExLlamaV2Cache_TP,
ExLlamaV2Tokenizer,
ExLlamaV2Lora,
ExLlamaV2VisionTower,
)
from exllamav2.generator import (
ExLlamaV2Sampler,
Expand All @@ -28,6 +29,7 @@
)
from itertools import zip_longest
from loguru import logger
from PIL import Image
from typing import List, Optional, Union

from ruamel.yaml import YAML
Expand Down Expand Up @@ -91,6 +93,10 @@ class ExllamaV2Container:
autosplit_reserve: List[float] = [96 * 1024**2]
use_tp: bool = False

# Vision vars
use_vision: bool = False
vision_model: Optional[ExLlamaV2VisionTower] = None

# Load state
model_is_loading: bool = False
model_loaded: bool = False
Expand Down Expand Up @@ -144,6 +150,9 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)

# Set vision state
self.use_vision = unwrap(kwargs.get("vision"), True)

# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
Expand Down Expand Up @@ -608,6 +617,14 @@ def progress(loaded_modules: int, total_modules: int)
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)

# Load vision tower if it exists
if self.use_vision:
self.vision_model = ExLlamaV2VisionTower(self.config)

for value in self.vision_model.load_gen(callback_gen=progress_callback):
if value:
yield value

self.model = ExLlamaV2(self.config)
if not self.quiet:
logger.info("Loading model: " + self.config.model_dir)
Expand Down
6 changes: 6 additions & 0 deletions common/config_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ class ModelConfig(BaseConfigModel):
"NOTE: Only works with chat completion message lists!"
),
)
vision: Optional[bool] = Field(
False,
description=(
"Enables vision support if the model supports it. (default: False)"
),
)
num_experts_per_token: Optional[int] = Field(
None,
description=(
Expand Down
20 changes: 15 additions & 5 deletions common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ModelType(Enum):
MODEL = "model"
DRAFT = "draft"
EMBEDDING = "embedding"
VISION = "vision"


def load_progress(module, modules):
Expand Down Expand Up @@ -70,17 +71,26 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
# Create a new container
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)

model_type = "draft" if container.draft_config else "model"
# Add possible types of models that can be loaded
model_type = [ModelType.MODEL]

if container.use_vision:
model_type.insert(0, ModelType.VISION)

if container.draft_config:
model_type.insert(0, ModelType.DRAFT)

load_status = container.load_gen(load_progress, **kwargs)

progress = get_loading_progress_bar()
progress.start()

try:
index = 0
async for module, modules in load_status:
if module == 0:
loading_task = progress.add_task(
f"[cyan]Loading {model_type} modules", total=modules
f"[cyan]Loading {model_type[index].value} modules", total=modules
)
else:
progress.advance(loading_task)
Expand All @@ -89,10 +99,10 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):

if module == modules:
# Switch to model progress if the draft model is loaded
if model_type == "draft":
model_type = "model"
else:
if index == len(model_type):
progress.stop()
else:
index += 1
finally:
progress.stop()

Expand Down
3 changes: 3 additions & 0 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ model:
# NOTE: Only works with chat completion message lists!
prompt_template:

# Enables vision support if the model supports it. (default: False)
vision: false

# Number of experts to use per token.
# Fetched from the model's config.json if empty.
# NOTE: For MoE models only.
Expand Down
1 change: 1 addition & 0 deletions endpoints/core/types/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class ModelLoadRequest(BaseModel):
cache_mode: Optional[str] = None
chunk_size: Optional[int] = None
prompt_template: Optional[str] = None
vision: Optional[bool] = None
num_experts_per_token: Optional[int] = None

# Non-config arguments
Expand Down

0 comments on commit 69ac0eb

Please sign in to comment.