Skip to content

Commit

Permalink
Merge pull request #158 from AlpinDale/embeddings
Browse files Browse the repository at this point in the history
feat: add embeddings support via Infinity-emb
  • Loading branch information
bdashore3 authored Aug 1, 2024
2 parents f111052 + dc3dcc9 commit 1bf0625
Show file tree
Hide file tree
Showing 14 changed files with 443 additions and 11 deletions.
66 changes: 66 additions & 0 deletions backends/infinity/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import gc
import pathlib
import torch
from loguru import logger
from typing import List, Optional

from common.utils import unwrap

# Conditionally import infinity to sidestep its logger
# TODO: Make this prettier
try:
from infinity_emb import EngineArgs, AsyncEmbeddingEngine

has_infinity_emb = True
except ImportError:
has_infinity_emb = False


class InfinityContainer:
model_dir: pathlib.Path
model_is_loading: bool = False
model_loaded: bool = False

# Conditionally set the type hint based on importablity
# TODO: Clean this up
if has_infinity_emb:
engine: Optional[AsyncEmbeddingEngine] = None
else:
engine = None

def __init__(self, model_directory: pathlib.Path):
self.model_dir = model_directory

async def load(self, **kwargs):
self.model_is_loading = True

# Use cpu by default
device = unwrap(kwargs.get("embeddings_device"), "cpu")

engine_args = EngineArgs(
model_name_or_path=str(self.model_dir),
engine="torch",
device=device,
bettertransformer=False,
model_warmup=False,
)

self.engine = AsyncEmbeddingEngine.from_args(engine_args)
await self.engine.astart()

self.model_loaded = True
logger.info("Embedding model successfully loaded.")

async def unload(self):
await self.engine.astop()
self.engine = None

gc.collect()
torch.cuda.empty_cache()

logger.info("Embedding model unloaded.")

async def generate(self, sentence_input: List[str]):
result_embeddings, usage = await self.engine.embed(sentence_input)

return {"embeddings": result_embeddings, "usage": usage}
20 changes: 20 additions & 0 deletions common/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def init_argparser():
)
add_network_args(parser)
add_model_args(parser)
add_embeddings_args(parser)
add_logging_args(parser)
add_developer_args(parser)
add_sampling_args(parser)
Expand Down Expand Up @@ -209,3 +210,22 @@ def add_sampling_args(parser: argparse.ArgumentParser):
sampling_group.add_argument(
"--override-preset", type=str, help="Select a sampler override preset"
)


def add_embeddings_args(parser: argparse.ArgumentParser):
"""Adds arguments specific to embeddings"""

embeddings_group = parser.add_argument_group("embeddings")
embeddings_group.add_argument(
"--embedding-model-dir",
type=str,
help="Overrides the directory to look for models",
)
embeddings_group.add_argument(
"--embedding-model-name", type=str, help="An initial model to load"
)
embeddings_group.add_argument(
"--embeddings-device",
type=str,
help="Device to use for embeddings. Options: (cpu, auto, cuda)",
)
10 changes: 10 additions & 0 deletions common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def from_args(args: dict):
cur_developer_config = developer_config()
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}

embeddings_override = args.get("embeddings")
if embeddings_override:
cur_embeddings_config = embeddings_config()
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override}


def sampling_config():
"""Returns the sampling parameter config from the global config"""
Expand Down Expand Up @@ -95,3 +100,8 @@ def logging_config():
def developer_config():
"""Returns the developer specific config from the global config"""
return unwrap(GLOBAL_CONFIG.get("developer"), {})


def embeddings_config():
"""Returns the embeddings config from the global config"""
return unwrap(GLOBAL_CONFIG.get("embeddings"), {})
64 changes: 62 additions & 2 deletions common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@

# Global model container
container: Optional[ExllamaV2Container] = None
embeddings_container = None

# Type hint the infinity emb container if it exists
from backends.infinity.model import has_infinity_emb

if has_infinity_emb:
from backends.infinity.model import InfinityContainer

embeddings_container: Optional[InfinityContainer] = None


def load_progress(module, modules):
Expand Down Expand Up @@ -48,8 +57,6 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
f'Model "{loaded_model_name}" is already loaded! Aborting.'
)

# Unload the existing model
if container and container.model:
logger.info("Unloading existing model.")
await unload_model()

Expand Down Expand Up @@ -100,6 +107,41 @@ async def unload_loras():
await container.unload(loras_only=True)


async def load_embedding_model(model_path: pathlib.Path, **kwargs):
global embeddings_container

