Skip to content

Commit

Permalink
Model: Fix inline loading and draft key (#225)
Browse files Browse the repository at this point in the history
* Model: Fix inline loading and draft key

There was a lack of foresight between the new config.yml and how
it was structured. The "draft" key became "draft_model" without updating
both the API request and inline loading keys.

For the API requests, still support "draft" as legacy, but the "draft_model"
key is preferred.

Signed-off-by: kingbri <[email protected]>

* OAI: Add draft model dir to inline load

Was not pushed before and caused errors of the kwargs being None.

Signed-off-by: kingbri <[email protected]>

* Model: Fix draft args application

Draft model args weren't applying since there was a reset due to how
the old override behavior worked.

Signed-off-by: kingbri <[email protected]>

* OAI: Change embedding model load params

Use embedding_model_name to be inline with the config.

Signed-off-by: kingbri <[email protected]>

* API: Fix parameter for draft model load

Alias name to draft_model_name.

Signed-off-by: kingbri <[email protected]>

* API: Fix parameter for template switch

Add prompt_template_name to be more descriptive.

Signed-off-by: kingbri <[email protected]>

* API: Fix parameter for model load

Alias name to model_name for config parity.

Signed-off-by: kingbri <[email protected]>

* API: Add alias documentation

Signed-off-by: kingbri <[email protected]>

---------

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 authored Oct 25, 2024
1 parent f20857c commit 6e48bb4
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 46 deletions.
52 changes: 26 additions & 26 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,27 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
# Check if the model arch is compatible with various exl2 features
self.config.arch_compat_overrides()

# Create the hf_config
self.hf_config = await HuggingFaceConfig.from_file(model_directory)

# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)

# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)

# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft"), {})
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name

Expand All @@ -154,25 +173,6 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()

# Create the hf_config
self.hf_config = await HuggingFaceConfig.from_file(model_directory)

# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)

# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)

# MARK: User configuration

# Get cache mode
Expand Down Expand Up @@ -338,9 +338,6 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):

# Set user-configured draft model values
if enable_draft:
# Fetch from the updated kwargs
draft_args = unwrap(kwargs.get("draft"), {})

self.draft_config.max_seq_len = self.config.max_seq_len

