Skip to content

Commit

Permalink
llama-index-embeddings-vertex
Browse files Browse the repository at this point in the history
  • Loading branch information
skvrd committed Mar 1, 2024
1 parent 65e96de commit 835bdfc
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
poetry_requirements(
name="poetry", module_mapping={"google-cloud-aiplatform": ["vertexai"]}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)

help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'

format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files

lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files

test: ## Run tests via pytest.
pytest tests

watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# LlamaIndex Embeddings Integration: Vertex AI
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llama_index.embeddings.vertex.base import VertexEmbedding

__all__ = ["VertexEmbedding"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Any, Dict, List, Optional

from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks.base import CallbackManager
from vertexai import init
from vertexai.language_models import TextEmbeddingModel

from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE


class VertexEmbedding(BaseEmbedding):
"""Class for Vertex embeddings."""

model_name: str = Field(
description="The Vertex model to use.", default="textembedding-gecko"
)
embed_batch_size: int = Field(
default=DEFAULT_EMBED_BATCH_SIZE,
description="The batch size for embedding calls.",
gt=0,
lte=2048,
)
vertex_additional_kwargs: Dict[str, Any] = Field(
default_factory=dict, description="Additional kwargs for the Vertex AI API."
)
_client: TextEmbeddingModel = PrivateAttr()

def __init__(
self,
model_name: str,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
credentials: Optional[Any] = None,
project: Optional[str] = None,
location: Optional[str] = None,
vertex_additional_kwargs: Optional[Dict[str, Any]] = None,
callback_manager: Optional[CallbackManager] = None,
) -> None:
init(project=project, location=location, credentials=credentials)
self._client = TextEmbeddingModel.from_pretrained(model_name)
super().__init__(
model_name=model_name,
embed_batch_size=embed_batch_size,
vertex_additional_kwargs=vertex_additional_kwargs or {},
callback_manager=callback_manager,
)

@classmethod
def class_name(cls) -> str:
return "VertexEmbedding"

def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self.get_general_text_embedding(query)

async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
return self.get_general_text_embedding(query)

def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self.get_general_text_embedding(text)

async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
return self.get_general_text_embedding(text)

def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
embeddings_list: List[List[float]] = []
for text in texts:
embeddings = self.get_general_text_embedding(text)
embeddings_list.append(embeddings)

return embeddings_list

async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
return self._get_text_embeddings(texts)

def get_general_text_embedding(self, prompt: str) -> List[float]:
"""Get Vertex embedding."""
embeddings = self._client.get_embeddings([prompt])
return embeddings[0].values
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]

[tool.codespell]
check-filenames = true
check-hidden = true
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"

[tool.llamahub]
contains_example = false
import_path = "llama_index.embeddings.vertex"

[tool.llamahub.class_authors]
VertexEmbedding = "skvrd"

[tool.mypy]
disallow_untyped_defs = true
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"

[tool.poetry]
authors = ["Grisha Romanyuk"]
description = "llama-index embeddings vertex integration"
exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-vertex"
readme = "README.md"
version = "0.0.1"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.1"
google-cloud-aiplatform = "^1.39.0"

[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
types-PyYAML = "^6.0.12.12"
types-protobuf = "^4.24.0.4"
types-redis = "4.5.5.0"
types-requests = "2.28.11.8"
types-setuptools = "67.1.0.0"

[tool.poetry.group.dev.dependencies.black]
extras = ["jupyter"]
version = "<=23.9.1,>=23.7.0"

[tool.poetry.group.dev.dependencies.codespell]
extras = ["toml"]
version = ">=v2.2.6"

[[tool.poetry.packages]]
include = "llama_index/"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_tests()
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.embeddings.vertex import VertexEmbedding


def test_embedding_class():
names_of_base_classes = [b.__name__ for b in VertexEmbedding.__mro__]
assert BaseEmbedding.__name__ in names_of_base_classes

0 comments on commit 835bdfc

Please sign in to comment.