From f20cd330ef3b1e9999c4242db4a83e75eb392050 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 26 Jul 2024 02:45:07 +0000 Subject: [PATCH 01/11] feat: add embeddings support via sentence-transformers --- endpoints/OAI/embeddings.py | 145 +++++++++++++++++++++++++++++++ endpoints/OAI/router.py | 23 +++++ endpoints/OAI/types/embedding.py | 39 +++++++++ pyproject.toml | 3 +- tabbyAPI | 1 + 5 files changed, 210 insertions(+), 1 deletion(-) create mode 100644 endpoints/OAI/embeddings.py create mode 100644 endpoints/OAI/types/embedding.py create mode 160000 tabbyAPI diff --git a/endpoints/OAI/embeddings.py b/endpoints/OAI/embeddings.py new file mode 100644 index 00000000..3cc59de8 --- /dev/null +++ b/endpoints/OAI/embeddings.py @@ -0,0 +1,145 @@ +""" +This file is derived from +[text-generation-webui openai extension embeddings](https://github.com/oobabooga/text-generation-webui/blob/1a7c027386f43b84f3ca3b0ff04ca48d861c2d7a/extensions/openai/embeddings.py) +and modified. +The changes introduced are: Suppression of progress bar, +typing/pydantic classes moved into this file, +embeddings function declared async. +""" + +import os +import base64 +import numpy as np +from transformers import AutoModel + +embeddings_params_initialized = False + + +def initialize_embedding_params(): + ''' + using 'lazy loading' to avoid circular import + so this function will be executed only once + ''' + global embeddings_params_initialized + if not embeddings_params_initialized: + + global st_model, embeddings_model, embeddings_device + + st_model = os.environ.get("OPENAI_EMBEDDING_MODEL", + 'all-mpnet-base-v2') + embeddings_model = None + # OPENAI_EMBEDDING_DEVICE: auto (best or cpu), + # cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, + # hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, + # hpu, mtia, privateuseone + embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", 'cpu') + if embeddings_device.lower() == 'auto': + embeddings_device = None + + embeddings_params_initialized = True + + +def load_embedding_model(model: str): + try: + from sentence_transformers import SentenceTransformer + except ModuleNotFoundError: + print("The sentence_transformers module has not been found. " + + "Please install it manually with " + + "pip install -U sentence-transformers.") + raise ModuleNotFoundError from None + + initialize_embedding_params() + global embeddings_device, embeddings_model + try: + print(f"Try embedding model: {model} on {embeddings_device}") + if 'jina-embeddings' in model: + # trust_remote_code is needed to use the encode method + embeddings_model = AutoModel.from_pretrained( + model, trust_remote_code=True) + embeddings_model = embeddings_model.to(embeddings_device) + else: + embeddings_model = SentenceTransformer( + model, + device=embeddings_device, + ) + + print(f"Loaded embedding model: {model}") + except Exception as e: + embeddings_model = None + raise Exception(f"Error: Failed to load embedding model: {model}", + internal_message=repr(e)) from None + + +def get_embeddings_model(): + initialize_embedding_params() + global embeddings_model, st_model + if st_model and not embeddings_model: + load_embedding_model(st_model) # lazy load the model + + return embeddings_model + + +def get_embeddings_model_name() -> str: + initialize_embedding_params() + global st_model + return st_model + + +def get_embeddings(input: list) -> np.ndarray: + model = get_embeddings_model() + embedding = model.encode(input, + convert_to_numpy=True, + normalize_embeddings=True, + convert_to_tensor=False, + show_progress_bar=False) + return embedding + + +async def embeddings(input: list, + encoding_format: str, + model: str = None) -> dict: + if model is None: + model = st_model + else: + load_embedding_model(model) + + embeddings = get_embeddings(input) + if encoding_format == "base64": + data = [{ + "object": "embedding", + "embedding": float_list_to_base64(emb), + "index": n + } for n, emb in enumerate(embeddings)] + else: + data = [{ + "object": "embedding", + "embedding": emb.tolist(), + "index": n + } for n, emb in enumerate(embeddings)] + + response = { + "object": "list", + "data": data, + "model": st_model if model is None else model, + "usage": { + "prompt_tokens": 0, + "total_tokens": 0, + } + } + return response + + +def float_list_to_base64(float_array: np.ndarray) -> str: + # Convert the list to a float32 array that the OpenAPI client expects + # float_array = np.array(float_list, dtype="float32") + + # Get raw bytes + bytes_array = float_array.tobytes() + + # Encode bytes into base64 + encoded_bytes = base64.b64encode(bytes_array) + + # Turn raw base64 encoded bytes into ASCII + ascii_string = encoded_bytes.decode('ascii') + return ascii_string + diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 771b7f39..98973530 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,5 +1,6 @@ import asyncio from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import JSONResponse from sse_starlette import EventSourceResponse from sys import maxsize @@ -8,11 +9,16 @@ from common.model import check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.utils import unwrap +import endpoints.OAI.embeddings as OAIembeddings from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse from endpoints.OAI.types.chat_completion import ( ChatCompletionRequest, ChatCompletionResponse, ) +from endpoints.OAI.types.embedding import ( + EmbeddingsRequest, + EmbeddingsResponse +) from endpoints.OAI.utils.chat_completion import ( format_prompt_with_template, generate_chat_completion, @@ -125,3 +131,20 @@ async def chat_completion_request( disconnect_message=f"Chat completion {request.state.id} cancelled by user.", ) return response + +# Embeddings endpoint +@router.post( + "/v1/embeddings", + dependencies=[Depends(check_api_key), Depends(check_model_container)], + response_model=EmbeddingsResponse +) +async def handle_embeddings(request: EmbeddingsRequest): + input = request.input + if not input: + raise JSONResponse(status_code=400, + content={"error": "Missing required argument input"}) + model = request.model if request.model else None + response = await OAIembeddings.embeddings(input, request.encoding_format, + model) + return JSONResponse(response) + diff --git a/endpoints/OAI/types/embedding.py b/endpoints/OAI/types/embedding.py new file mode 100644 index 00000000..3dd5e864 --- /dev/null +++ b/endpoints/OAI/types/embedding.py @@ -0,0 +1,39 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + +class EmbeddingsRequest(BaseModel): + input: List[str] = Field( + ..., description="List of input texts to generate embeddings for.") + encoding_format: str = Field( + "float", + description="Encoding format for the embeddings. " + "Can be 'float' or 'base64'.") + model: Optional[str] = Field( + None, + description="Name of the embedding model to use. " + "If not provided, the default model will be used.") + + +class EmbeddingObject(BaseModel): + object: str = Field("embedding", description="Type of the object.") + embedding: List[float] = Field( + ..., description="Embedding values as a list of floats.") + index: int = Field( + ..., + description="Index of the input text corresponding to " + "the embedding.") + + +class EmbeddingsResponse(BaseModel): + object: str = Field("list", description="Type of the response object.") + data: List[EmbeddingObject] = Field( + ..., description="List of embedding objects.") + model: str = Field(..., description="Name of the embedding model used.") + usage: UsageInfo = Field(..., description="Information about token usage.") diff --git a/pyproject.toml b/pyproject.toml index 38cdbcf4..f9340964 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,8 @@ dependencies = [ [project.optional-dependencies] extras = [ # Heavy dependencies that aren't for everyday use - "outlines" + "outlines", + "sentence-transformers" ] dev = [ "ruff == 0.3.2" diff --git a/tabbyAPI b/tabbyAPI new file mode 160000 index 00000000..1650e6e6 --- /dev/null +++ b/tabbyAPI @@ -0,0 +1 @@ +Subproject commit 1650e6e6406edf797576c077aaceafcf28895c26 From 765d3593b3e22888149292355a8d5ae03bd2f630 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 26 Jul 2024 02:52:18 +0000 Subject: [PATCH 02/11] remove submodule --- tabbyAPI | 1 - 1 file changed, 1 deletion(-) delete mode 160000 tabbyAPI diff --git a/tabbyAPI b/tabbyAPI deleted file mode 160000 index 1650e6e6..00000000 --- a/tabbyAPI +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1650e6e6406edf797576c077aaceafcf28895c26 From 5adfab1cbd7dc4d23501384bebb0ac6566c4fd62 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 26 Jul 2024 02:53:14 +0000 Subject: [PATCH 03/11] ruff: formatting --- endpoints/OAI/embeddings.py | 69 +++++++++++++++----------------- endpoints/OAI/router.py | 17 ++++---- endpoints/OAI/types/embedding.py | 21 +++++----- 3 files changed, 52 insertions(+), 55 deletions(-) diff --git a/endpoints/OAI/embeddings.py b/endpoints/OAI/embeddings.py index 3cc59de8..725d7ba9 100644 --- a/endpoints/OAI/embeddings.py +++ b/endpoints/OAI/embeddings.py @@ -16,24 +16,22 @@ def initialize_embedding_params(): - ''' + """ using 'lazy loading' to avoid circular import so this function will be executed only once - ''' + """ global embeddings_params_initialized if not embeddings_params_initialized: - global st_model, embeddings_model, embeddings_device - st_model = os.environ.get("OPENAI_EMBEDDING_MODEL", - 'all-mpnet-base-v2') + st_model = os.environ.get("OPENAI_EMBEDDING_MODEL", "all-mpnet-base-v2") embeddings_model = None # OPENAI_EMBEDDING_DEVICE: auto (best or cpu), # cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, # hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, # hpu, mtia, privateuseone - embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", 'cpu') - if embeddings_device.lower() == 'auto': + embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", "cpu") + if embeddings_device.lower() == "auto": embeddings_device = None embeddings_params_initialized = True @@ -43,19 +41,20 @@ def load_embedding_model(model: str): try: from sentence_transformers import SentenceTransformer except ModuleNotFoundError: - print("The sentence_transformers module has not been found. " + - "Please install it manually with " + - "pip install -U sentence-transformers.") + print( + "The sentence_transformers module has not been found. " + + "Please install it manually with " + + "pip install -U sentence-transformers." + ) raise ModuleNotFoundError from None initialize_embedding_params() global embeddings_device, embeddings_model try: print(f"Try embedding model: {model} on {embeddings_device}") - if 'jina-embeddings' in model: + if "jina-embeddings" in model: # trust_remote_code is needed to use the encode method - embeddings_model = AutoModel.from_pretrained( - model, trust_remote_code=True) + embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=True) embeddings_model = embeddings_model.to(embeddings_device) else: embeddings_model = SentenceTransformer( @@ -66,8 +65,9 @@ def load_embedding_model(model: str): print(f"Loaded embedding model: {model}") except Exception as e: embeddings_model = None - raise Exception(f"Error: Failed to load embedding model: {model}", - internal_message=repr(e)) from None + raise Exception( + f"Error: Failed to load embedding model: {model}", internal_message=repr(e) + ) from None def get_embeddings_model(): @@ -87,17 +87,17 @@ def get_embeddings_model_name() -> str: def get_embeddings(input: list) -> np.ndarray: model = get_embeddings_model() - embedding = model.encode(input, - convert_to_numpy=True, - normalize_embeddings=True, - convert_to_tensor=False, - show_progress_bar=False) + embedding = model.encode( + input, + convert_to_numpy=True, + normalize_embeddings=True, + convert_to_tensor=False, + show_progress_bar=False, + ) return embedding -async def embeddings(input: list, - encoding_format: str, - model: str = None) -> dict: +async def embeddings(input: list, encoding_format: str, model: str = None) -> dict: if model is None: model = st_model else: @@ -105,17 +105,15 @@ async def embeddings(input: list, embeddings = get_embeddings(input) if encoding_format == "base64": - data = [{ - "object": "embedding", - "embedding": float_list_to_base64(emb), - "index": n - } for n, emb in enumerate(embeddings)] + data = [ + {"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} + for n, emb in enumerate(embeddings) + ] else: - data = [{ - "object": "embedding", - "embedding": emb.tolist(), - "index": n - } for n, emb in enumerate(embeddings)] + data = [ + {"object": "embedding", "embedding": emb.tolist(), "index": n} + for n, emb in enumerate(embeddings) + ] response = { "object": "list", @@ -124,7 +122,7 @@ async def embeddings(input: list, "usage": { "prompt_tokens": 0, "total_tokens": 0, - } + }, } return response @@ -140,6 +138,5 @@ def float_list_to_base64(float_array: np.ndarray) -> str: encoded_bytes = base64.b64encode(bytes_array) # Turn raw base64 encoded bytes into ASCII - ascii_string = encoded_bytes.decode('ascii') + ascii_string = encoded_bytes.decode("ascii") return ascii_string - diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 98973530..039042a1 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -15,10 +15,7 @@ ChatCompletionRequest, ChatCompletionResponse, ) -from endpoints.OAI.types.embedding import ( - EmbeddingsRequest, - EmbeddingsResponse -) +from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse from endpoints.OAI.utils.chat_completion import ( format_prompt_with_template, generate_chat_completion, @@ -132,19 +129,19 @@ async def chat_completion_request( ) return response + # Embeddings endpoint @router.post( "/v1/embeddings", dependencies=[Depends(check_api_key), Depends(check_model_container)], - response_model=EmbeddingsResponse + response_model=EmbeddingsResponse, ) async def handle_embeddings(request: EmbeddingsRequest): input = request.input if not input: - raise JSONResponse(status_code=400, - content={"error": "Missing required argument input"}) + raise JSONResponse( + status_code=400, content={"error": "Missing required argument input"} + ) model = request.model if request.model else None - response = await OAIembeddings.embeddings(input, request.encoding_format, - model) + response = await OAIembeddings.embeddings(input, request.encoding_format, model) return JSONResponse(response) - diff --git a/endpoints/OAI/types/embedding.py b/endpoints/OAI/types/embedding.py index 3dd5e864..7d5779fa 100644 --- a/endpoints/OAI/types/embedding.py +++ b/endpoints/OAI/types/embedding.py @@ -8,32 +8,35 @@ class UsageInfo(BaseModel): total_tokens: int = 0 completion_tokens: Optional[int] = 0 + class EmbeddingsRequest(BaseModel): input: List[str] = Field( - ..., description="List of input texts to generate embeddings for.") + ..., description="List of input texts to generate embeddings for." + ) encoding_format: str = Field( "float", description="Encoding format for the embeddings. " - "Can be 'float' or 'base64'.") + "Can be 'float' or 'base64'.", + ) model: Optional[str] = Field( None, description="Name of the embedding model to use. " - "If not provided, the default model will be used.") + "If not provided, the default model will be used.", + ) class EmbeddingObject(BaseModel): object: str = Field("embedding", description="Type of the object.") embedding: List[float] = Field( - ..., description="Embedding values as a list of floats.") + ..., description="Embedding values as a list of floats." + ) index: int = Field( - ..., - description="Index of the input text corresponding to " - "the embedding.") + ..., description="Index of the input text corresponding to " "the embedding." + ) class EmbeddingsResponse(BaseModel): object: str = Field("list", description="Type of the response object.") - data: List[EmbeddingObject] = Field( - ..., description="List of embedding objects.") + data: List[EmbeddingObject] = Field(..., description="List of embedding objects.") model: str = Field(..., description="Name of the embedding model used.") usage: UsageInfo = Field(..., description="Information about token usage.") From c9a5d2c363ce2e941937ec34a335a60bc34bbd65 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 28 Jul 2024 14:10:51 -0400 Subject: [PATCH 04/11] OAI: Refactor embeddings Move files and rewrite routes to adhere to Tabby's code style. Signed-off-by: kingbri --- endpoints/OAI/router.py | 17 +++++------------ endpoints/OAI/{ => utils}/embeddings.py | 0 2 files changed, 5 insertions(+), 12 deletions(-) rename endpoints/OAI/{ => utils}/embeddings.py (100%) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 039042a1..ffb678bd 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,6 +1,5 @@ import asyncio from fastapi import APIRouter, Depends, HTTPException, Request -from fastapi.responses import JSONResponse from sse_starlette import EventSourceResponse from sys import maxsize @@ -9,7 +8,6 @@ from common.model import check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.utils import unwrap -import endpoints.OAI.embeddings as OAIembeddings from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse from endpoints.OAI.types.chat_completion import ( ChatCompletionRequest, @@ -25,6 +23,7 @@ generate_completion, stream_generate_completion, ) +from endpoints.OAI.utils.embeddings import embeddings router = APIRouter() @@ -134,14 +133,8 @@ async def chat_completion_request( @router.post( "/v1/embeddings", dependencies=[Depends(check_api_key), Depends(check_model_container)], - response_model=EmbeddingsResponse, ) -async def handle_embeddings(request: EmbeddingsRequest): - input = request.input - if not input: - raise JSONResponse( - status_code=400, content={"error": "Missing required argument input"} - ) - model = request.model if request.model else None - response = await OAIembeddings.embeddings(input, request.encoding_format, model) - return JSONResponse(response) +async def handle_embeddings(data: EmbeddingsRequest) -> EmbeddingsResponse: + response = await embeddings(data.input, data.encoding_format, data.model) + + return response diff --git a/endpoints/OAI/embeddings.py b/endpoints/OAI/utils/embeddings.py similarity index 100% rename from endpoints/OAI/embeddings.py rename to endpoints/OAI/utils/embeddings.py From 3f21d9ef96a8c80b90e4444024d8bf3d4ac10a5b Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 29 Jul 2024 13:42:03 -0400 Subject: [PATCH 05/11] Embeddings: Switch to Infinity Infinity-emb is an async batching engine for embeddings. This is preferable to sentence-transformers since it handles scalable usecases without the need for external thread intervention. Signed-off-by: kingbri --- common/config.py | 5 + config_sample.yml | 7 ++ endpoints/OAI/router.py | 2 +- endpoints/OAI/utils/embeddings.py | 181 +++++++++++++----------------- 4 files changed, 91 insertions(+), 104 deletions(-) diff --git a/common/config.py b/common/config.py index 972b382a..5546240e 100644 --- a/common/config.py +++ b/common/config.py @@ -95,3 +95,8 @@ def logging_config(): def developer_config(): """Returns the developer specific config from the global config""" return unwrap(GLOBAL_CONFIG.get("developer"), {}) + + +def embeddings_config(): + """Returns the embeddings config from the global config""" + return unwrap(GLOBAL_CONFIG.get("embeddings"), {}) diff --git a/config_sample.yml b/config_sample.yml index c92f6730..053feb62 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -72,6 +72,13 @@ developer: # Otherwise, the priority will be set to high #realtime_process_priority: False +embeddings: + embeddings_model_dir: models + + embeddings_model_name: + + embeddings_device: cpu + # Options for model overrides and loading # Please read the comments to understand how arguments are handled between initial and API loads model: diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index ffb678bd..2cad8762 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -135,6 +135,6 @@ async def chat_completion_request( dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def handle_embeddings(data: EmbeddingsRequest) -> EmbeddingsResponse: - response = await embeddings(data.input, data.encoding_format, data.model) + response = await embeddings(data) return response diff --git a/endpoints/OAI/utils/embeddings.py b/endpoints/OAI/utils/embeddings.py index 725d7ba9..cf5b799e 100644 --- a/endpoints/OAI/utils/embeddings.py +++ b/endpoints/OAI/utils/embeddings.py @@ -7,135 +7,110 @@ embeddings function declared async. """ +import asyncio import os import base64 +import pathlib +from loguru import logger import numpy as np from transformers import AutoModel -embeddings_params_initialized = False +from common import config +from common.utils import unwrap +from endpoints.OAI.types.embedding import ( + EmbeddingObject, + EmbeddingsRequest, + EmbeddingsResponse, +) -def initialize_embedding_params(): - """ - using 'lazy loading' to avoid circular import - so this function will be executed only once - """ - global embeddings_params_initialized - if not embeddings_params_initialized: - global st_model, embeddings_model, embeddings_device - - st_model = os.environ.get("OPENAI_EMBEDDING_MODEL", "all-mpnet-base-v2") - embeddings_model = None - # OPENAI_EMBEDDING_DEVICE: auto (best or cpu), - # cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, - # hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, - # hpu, mtia, privateuseone - embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", "cpu") - if embeddings_device.lower() == "auto": - embeddings_device = None +embeddings_model = None - embeddings_params_initialized = True - -def load_embedding_model(model: str): +def load_embedding_model(model_path: pathlib.Path, device: str): try: - from sentence_transformers import SentenceTransformer + from infinity_emb import EngineArgs, AsyncEmbeddingEngine except ModuleNotFoundError: - print( - "The sentence_transformers module has not been found. " - + "Please install it manually with " - + "pip install -U sentence-transformers." + logger.error( + "Skipping embeddings because infinity-emb is not installed.\n" + "Please run the following command in your environment " + "to install extra packages:\n" + "pip install -U .[extras]" ) raise ModuleNotFoundError from None - initialize_embedding_params() - global embeddings_device, embeddings_model + global embeddings_model try: - print(f"Try embedding model: {model} on {embeddings_device}") - if "jina-embeddings" in model: - # trust_remote_code is needed to use the encode method - embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=True) - embeddings_model = embeddings_model.to(embeddings_device) - else: - embeddings_model = SentenceTransformer( - model, - device=embeddings_device, - ) - - print(f"Loaded embedding model: {model}") + engine_args = EngineArgs( + model_name_or_path=str(model_path.resolve()), + engine="torch", + device="cpu", + bettertransformer=False, + model_warmup=False, + ) + embeddings_model = AsyncEmbeddingEngine.from_args(engine_args) + logger.info(f"Trying to load embeddings model: {model_path.name} on {device}") except Exception as e: embeddings_model = None - raise Exception( - f"Error: Failed to load embedding model: {model}", internal_message=repr(e) - ) from None + raise e -def get_embeddings_model(): - initialize_embedding_params() - global embeddings_model, st_model - if st_model and not embeddings_model: - load_embedding_model(st_model) # lazy load the model +async def embeddings(data: EmbeddingsRequest) -> dict: + embeddings_config = config.embeddings_config() - return embeddings_model + # Use CPU by default + device = embeddings_config.get("embeddings_device", "cpu") + if device == "auto": + device = None - -def get_embeddings_model_name() -> str: - initialize_embedding_params() - global st_model - return st_model - - -def get_embeddings(input: list) -> np.ndarray: - model = get_embeddings_model() - embedding = model.encode( - input, - convert_to_numpy=True, - normalize_embeddings=True, - convert_to_tensor=False, - show_progress_bar=False, + model_path = pathlib.Path(embeddings_config.get("embeddings_model_dir")) + model_path: pathlib.Path = model_path / embeddings_config.get( + "embeddings_model_name" ) - return embedding - - -async def embeddings(input: list, encoding_format: str, model: str = None) -> dict: - if model is None: - model = st_model - else: - load_embedding_model(model) - - embeddings = get_embeddings(input) - if encoding_format == "base64": - data = [ - {"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} - for n, emb in enumerate(embeddings) - ] - else: - data = [ - {"object": "embedding", "embedding": emb.tolist(), "index": n} - for n, emb in enumerate(embeddings) - ] - - response = { - "object": "list", - "data": data, - "model": st_model if model is None else model, - "usage": { - "prompt_tokens": 0, - "total_tokens": 0, - }, - } - return response + if not model_path: + logger.info("Embeddings model path not found") + + load_embedding_model(model_path, device) + + async with embeddings_model: + embeddings, usage = await embeddings_model.embed(data.input) + + # OAI expects a return of base64 if the input is base64 + if data.encoding_format == "base64": + embedding_data = [ + { + "object": "embedding", + "embedding": float_list_to_base64(emb), + "index": n, + } + for n, emb in enumerate(embeddings) + ] + else: + embedding_data = [ + {"object": "embedding", "embedding": emb.tolist(), "index": n} + for n, emb in enumerate(embeddings) + ] + + response = { + "object": "list", + "data": embedding_data, + "model": model_path.name, + "usage": { + "prompt_tokens": usage, + "total_tokens": usage, + }, + } + return response def float_list_to_base64(float_array: np.ndarray) -> str: - # Convert the list to a float32 array that the OpenAPI client expects - # float_array = np.array(float_list, dtype="float32") - - # Get raw bytes - bytes_array = float_array.tobytes() + """ + Converts the provided list to a float32 array for OpenAI + Ex. float_array = np.array(float_list, dtype="float32") + """ - # Encode bytes into base64 - encoded_bytes = base64.b64encode(bytes_array) + # Encode raw bytes into base64 + encoded_bytes = base64.b64encode(float_array.tobytes()) # Turn raw base64 encoded bytes into ASCII ascii_string = encoded_bytes.decode("ascii") From ac1afcc5886aba728fbe4df3f049dc131cc19ae3 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 29 Jul 2024 14:15:40 -0400 Subject: [PATCH 06/11] Embeddings: Use response classes instead of dicts Follows the existing code style. Signed-off-by: kingbri --- endpoints/OAI/utils/embeddings.py | 62 +++++++++++++++---------------- 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/endpoints/OAI/utils/embeddings.py b/endpoints/OAI/utils/embeddings.py index cf5b799e..1ce611cf 100644 --- a/endpoints/OAI/utils/embeddings.py +++ b/endpoints/OAI/utils/embeddings.py @@ -7,37 +7,41 @@ embeddings function declared async. """ -import asyncio -import os import base64 import pathlib -from loguru import logger import numpy as np -from transformers import AutoModel +from loguru import logger from common import config -from common.utils import unwrap from endpoints.OAI.types.embedding import ( EmbeddingObject, EmbeddingsRequest, EmbeddingsResponse, + UsageInfo, ) +# Conditionally import infinity embeddings engine +# Required so the logger doesn't take over tabby's logging handlers +try: + from infinity_emb import EngineArgs, AsyncEmbeddingEngine + + has_infinity_emb = True +except ImportError: + has_infinity_emb = False + embeddings_model = None def load_embedding_model(model_path: pathlib.Path, device: str): - try: - from infinity_emb import EngineArgs, AsyncEmbeddingEngine - except ModuleNotFoundError: + if not has_infinity_emb: logger.error( "Skipping embeddings because infinity-emb is not installed.\n" "Please run the following command in your environment " "to install extra packages:\n" "pip install -U .[extras]" ) - raise ModuleNotFoundError from None + raise ModuleNotFoundError global embeddings_model try: @@ -76,30 +80,22 @@ async def embeddings(data: EmbeddingsRequest) -> dict: embeddings, usage = await embeddings_model.embed(data.input) # OAI expects a return of base64 if the input is base64 - if data.encoding_format == "base64": - embedding_data = [ - { - "object": "embedding", - "embedding": float_list_to_base64(emb), - "index": n, - } - for n, emb in enumerate(embeddings) - ] - else: - embedding_data = [ - {"object": "embedding", "embedding": emb.tolist(), "index": n} - for n, emb in enumerate(embeddings) - ] - - response = { - "object": "list", - "data": embedding_data, - "model": model_path.name, - "usage": { - "prompt_tokens": usage, - "total_tokens": usage, - }, - } + embedding_data = [ + EmbeddingObject( + embedding=float_list_to_base64(emb) + if data.encoding_format == "base64" + else emb.tolist(), + index=n, + ) + for n, emb in enumerate(embeddings) + ] + + response = EmbeddingsResponse( + data=embedding_data, + model=model_path.name, + usage=UsageInfo(prompt_tokens=usage, total_tokens=usage), + ) + return response From fbf1455db18a3ac2f4f312796150e99536d2361c Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 30 Jul 2024 11:00:23 -0400 Subject: [PATCH 07/11] Embeddings: Migrate and organize Infinity Use Infinity as a separate backend and handle the model within the common module. This separates out the embeddings model from the endpoint which allows for model loading/unloading in core. Signed-off-by: kingbri --- backends/infinity/model.py | 56 +++++++++++++++ common/model.py | 45 ++++++++++++ common/signals.py | 13 ++++ endpoints/OAI/router.py | 11 ++- endpoints/OAI/utils/embeddings.py | 111 +++++++++--------------------- main.py | 12 ++++ 6 files changed, 165 insertions(+), 83 deletions(-) create mode 100644 backends/infinity/model.py diff --git a/backends/infinity/model.py b/backends/infinity/model.py new file mode 100644 index 00000000..2d4ae831 --- /dev/null +++ b/backends/infinity/model.py @@ -0,0 +1,56 @@ +import gc +import pathlib +import torch +from typing import List, Optional + +from common.utils import unwrap + +# Conditionally import infinity to sidestep its logger +# TODO: Make this prettier +try: + from infinity_emb import EngineArgs, AsyncEmbeddingEngine + + has_infinity_emb = True +except ImportError: + has_infinity_emb = False + + +class InfinityContainer: + model_dir: pathlib.Path + + # Conditionally set the type hint based on importablity + # TODO: Clean this up + if has_infinity_emb: + engine: Optional[AsyncEmbeddingEngine] = None + else: + engine = None + + def __init__(self, model_directory: pathlib.Path): + self.model_dir = model_directory + + async def load(self, **kwargs): + # Use cpu by default + device = unwrap(kwargs.get("device"), "cpu") + + engine_args = EngineArgs( + model_name_or_path=str(self.model_dir), + engine="torch", + device=device, + bettertransformer=False, + model_warmup=False, + ) + + self.engine = AsyncEmbeddingEngine.from_args(engine_args) + await self.engine.astart() + + async def unload(self): + await self.engine.astop() + self.engine = None + + gc.collect() + torch.cuda.empty_cache() + + async def generate(self, sentence_input: List[str]): + result_embeddings, usage = await self.engine.embed(sentence_input) + + return {"embeddings": result_embeddings, "usage": usage} diff --git a/common/model.py b/common/model.py index a6477c26..b4b259e8 100644 --- a/common/model.py +++ b/common/model.py @@ -20,6 +20,15 @@ # Global model container container: Optional[ExllamaV2Container] = None + embeddings_container = None + + # Type hint the infinity emb container if it exists + from backends.infinity.model import has_infinity_emb + + if has_infinity_emb: + from backends.infinity.model import InfinityContainer + + embeddings_container: Optional[InfinityContainer] = None def load_progress(module, modules): @@ -100,6 +109,30 @@ async def unload_loras(): await container.unload(loras_only=True) +async def load_embeddings_model(model_path: pathlib.Path, **kwargs): + global embeddings_container + + # Break out if infinity isn't installed + if not has_infinity_emb: + logger.warning( + "Skipping embeddings because infinity-emb is not installed.\n" + "Please run the following command in your environment " + "to install extra packages:\n" + "pip install -U .[extras]" + ) + return + + embeddings_container = InfinityContainer(model_path) + await embeddings_container.load(**kwargs) + + +async def unload_embeddings_model(): + global embeddings_container + + await embeddings_container.unload() + embeddings_container = None + + def get_config_default(key, fallback=None, is_draft=False): """Fetches a default value from model config if allowed by the user.""" @@ -126,3 +159,15 @@ async def check_model_container(): ).error.message raise HTTPException(400, error_message) + + +async def check_embeddings_container(): + """FastAPI depends that checks if an embeddings model is loaded.""" + + if embeddings_container is None: + error_message = handle_request_error( + "No embeddings models are currently loaded.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) diff --git a/common/signals.py b/common/signals.py index 07d75644..d4b67bc0 100644 --- a/common/signals.py +++ b/common/signals.py @@ -1,13 +1,26 @@ +import asyncio import signal import sys from loguru import logger from types import FrameType +from common import model + def signal_handler(*_): """Signal handler for main function. Run before uvicorn starts.""" logger.warning("Shutdown signal called. Exiting gracefully.") + + # Run async unloads for model + loop = asyncio.get_running_loop() + if model.container: + loop.create_task(model.container.unload()) + + if model.embeddings_container: + loop.create_task(model.embeddings_container.unload()) + + # Exit the program sys.exit(0) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 2cad8762..b702e520 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -23,7 +23,7 @@ generate_completion, stream_generate_completion, ) -from endpoints.OAI.utils.embeddings import embeddings +from endpoints.OAI.utils.embeddings import get_embeddings router = APIRouter() @@ -134,7 +134,12 @@ async def chat_completion_request( "/v1/embeddings", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) -async def handle_embeddings(data: EmbeddingsRequest) -> EmbeddingsResponse: - response = await embeddings(data) +async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse: + embeddings_task = asyncio.create_task(get_embeddings(data, request)) + response = await run_with_request_disconnect( + request, + embeddings_task, + f"Embeddings request {request.state.id} cancelled by user.", + ) return response diff --git a/endpoints/OAI/utils/embeddings.py b/endpoints/OAI/utils/embeddings.py index 1ce611cf..5b43953b 100644 --- a/endpoints/OAI/utils/embeddings.py +++ b/endpoints/OAI/utils/embeddings.py @@ -8,11 +8,11 @@ """ import base64 -import pathlib +from fastapi import Request import numpy as np from loguru import logger -from common import config +from common import model from endpoints.OAI.types.embedding import ( EmbeddingObject, EmbeddingsRequest, @@ -20,84 +20,6 @@ UsageInfo, ) -# Conditionally import infinity embeddings engine -# Required so the logger doesn't take over tabby's logging handlers -try: - from infinity_emb import EngineArgs, AsyncEmbeddingEngine - - has_infinity_emb = True -except ImportError: - has_infinity_emb = False - - -embeddings_model = None - - -def load_embedding_model(model_path: pathlib.Path, device: str): - if not has_infinity_emb: - logger.error( - "Skipping embeddings because infinity-emb is not installed.\n" - "Please run the following command in your environment " - "to install extra packages:\n" - "pip install -U .[extras]" - ) - raise ModuleNotFoundError - - global embeddings_model - try: - engine_args = EngineArgs( - model_name_or_path=str(model_path.resolve()), - engine="torch", - device="cpu", - bettertransformer=False, - model_warmup=False, - ) - embeddings_model = AsyncEmbeddingEngine.from_args(engine_args) - logger.info(f"Trying to load embeddings model: {model_path.name} on {device}") - except Exception as e: - embeddings_model = None - raise e - - -async def embeddings(data: EmbeddingsRequest) -> dict: - embeddings_config = config.embeddings_config() - - # Use CPU by default - device = embeddings_config.get("embeddings_device", "cpu") - if device == "auto": - device = None - - model_path = pathlib.Path(embeddings_config.get("embeddings_model_dir")) - model_path: pathlib.Path = model_path / embeddings_config.get( - "embeddings_model_name" - ) - if not model_path: - logger.info("Embeddings model path not found") - - load_embedding_model(model_path, device) - - async with embeddings_model: - embeddings, usage = await embeddings_model.embed(data.input) - - # OAI expects a return of base64 if the input is base64 - embedding_data = [ - EmbeddingObject( - embedding=float_list_to_base64(emb) - if data.encoding_format == "base64" - else emb.tolist(), - index=n, - ) - for n, emb in enumerate(embeddings) - ] - - response = EmbeddingsResponse( - data=embedding_data, - model=model_path.name, - usage=UsageInfo(prompt_tokens=usage, total_tokens=usage), - ) - - return response - def float_list_to_base64(float_array: np.ndarray) -> str: """ @@ -111,3 +33,32 @@ def float_list_to_base64(float_array: np.ndarray) -> str: # Turn raw base64 encoded bytes into ASCII ascii_string = encoded_bytes.decode("ascii") return ascii_string + + +async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict: + model_path = model.embeddings_container.model_dir + + logger.info(f"Recieved embeddings request {request.state.id}") + embedding_data = await model.embeddings_container.generate(data.input) + + # OAI expects a return of base64 if the input is base64 + embedding_object = [ + EmbeddingObject( + embedding=float_list_to_base64(emb) + if data.encoding_format == "base64" + else emb.tolist(), + index=n, + ) + for n, emb in enumerate(embedding_data.get("embeddings")) + ] + + usage = embedding_data.get("usage") + response = EmbeddingsResponse( + data=embedding_object, + model=model_path.name, + usage=UsageInfo(prompt_tokens=usage, total_tokens=usage), + ) + + logger.info(f"Finished embeddings request {request.state.id}") + + return response diff --git a/main.py b/main.py index c62a381d..56873c41 100644 --- a/main.py +++ b/main.py @@ -87,6 +87,18 @@ async def entrypoint_async(): lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) await model.container.load_loras(lora_dir.resolve(), **lora_config) + # If an initial embedding model name is specified, create a separate container + # and load the model + embedding_config = config.embeddings_config() + embedding_model_name = embedding_config.get("embeddings_model_name") + if embedding_model_name: + embedding_model_path = pathlib.Path( + unwrap(embedding_config.get("embeddings_model_dir"), "models") + ) + embedding_model_path = embedding_model_path / embedding_model_name + + await model.load_embeddings_model(embedding_model_path, **embedding_config) + await start_api(host, port) From 01c77028599f4fdb58a01433903ff5f57a0cf64c Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 30 Jul 2024 11:11:05 -0400 Subject: [PATCH 08/11] Signal: Fix async signal handling Run unload async functions before exiting the program. Signed-off-by: kingbri --- backends/infinity/model.py | 3 +++ common/signals.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/backends/infinity/model.py b/backends/infinity/model.py index 2d4ae831..27fc9e5a 100644 --- a/backends/infinity/model.py +++ b/backends/infinity/model.py @@ -1,6 +1,7 @@ import gc import pathlib import torch +from loguru import logger from typing import List, Optional from common.utils import unwrap @@ -50,6 +51,8 @@ async def unload(self): gc.collect() torch.cuda.empty_cache() + logger.info("Embedding model unloaded.") + async def generate(self, sentence_input: List[str]): result_embeddings, usage = await self.engine.embed(sentence_input) diff --git a/common/signals.py b/common/signals.py index d4b67bc0..f0b7f192 100644 --- a/common/signals.py +++ b/common/signals.py @@ -13,17 +13,20 @@ def signal_handler(*_): logger.warning("Shutdown signal called. Exiting gracefully.") # Run async unloads for model - loop = asyncio.get_running_loop() - if model.container: - loop.create_task(model.container.unload()) - - if model.embeddings_container: - loop.create_task(model.embeddings_container.unload()) + asyncio.ensure_future(signal_handler_async()) # Exit the program sys.exit(0) +async def signal_handler_async(*_): + if model.container: + await model.container.unload() + + if model.embeddings_container: + await model.embeddings_container.unload() + + def uvicorn_signal_handler(signal_event: signal.Signals): """Overrides uvicorn's signal handler.""" From f13d0fb8b3ed36af37b89102c50c6fcc6a7179e5 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 30 Jul 2024 11:17:36 -0400 Subject: [PATCH 09/11] Embeddings: Add model load checks Same as the normal model container. Signed-off-by: kingbri --- backends/infinity/model.py | 7 +++++++ common/model.py | 10 ++++++++-- endpoints/OAI/router.py | 4 ++-- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/backends/infinity/model.py b/backends/infinity/model.py index 27fc9e5a..4c9bb697 100644 --- a/backends/infinity/model.py +++ b/backends/infinity/model.py @@ -18,6 +18,8 @@ class InfinityContainer: model_dir: pathlib.Path + model_is_loading: bool = False + model_loaded: bool = False # Conditionally set the type hint based on importablity # TODO: Clean this up @@ -30,6 +32,8 @@ def __init__(self, model_directory: pathlib.Path): self.model_dir = model_directory async def load(self, **kwargs): + self.model_is_loading = True + # Use cpu by default device = unwrap(kwargs.get("device"), "cpu") @@ -44,6 +48,9 @@ async def load(self, **kwargs): self.engine = AsyncEmbeddingEngine.from_args(engine_args) await self.engine.astart() + self.model_loaded = True + logger.info("Embedding model successfully loaded.") + async def unload(self): await self.engine.astop() self.engine = None diff --git a/common/model.py b/common/model.py index b4b259e8..3776ff9e 100644 --- a/common/model.py +++ b/common/model.py @@ -162,9 +162,15 @@ async def check_model_container(): async def check_embeddings_container(): - """FastAPI depends that checks if an embeddings model is loaded.""" + """ + FastAPI depends that checks if an embeddings model is loaded. - if embeddings_container is None: + This is the same as the model container check, but with embeddings instead. + """ + + if embeddings_container is None or not ( + embeddings_container.model_is_loading or embeddings_container.model_loaded + ): error_message = handle_request_error( "No embeddings models are currently loaded.", exc_info=False, diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index b702e520..b428c006 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -5,7 +5,7 @@ from common import config, model from common.auth import check_api_key -from common.model import check_model_container +from common.model import check_embeddings_container, check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.utils import unwrap from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse @@ -132,7 +132,7 @@ async def chat_completion_request( # Embeddings endpoint @router.post( "/v1/embeddings", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], ) async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse: embeddings_task = asyncio.create_task(get_embeddings(data, request)) From bfa011e0cea4a1bc934222ce4502e096df2ecad6 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 30 Jul 2024 15:19:27 -0400 Subject: [PATCH 10/11] Embeddings: Add model management Embedding models are managed on a separate backend, but are run in parallel with the model itself. Therefore, manage this in a separate container with separate routes. Signed-off-by: kingbri --- common/model.py | 23 ++++++--- config_sample.yml | 4 +- endpoints/core/router.py | 90 ++++++++++++++++++++++++++++++++++- endpoints/core/types/model.py | 5 ++ endpoints/core/utils/model.py | 23 ++++++--- main.py | 9 ++-- 6 files changed, 135 insertions(+), 19 deletions(-) diff --git a/common/model.py b/common/model.py index 3776ff9e..80858d40 100644 --- a/common/model.py +++ b/common/model.py @@ -57,8 +57,6 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): f'Model "{loaded_model_name}" is already loaded! Aborting.' ) - # Unload the existing model - if container and container.model: logger.info("Unloading existing model.") await unload_model() @@ -109,24 +107,35 @@ async def unload_loras(): await container.unload(loras_only=True) -async def load_embeddings_model(model_path: pathlib.Path, **kwargs): +async def load_embedding_model(model_path: pathlib.Path, **kwargs): global embeddings_container # Break out if infinity isn't installed if not has_infinity_emb: - logger.warning( + raise ImportError( "Skipping embeddings because infinity-emb is not installed.\n" "Please run the following command in your environment " "to install extra packages:\n" "pip install -U .[extras]" ) - return + + # Check if the model is already loaded + if embeddings_container and embeddings_container.engine: + loaded_model_name = embeddings_container.model_dir.name + + if loaded_model_name == model_path.name and embeddings_container.model_loaded: + raise ValueError( + f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.' + ) + + logger.info("Unloading existing embeddings model.") + await unload_embedding_model() embeddings_container = InfinityContainer(model_path) await embeddings_container.load(**kwargs) -async def unload_embeddings_model(): +async def unload_embedding_model(): global embeddings_container await embeddings_container.unload() @@ -172,7 +181,7 @@ async def check_embeddings_container(): embeddings_container.model_is_loading or embeddings_container.model_loaded ): error_message = handle_request_error( - "No embeddings models are currently loaded.", + "No embedding models are currently loaded.", exc_info=False, ).error.message diff --git a/config_sample.yml b/config_sample.yml index 053feb62..71a58d2a 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -73,9 +73,9 @@ developer: #realtime_process_priority: False embeddings: - embeddings_model_dir: models + embedding_model_dir: models - embeddings_model_name: + embedding_model_name: embeddings_device: cpu diff --git a/endpoints/core/router.py b/endpoints/core/router.py index cd0ed377..5aabd481 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -7,7 +7,7 @@ from common import config, model, sampling from common.auth import check_admin_key, check_api_key, get_key_permission from common.downloader import hf_repo_download -from common.model import check_model_container +from common.model import check_embeddings_container, check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.templating import PromptTemplate, get_all_templates from common.utils import unwrap @@ -15,6 +15,7 @@ from endpoints.core.types.download import DownloadRequest, DownloadResponse from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse from endpoints.core.types.model import ( + EmbeddingModelLoadRequest, ModelCard, ModelList, ModelLoadRequest, @@ -253,6 +254,93 @@ async def unload_loras(): await model.unload_loras() +@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)]) +async def list_embedding_models(request: Request) -> ModelList: + """ + Lists all embedding models in the model directory. + + Requires an admin key to see all embedding models. + """ + + if get_key_permission(request) == "admin": + embedding_model_dir = unwrap( + config.embeddings_config().get("embedding_model_dir"), "models" + ) + embedding_model_path = pathlib.Path(embedding_model_dir) + + models = get_model_list(embedding_model_path.resolve()) + else: + models = await get_current_model_list(model_type="embedding") + + return models + + +@router.get( + "/v1/model/embedding", + dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], +) +async def get_embedding_model() -> ModelList: + """Returns the currently loaded embedding model.""" + + return get_current_model_list(model_type="embedding")[0] + + +@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)]) +async def load_embedding_model( + request: Request, data: EmbeddingModelLoadRequest +) -> ModelLoadResponse: + # Verify request parameters + if not data.name: + error_message = handle_request_error( + "A model name was not provided for load.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + embedding_model_dir = pathlib.Path( + unwrap(config.model_config().get("embedding_model_dir"), "models") + ) + embedding_model_path = embedding_model_dir / data.name + + if not embedding_model_path.exists(): + error_message = handle_request_error( + "Could not find the embedding model path for load. " + + "Check model name or config.yml?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + try: + load_task = asyncio.create_task( + model.load_embedding_model(embedding_model_path, **data.model_dump()) + ) + await run_with_request_disconnect( + request, load_task, "Embedding model load request cancelled by user." + ) + except Exception as exc: + error_message = handle_request_error(str(exc)).error.message + + raise HTTPException(400, error_message) from exc + + response = ModelLoadResponse( + model_type="embedding_model", module=1, modules=1, status="finished" + ) + + return response + + +@router.post( + "/v1/model/embedding/unload", + dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)], +) +async def unload_embedding_model(): + """Unloads the current embedding model.""" + + await model.unload_embedding_model() + + # Encode tokens endpoint @router.post( "/v1/token/encode", diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 30730b8a..c107dde9 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -137,6 +137,11 @@ class ModelLoadRequest(BaseModel): skip_queue: Optional[bool] = False +class EmbeddingModelLoadRequest(BaseModel): + name: str + device: Optional[str] = None + + class ModelLoadResponse(BaseModel): """Represents a model load response.""" diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py index 0cfb26a7..fc613377 100644 --- a/endpoints/core/utils/model.py +++ b/endpoints/core/utils/model.py @@ -32,15 +32,26 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N return model_card_list -async def get_current_model_list(is_draft: bool = False): - """Gets the current model in list format and with path only.""" +async def get_current_model_list(model_type: str = "model"): + """ + Gets the current model in list format and with path only. + + Unified for fetching both models and embedding models. + """ + current_models = [] + model_path = None # Make sure the model container exists - if model.container: - model_path = model.container.get_model_path(is_draft) - if model_path: - current_models.append(ModelCard(id=model_path.name)) + if model_type == "model" or model_type == "draft": + if model.container: + model_path = model.container.get_model_path(model_type == "draft") + elif model_type == "embedding": + if model.embeddings_container: + model_path = model.embeddings_container.model_dir + + if model_path: + current_models.append(ModelCard(id=model_path.name)) return ModelList(data=current_models) diff --git a/main.py b/main.py index 56873c41..bae2f985 100644 --- a/main.py +++ b/main.py @@ -90,14 +90,17 @@ async def entrypoint_async(): # If an initial embedding model name is specified, create a separate container # and load the model embedding_config = config.embeddings_config() - embedding_model_name = embedding_config.get("embeddings_model_name") + embedding_model_name = embedding_config.get("embedding_model_name") if embedding_model_name: embedding_model_path = pathlib.Path( - unwrap(embedding_config.get("embeddings_model_dir"), "models") + unwrap(embedding_config.get("embedding_model_dir"), "models") ) embedding_model_path = embedding_model_path / embedding_model_name - await model.load_embeddings_model(embedding_model_path, **embedding_config) + try: + await model.load_embedding_model(embedding_model_path, **embedding_config) + except ImportError as ex: + logger.error(ex.msg) await start_api(host, port) From dc3dcc9c0ddf721ee67a54b2395df271f0393d2a Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 30 Jul 2024 15:32:26 -0400 Subject: [PATCH 11/11] Embeddings: Update config, args, and parameter names Use embeddings_device as the parameter for device to remove ambiguity. Signed-off-by: kingbri --- backends/infinity/model.py | 2 +- common/args.py | 20 ++++++++++++++++++++ common/config.py | 5 +++++ config_sample.yml | 23 ++++++++++++++++------- endpoints/core/types/model.py | 2 +- 5 files changed, 43 insertions(+), 9 deletions(-) diff --git a/backends/infinity/model.py b/backends/infinity/model.py index 4c9bb697..35a4df45 100644 --- a/backends/infinity/model.py +++ b/backends/infinity/model.py @@ -35,7 +35,7 @@ async def load(self, **kwargs): self.model_is_loading = True # Use cpu by default - device = unwrap(kwargs.get("device"), "cpu") + device = unwrap(kwargs.get("embeddings_device"), "cpu") engine_args = EngineArgs( model_name_or_path=str(self.model_dir), diff --git a/common/args.py b/common/args.py index e57de788..0548eaf6 100644 --- a/common/args.py +++ b/common/args.py @@ -23,6 +23,7 @@ def init_argparser(): ) add_network_args(parser) add_model_args(parser) + add_embeddings_args(parser) add_logging_args(parser) add_developer_args(parser) add_sampling_args(parser) @@ -209,3 +210,22 @@ def add_sampling_args(parser: argparse.ArgumentParser): sampling_group.add_argument( "--override-preset", type=str, help="Select a sampler override preset" ) + + +def add_embeddings_args(parser: argparse.ArgumentParser): + """Adds arguments specific to embeddings""" + + embeddings_group = parser.add_argument_group("embeddings") + embeddings_group.add_argument( + "--embedding-model-dir", + type=str, + help="Overrides the directory to look for models", + ) + embeddings_group.add_argument( + "--embedding-model-name", type=str, help="An initial model to load" + ) + embeddings_group.add_argument( + "--embeddings-device", + type=str, + help="Device to use for embeddings. Options: (cpu, auto, cuda)", + ) diff --git a/common/config.py b/common/config.py index 5546240e..9b2f654d 100644 --- a/common/config.py +++ b/common/config.py @@ -59,6 +59,11 @@ def from_args(args: dict): cur_developer_config = developer_config() GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override} + embeddings_override = args.get("embeddings") + if embeddings_override: + cur_embeddings_config = embeddings_config() + GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override} + def sampling_config(): """Returns the sampling parameter config from the global config""" diff --git a/config_sample.yml b/config_sample.yml index 71a58d2a..09ae000b 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -72,13 +72,6 @@ developer: # Otherwise, the priority will be set to high #realtime_process_priority: False -embeddings: - embedding_model_dir: models - - embedding_model_name: - - embeddings_device: cpu - # Options for model overrides and loading # Please read the comments to understand how arguments are handled between initial and API loads model: @@ -208,3 +201,19 @@ model: #loras: #- name: lora1 # scaling: 1.0 + +# Options for embedding models and loading. +# NOTE: Embeddings requires the "extras" feature to be installed +# Install it via "pip install .[extras]" +embeddings: + # Overrides directory to look for embedding models (default: models) + embedding_model_dir: models + + # An initial embedding model to load on the infinity backend (default: None) + embedding_model_name: + + # Device to load embedding models on (default: cpu) + # Possible values: cpu, auto, cuda + # NOTE: It's recommended to load embedding models on the CPU. + # If you'd like to load on an AMD gpu, set this value to "cuda" as well. + embeddings_device: cpu diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index c107dde9..8b3d83ed 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -139,7 +139,7 @@ class ModelLoadRequest(BaseModel): class EmbeddingModelLoadRequest(BaseModel): name: str - device: Optional[str] = None + embeddings_device: Optional[str] = None class ModelLoadResponse(BaseModel):