Skip to content

Commit

Permalink
Model: Fix inline loading and draft key
Browse files Browse the repository at this point in the history
There was a lack of foresight between the new config.yml and how
it was structured. The "draft" key became "draft_model" without updating
both the API request and inline loading keys.

For the API requests, still support "draft" as legacy, but the "draft_model"
key is preferred.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Oct 22, 2024
1 parent f20857c commit fba462f
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 30 deletions.
49 changes: 26 additions & 23 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,27 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
# Check if the model arch is compatible with various exl2 features
self.config.arch_compat_overrides()

# Create the hf_config
self.hf_config = await HuggingFaceConfig.from_file(model_directory)

# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)

# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)

# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft"), {})
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name

Expand All @@ -154,25 +173,6 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()

# Create the hf_config
self.hf_config = await HuggingFaceConfig.from_file(model_directory)

# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)

# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)

# MARK: User configuration

# Get cache mode
Expand Down Expand Up @@ -384,9 +384,12 @@ async def set_model_overrides(self, **kwargs):
override_args = unwrap(yaml.load(contents), {})

# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft"), {})
if self.draft_config and draft_override_args:
kwargs["draft"] = {**draft_override_args, **kwargs.get("draft")}
draft_override_args = unwrap(override_args.get("draft_model"), {})
if draft_override_args:
kwargs["draft_model"] = {
**draft_override_args,
**kwargs.get("draft_model"),
}

# Merge the override and model kwargs
merged_kwargs = {**override_args, **kwargs}
Expand Down
4 changes: 2 additions & 2 deletions endpoints/core/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
model_path = model_path / data.name

draft_model_path = None
if data.draft:
if not data.draft.draft_model_name:
if data.draft_model:
if not data.draft_model.draft_model_name:
error_message = handle_request_error(
"Could not find the draft model name for model load.",
exc_info=False,
Expand Down
9 changes: 6 additions & 3 deletions endpoints/core/types/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Contains model card types."""

from pydantic import BaseModel, Field, ConfigDict
from pydantic import AliasChoices, BaseModel, Field, ConfigDict
from time import time
from typing import List, Literal, Optional, Union

Expand Down Expand Up @@ -64,7 +64,7 @@ class ModelLoadRequest(BaseModel):
"""Represents a model load request."""

# Required
name: str
name: str = Field(alias=AliasChoices("model_name", "name"))

# Config arguments

Expand Down Expand Up @@ -108,7 +108,10 @@ class ModelLoadRequest(BaseModel):
num_experts_per_token: Optional[int] = None

# Non-config arguments
draft: Optional[DraftModelLoadRequest] = None
draft_model: Optional[DraftModelLoadRequest] = Field(
default=None,
alias=AliasChoices("draft_model", "draft"),
)
skip_queue: Optional[bool] = False


Expand Down
2 changes: 1 addition & 1 deletion endpoints/core/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def stream_model_load(

# Set the draft model path if it exists
if draft_model_path:
load_data["draft"]["draft_model_dir"] = draft_model_path
load_data["draft_model"]["draft_model_dir"] = draft_model_path

load_status = model.load_model_gen(
model_path, skip_wait=data.skip_queue, **load_data
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def entrypoint_async():
await model.load_model(
model_path.resolve(),
**config.model.model_dump(exclude_none=True),
draft=config.draft_model.model_dump(exclude_none=True),
draft_model=config.draft_model.model_dump(exclude_none=True),
)

# Load loras after loading the model
Expand Down

0 comments on commit fba462f

Please sign in to comment.