diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/BUILD new file mode 100644 index 00000000000000..84f2657a9f8799 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", module_mapping={"google-cloud-aiplatform": ["vertexai"]} +) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/Makefile b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/Makefile new file mode 100644 index 00000000000000..b9eab05aa37062 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/README.md b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/README.md new file mode 100644 index 00000000000000..6f3022cebd0553 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/README.md @@ -0,0 +1 @@ +# LlamaIndex Embeddings Integration: Vertex AI diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/BUILD new file mode 100644 index 00000000000000..db46e8d6c978c6 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/__init__.py b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/__init__.py new file mode 100644 index 00000000000000..714745d581511d --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/__init__.py @@ -0,0 +1,3 @@ +from llama_index.embeddings.vertex.base import VertexEmbedding + +__all__ = ["VertexEmbedding"] diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/base.py new file mode 100644 index 00000000000000..c5bf9995f4392d --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/llama_index/embeddings/vertex/base.py @@ -0,0 +1,84 @@ +from typing import Any, Dict, List, Optional + +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.bridge.pydantic import Field, PrivateAttr +from llama_index.core.callbacks.base import CallbackManager +from llama_index.llms.vertex.utils import init_vertexai +from vertexai.language_models import TextEmbeddingModel + +from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE + + +class VertexEmbedding(BaseEmbedding): + """Class for Vertex embeddings.""" + + model_name: str = Field( + description="The Vertex model to use.", default="textembedding-gecko" + ) + embed_batch_size: int = Field( + default=DEFAULT_EMBED_BATCH_SIZE, + description="The batch size for embedding calls.", + gt=0, + lte=2048, + ) + vertex_additional_kwargs: Dict[str, Any] = Field( + default_factory=dict, description="Additional kwargs for the Vertex AI API." + ) + _client: TextEmbeddingModel = PrivateAttr() + + def __init__( + self, + model_name: str, + embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, + credentials: Optional[Any] = None, + project: Optional[str] = None, + location: Optional[str] = None, + vertex_additional_kwargs: Optional[Dict[str, Any]] = None, + callback_manager: Optional[CallbackManager] = None, + ) -> None: + init_vertexai(project=project, location=location, credentials=credentials) + self._client = TextEmbeddingModel.from_pretrained(model_name) + super().__init__( + model_name=model_name, + embed_batch_size=embed_batch_size, + vertex_additional_kwargs=vertex_additional_kwargs or {}, + callback_manager=callback_manager, + ) + + @classmethod + def class_name(cls) -> str: + return "VertexEmbedding" + + def _get_query_embedding(self, query: str) -> List[float]: + """Get query embedding.""" + return self.get_general_text_embedding(query) + + async def _aget_query_embedding(self, query: str) -> List[float]: + """The asynchronous version of _get_query_embedding.""" + return self.get_general_text_embedding(query) + + def _get_text_embedding(self, text: str) -> List[float]: + """Get text embedding.""" + return self.get_general_text_embedding(text) + + async def _aget_text_embedding(self, text: str) -> List[float]: + """Asynchronously get text embedding.""" + return self.get_general_text_embedding(text) + + def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get text embeddings.""" + embeddings_list: List[List[float]] = [] + for text in texts: + embeddings = self.get_general_text_embedding(text) + embeddings_list.append(embeddings) + + return embeddings_list + + async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Asynchronously get text embeddings.""" + return self._get_text_embeddings(texts) + + def get_general_text_embedding(self, prompt: str) -> List[float]: + """Get Vertex embedding.""" + embeddings = self._client.get_embeddings([prompt]) + return embeddings[0].values diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/pyproject.toml new file mode 100644 index 00000000000000..471962fc252367 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/pyproject.toml @@ -0,0 +1,63 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = false +import_path = "llama_index.embeddings.vertex" + +[tool.llamahub.class_authors] +VertexEmbedding = "skvrd" + +[tool.mypy] +disallow_untyped_defs = true +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["Grisha Romanyuk"] +description = "llama-index embeddings vertex integration" +exclude = ["**/BUILD"] +license = "MIT" +name = "llama-index-embeddings-vertex" +readme = "README.md" +version = "0.0.1" + +[tool.poetry.dependencies] +python = ">=3.8.1,<4.0" +llama-index-core = "^0.10.1" +google-cloud-aiplatform = "^1.39.0" + +[tool.poetry.group.dev.dependencies] +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" +types-setuptools = "67.1.0.0" + +[tool.poetry.group.dev.dependencies.black] +extras = ["jupyter"] +version = "<=23.9.1,>=23.7.0" + +[tool.poetry.group.dev.dependencies.codespell] +extras = ["toml"] +version = ">=v2.2.6" + +[[tool.poetry.packages]] +include = "llama_index/" diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/BUILD new file mode 100644 index 00000000000000..dabf212d7e7162 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/__init__.py b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/test_embeddings_vertex.py b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/test_embeddings_vertex.py new file mode 100644 index 00000000000000..2f57b8f907e9d2 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-vertex/tests/test_embeddings_vertex.py @@ -0,0 +1,7 @@ +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.embeddings.vertex import VertexEmbedding + + +def test_embedding_class(): + names_of_base_classes = [b.__name__ for b in VertexEmbedding.__mro__] + assert BaseEmbedding.__name__ in names_of_base_classes