From 6e48bb420a41cb0f8eeb3999ef270f7747f0d6fe Mon Sep 17 00:00:00 2001 From: Brian Dashore Date: Thu, 24 Oct 2024 23:35:05 -0400 Subject: [PATCH] Model: Fix inline loading and draft key (#225) * Model: Fix inline loading and draft key There was a lack of foresight between the new config.yml and how it was structured. The "draft" key became "draft_model" without updating both the API request and inline loading keys. For the API requests, still support "draft" as legacy, but the "draft_model" key is preferred. Signed-off-by: kingbri * OAI: Add draft model dir to inline load Was not pushed before and caused errors of the kwargs being None. Signed-off-by: kingbri * Model: Fix draft args application Draft model args weren't applying since there was a reset due to how the old override behavior worked. Signed-off-by: kingbri * OAI: Change embedding model load params Use embedding_model_name to be inline with the config. Signed-off-by: kingbri * API: Fix parameter for draft model load Alias name to draft_model_name. Signed-off-by: kingbri * API: Fix parameter for template switch Add prompt_template_name to be more descriptive. Signed-off-by: kingbri * API: Fix parameter for model load Alias name to model_name for config parity. Signed-off-by: kingbri * API: Add alias documentation Signed-off-by: kingbri --------- Signed-off-by: kingbri --- backends/exllamav2/model.py | 52 +++++++++++++++---------------- endpoints/OAI/utils/completion.py | 7 +++-- endpoints/core/router.py | 19 +++++------ endpoints/core/types/model.py | 25 ++++++++++++--- endpoints/core/types/template.py | 7 +++-- endpoints/core/utils/model.py | 2 +- main.py | 2 +- 7 files changed, 68 insertions(+), 46 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index b610cfdc..0f9ba3a2 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -129,8 +129,27 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Check if the model arch is compatible with various exl2 features self.config.arch_compat_overrides() + # Create the hf_config + self.hf_config = await HuggingFaceConfig.from_file(model_directory) + + # Load generation config overrides + generation_config_path = model_directory / "generation_config.json" + if generation_config_path.exists(): + try: + self.generation_config = await GenerationConfig.from_file( + generation_config_path.parent + ) + except Exception: + logger.error(traceback.format_exc()) + logger.warning( + "Skipping generation config load because of an unexpected error." + ) + + # Apply a model's config overrides while respecting user settings + kwargs = await self.set_model_overrides(**kwargs) + # Prepare the draft model config if necessary - draft_args = unwrap(kwargs.get("draft"), {}) + draft_args = unwrap(kwargs.get("draft_model"), {}) draft_model_name = draft_args.get("draft_model_name") enable_draft = draft_args and draft_model_name @@ -154,25 +173,6 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() - # Create the hf_config - self.hf_config = await HuggingFaceConfig.from_file(model_directory) - - # Load generation config overrides - generation_config_path = model_directory / "generation_config.json" - if generation_config_path.exists(): - try: - self.generation_config = await GenerationConfig.from_file( - generation_config_path.parent - ) - except Exception: - logger.error(traceback.format_exc()) - logger.warning( - "Skipping generation config load because of an unexpected error." - ) - - # Apply a model's config overrides while respecting user settings - kwargs = await self.set_model_overrides(**kwargs) - # MARK: User configuration # Get cache mode @@ -338,9 +338,6 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Set user-configured draft model values if enable_draft: - # Fetch from the updated kwargs - draft_args = unwrap(kwargs.get("draft"), {}) - self.draft_config.max_seq_len = self.config.max_seq_len self.draft_config.scale_pos_emb = unwrap( @@ -384,9 +381,12 @@ async def set_model_overrides(self, **kwargs): override_args = unwrap(yaml.load(contents), {}) # Merge draft overrides beforehand - draft_override_args = unwrap(override_args.get("draft"), {}) - if self.draft_config and draft_override_args: - kwargs["draft"] = {**draft_override_args, **kwargs.get("draft")} + draft_override_args = unwrap(override_args.get("draft_model"), {}) + if draft_override_args: + kwargs["draft_model"] = { + **draft_override_args, + **unwrap(kwargs.get("draft_model"), {}), + } # Merge the override and model kwargs merged_kwargs = {**override_args, **kwargs} diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index c8b02c81..59f3844b 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -149,8 +149,11 @@ async def load_inline_model(model_name: str, request: Request): return - # Load the model - await model.load_model(model_path) + # Load the model and also add draft dir + await model.load_model( + model_path, + draft_model=config.draft_model.model_dump(include={"draft_model_dir"}), + ) async def stream_generate_completion( diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 2c60cd77..f2b42473 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -123,7 +123,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: """Loads a model into the model container. This returns an SSE stream.""" # Verify request parameters - if not data.name: + if not data.model_name: error_message = handle_request_error( "A model name was not provided for load.", exc_info=False, @@ -132,11 +132,11 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: raise HTTPException(400, error_message) model_path = pathlib.Path(config.model.model_dir) - model_path = model_path / data.name + model_path = model_path / data.model_name draft_model_path = None - if data.draft: - if not data.draft.draft_model_name: + if data.draft_model: + if not data.draft_model.draft_model_name: error_message = handle_request_error( "Could not find the draft model name for model load.", exc_info=False, @@ -301,7 +301,7 @@ async def load_embedding_model( request: Request, data: EmbeddingModelLoadRequest ) -> ModelLoadResponse: # Verify request parameters - if not data.name: + if not data.embedding_model_name: error_message = handle_request_error( "A model name was not provided for load.", exc_info=False, @@ -310,7 +310,7 @@ async def load_embedding_model( raise HTTPException(400, error_message) embedding_model_dir = pathlib.Path(config.embeddings.embedding_model_dir) - embedding_model_path = embedding_model_dir / data.name + embedding_model_path = embedding_model_dir / data.embedding_model_name if not embedding_model_path.exists(): error_message = handle_request_error( @@ -441,7 +441,7 @@ async def list_templates(request: Request) -> TemplateList: async def switch_template(data: TemplateSwitchRequest): """Switch the currently loaded template.""" - if not data.name: + if not data.prompt_template_name: error_message = handle_request_error( "New template name not found.", exc_info=False, @@ -450,11 +450,12 @@ async def switch_template(data: TemplateSwitchRequest): raise HTTPException(400, error_message) try: - template_path = pathlib.Path("templates") / data.name + template_path = pathlib.Path("templates") / data.prompt_template_name model.container.prompt_template = await PromptTemplate.from_file(template_path) except FileNotFoundError as e: error_message = handle_request_error( - f"The template name {data.name} doesn't exist. Check the spelling?", + f"The template name {data.prompt_template_name} doesn't exist. " + + "Check the spelling?", exc_info=False, ).error.message diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 25a40329..58ab0dc4 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -1,6 +1,6 @@ """Contains model card types.""" -from pydantic import BaseModel, Field, ConfigDict +from pydantic import AliasChoices, BaseModel, Field, ConfigDict from time import time from typing import List, Literal, Optional, Union @@ -48,7 +48,10 @@ class DraftModelLoadRequest(BaseModel): """Represents a draft model load request.""" # Required - draft_model_name: str + draft_model_name: str = Field( + alias=AliasChoices("draft_model_name", "name"), + description="Aliases: name", + ) # Config arguments draft_rope_scale: Optional[float] = None @@ -63,8 +66,14 @@ class DraftModelLoadRequest(BaseModel): class ModelLoadRequest(BaseModel): """Represents a model load request.""" + # Avoids pydantic namespace warning + model_config = ConfigDict(protected_namespaces=[]) + # Required - name: str + model_name: str = Field( + alias=AliasChoices("model_name", "name"), + description="Aliases: name", + ) # Config arguments @@ -108,12 +117,18 @@ class ModelLoadRequest(BaseModel): num_experts_per_token: Optional[int] = None # Non-config arguments - draft: Optional[DraftModelLoadRequest] = None + draft_model: Optional[DraftModelLoadRequest] = Field( + default=None, + alias=AliasChoices("draft_model", "draft"), + ) skip_queue: Optional[bool] = False class EmbeddingModelLoadRequest(BaseModel): - name: str + embedding_model_name: str = Field( + alias=AliasChoices("embedding_model_name", "name"), + description="Aliases: name", + ) # Set default from the config embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device) diff --git a/endpoints/core/types/template.py b/endpoints/core/types/template.py index d72d6210..010c9dbd 100644 --- a/endpoints/core/types/template.py +++ b/endpoints/core/types/template.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import AliasChoices, BaseModel, Field from typing import List @@ -12,4 +12,7 @@ class TemplateList(BaseModel): class TemplateSwitchRequest(BaseModel): """Request to switch a template.""" - name: str + prompt_template_name: str = Field( + alias=AliasChoices("prompt_template_name", "name"), + description="Aliases: name", + ) diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py index d151fdd1..973337d0 100644 --- a/endpoints/core/utils/model.py +++ b/endpoints/core/utils/model.py @@ -104,7 +104,7 @@ async def stream_model_load( # Set the draft model path if it exists if draft_model_path: - load_data["draft"]["draft_model_dir"] = draft_model_path + load_data["draft_model"]["draft_model_dir"] = draft_model_path load_status = model.load_model_gen( model_path, skip_wait=data.skip_queue, **load_data diff --git a/main.py b/main.py index 4b420dff..e17e2f8a 100644 --- a/main.py +++ b/main.py @@ -70,7 +70,7 @@ async def entrypoint_async(): await model.load_model( model_path.resolve(), **config.model.model_dump(exclude_none=True), - draft=config.draft_model.model_dump(exclude_none=True), + draft_model=config.draft_model.model_dump(exclude_none=True), ) # Load loras after loading the model