# Break out if infinity isn't installed
if not has_infinity_emb:
raise ImportError(
"Skipping embeddings because infinity-emb is not installed.\n"
"Please run the following command in your environment "
"to install extra packages:\n"
"pip install -U .[extras]"
)

# Check if the model is already loaded
if embeddings_container and embeddings_container.engine:
loaded_model_name = embeddings_container.model_dir.name

if loaded_model_name == model_path.name and embeddings_container.model_loaded:
raise ValueError(
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
)

logger.info("Unloading existing embeddings model.")
await unload_embedding_model()

embeddings_container = InfinityContainer(model_path)
await embeddings_container.load(**kwargs)


async def unload_embedding_model():
global embeddings_container

await embeddings_container.unload()
embeddings_container = None


def get_config_default(key, fallback=None, is_draft=False):
"""Fetches a default value from model config if allowed by the user."""

Expand All @@ -126,3 +168,21 @@ async def check_model_container():
).error.message

raise HTTPException(400, error_message)


async def check_embeddings_container():
"""
FastAPI depends that checks if an embeddings model is loaded.
This is the same as the model container check, but with embeddings instead.
"""

if embeddings_container is None or not (
embeddings_container.model_is_loading or embeddings_container.model_loaded
):
error_message = handle_request_error(
"No embedding models are currently loaded.",
exc_info=False,
).error.message

raise HTTPException(400, error_message)
16 changes: 16 additions & 0 deletions common/signals.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,32 @@
import asyncio
import signal
import sys
from loguru import logger
from types import FrameType

from common import model


def signal_handler(*_):
"""Signal handler for main function. Run before uvicorn starts."""

logger.warning("Shutdown signal called. Exiting gracefully.")

# Run async unloads for model
asyncio.ensure_future(signal_handler_async())

# Exit the program
sys.exit(0)


async def signal_handler_async(*_):
if model.container:
await model.container.unload()

if model.embeddings_container:
await model.embeddings_container.unload()


def uvicorn_signal_handler(signal_event: signal.Signals):
"""Overrides uvicorn's signal handler."""

Expand Down
16 changes: 16 additions & 0 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,19 @@ model:
#loras:
#- name: lora1
# scaling: 1.0

# Options for embedding models and loading.
# NOTE: Embeddings requires the "extras" feature to be installed
# Install it via "pip install .[extras]"
embeddings:
# Overrides directory to look for embedding models (default: models)
embedding_model_dir: models

# An initial embedding model to load on the infinity backend (default: None)
embedding_model_name:

# Device to load embedding models on (default: cpu)
# Possible values: cpu, auto, cuda
# NOTE: It's recommended to load embedding models on the CPU.
# If you'd like to load on an AMD gpu, set this value to "cuda" as well.
embeddings_device: cpu
20 changes: 19 additions & 1 deletion endpoints/OAI/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

from common import config, model
from common.auth import check_api_key
from common.model import check_model_container
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.utils import unwrap
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest,
ChatCompletionResponse,
)
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
generate_chat_completion,
Expand All @@ -22,6 +23,7 @@
generate_completion,
stream_generate_completion,
)
from endpoints.OAI.utils.embeddings import get_embeddings


api_name = "OAI"
Expand Down Expand Up @@ -134,3 +136,19 @@ async def chat_completion_request(
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
)
return response


# Embeddings endpoint
@router.post(
"/v1/embeddings",
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
)
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
embeddings_task = asyncio.create_task(get_embeddings(data, request))
response = await run_with_request_disconnect(
request,
embeddings_task,
f"Embeddings request {request.state.id} cancelled by user.",
)

return response
42 changes: 42 additions & 0 deletions endpoints/OAI/types/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import List, Optional

from pydantic import BaseModel, Field


class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0


class EmbeddingsRequest(BaseModel):
input: List[str] = Field(
..., description="List of input texts to generate embeddings for."
)
encoding_format: str = Field(
"float",
description="Encoding format for the embeddings. "
"Can be 'float' or 'base64'.",
)
model: Optional[str] = Field(
None,
description="Name of the embedding model to use. "
"If not provided, the default model will be used.",
)


class EmbeddingObject(BaseModel):
object: str = Field("embedding", description="Type of the object.")
embedding: List[float] = Field(
..., description="Embedding values as a list of floats."
)
index: int = Field(
..., description="Index of the input text corresponding to " "the embedding."
)


class EmbeddingsResponse(BaseModel):
object: str = Field("list", description="Type of the response object.")
data: List[EmbeddingObject] = Field(..., description="List of embedding objects.")
model: str = Field(..., description="Name of the embedding model used.")
usage: UsageInfo = Field(..., description="Information about token usage.")
Loading

0 comments on commit 1bf0625

Please sign in to comment.