Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama-index-embeddings-vertex #11528

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
poetry_requirements(
name="poetry", module_mapping={"google-cloud-aiplatform": ["vertexai"]}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)

help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'

format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files

lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files

test: ## Run tests via pytest.
pytest tests

watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# LlamaIndex Embeddings Integration: Vertex AI
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llama_index.embeddings.vertex.base import VertexEmbedding

__all__ = ["VertexEmbedding"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Any, Dict, List, Optional

from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks.base import CallbackManager
from vertexai import init
from vertexai.language_models import TextEmbeddingModel

from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE


class VertexEmbedding(BaseEmbedding):
"""Class for Vertex embeddings."""

model_name: str = Field(
description="The Vertex model to use.", default="textembedding-gecko"
)
embed_batch_size: int = Field(
default=DEFAULT_EMBED_BATCH_SIZE,
description="The batch size for embedding calls.",
gt=0,
lte=2048,
)
vertex_additional_kwargs: Dict[str, Any] = Field(
default_factory=dict, description="Additional kwargs for the Vertex AI API."
)
_client: TextEmbeddingModel = PrivateAttr()

def __init__(
self,
model_name: str,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
credentials: Optional[Any] = None,
project: Optional[str] = None,
location: Optional[str] = None,
vertex_additional_kwargs: Optional[Dict[str, Any]] = None,
callback_manager: Optional[CallbackManager] = None,
) -> None:
init(project=project, location=location, credentials=credentials)
self._client = TextEmbeddingModel.from_pretrained(model_name)
super().__init__(
model_name=model_name,
embed_batch_size=embed_batch_size,
vertex_additional_kwargs=vertex_additional_kwargs or {},
callback_manager=callback_manager,
)

@classmethod
def class_name(cls) -> str:
return "VertexEmbedding"

def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self.get_general_text_embedding(query)

async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
return self.get_general_text_embedding(query)

def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self.get_general_text_embedding(text)

async def _aget_text_embedding(self, text: str) -> List[float]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vertexai don't have async client to get embeddings?

"""Asynchronously get text embedding."""
return self.get_general_text_embedding(text)

def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
embeddings_list: List[List[float]] = []
for text in texts:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a sequential operation, is there no batch operation available here?

embeddings = self.get_general_text_embedding(text)
embeddings_list.append(embeddings)

return embeddings_list

async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
return self._get_text_embeddings(texts)

def get_general_text_embedding(self, prompt: str) -> List[float]:
"""Get Vertex embedding."""
embeddings = self._client.get_embeddings([prompt])
return embeddings[0].values
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]

[tool.codespell]
check-filenames = true
check-hidden = true
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"

[tool.llamahub]
contains_example = false
import_path = "llama_index.embeddings.vertex"

[tool.llamahub.class_authors]
VertexEmbedding = "skvrd"

[tool.mypy]
disallow_untyped_defs = true
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"

[tool.poetry]
authors = ["Grisha Romanyuk"]
description = "llama-index embeddings vertex integration"
exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-vertex"
readme = "README.md"
version = "0.0.1"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.1"
google-cloud-aiplatform = "^1.39.0"

[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
types-PyYAML = "^6.0.12.12"
types-protobuf = "^4.24.0.4"
types-redis = "4.5.5.0"
types-requests = "2.28.11.8"
types-setuptools = "67.1.0.0"

[tool.poetry.group.dev.dependencies.black]
extras = ["jupyter"]
version = "<=23.9.1,>=23.7.0"

[tool.poetry.group.dev.dependencies.codespell]
extras = ["toml"]
version = ">=v2.2.6"

[[tool.poetry.packages]]
include = "llama_index/"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_tests()
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.embeddings.vertex import VertexEmbedding


def test_embedding_class():
names_of_base_classes = [b.__name__ for b in VertexEmbedding.__mro__]
assert BaseEmbedding.__name__ in names_of_base_classes
Loading