Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embeddings endpoint and pipeline #385

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions runner/app/pipelines/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import logging
import os
from typing import List, Union, Dict, Any, Optional
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir
from sentence_transformers import SentenceTransformer
from InstructorEmbedding import INSTRUCTOR
from dataclasses import dataclass
from enum import Enum
from huggingface_hub import file_download

logger = logging.getLogger(__name__)


class EmbeddingModelType(Enum):
SENTENCE_TRANSFORMER = "sentence-transformer"
TRANSFORMER = "transformer"
INSTRUCTOR = "instructor"


@dataclass
class EmbeddingConfig:
normalize: bool = True
max_length: int = 512
batch_size: int = 32

def validate(self):
"""Validate embedding parameters"""
if self.max_length < 1:
raise ValueError("max_length must be positive")
if self.batch_size < 1:
raise ValueError("batch_size must be positive")


class EmbeddingPipeline(Pipeline):
def __init__(self, model_id: str):
"""Initialize the Embedding Pipeline."""
logger.info("Initializing embedding pipeline")

self.model_id = model_id
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model")
base_path = os.path.join(get_model_dir(), folder_name)

# Find the actual model path
self.local_model_path = self._find_model_path(base_path)

if not self.local_model_path:
raise ValueError(f"Could not find model files for {model_id}")

self.device = "cuda" if torch.cuda.is_available() else "cpu"

# Get configuration from environment
model_type = os.getenv("EMBEDDING_MODEL_TYPE", "sentence-transformer")
self.model_type = EmbeddingModelType(model_type)
self.max_length = int(os.getenv("EMBEDDING_MAX_LENGTH", "512"))

logger.info(f"Loading embedding model: {model_id} of type {model_type}")

try:
if self.model_type == EmbeddingModelType.SENTENCE_TRANSFORMER:
self.model = SentenceTransformer(self.local_model_path).to(self.device)
elif self.model_type == EmbeddingModelType.INSTRUCTOR:
self.model = INSTRUCTOR(self.local_model_path).to(self.device)

logger.info(f"Model loaded successfully on {self.device}")

except Exception as e:
logger.error(f"Error loading model: {e}")
raise
async def generate(
self,
texts: Union[str, List[str]],
embedding_config: Optional[EmbeddingConfig] = None,
instruction: Optional[str] = None,
) -> Dict[str, Any]:
"""Generate embeddings for input texts."""
config = embedding_config or EmbeddingConfig()

if isinstance(texts, str):
texts = [texts]

try:
if self.model_type == EmbeddingModelType.INSTRUCTOR and instruction:
texts = [f"{instruction} {text}" for text in texts]

embeddings = self.model.encode(
texts,
batch_size=config.batch_size,
normalize_embeddings=config.normalize,
convert_to_tensor=True,
show_progress_bar=False
)

embeddings_list = embeddings.cpu().numpy().tolist()

return {
"object": "list",
"data": [
{
"object": "embedding",
"embedding": emb,
"index": i,
}
for i, emb in enumerate(embeddings_list)
],
"model": self.model_id,
"usage": {
"prompt_tokens": sum(len(text.split()) for text in texts),
"total_tokens": sum(len(text.split()) for text in texts),
}
}

except Exception as e:
logger.error(f"Error generating embeddings: {e}")
raise

async def __call__(
self,
texts: Union[str, List[str]],
**kwargs
) -> Dict[str, Any]:
"""Generate embeddings with configuration."""
try:
config = EmbeddingConfig(**kwargs)
config.validate()
return await self.generate(texts, config)
except Exception as e:
logger.error(f"Error in pipeline: {e}")
raise

def __str__(self):
return f"EmbeddingPipeline(model_id={self.model_id})"

def _find_model_path(self, base_path):
# Check if the model files are directly in the base path
if any(file.endswith('.bin') or file.endswith('.safetensors') for file in os.listdir(base_path)):
return base_path

# If not, look in subdirectories
for root, dirs, files in os.walk(base_path):
if any(file.endswith('.bin') or file.endswith('.safetensors') for file in files):
return root
68 changes: 68 additions & 0 deletions runner/app/routes/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
import os
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.utils import HTTPError, http_error
from app.routes.utils import EmbeddingRequest, EmbeddingResponse

router = APIRouter()
logger = logging.getLogger(__name__)

RESPONSES = {
status.HTTP_200_OK: {"model": EmbeddingResponse},
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


@router.post(
"/embeddings",
response_model=EmbeddingResponse,
responses=RESPONSES,
operation_id="createEmbeddings",
description="Generate embeddings for provided text",
summary="Create Embeddings",
tags=["embeddings"],
)
async def create_embeddings(
request: EmbeddingRequest,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
# Auth check
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)

# Model check
if request.model != "" and request.model != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"Pipeline configured with {pipeline.model_id} but called with {request.model}"
),
)

try:
response = await pipeline(
texts=request.input,
normalize=request.normalize,
instruction=request.instruction
)
return response

except Exception as e:
logger.error(f"Embedding generation error: {str(e)}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "Internal server error during embedding generation."}
)
15 changes: 15 additions & 0 deletions runner/app/routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,21 @@ class LLMResponse(BaseModel):
created: int


class EmbeddingRequest(BaseModel):
input: Union[str, List[str]] = Field(..., description="Text to embed")
model: str = Field("", description="Model to use")
instruction: Optional[str] = Field(
None, description="Instruction for instructor models")
normalize: bool = Field(True, description="Whether to normalize embeddings")


class EmbeddingResponse(BaseModel):
object: str
data: List[Dict[str, Union[List[float], int]]]
model: str
usage: Dict[str, int]


class ImageToTextResponse(BaseModel):
"""Response model for text generation."""

Expand Down
2 changes: 2 additions & 0 deletions runner/requirements.llm.in
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ sentencepiece
protobuf
bitsandbytes
psutil
sentence-transformers
InstructorEmbedding
Loading