Skip to content

Commit

Permalink
feat: add Jina Embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jan 23, 2024
1 parent 13527a8 commit 0639020
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions dbgpt/rag/embedding/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,69 @@ def embed_query(self, text: str) -> List[float]:
Embeddings for the text.
"""
return self.embed_documents([text])[0]


class JinaEmbeddings(BaseModel, Embeddings):
"""
This class is used to get embeddings for a list of texts using the Jina AI API.
It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en".
"""
api_url: Any #: :meta private:
session: Any #: :meta private:
api_key: str
"""our API key for the Jina AI API.."""
model_name: str = "jina-embeddings-v2-base-en"
"""he name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en"."""

def __init__(self, **kwargs):
"""
Initialize the JinaEmbeddings.
"""
super().__init__(**kwargs)
try:
import requests
except ImportError:
raise ValueError(
"The requests python package is not installed. Please install it with `pip install requests`"
)
self.api_url = 'https://api.jina.ai/v1/embeddings'
self.session = requests.Session()
self.session.headers.update(
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
"""
# Call Jina AI Embedding API
resp = self.session.post( # type: ignore
self.api_url, json={"input": texts, "model": self.model_name}
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])

embeddings = resp["data"]

# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore

# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]

def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]

0 comments on commit 0639020

Please sign in to comment.