self.draft_config.scale_pos_emb = unwrap(
Expand Down Expand Up @@ -384,9 +381,12 @@ async def set_model_overrides(self, **kwargs):
override_args = unwrap(yaml.load(contents), {})

# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft"), {})
if self.draft_config and draft_override_args:
kwargs["draft"] = {**draft_override_args, **kwargs.get("draft")}
draft_override_args = unwrap(override_args.get("draft_model"), {})
if draft_override_args:
kwargs["draft_model"] = {
**draft_override_args,
**unwrap(kwargs.get("draft_model"), {}),
}

# Merge the override and model kwargs
merged_kwargs = {**override_args, **kwargs}
Expand Down
7 changes: 5 additions & 2 deletions endpoints/OAI/utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,11 @@ async def load_inline_model(model_name: str, request: Request):

return

# Load the model
await model.load_model(model_path)
# Load the model and also add draft dir
await model.load_model(
model_path,
draft_model=config.draft_model.model_dump(include={"draft_model_dir"}),
)


async def stream_generate_completion(
Expand Down
19 changes: 10 additions & 9 deletions endpoints/core/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
"""Loads a model into the model container. This returns an SSE stream."""

# Verify request parameters
if not data.name:
if not data.model_name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
Expand All @@ -132,11 +132,11 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
raise HTTPException(400, error_message)

model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / data.name
model_path = model_path / data.model_name

draft_model_path = None
if data.draft:
if not data.draft.draft_model_name:
if data.draft_model:
if not data.draft_model.draft_model_name:
error_message = handle_request_error(
"Could not find the draft model name for model load.",
exc_info=False,
Expand Down Expand Up @@ -301,7 +301,7 @@ async def load_embedding_model(
request: Request, data: EmbeddingModelLoadRequest
) -> ModelLoadResponse:
# Verify request parameters
if not data.name:
if not data.embedding_model_name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
Expand All @@ -310,7 +310,7 @@ async def load_embedding_model(
raise HTTPException(400, error_message)

embedding_model_dir = pathlib.Path(config.embeddings.embedding_model_dir)
embedding_model_path = embedding_model_dir / data.name
embedding_model_path = embedding_model_dir / data.embedding_model_name

if not embedding_model_path.exists():
error_message = handle_request_error(
Expand Down Expand Up @@ -441,7 +441,7 @@ async def list_templates(request: Request) -> TemplateList:
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template."""

if not data.name:
if not data.prompt_template_name:
error_message = handle_request_error(
"New template name not found.",
exc_info=False,
Expand All @@ -450,11 +450,12 @@ async def switch_template(data: TemplateSwitchRequest):
raise HTTPException(400, error_message)

try:
template_path = pathlib.Path("templates") / data.name
template_path = pathlib.Path("templates") / data.prompt_template_name
model.container.prompt_template = await PromptTemplate.from_file(template_path)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
f"The template name {data.prompt_template_name} doesn't exist. "
+ "Check the spelling?",
exc_info=False,
).error.message

Expand Down
25 changes: 20 additions & 5 deletions endpoints/core/types/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Contains model card types."""

from pydantic import BaseModel, Field, ConfigDict
from pydantic import AliasChoices, BaseModel, Field, ConfigDict
from time import time
from typing import List, Literal, Optional, Union

Expand Down Expand Up @@ -48,7 +48,10 @@ class DraftModelLoadRequest(BaseModel):
"""Represents a draft model load request."""

# Required
draft_model_name: str
draft_model_name: str = Field(
alias=AliasChoices("draft_model_name", "name"),
description="Aliases: name",
)

# Config arguments
draft_rope_scale: Optional[float] = None
Expand All @@ -63,8 +66,14 @@ class DraftModelLoadRequest(BaseModel):
class ModelLoadRequest(BaseModel):
"""Represents a model load request."""

# Avoids pydantic namespace warning
model_config = ConfigDict(protected_namespaces=[])

# Required
name: str
model_name: str = Field(
alias=AliasChoices("model_name", "name"),
description="Aliases: name",
)

# Config arguments

Expand Down Expand Up @@ -108,12 +117,18 @@ class ModelLoadRequest(BaseModel):
num_experts_per_token: Optional[int] = None

# Non-config arguments
draft: Optional[DraftModelLoadRequest] = None
draft_model: Optional[DraftModelLoadRequest] = Field(
default=None,
alias=AliasChoices("draft_model", "draft"),
)
skip_queue: Optional[bool] = False


class EmbeddingModelLoadRequest(BaseModel):
name: str
embedding_model_name: str = Field(
alias=AliasChoices("embedding_model_name", "name"),
description="Aliases: name",
)

# Set default from the config
embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device)
Expand Down
7 changes: 5 additions & 2 deletions endpoints/core/types/template.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import AliasChoices, BaseModel, Field
from typing import List


Expand All @@ -12,4 +12,7 @@ class TemplateList(BaseModel):
class TemplateSwitchRequest(BaseModel):
"""Request to switch a template."""

name: str
prompt_template_name: str = Field(
alias=AliasChoices("prompt_template_name", "name"),
description="Aliases: name",
)
2 changes: 1 addition & 1 deletion endpoints/core/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def stream_model_load(

# Set the draft model path if it exists
if draft_model_path:
load_data["draft"]["draft_model_dir"] = draft_model_path
load_data["draft_model"]["draft_model_dir"] = draft_model_path

load_status = model.load_model_gen(
model_path, skip_wait=data.skip_queue, **load_data
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def entrypoint_async():
await model.load_model(
model_path.resolve(),
**config.model.model_dump(exclude_none=True),
draft=config.draft_model.model_dump(exclude_none=True),
draft_model=config.draft_model.model_dump(exclude_none=True),
)

# Load loras after loading the model
Expand Down

0 comments on commit 6e48bb4

Please sign in to comment.