diff --git a/backends/infinity/model.py b/backends/infinity/model.py new file mode 100644 index 00000000..35a4df45 --- /dev/null +++ b/backends/infinity/model.py @@ -0,0 +1,66 @@ +import gc +import pathlib +import torch +from loguru import logger +from typing import List, Optional + +from common.utils import unwrap + +# Conditionally import infinity to sidestep its logger +# TODO: Make this prettier +try: + from infinity_emb import EngineArgs, AsyncEmbeddingEngine + + has_infinity_emb = True +except ImportError: + has_infinity_emb = False + + +class InfinityContainer: + model_dir: pathlib.Path + model_is_loading: bool = False + model_loaded: bool = False + + # Conditionally set the type hint based on importablity + # TODO: Clean this up + if has_infinity_emb: + engine: Optional[AsyncEmbeddingEngine] = None + else: + engine = None + + def __init__(self, model_directory: pathlib.Path): + self.model_dir = model_directory + + async def load(self, **kwargs): + self.model_is_loading = True + + # Use cpu by default + device = unwrap(kwargs.get("embeddings_device"), "cpu") + + engine_args = EngineArgs( + model_name_or_path=str(self.model_dir), + engine="torch", + device=device, + bettertransformer=False, + model_warmup=False, + ) + + self.engine = AsyncEmbeddingEngine.from_args(engine_args) + await self.engine.astart() + + self.model_loaded = True + logger.info("Embedding model successfully loaded.") + + async def unload(self): + await self.engine.astop() + self.engine = None + + gc.collect() + torch.cuda.empty_cache() + + logger.info("Embedding model unloaded.") + + async def generate(self, sentence_input: List[str]): + result_embeddings, usage = await self.engine.embed(sentence_input) + + return {"embeddings": result_embeddings, "usage": usage} diff --git a/common/args.py b/common/args.py index e57de788..0548eaf6 100644 --- a/common/args.py +++ b/common/args.py @@ -23,6 +23,7 @@ def init_argparser(): ) add_network_args(parser) add_model_args(parser) + add_embeddings_args(parser) add_logging_args(parser) add_developer_args(parser) add_sampling_args(parser) @@ -209,3 +210,22 @@ def add_sampling_args(parser: argparse.ArgumentParser): sampling_group.add_argument( "--override-preset", type=str, help="Select a sampler override preset" ) + + +def add_embeddings_args(parser: argparse.ArgumentParser): + """Adds arguments specific to embeddings""" + + embeddings_group = parser.add_argument_group("embeddings") + embeddings_group.add_argument( + "--embedding-model-dir", + type=str, + help="Overrides the directory to look for models", + ) + embeddings_group.add_argument( + "--embedding-model-name", type=str, help="An initial model to load" + ) + embeddings_group.add_argument( + "--embeddings-device", + type=str, + help="Device to use for embeddings. Options: (cpu, auto, cuda)", + ) diff --git a/common/config.py b/common/config.py index 972b382a..9b2f654d 100644 --- a/common/config.py +++ b/common/config.py @@ -59,6 +59,11 @@ def from_args(args: dict): cur_developer_config = developer_config() GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override} + embeddings_override = args.get("embeddings") + if embeddings_override: + cur_embeddings_config = embeddings_config() + GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override} + def sampling_config(): """Returns the sampling parameter config from the global config""" @@ -95,3 +100,8 @@ def logging_config(): def developer_config(): """Returns the developer specific config from the global config""" return unwrap(GLOBAL_CONFIG.get("developer"), {}) + + +def embeddings_config(): + """Returns the embeddings config from the global config""" + return unwrap(GLOBAL_CONFIG.get("embeddings"), {}) diff --git a/common/model.py b/common/model.py index a6477c26..80858d40 100644 --- a/common/model.py +++ b/common/model.py @@ -20,6 +20,15 @@ # Global model container container: Optional[ExllamaV2Container] = None + embeddings_container = None + + # Type hint the infinity emb container if it exists + from backends.infinity.model import has_infinity_emb + + if has_infinity_emb: + from backends.infinity.model import InfinityContainer + + embeddings_container: Optional[InfinityContainer] = None def load_progress(module, modules): @@ -48,8 +57,6 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): f'Model "{loaded_model_name}" is already loaded! Aborting.' ) - # Unload the existing model - if container and container.model: logger.info("Unloading existing model.") await unload_model() @@ -100,6 +107,41 @@ async def unload_loras(): await container.unload(loras_only=True) +async def load_embedding_model(model_path: pathlib.Path, **kwargs): + global embeddings_container + + # Break out if infinity isn't installed + if not has_infinity_emb: + raise ImportError( + "Skipping embeddings because infinity-emb is not installed.\n" + "Please run the following command in your environment " + "to install extra packages:\n" + "pip install -U .[extras]" + ) + + # Check if the model is already loaded + if embeddings_container and embeddings_container.engine: + loaded_model_name = embeddings_container.model_dir.name + + if loaded_model_name == model_path.name and embeddings_container.model_loaded: + raise ValueError( + f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.' + ) + + logger.info("Unloading existing embeddings model.") + await unload_embedding_model() + + embeddings_container = InfinityContainer(model_path) + await embeddings_container.load(**kwargs) + + +async def unload_embedding_model(): + global embeddings_container + + await embeddings_container.unload() + embeddings_container = None + + def get_config_default(key, fallback=None, is_draft=False): """Fetches a default value from model config if allowed by the user.""" @@ -126,3 +168,21 @@ async def check_model_container(): ).error.message raise HTTPException(400, error_message) + + +async def check_embeddings_container(): + """ + FastAPI depends that checks if an embeddings model is loaded. + + This is the same as the model container check, but with embeddings instead. + """ + + if embeddings_container is None or not ( + embeddings_container.model_is_loading or embeddings_container.model_loaded + ): + error_message = handle_request_error( + "No embedding models are currently loaded.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) diff --git a/common/signals.py b/common/signals.py index 07d75644..f0b7f192 100644 --- a/common/signals.py +++ b/common/signals.py @@ -1,16 +1,32 @@ +import asyncio import signal import sys from loguru import logger from types import FrameType +from common import model + def signal_handler(*_): """Signal handler for main function. Run before uvicorn starts.""" logger.warning("Shutdown signal called. Exiting gracefully.") + + # Run async unloads for model + asyncio.ensure_future(signal_handler_async()) + + # Exit the program sys.exit(0) +async def signal_handler_async(*_): + if model.container: + await model.container.unload() + + if model.embeddings_container: + await model.embeddings_container.unload() + + def uvicorn_signal_handler(signal_event: signal.Signals): """Overrides uvicorn's signal handler.""" diff --git a/config_sample.yml b/config_sample.yml index e57d9474..f3a1c515 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -201,3 +201,19 @@ model: #loras: #- name: lora1 # scaling: 1.0 + +# Options for embedding models and loading. +# NOTE: Embeddings requires the "extras" feature to be installed +# Install it via "pip install .[extras]" +embeddings: + # Overrides directory to look for embedding models (default: models) + embedding_model_dir: models + + # An initial embedding model to load on the infinity backend (default: None) + embedding_model_name: + + # Device to load embedding models on (default: cpu) + # Possible values: cpu, auto, cuda + # NOTE: It's recommended to load embedding models on the CPU. + # If you'd like to load on an AMD gpu, set this value to "cuda" as well. + embeddings_device: cpu diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index d9701619..c1ee3430 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -5,7 +5,7 @@ from common import config, model from common.auth import check_api_key -from common.model import check_model_container +from common.model import check_embeddings_container, check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.utils import unwrap from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse @@ -13,6 +13,7 @@ ChatCompletionRequest, ChatCompletionResponse, ) +from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse from endpoints.OAI.utils.chat_completion import ( format_prompt_with_template, generate_chat_completion, @@ -22,6 +23,7 @@ generate_completion, stream_generate_completion, ) +from endpoints.OAI.utils.embeddings import get_embeddings api_name = "OAI" @@ -134,3 +136,19 @@ async def chat_completion_request( disconnect_message=f"Chat completion {request.state.id} cancelled by user.", ) return response + + +# Embeddings endpoint +@router.post( + "/v1/embeddings", + dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], +) +async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse: + embeddings_task = asyncio.create_task(get_embeddings(data, request)) + response = await run_with_request_disconnect( + request, + embeddings_task, + f"Embeddings request {request.state.id} cancelled by user.", + ) + + return response diff --git a/endpoints/OAI/types/embedding.py b/endpoints/OAI/types/embedding.py new file mode 100644 index 00000000..7d5779fa --- /dev/null +++ b/endpoints/OAI/types/embedding.py @@ -0,0 +1,42 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class EmbeddingsRequest(BaseModel): + input: List[str] = Field( + ..., description="List of input texts to generate embeddings for." + ) + encoding_format: str = Field( + "float", + description="Encoding format for the embeddings. " + "Can be 'float' or 'base64'.", + ) + model: Optional[str] = Field( + None, + description="Name of the embedding model to use. " + "If not provided, the default model will be used.", + ) + + +class EmbeddingObject(BaseModel): + object: str = Field("embedding", description="Type of the object.") + embedding: List[float] = Field( + ..., description="Embedding values as a list of floats." + ) + index: int = Field( + ..., description="Index of the input text corresponding to " "the embedding." + ) + + +class EmbeddingsResponse(BaseModel): + object: str = Field("list", description="Type of the response object.") + data: List[EmbeddingObject] = Field(..., description="List of embedding objects.") + model: str = Field(..., description="Name of the embedding model used.") + usage: UsageInfo = Field(..., description="Information about token usage.") diff --git a/endpoints/OAI/utils/embeddings.py b/endpoints/OAI/utils/embeddings.py new file mode 100644 index 00000000..5b43953b --- /dev/null +++ b/endpoints/OAI/utils/embeddings.py @@ -0,0 +1,64 @@ +""" +This file is derived from +[text-generation-webui openai extension embeddings](https://github.com/oobabooga/text-generation-webui/blob/1a7c027386f43b84f3ca3b0ff04ca48d861c2d7a/extensions/openai/embeddings.py) +and modified. +The changes introduced are: Suppression of progress bar, +typing/pydantic classes moved into this file, +embeddings function declared async. +""" + +import base64 +from fastapi import Request +import numpy as np +from loguru import logger + +from common import model +from endpoints.OAI.types.embedding import ( + EmbeddingObject, + EmbeddingsRequest, + EmbeddingsResponse, + UsageInfo, +) + + +def float_list_to_base64(float_array: np.ndarray) -> str: + """ + Converts the provided list to a float32 array for OpenAI + Ex. float_array = np.array(float_list, dtype="float32") + """ + + # Encode raw bytes into base64 + encoded_bytes = base64.b64encode(float_array.tobytes()) + + # Turn raw base64 encoded bytes into ASCII + ascii_string = encoded_bytes.decode("ascii") + return ascii_string + + +async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict: + model_path = model.embeddings_container.model_dir + + logger.info(f"Recieved embeddings request {request.state.id}") + embedding_data = await model.embeddings_container.generate(data.input) + + # OAI expects a return of base64 if the input is base64 + embedding_object = [ + EmbeddingObject( + embedding=float_list_to_base64(emb) + if data.encoding_format == "base64" + else emb.tolist(), + index=n, + ) + for n, emb in enumerate(embedding_data.get("embeddings")) + ] + + usage = embedding_data.get("usage") + response = EmbeddingsResponse( + data=embedding_object, + model=model_path.name, + usage=UsageInfo(prompt_tokens=usage, total_tokens=usage), + ) + + logger.info(f"Finished embeddings request {request.state.id}") + + return response diff --git a/endpoints/core/router.py b/endpoints/core/router.py index cd0ed377..5aabd481 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -7,7 +7,7 @@ from common import config, model, sampling from common.auth import check_admin_key, check_api_key, get_key_permission from common.downloader import hf_repo_download -from common.model import check_model_container +from common.model import check_embeddings_container, check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.templating import PromptTemplate, get_all_templates from common.utils import unwrap @@ -15,6 +15,7 @@ from endpoints.core.types.download import DownloadRequest, DownloadResponse from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse from endpoints.core.types.model import ( + EmbeddingModelLoadRequest, ModelCard, ModelList, ModelLoadRequest, @@ -253,6 +254,93 @@ async def unload_loras(): await model.unload_loras() +@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)]) +async def list_embedding_models(request: Request) -> ModelList: + """ + Lists all embedding models in the model directory. + + Requires an admin key to see all embedding models. + """ + + if get_key_permission(request) == "admin": + embedding_model_dir = unwrap( + config.embeddings_config().get("embedding_model_dir"), "models" + ) + embedding_model_path = pathlib.Path(embedding_model_dir) + + models = get_model_list(embedding_model_path.resolve()) + else: + models = await get_current_model_list(model_type="embedding") + + return models + + +@router.get( + "/v1/model/embedding", + dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], +) +async def get_embedding_model() -> ModelList: + """Returns the currently loaded embedding model.""" + + return get_current_model_list(model_type="embedding")[0] + + +@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)]) +async def load_embedding_model( + request: Request, data: EmbeddingModelLoadRequest +) -> ModelLoadResponse: + # Verify request parameters + if not data.name: + error_message = handle_request_error( + "A model name was not provided for load.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + embedding_model_dir = pathlib.Path( + unwrap(config.model_config().get("embedding_model_dir"), "models") + ) + embedding_model_path = embedding_model_dir / data.name + + if not embedding_model_path.exists(): + error_message = handle_request_error( + "Could not find the embedding model path for load. " + + "Check model name or config.yml?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + try: + load_task = asyncio.create_task( + model.load_embedding_model(embedding_model_path, **data.model_dump()) + ) + await run_with_request_disconnect( + request, load_task, "Embedding model load request cancelled by user." + ) + except Exception as exc: + error_message = handle_request_error(str(exc)).error.message + + raise HTTPException(400, error_message) from exc + + response = ModelLoadResponse( + model_type="embedding_model", module=1, modules=1, status="finished" + ) + + return response + + +@router.post( + "/v1/model/embedding/unload", + dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)], +) +async def unload_embedding_model(): + """Unloads the current embedding model.""" + + await model.unload_embedding_model() + + # Encode tokens endpoint @router.post( "/v1/token/encode", diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 30730b8a..8b3d83ed 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -137,6 +137,11 @@ class ModelLoadRequest(BaseModel): skip_queue: Optional[bool] = False +class EmbeddingModelLoadRequest(BaseModel): + name: str + embeddings_device: Optional[str] = None + + class ModelLoadResponse(BaseModel): """Represents a model load response.""" diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py index 0cfb26a7..fc613377 100644 --- a/endpoints/core/utils/model.py +++ b/endpoints/core/utils/model.py @@ -32,15 +32,26 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N return model_card_list -async def get_current_model_list(is_draft: bool = False): - """Gets the current model in list format and with path only.""" +async def get_current_model_list(model_type: str = "model"): + """ + Gets the current model in list format and with path only. + + Unified for fetching both models and embedding models. + """ + current_models = [] + model_path = None # Make sure the model container exists - if model.container: - model_path = model.container.get_model_path(is_draft) - if model_path: - current_models.append(ModelCard(id=model_path.name)) + if model_type == "model" or model_type == "draft": + if model.container: + model_path = model.container.get_model_path(model_type == "draft") + elif model_type == "embedding": + if model.embeddings_container: + model_path = model.embeddings_container.model_dir + + if model_path: + current_models.append(ModelCard(id=model_path.name)) return ModelList(data=current_models) diff --git a/main.py b/main.py index c62a381d..bae2f985 100644 --- a/main.py +++ b/main.py @@ -87,6 +87,21 @@ async def entrypoint_async(): lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) await model.container.load_loras(lora_dir.resolve(), **lora_config) + # If an initial embedding model name is specified, create a separate container + # and load the model + embedding_config = config.embeddings_config() + embedding_model_name = embedding_config.get("embedding_model_name") + if embedding_model_name: + embedding_model_path = pathlib.Path( + unwrap(embedding_config.get("embedding_model_dir"), "models") + ) + embedding_model_path = embedding_model_path / embedding_model_name + + try: + await model.load_embedding_model(embedding_model_path, **embedding_config) + except ImportError as ex: + logger.error(ex.msg) + await start_api(host, port) diff --git a/pyproject.toml b/pyproject.toml index 19be7f18..547c363f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,8 @@ dependencies = [ [project.optional-dependencies] extras = [ # Heavy dependencies that aren't for everyday use - "outlines" + "outlines", + "sentence-transformers" ] dev = [ "ruff == 0.3.2"