Skip to content

Commit

Permalink
Embeddings: Update config, args, and parameter names
Browse files Browse the repository at this point in the history
Use embeddings_device as the parameter for device to remove ambiguity.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Jul 30, 2024
1 parent bfa011e commit dc3dcc9
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 9 deletions.
2 changes: 1 addition & 1 deletion backends/infinity/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def load(self, **kwargs):
self.model_is_loading = True

# Use cpu by default
device = unwrap(kwargs.get("device"), "cpu")
device = unwrap(kwargs.get("embeddings_device"), "cpu")

engine_args = EngineArgs(
model_name_or_path=str(self.model_dir),
Expand Down
20 changes: 20 additions & 0 deletions common/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def init_argparser():
)
add_network_args(parser)
add_model_args(parser)
add_embeddings_args(parser)
add_logging_args(parser)
add_developer_args(parser)
add_sampling_args(parser)
Expand Down Expand Up @@ -209,3 +210,22 @@ def add_sampling_args(parser: argparse.ArgumentParser):
sampling_group.add_argument(
"--override-preset", type=str, help="Select a sampler override preset"
)


def add_embeddings_args(parser: argparse.ArgumentParser):
"""Adds arguments specific to embeddings"""

embeddings_group = parser.add_argument_group("embeddings")
embeddings_group.add_argument(
"--embedding-model-dir",
type=str,
help="Overrides the directory to look for models",
)
embeddings_group.add_argument(
"--embedding-model-name", type=str, help="An initial model to load"
)
embeddings_group.add_argument(
"--embeddings-device",
type=str,
help="Device to use for embeddings. Options: (cpu, auto, cuda)",
)
5 changes: 5 additions & 0 deletions common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def from_args(args: dict):
cur_developer_config = developer_config()
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}

embeddings_override = args.get("embeddings")
if embeddings_override:
cur_embeddings_config = embeddings_config()
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override}


def sampling_config():
"""Returns the sampling parameter config from the global config"""
Expand Down
23 changes: 16 additions & 7 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,6 @@ developer:
# Otherwise, the priority will be set to high
#realtime_process_priority: False

embeddings:
embedding_model_dir: models

embedding_model_name:

embeddings_device: cpu

# Options for model overrides and loading
# Please read the comments to understand how arguments are handled between initial and API loads
model:
Expand Down Expand Up @@ -208,3 +201,19 @@ model:
#loras:
#- name: lora1
# scaling: 1.0

# Options for embedding models and loading.
# NOTE: Embeddings requires the "extras" feature to be installed
# Install it via "pip install .[extras]"
embeddings:
# Overrides directory to look for embedding models (default: models)
embedding_model_dir: models

# An initial embedding model to load on the infinity backend (default: None)
embedding_model_name:

# Device to load embedding models on (default: cpu)
# Possible values: cpu, auto, cuda
# NOTE: It's recommended to load embedding models on the CPU.
# If you'd like to load on an AMD gpu, set this value to "cuda" as well.
embeddings_device: cpu
2 changes: 1 addition & 1 deletion endpoints/core/types/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class ModelLoadRequest(BaseModel):

class EmbeddingModelLoadRequest(BaseModel):
name: str
device: Optional[str] = None
embeddings_device: Optional[str] = None


class ModelLoadResponse(BaseModel):
Expand Down

0 comments on commit dc3dcc9

Please sign in to comment.