Skip to content

Commit

Permalink
API: Don't fallback to default values on model load request
Browse files Browse the repository at this point in the history
It's best to pass them down the config stack.

API/User config.yml -> model config.yml -> model config.json -> fallback.

Doing this allows for seamless flow and yielding control to each
member in the stack.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Sep 1, 2024
1 parent 4452d6f commit a96fa5f
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 17 deletions.
2 changes: 2 additions & 0 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
self.draft_config.max_attention_size = chunk_size**2

def set_model_overrides(self, **kwargs):
"""Sets overrides from a model folder's config yaml."""

override_config_path = self.model_dir / "tabby_config.yml"

if not override_config_path.exists():
Expand Down
10 changes: 4 additions & 6 deletions common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def unload_embedding_model():
embeddings_container = None


def get_config_default(key: str, fallback=None, model_type: str = "model"):
def get_config_default(key: str, model_type: str = "model"):
"""Fetches a default value from model config if allowed by the user."""

model_config = config.model_config()
Expand All @@ -162,14 +162,12 @@ def get_config_default(key: str, fallback=None, model_type: str = "model"):
# Is this a draft model load parameter?
if model_type == "draft":
draft_config = config.draft_model_config()
return unwrap(draft_config.get(key), fallback)
return draft_config.get(key)
elif model_type == "embedding":
embeddings_config = config.embeddings_config()
return unwrap(embeddings_config.get(key), fallback)
return embeddings_config.get(key)
else:
return unwrap(model_config.get(key), fallback)
else:
return fallback
return model_config.get(key)


async def check_model_container():
Expand Down
20 changes: 10 additions & 10 deletions endpoints/core/types/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,19 @@ class DraftModelLoadRequest(BaseModel):
# Config arguments
draft_rope_scale: Optional[float] = Field(
default_factory=lambda: get_config_default(
"draft_rope_scale", 1.0, model_type="draft"
"draft_rope_scale", model_type="draft"
)
)
draft_rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present",
default_factory=lambda: get_config_default(
"draft_rope_alpha", None, model_type="draft"
"draft_rope_alpha", model_type="draft"
),
examples=[1.0],
)
draft_cache_mode: Optional[str] = Field(
default_factory=lambda: get_config_default(
"draft_cache_mode", "FP16", model_type="draft"
"draft_cache_mode", model_type="draft"
)
)

Expand Down Expand Up @@ -97,16 +97,16 @@ class ModelLoadRequest(BaseModel):
examples=[4096],
)
tensor_parallel: Optional[bool] = Field(
default_factory=lambda: get_config_default("tensor_parallel", False)
default_factory=lambda: get_config_default("tensor_parallel")
)
gpu_split_auto: Optional[bool] = Field(
default_factory=lambda: get_config_default("gpu_split_auto", True)
default_factory=lambda: get_config_default("gpu_split_auto")
)
autosplit_reserve: Optional[List[float]] = Field(
default_factory=lambda: get_config_default("autosplit_reserve", [96])
default_factory=lambda: get_config_default("autosplit_reserve")
)
gpu_split: Optional[List[float]] = Field(
default_factory=lambda: get_config_default("gpu_split", []),
default_factory=lambda: get_config_default("gpu_split"),
examples=[[24.0, 20.0]],
)
rope_scale: Optional[float] = Field(
Expand All @@ -120,10 +120,10 @@ class ModelLoadRequest(BaseModel):
examples=[1.0],
)
cache_mode: Optional[str] = Field(
default_factory=lambda: get_config_default("cache_mode", "FP16")
default_factory=lambda: get_config_default("cache_mode")
)
chunk_size: Optional[int] = Field(
default_factory=lambda: get_config_default("chunk_size", 2048)
default_factory=lambda: get_config_default("chunk_size")
)
prompt_template: Optional[str] = Field(
default_factory=lambda: get_config_default("prompt_template")
Expand All @@ -132,7 +132,7 @@ class ModelLoadRequest(BaseModel):
default_factory=lambda: get_config_default("num_experts_per_token")
)
fasttensors: Optional[bool] = Field(
default_factory=lambda: get_config_default("fasttensors", False)
default_factory=lambda: get_config_default("fasttensors")
)

# Non-config arguments
Expand Down
4 changes: 3 additions & 1 deletion endpoints/core/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ async def stream_model_load(
):
"""Request generation wrapper for the loading process."""

# Get trimmed load data
load_data = data.model_dump(exclude_none=True)

# Set the draft model path if it exists
load_data = data.model_dump()
if draft_model_path:
load_data["draft"]["draft_model_dir"] = draft_model_path

Expand Down

0 comments on commit a96fa5f

Please sign in to comment.