-
-
Notifications
You must be signed in to change notification settings - Fork 85
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #158 from AlpinDale/embeddings
feat: add embeddings support via Infinity-emb
- Loading branch information
Showing
14 changed files
with
443 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import gc | ||
import pathlib | ||
import torch | ||
from loguru import logger | ||
from typing import List, Optional | ||
|
||
from common.utils import unwrap | ||
|
||
# Conditionally import infinity to sidestep its logger | ||
# TODO: Make this prettier | ||
try: | ||
from infinity_emb import EngineArgs, AsyncEmbeddingEngine | ||
|
||
has_infinity_emb = True | ||
except ImportError: | ||
has_infinity_emb = False | ||
|
||
|
||
class InfinityContainer: | ||
model_dir: pathlib.Path | ||
model_is_loading: bool = False | ||
model_loaded: bool = False | ||
|
||
# Conditionally set the type hint based on importablity | ||
# TODO: Clean this up | ||
if has_infinity_emb: | ||
engine: Optional[AsyncEmbeddingEngine] = None | ||
else: | ||
engine = None | ||
|
||
def __init__(self, model_directory: pathlib.Path): | ||
self.model_dir = model_directory | ||
|
||
async def load(self, **kwargs): | ||
self.model_is_loading = True | ||
|
||
# Use cpu by default | ||
device = unwrap(kwargs.get("embeddings_device"), "cpu") | ||
|
||
engine_args = EngineArgs( | ||
model_name_or_path=str(self.model_dir), | ||
engine="torch", | ||
device=device, | ||
bettertransformer=False, | ||
model_warmup=False, | ||
) | ||
|
||
self.engine = AsyncEmbeddingEngine.from_args(engine_args) | ||
await self.engine.astart() | ||
|
||
self.model_loaded = True | ||
logger.info("Embedding model successfully loaded.") | ||
|
||
async def unload(self): | ||
await self.engine.astop() | ||
self.engine = None | ||
|
||
gc.collect() | ||
torch.cuda.empty_cache() | ||
|
||
logger.info("Embedding model unloaded.") | ||
|
||
async def generate(self, sentence_input: List[str]): | ||
result_embeddings, usage = await self.engine.embed(sentence_input) | ||
|
||
return {"embeddings": result_embeddings, "usage": usage} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import List, Optional | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class UsageInfo(BaseModel): | ||
prompt_tokens: int = 0 | ||
total_tokens: int = 0 | ||
completion_tokens: Optional[int] = 0 | ||
|
||
|
||
class EmbeddingsRequest(BaseModel): | ||
input: List[str] = Field( | ||
..., description="List of input texts to generate embeddings for." | ||
) | ||
encoding_format: str = Field( | ||
"float", | ||
description="Encoding format for the embeddings. " | ||
"Can be 'float' or 'base64'.", | ||
) | ||
model: Optional[str] = Field( | ||
None, | ||
description="Name of the embedding model to use. " | ||
"If not provided, the default model will be used.", | ||
) | ||
|
||
|
||
class EmbeddingObject(BaseModel): | ||
object: str = Field("embedding", description="Type of the object.") | ||
embedding: List[float] = Field( | ||
..., description="Embedding values as a list of floats." | ||
) | ||
index: int = Field( | ||
..., description="Index of the input text corresponding to " "the embedding." | ||
) | ||
|
||
|
||
class EmbeddingsResponse(BaseModel): | ||
object: str = Field("list", description="Type of the response object.") | ||
data: List[EmbeddingObject] = Field(..., description="List of embedding objects.") | ||
model: str = Field(..., description="Name of the embedding model used.") | ||
usage: UsageInfo = Field(..., description="Information about token usage.") |
Oops, something went wrong.