-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] feat: add Voyage embeddings #408
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""Implements embeddings from [Voyage AI](https://voyageai.com). | ||
""" | ||
|
||
import importlib | ||
|
||
from kotaemon.base import Document, DocumentWithEmbedding, Param | ||
|
||
from .base import BaseEmbeddings | ||
|
||
|
||
vo = None | ||
|
||
|
||
def _import_voyageai(): | ||
global vo | ||
if not vo: | ||
vo = importlib.import_module("voyageai") | ||
return vo | ||
|
||
|
||
def _format_output(texts: list[str], embeddings: list[list]): | ||
"""Formats the output of all `.embed` calls. | ||
|
||
Args: | ||
texts: List of original documents | ||
embeddings: Embeddings corresponding to each document | ||
""" | ||
return [ | ||
DocumentWithEmbedding(content=text, embedding=embedding) | ||
for text, embedding in zip(texts, embeddings) | ||
] | ||
|
||
|
||
class VoyageAIEmbeddings(BaseEmbeddings): | ||
"""Voyage AI provides best-in-class embedding models and rerankers.""" | ||
|
||
api_key: str = Param(None, help="Voyage API key", required=False) | ||
model_name: str = Param( | ||
"voyage-3", | ||
help=( | ||
"Model name to use. The Voyage " | ||
"[documentation](https://docs.voyageai.com/docs/embeddings) " | ||
"provides a list of all available embedding models." | ||
), | ||
required=True, | ||
) | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self._client = _import_voyageai().Client(api_key=self.api_key) | ||
self._aclient = _import_voyageai().AsyncClient(api_key=self.api_key) | ||
|
||
def invoke( | ||
self, text: str | list[str] | Document | list[Document], *args, **kwargs | ||
) -> list[DocumentWithEmbedding]: | ||
texts = [t.content for t in self.prepare_input(text)] | ||
embeddings = self._client.embed(texts, model=self.model_name).embeddings | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: Handle exceptions when calling the embed method to prevent crashes. Existing Code: embeddings = self._client.embed(texts, model=self.model_name).embeddings Improved Code: try:
embeddings = self._client.embed(texts, model=self.model_name).embeddings
except Exception as e:
raise RuntimeError(f"Failed to retrieve embeddings: {e}") Details: |
||
return _format_output(texts, embeddings) | ||
|
||
async def ainvoke( | ||
self, text: str | list[str] | Document | list[Document], *args, **kwargs | ||
) -> list[DocumentWithEmbedding]: | ||
texts = [t.content for t in self.prepare_input(text)] | ||
embeddings = await self._aclient.embed(texts, model=self.model_name).embeddings | ||
return _format_output(texts, embeddings) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,12 +11,14 @@ | |
LCCohereEmbeddings, | ||
LCHuggingFaceEmbeddings, | ||
OpenAIEmbeddings, | ||
VoyageAIEmbeddings, | ||
) | ||
|
||
from .conftest import ( | ||
skip_when_cohere_not_installed, | ||
skip_when_fastembed_not_installed, | ||
skip_when_sentence_bert_not_installed, | ||
skip_when_voyageai_not_installed, | ||
) | ||
|
||
with open(Path(__file__).parent / "resources" / "embedding_openai_batch.json") as f: | ||
|
@@ -155,3 +157,18 @@ def test_fastembed_embeddings(): | |
model = FastEmbedEmbeddings() | ||
output = model("Hello World") | ||
assert_embedding_result(output) | ||
|
||
|
||
@skip_when_voyageai_not_installed | ||
@patch( | ||
"voyageai.Client.embed", | ||
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]], | ||
) | ||
@patch( | ||
"voyageai.AsyncClient.embed", | ||
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]], | ||
) | ||
def test_voyageai_embeddings(): | ||
model = VoyageAIEmbeddings() | ||
output = model("Hello, world!") | ||
assert_embedding_result(output) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: Ensure the test checks for the correct output structure. Existing Code: assert_embedding_result(output) Improved Code: assert isinstance(output, list) and all(isinstance(doc, DocumentWithEmbedding) for doc in output) Details: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Validate the API key before using it in the client.
Existing Code:
Improved Code:
Details:
Ensure that the
api_key
is validated before using it to create the client.