Skip to content

Commit

Permalink
Config: Embeddings: Make embeddings_device a default when API loading
Browse files Browse the repository at this point in the history
When loading from the API, the fallback for embeddings_device will be
the same as the config.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Aug 1, 2024
1 parent 54aeeba commit 3e42211
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 9 deletions.
18 changes: 16 additions & 2 deletions common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import pathlib
from enum import Enum
from fastapi import HTTPException
from loguru import logger
from typing import Optional
Expand All @@ -31,6 +32,12 @@
embeddings_container: Optional[InfinityContainer] = None


class ModelType(Enum):
MODEL = "model"
DRAFT = "draft"
EMBEDDING = "embedding"


def load_progress(module, modules):
"""Wrapper callback for load progress."""
yield module, modules
Expand Down Expand Up @@ -142,16 +149,23 @@ async def unload_embedding_model():
embeddings_container = None


def get_config_default(key, fallback=None, is_draft=False):
def get_config_default(key: str, fallback=None, model_type: str = "model"):
"""Fetches a default value from model config if allowed by the user."""

model_config = config.model_config()
default_keys = unwrap(model_config.get("use_as_default"), [])

# Add extra keys to defaults
default_keys.append("embeddings_device")

if key in default_keys:
# Is this a draft model load parameter?
if is_draft:
if model_type == "draft":
draft_config = config.draft_model_config()
return unwrap(draft_config.get(key), fallback)
elif model_type == "embedding":
embeddings_config = config.embeddings_config()
return unwrap(embeddings_config.get(key), fallback)
else:
return unwrap(model_config.get(key), fallback)
else:
Expand Down
9 changes: 6 additions & 3 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,14 @@ embeddings:
# Overrides directory to look for embedding models (default: models)
embedding_model_dir: models

# An initial embedding model to load on the infinity backend (default: None)
embedding_model_name:

# Device to load embedding models on (default: cpu)
# Possible values: cpu, auto, cuda
# NOTE: It's recommended to load embedding models on the CPU.
# If you'd like to load on an AMD gpu, set this value to "cuda" as well.
embeddings_device: cpu

# The below parameters only apply for initial loads
# All API based loads do NOT inherit these settings unless specified in use_as_default

# An initial embedding model to load on the infinity backend (default: None)
embedding_model_name:
12 changes: 8 additions & 4 deletions endpoints/core/types/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,19 @@ class DraftModelLoadRequest(BaseModel):
# Config arguments
draft_rope_scale: Optional[float] = Field(
default_factory=lambda: get_config_default(
"draft_rope_scale", 1.0, is_draft=True
"draft_rope_scale", 1.0, model_type="draft"
)
)
draft_rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present",
default_factory=lambda: get_config_default(
"draft_rope_alpha", None, is_draft=True
"draft_rope_alpha", None, model_type="draft"
),
examples=[1.0],
)
draft_cache_mode: Optional[str] = Field(
default_factory=lambda: get_config_default(
"draft_cache_mode", "FP16", is_draft=True
"draft_cache_mode", "FP16", model_type="draft"
)
)

Expand Down Expand Up @@ -139,7 +139,11 @@ class ModelLoadRequest(BaseModel):

class EmbeddingModelLoadRequest(BaseModel):
name: str
embeddings_device: Optional[str] = None
embeddings_device: Optional[str] = Field(
default_factory=lambda: get_config_default(
"embeddings_device", model_type="embedding"
)
)


class ModelLoadResponse(BaseModel):
Expand Down

0 comments on commit 3e42211

Please sign in to comment.