From ed1413ced3fd6b9e15c2f21ae600d52204387af9 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Thu, 26 Dec 2024 11:31:23 +0100 Subject: [PATCH 1/3] wip --- runner/app/pipelines/embeddings.py | 125 +++++++++++++++++++++++++++++ runner/app/routes/embeddings.py | 86 ++++++++++++++++++++ runner/requirements.llm.in | 2 + 3 files changed, 213 insertions(+) create mode 100644 runner/app/pipelines/embeddings.py create mode 100644 runner/app/routes/embeddings.py diff --git a/runner/app/pipelines/embeddings.py b/runner/app/pipelines/embeddings.py new file mode 100644 index 00000000..ddec2b7e --- /dev/null +++ b/runner/app/pipelines/embeddings.py @@ -0,0 +1,125 @@ +# app/pipelines/embeddings.py +import logging +import os +from typing import List, Union, Dict, Any, Optional +import torch +from app.pipelines.base import Pipeline +from app.pipelines.utils import get_model_dir +from sentence_transformers import SentenceTransformer +from InstructorEmbedding import INSTRUCTOR +from dataclasses import dataclass +from enum import Enum + +logger = logging.getLogger(__name__) + + +class EmbeddingModelType(Enum): + SENTENCE_TRANSFORMER = "sentence-transformer" + TRANSFORMER = "transformer" + INSTRUCTOR = "instructor" + + +@dataclass +class EmbeddingConfig: + normalize: bool = True + max_length: int = 512 + batch_size: int = 32 + + def validate(self): + """Validate embedding parameters""" + if self.max_length < 1: + raise ValueError("max_length must be positive") + if self.batch_size < 1: + raise ValueError("batch_size must be positive") + + +class EmbeddingPipeline(Pipeline): + def __init__(self, model_id: str): + """Initialize the Embedding Pipeline.""" + logger.info("Initializing embedding pipeline") + + self.model_id = model_id + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + # Get configuration from environment + model_type = os.getenv("EMBEDDING_MODEL_TYPE", "sentence-transformer") + self.model_type = EmbeddingModelType(model_type) + self.max_length = int(os.getenv("EMBEDDING_MAX_LENGTH", "512")) + + logger.info(f"Loading embedding model: {model_id} of type {model_type}") + + try: + if self.model_type == EmbeddingModelType.SENTENCE_TRANSFORMER: + self.model = SentenceTransformer(model_id).to(self.device) + elif self.model_type == EmbeddingModelType.INSTRUCTOR: + self.model = INSTRUCTOR(model_id).to(self.device) + + logger.info(f"Model loaded successfully on {self.device}") + + except Exception as e: + logger.error(f"Error loading model: {e}") + raise + + async def generate( + self, + texts: Union[str, List[str]], + embedding_config: Optional[EmbeddingConfig] = None, + instruction: Optional[str] = None, + ) -> Dict[str, Any]: + """Generate embeddings for input texts.""" + config = embedding_config or EmbeddingConfig() + + if isinstance(texts, str): + texts = [texts] + + try: + if self.model_type == EmbeddingModelType.INSTRUCTOR and instruction: + texts = [f"{instruction} {text}" for text in texts] + + embeddings = self.model.encode( + texts, + batch_size=config.batch_size, + normalize_embeddings=config.normalize, + convert_to_tensor=True, + show_progress_bar=False + ) + + embeddings_list = embeddings.cpu().numpy().tolist() + + return { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": emb, + "index": i, + } + for i, emb in enumerate(embeddings_list) + ], + "model": self.model_id, + "usage": { + "prompt_tokens": sum(len(text.split()) for text in texts), + "total_tokens": sum(len(text.split()) for text in texts), + } + } + + except Exception as e: + logger.error(f"Error generating embeddings: {e}") + raise + + async def __call__( + self, + texts: Union[str, List[str]], + **kwargs + ) -> Dict[str, Any]: + """Generate embeddings with configuration.""" + try: + config = EmbeddingConfig(**kwargs) + config.validate() + return await self.generate(texts, config) + except Exception as e: + logger.error(f"Error in pipeline: {e}") + raise + + def __str__(self): + return f"EmbeddingPipeline(model_id={self.model_id})" diff --git a/runner/app/routes/embeddings.py b/runner/app/routes/embeddings.py new file mode 100644 index 00000000..91315283 --- /dev/null +++ b/runner/app/routes/embeddings.py @@ -0,0 +1,86 @@ +# app/routes/embeddings.py +import logging +import os +from typing import Union, List +from fastapi import APIRouter, Depends, status +from fastapi.responses import JSONResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from pydantic import BaseModel, Field +from app.dependencies import get_pipeline +from app.pipelines.base import Pipeline +from app.routes.utils import HTTPError, http_error + +router = APIRouter() +logger = logging.getLogger(__name__) + + +class EmbeddingRequest(BaseModel): + input: Union[str, List[str]] = Field(..., description="Text to embed") + model: str = Field("", description="Model to use") + instruction: Optional[str] = Field( + None, description="Instruction for instructor models") + normalize: bool = Field(True, description="Whether to normalize embeddings") + + +class EmbeddingResponse(BaseModel): + object: str + data: List[Dict[str, Union[List[float], int]]] + model: str + usage: Dict[str, int] + + +RESPONSES = { + status.HTTP_200_OK: {"model": EmbeddingResponse}, + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + + +@router.post( + "/embeddings", + response_model=EmbeddingResponse, + responses=RESPONSES, + operation_id="createEmbeddings", + description="Generate embeddings for provided text", + summary="Create Embeddings", + tags=["embeddings"], +) +async def create_embeddings( + request: EmbeddingRequest, + pipeline: Pipeline = Depends(get_pipeline), + token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), +): + # Auth check + auth_token = os.environ.get("AUTH_TOKEN") + if auth_token: + if not token or token.credentials != auth_token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + content=http_error("Invalid bearer token"), + ) + + # Model check + if request.model != "" and request.model != pipeline.model_id: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error( + f"Pipeline configured with {pipeline.model_id} but called with {request.model}" + ), + ) + + try: + response = await pipeline( + texts=request.input, + normalize=request.normalize, + instruction=request.instruction + ) + return response + + except Exception as e: + logger.error(f"Embedding generation error: {str(e)}") + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": "Internal server error during embedding generation."} + ) diff --git a/runner/requirements.llm.in b/runner/requirements.llm.in index bd148088..73c7b2e0 100644 --- a/runner/requirements.llm.in +++ b/runner/requirements.llm.in @@ -20,3 +20,5 @@ sentencepiece protobuf bitsandbytes psutil +sentence-transformers +InstructorEmbedding \ No newline at end of file From 79876017d8541ce67f912fa2abdcf13464a40f87 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Thu, 26 Dec 2024 11:45:08 +0100 Subject: [PATCH 2/3] wip --- runner/app/routes/embeddings.py | 18 +----------------- runner/app/routes/utils.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/runner/app/routes/embeddings.py b/runner/app/routes/embeddings.py index 91315283..eb7c3411 100644 --- a/runner/app/routes/embeddings.py +++ b/runner/app/routes/embeddings.py @@ -5,30 +5,14 @@ from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from pydantic import BaseModel, Field from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, http_error +from app.routes.utils import EmbeddingRequest, EmbeddingResponse router = APIRouter() logger = logging.getLogger(__name__) - -class EmbeddingRequest(BaseModel): - input: Union[str, List[str]] = Field(..., description="Text to embed") - model: str = Field("", description="Model to use") - instruction: Optional[str] = Field( - None, description="Instruction for instructor models") - normalize: bool = Field(True, description="Whether to normalize embeddings") - - -class EmbeddingResponse(BaseModel): - object: str - data: List[Dict[str, Union[List[float], int]]] - model: str - usage: Dict[str, int] - - RESPONSES = { status.HTTP_200_OK: {"model": EmbeddingResponse}, status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, diff --git a/runner/app/routes/utils.py b/runner/app/routes/utils.py index 6f8271c0..258d5e9c 100644 --- a/runner/app/routes/utils.py +++ b/runner/app/routes/utils.py @@ -95,6 +95,21 @@ class LLMResponse(BaseModel): created: int +class EmbeddingRequest(BaseModel): + input: Union[str, List[str]] = Field(..., description="Text to embed") + model: str = Field("", description="Model to use") + instruction: Optional[str] = Field( + None, description="Instruction for instructor models") + normalize: bool = Field(True, description="Whether to normalize embeddings") + + +class EmbeddingResponse(BaseModel): + object: str + data: List[Dict[str, Union[List[float], int]]] + model: str + usage: Dict[str, int] + + class ImageToTextResponse(BaseModel): """Response model for text generation.""" From be5f89032bfd78503640421a4700b49ba3f5050c Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Thu, 26 Dec 2024 11:53:00 +0100 Subject: [PATCH 3/3] wip --- runner/app/pipelines/embeddings.py | 27 +++++++++++++++++++++++---- runner/app/routes/embeddings.py | 2 -- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/runner/app/pipelines/embeddings.py b/runner/app/pipelines/embeddings.py index ddec2b7e..35bc6961 100644 --- a/runner/app/pipelines/embeddings.py +++ b/runner/app/pipelines/embeddings.py @@ -1,4 +1,3 @@ -# app/pipelines/embeddings.py import logging import os from typing import List, Union, Dict, Any, Optional @@ -9,6 +8,7 @@ from InstructorEmbedding import INSTRUCTOR from dataclasses import dataclass from enum import Enum +from huggingface_hub import file_download logger = logging.getLogger(__name__) @@ -39,6 +39,16 @@ def __init__(self, model_id: str): logger.info("Initializing embedding pipeline") self.model_id = model_id + folder_name = file_download.repo_folder_name( + repo_id=model_id, repo_type="model") + base_path = os.path.join(get_model_dir(), folder_name) + + # Find the actual model path + self.local_model_path = self._find_model_path(base_path) + + if not self.local_model_path: + raise ValueError(f"Could not find model files for {model_id}") + self.device = "cuda" if torch.cuda.is_available() else "cpu" # Get configuration from environment @@ -50,16 +60,15 @@ def __init__(self, model_id: str): try: if self.model_type == EmbeddingModelType.SENTENCE_TRANSFORMER: - self.model = SentenceTransformer(model_id).to(self.device) + self.model = SentenceTransformer(self.local_model_path).to(self.device) elif self.model_type == EmbeddingModelType.INSTRUCTOR: - self.model = INSTRUCTOR(model_id).to(self.device) + self.model = INSTRUCTOR(self.local_model_path).to(self.device) logger.info(f"Model loaded successfully on {self.device}") except Exception as e: logger.error(f"Error loading model: {e}") raise - async def generate( self, texts: Union[str, List[str]], @@ -123,3 +132,13 @@ async def __call__( def __str__(self): return f"EmbeddingPipeline(model_id={self.model_id})" + + def _find_model_path(self, base_path): + # Check if the model files are directly in the base path + if any(file.endswith('.bin') or file.endswith('.safetensors') for file in os.listdir(base_path)): + return base_path + + # If not, look in subdirectories + for root, dirs, files in os.walk(base_path): + if any(file.endswith('.bin') or file.endswith('.safetensors') for file in files): + return root diff --git a/runner/app/routes/embeddings.py b/runner/app/routes/embeddings.py index eb7c3411..03d5a40d 100644 --- a/runner/app/routes/embeddings.py +++ b/runner/app/routes/embeddings.py @@ -1,7 +1,5 @@ -# app/routes/embeddings.py import logging import os -from typing import Union, List from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer