Skip to content

Commit

Permalink
Load the models in the startup; (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
ranjan-stha authored Nov 21, 2024
1 parent 573288e commit 4b8405e
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 32 deletions.
6 changes: 5 additions & 1 deletion .env.sample
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
OPENAI_API_KEY=
OPENAI_API_KEY=
EMBEDDING_MODEL_NAME=
EMBEDDING_MODEL_TYPE=
EMBEDDING_MODEL_VECTOR_SIZE=
OLLAMA_BASE_URL=
51 changes: 24 additions & 27 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
from enum import Enum
from typing import List, Optional, Union
from typing import List, Union

from dotenv import load_dotenv
from fastapi import FastAPI, Response, status
Expand Down Expand Up @@ -31,13 +32,25 @@ class EmbeddingModelType(Enum):
OPENAI = 3


MODEL_NAME = os.getenv("EMBEDDING_MODEL_NAME", "sentence-transformers/gtr-t5-large")
MODEL_TYPE = EmbeddingModelType(int(os.getenv("EMBEDDING_MODEL_TYPE", "1")))
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", None)

embedding_model = None

if MODEL_TYPE == EmbeddingModelType.SENTENCE_TRANSFORMERS:
embedding_model = SentenceTransformerEmbeddingModel(model=MODEL_NAME)
elif MODEL_TYPE == EmbeddingModelType.OLLAMA:
embedding_model = OllamaEmbeddingModel(model=MODEL_NAME, base_url=OLLAMA_BASE_URL)
elif MODEL_TYPE == EmbeddingModelType.OPENAI:
embedding_model = OpenAIEmbeddingModel(model=MODEL_NAME)


class RequestSchemaForEmbeddings(BaseModel):
"""Request Schema"""

type_model: EmbeddingModelType
name_model: str
texts: Union[str, List[str]]
base_url: Optional[str] = None


class RequestSchemaForTextSplitter(BaseModel):
Expand Down Expand Up @@ -68,29 +81,13 @@ async def generate_embeddings(item: RequestSchemaForEmbeddings):
Generates the embedding vectors for the text/documents
based on different models
"""
type_model = item.type_model
name_model = item.name_model
base_url = item.base_url
texts = item.texts

def generate(em_model, texts):
if isinstance(texts, str):
return em_model.embed_query(text=texts)
elif isinstance(texts, list):
return em_model.embed_documents(texts=texts)
return None

if type_model == EmbeddingModelType.SENTENCE_TRANSFORMERS:
embedding_model = SentenceTransformerEmbeddingModel(model=name_model)
return generate(em_model=embedding_model, texts=texts)

elif type_model == EmbeddingModelType.OLLAMA:
embedding_model = OllamaEmbeddingModel(model=name_model, base_url=base_url)
return generate(em_model=embedding_model, texts=texts)

elif type_model == EmbeddingModelType.OPENAI:
embedding_model = OpenAIEmbeddingModel(model=name_model)
return generate(em_model=embedding_model, texts=texts)

if embedding_model:
if isinstance(item.texts, str):
return embedding_model.embed_query(text=item.texts)
elif isinstance(item.texts, list):
return embedding_model.embed_documents(texts=item.texts)
return []


@app.post("/split_docs_based_on_tokens")
Expand Down
7 changes: 4 additions & 3 deletions reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from sentence_transformers import CrossEncoder
from torch import Tensor

cross_encoder_model = CrossEncoder(model_name="cross-encoder/ms-marco-MiniLM-L-2-v2", max_length=512)

def get_scores(query: str, documents: List[str], model_name: str = "cross-encoder/ms-marco-MiniLM-L-2-v2"):

def get_scores(query: str, documents: List[str]):
"""Get the scores"""
model = CrossEncoder(model_name=model_name, max_length=512)
doc_tuple = [(query, doc) for doc in documents]
scores = model.predict(doc_tuple)
scores = cross_encoder_model.predict(doc_tuple)
return F.softmax(Tensor(scores), dim=0).tolist()
4 changes: 3 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from huggingface_hub import snapshot_download

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)


def download_model(embedding_model: str, models_path: str):
Expand All @@ -21,6 +21,8 @@ def check_models(sent_embedding_model: str):
models_path = Path("/opt/models")
models_info_path = models_path / "model_info.json"

logging.info("Checking models status.")

if not os.path.exists(models_path):
os.makedirs(models_path)

Expand Down

0 comments on commit 4b8405e

Please sign in to comment.