-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
180 additions
and
0 deletions.
There are no files selected for viewing
3 changes: 3 additions & 0 deletions
3
llama-index-integrations/embeddings/llama-index-embeddings-vertex/BUILD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
poetry_requirements( | ||
name="poetry", | ||
) |
17 changes: 17 additions & 0 deletions
17
llama-index-integrations/embeddings/llama-index-embeddings-vertex/Makefile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
GIT_ROOT ?= $(shell git rev-parse --show-toplevel) | ||
|
||
help: ## Show all Makefile targets. | ||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' | ||
|
||
format: ## Run code autoformatters (black). | ||
pre-commit install | ||
git ls-files | xargs pre-commit run black --files | ||
|
||
lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy | ||
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files | ||
|
||
test: ## Run tests via pytest. | ||
pytest tests | ||
|
||
watch-docs: ## Build and watch documentation. | ||
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ |
1 change: 1 addition & 0 deletions
1
llama-index-integrations/embeddings/llama-index-embeddings-vertex/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# LlamaIndex Embeddings Integration: Vertex AI |
1 change: 1 addition & 0 deletions
1
...integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/BUILD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python_sources() |
3 changes: 3 additions & 0 deletions
3
...ations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from llama_index.embeddings.vertex.base import VertexEmbedding | ||
|
||
__all__ = ["VertexEmbedding"] |
84 changes: 84 additions & 0 deletions
84
...tegrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
from llama_index.core.base.embeddings.base import BaseEmbedding | ||
from llama_index.core.bridge.pydantic import Field, PrivateAttr | ||
from llama_index.core.callbacks.base import CallbackManager | ||
from llama_index.llms.vertex.utils import init_vertexai | ||
from vertexai.language_models import TextEmbeddingModel | ||
|
||
from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE | ||
|
||
|
||
class VertexEmbedding(BaseEmbedding): | ||
"""Class for Vertex embeddings.""" | ||
|
||
model_name: str = Field( | ||
description="The Vertex model to use.", default="textembedding-gecko" | ||
) | ||
embed_batch_size: int = Field( | ||
default=DEFAULT_EMBED_BATCH_SIZE, | ||
description="The batch size for embedding calls.", | ||
gt=0, | ||
lte=2048, | ||
) | ||
vertex_additional_kwargs: Dict[str, Any] = Field( | ||
default_factory=dict, description="Additional kwargs for the Vertex AI API." | ||
) | ||
_client: TextEmbeddingModel = PrivateAttr() | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, | ||
credentials: Optional[Any] = None, | ||
project: Optional[str] = None, | ||
location: Optional[str] = None, | ||
vertex_additional_kwargs: Optional[Dict[str, Any]] = None, | ||
callback_manager: Optional[CallbackManager] = None, | ||
) -> None: | ||
init_vertexai(project=project, location=location, credentials=credentials) | ||
self._client = TextEmbeddingModel.from_pretrained(model_name) | ||
super().__init__( | ||
model_name=model_name, | ||
embed_batch_size=embed_batch_size, | ||
vertex_additional_kwargs=vertex_additional_kwargs or {}, | ||
callback_manager=callback_manager, | ||
) | ||
|
||
@classmethod | ||
def class_name(cls) -> str: | ||
return "VertexEmbedding" | ||
|
||
def _get_query_embedding(self, query: str) -> List[float]: | ||
"""Get query embedding.""" | ||
return self.get_general_text_embedding(query) | ||
|
||
async def _aget_query_embedding(self, query: str) -> List[float]: | ||
"""The asynchronous version of _get_query_embedding.""" | ||
return self.get_general_text_embedding(query) | ||
|
||
def _get_text_embedding(self, text: str) -> List[float]: | ||
"""Get text embedding.""" | ||
return self.get_general_text_embedding(text) | ||
|
||
async def _aget_text_embedding(self, text: str) -> List[float]: | ||
"""Asynchronously get text embedding.""" | ||
return self.get_general_text_embedding(text) | ||
|
||
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: | ||
"""Get text embeddings.""" | ||
embeddings_list: List[List[float]] = [] | ||
for text in texts: | ||
embeddings = self.get_general_text_embedding(text) | ||
embeddings_list.append(embeddings) | ||
|
||
return embeddings_list | ||
|
||
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: | ||
"""Asynchronously get text embeddings.""" | ||
return self._get_text_embeddings(texts) | ||
|
||
def get_general_text_embedding(self, prompt: str) -> List[float]: | ||
"""Get Vertex embedding.""" | ||
embeddings = self._client.get_embeddings([prompt]) | ||
return embeddings[0].values |
63 changes: 63 additions & 0 deletions
63
llama-index-integrations/embeddings/llama-index-embeddings-vertex/pyproject.toml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
[build-system] | ||
build-backend = "poetry.core.masonry.api" | ||
requires = ["poetry-core"] | ||
|
||
[tool.codespell] | ||
check-filenames = true | ||
check-hidden = true | ||
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" | ||
|
||
[tool.llamahub] | ||
contains_example = false | ||
import_path = "llama_index.embeddings.vertex" | ||
|
||
[tool.llamahub.class_authors] | ||
VertexEmbedding = "skvrd" | ||
|
||
[tool.mypy] | ||
disallow_untyped_defs = true | ||
exclude = ["_static", "build", "examples", "notebooks", "venv"] | ||
ignore_missing_imports = true | ||
python_version = "3.8" | ||
|
||
[tool.poetry] | ||
authors = ["Grisha Romanyuk"] | ||
description = "llama-index embeddings vertex integration" | ||
exclude = ["**/BUILD"] | ||
license = "MIT" | ||
name = "llama-index-embeddings-vertex" | ||
readme = "README.md" | ||
version = "0.0.1" | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.8.1,<4.0" | ||
llama-index-core = "^0.10.1" | ||
google-cloud-aiplatform = "^1.39.0" | ||
|
||
[tool.poetry.group.dev.dependencies] | ||
ipython = "8.10.0" | ||
jupyter = "^1.0.0" | ||
mypy = "0.991" | ||
pre-commit = "3.2.0" | ||
pylint = "2.15.10" | ||
pytest = "7.2.1" | ||
pytest-mock = "3.11.1" | ||
ruff = "0.0.292" | ||
tree-sitter-languages = "^1.8.0" | ||
types-Deprecated = ">=0.1.0" | ||
types-PyYAML = "^6.0.12.12" | ||
types-protobuf = "^4.24.0.4" | ||
types-redis = "4.5.5.0" | ||
types-requests = "2.28.11.8" | ||
types-setuptools = "67.1.0.0" | ||
|
||
[tool.poetry.group.dev.dependencies.black] | ||
extras = ["jupyter"] | ||
version = "<=23.9.1,>=23.7.0" | ||
|
||
[tool.poetry.group.dev.dependencies.codespell] | ||
extras = ["toml"] | ||
version = ">=v2.2.6" | ||
|
||
[[tool.poetry.packages]] | ||
include = "llama_index/" |
1 change: 1 addition & 0 deletions
1
llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/BUILD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python_tests() |
Empty file.
7 changes: 7 additions & 0 deletions
7
...dex-integrations/embeddings/llama-index-embeddings-vertex/tests/test_embeddings_vertex.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from llama_index.core.base.embeddings.base import BaseEmbedding | ||
from llama_index.embeddings.vertex import VertexEmbedding | ||
|
||
|
||
def test_embedding_class(): | ||
names_of_base_classes = [b.__name__ for b in VertexEmbedding.__mro__] | ||
assert BaseEmbedding.__name__ in names_of_base_classes |