Skip to content

Commit

Permalink
embed using huggingface inference endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
meal committed Sep 22, 2023
1 parent eb099c9 commit d6d8e2a
Show file tree
Hide file tree
Showing 7 changed files with 1,714 additions and 1,966 deletions.
3 changes: 2 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ local-server/
*.md
*.pyc
.dockerignore
Dockerfile
Dockerfile
.env
4 changes: 2 additions & 2 deletions app/.env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ OPENAI_API_KEY=<openai_api_key>
OPENAI_GPT_MODEL=<gpt-3.5-turbo, gpt-4 etc>
SENTRY_DSN=http://SENTRY_DSN
# CHROMA_HOST=<chroma_server_IP>
EMBEDDINGS="huggingface"
EMBEDDINGS="huggingface-inference"
# EMBEDDINGS="openai"
# DATABASE="chroma"
DATABASE="pinecone"
CHROMA_COLLECTION=<collection_name>
CUDA_ENABLED=True or False
PINECONE_API_KEY=<pinecone_api_key>
PINECONE_ENVIRONMENT="northamerica-northeast1-gcp"
PINECONE_INDEX_NAME="<index_name>
PINECONE_INDEX_NAME="<index_name>
3 changes: 1 addition & 2 deletions app/resources/arxiv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from langchain.chat_models import ChatOpenAI
from templates.condense_prompt import CONDENSE_PROMPT
from templates.qa_prompt import QA_PROMPT
from tools.factory import get_database, get_embeddings
from tools.factory import get_database

embeddings = get_embeddings()
db = get_database()


Expand Down
27 changes: 12 additions & 15 deletions app/tools/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def get_embeddings():
model_kwargs = {"device": "cpu"}
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1",
cache_folder="./models",
model_kwargs=model_kwargs,
)

Expand All @@ -40,25 +41,19 @@ def get_embeddings():
task="feature-extraction",
huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"),
)
return embeddings
case "huggingface-instruct":
from langchain.embeddings import HuggingFaceInstructEmbeddings

if os.environ.get("CUDA_ENABLED") == "True":
model_kwargs = {"device": "cuda"}
elif os.environ.get("MPS_ENABLED") == "True":
model_kwargs = {"device": "mps"}
else:
model_kwargs = {"device": "cpu"}
return embeddings
case "huggingface-inference":
from .inference_embeddings import HuggingFaceInferenceEmbeddings

# model_name = "hkunlp/instructor-large"
model_name = "hkunlp/instructor-xl"
embeddings = HuggingFaceInstructEmbeddings(
model_name=model_name, model_kwargs=model_kwargs
embeddings = HuggingFaceInferenceEmbeddings(
model_url=os.environ.get("HUGGINGFACE_MODEL_URL"),
huggingface_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"),
)

return embeddings
case _:
raise ValueError(f"Unsupported embeddings : {embeddings}")
raise ValueError(f"Unsupported embeddings : {embeddings_engine}")


def get_database():
Expand Down Expand Up @@ -127,7 +122,9 @@ def get_database():
)
index_name = os.getenv("PINECONE_INDEX_NAME", "arxiv")
db = Pinecone.from_existing_index(
embedding=embeddings, index_name=index_name
embedding=embeddings,
index_name=index_name,
namespace=os.getenv("PINECONE_NAMESPACE", "website"),
)

return db
Expand Down
70 changes: 70 additions & 0 deletions app/tools/inference_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Wrapper around HuggingFace Hub embedding models."""
from typing import Any, Dict, List, Optional

from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from pydantic import BaseModel, Extra, root_validator


class HuggingFaceInferenceEmbeddings(BaseModel, Embeddings):
client: Any #: :meta private:
"""Model url to use for embedding."""
model_url: str = "model-url"
"""Task to call the model with."""
model_kwargs: Optional[dict] = None
"""Key word arguments to pass to the model."""

huggingface_api_token: Optional[str] = None

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingface_api_token = get_from_dict_or_env(
values, "huggingface_api_token", "HUGGINGFACE_API_TOKEN"
)
try:
from huggingface_hub import InferenceClient

client = InferenceClient(
token=huggingface_api_token, model=values["model_url"]
)
values["client"] = client
except ImportError:
raise ValueError(
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)
return values

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to HuggingFaceHub's Inference Endpoint for embedding search docs.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
# replace newlines, which can negatively affect performance.
texts = [text.replace("\n", " ") for text in texts]
_model_kwargs = self.model_kwargs or {}
_res = self.client.post(json={"inputs": texts})
responses = _res.json()["embeddings"]
return responses

def embed_query(self, text: str) -> List[float]:
"""Call out to HuggingFace's Inference Endpoint for embedding query text.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
response = self.embed_documents([text])[0]
return response
Loading

0 comments on commit d6d8e2a

Please sign in to comment.