From 9010fdc7e30cf1d1501013c560d53e917ec0029e Mon Sep 17 00:00:00 2001 From: CJ Kindel Date: Mon, 24 Jun 2024 12:45:31 -0700 Subject: [PATCH 1/6] `meta` parameter added to TextArtifact (#891) --- CHANGELOG.md | 1 + griptape/artifacts/text_artifact.py | 3 ++- tests/unit/artifacts/test_text_artifact.py | 8 ++++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 710a1c76f..435b62505 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `DuckDuckGoWebSearchDriver` to web search with the DuckDuckGo search SDK. - `ProxyWebScraperDriver` to web scrape using proxies. - Parameter `session` on `AmazonBedrockStructureConfig`. +- Parameter `meta` on `TextArtifact`. ### Changed - **BREAKING**: `BaseVectorStoreDriver.upsert_text_artifact()` and `BaseVectorStoreDriver.upsert_text()` use artifact/string values to generate `vector_id` if it wasn't implicitly passed. This change ensures that we don't generate embeddings for the same content every time. diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index e8a2bb2a7..8b83303f0 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from attrs import define, field from griptape.artifacts import BaseArtifact @@ -13,6 +13,7 @@ class TextArtifact(BaseArtifact): value: str = field(converter=str, metadata={"serializable": True}) encoding: str = field(default="utf-8", kw_only=True) encoding_error_handler: str = field(default="strict", kw_only=True) + meta: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True}) _embedding: list[float] = field(factory=list, kw_only=True) @property diff --git a/tests/unit/artifacts/test_text_artifact.py b/tests/unit/artifacts/test_text_artifact.py index f913429d7..6ea2c6697 100644 --- a/tests/unit/artifacts/test_text_artifact.py +++ b/tests/unit/artifacts/test_text_artifact.py @@ -61,3 +61,11 @@ def test_name(self): assert artifact.name == artifact.id assert TextArtifact("foo", name="bar").name == "bar" + + def test_meta(self): + artifact = TextArtifact("foo") + + assert artifact.meta == {} + + meta = {"foo": "bar"} + assert TextArtifact("foo", meta=meta).meta == meta From c2575f5318fccb7d11efd83d909ec248e9a7a28e Mon Sep 17 00:00:00 2001 From: CJ Kindel Date: Tue, 25 Jun 2024 08:05:38 -0700 Subject: [PATCH 2/6] Handle error on vector entry not existing (#893) --- griptape/drivers/vector/base_vector_store_driver.py | 5 ++++- ...re_driver.py => test_base_local_vector_store_driver.py} | 7 ++++++- .../vector/test_in_memory_local_vector_store_driver.py | 2 +- .../vector/test_persistent_local_vector_store_driver.py | 4 ++-- 4 files changed, 13 insertions(+), 5 deletions(-) rename tests/unit/drivers/vector/{base_local_vector_store_driver.py => test_base_local_vector_store_driver.py} (91%) diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index 71f1c1061..b1d9ed6d0 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -92,7 +92,10 @@ def upsert_text( ) def does_entry_exist(self, vector_id: str, namespace: Optional[str] = None) -> bool: - return self.load_entry(vector_id, namespace) is not None + try: + return self.load_entry(vector_id, namespace) is not None + except Exception: + return False def load_artifacts(self, namespace: Optional[str] = None) -> ListArtifact: result = self.load_entries(namespace) diff --git a/tests/unit/drivers/vector/base_local_vector_store_driver.py b/tests/unit/drivers/vector/test_base_local_vector_store_driver.py similarity index 91% rename from tests/unit/drivers/vector/base_local_vector_store_driver.py rename to tests/unit/drivers/vector/test_base_local_vector_store_driver.py index 07cf284f8..c34a54d98 100644 --- a/tests/unit/drivers/vector/base_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_base_local_vector_store_driver.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod import pytest -from griptape.artifacts import TextArtifact, BaseArtifact +from unittest.mock import patch +from griptape.artifacts import TextArtifact class BaseLocalVectorStoreDriver(ABC): @@ -65,3 +66,7 @@ def test_load_artifacts(self, driver): assert len(driver.load_artifacts()) == 3 assert len(driver.load_artifacts("test-namespace-1")) == 2 assert len(driver.load_artifacts("test-namespace-2")) == 1 + + def test_does_entry_exist_exception(self, driver): + with patch.object(driver, "load_entry", side_effect=Exception): + assert driver.does_entry_exist("does_not_exist") is False diff --git a/tests/unit/drivers/vector/test_in_memory_local_vector_store_driver.py b/tests/unit/drivers/vector/test_in_memory_local_vector_store_driver.py index c426ea5b4..cb8fcbefe 100644 --- a/tests/unit/drivers/vector/test_in_memory_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_in_memory_local_vector_store_driver.py @@ -1,7 +1,7 @@ import pytest from griptape.drivers import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.unit.drivers.vector.base_local_vector_store_driver import BaseLocalVectorStoreDriver +from tests.unit.drivers.vector.test_base_local_vector_store_driver import BaseLocalVectorStoreDriver class TestInMemoryLocalVectorStoreDriver(BaseLocalVectorStoreDriver): diff --git a/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py b/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py index 1f800967a..8f6773fc1 100644 --- a/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_persistent_local_vector_store_driver.py @@ -1,10 +1,10 @@ import os import tempfile import pytest -from griptape.artifacts import TextArtifact, BaseArtifact +from griptape.artifacts import TextArtifact from griptape.drivers import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.unit.drivers.vector.base_local_vector_store_driver import BaseLocalVectorStoreDriver +from tests.unit.drivers.vector.test_base_local_vector_store_driver import BaseLocalVectorStoreDriver class TestPersistentLocalVectorStoreDriver(BaseLocalVectorStoreDriver): From 2df58d551963d214892642bc6645b407d45d2219 Mon Sep 17 00:00:00 2001 From: Emily Danielson <2302515+emjay07@users.noreply.github.com> Date: Tue, 25 Jun 2024 11:56:09 -0700 Subject: [PATCH 3/6] Bumping gemini to latest version (#892) --- poetry.lock | 92 +++++++++++++++++++++++++++++++++++++++++++++----- pyproject.toml | 2 +- 2 files changed, 84 insertions(+), 10 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1b70d0c93..e030c8906 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1543,17 +1543,18 @@ dev = ["flake8", "markdown", "twine", "wheel"] [[package]] name = "google-ai-generativelanguage" -version = "0.4.0" +version = "0.6.5" description = "Google Ai Generativelanguage API client library" optional = true python-versions = ">=3.7" files = [ - {file = "google-ai-generativelanguage-0.4.0.tar.gz", hash = "sha256:c8199066c08f74c4e91290778329bb9f357ba1ea5d6f82de2bc0d10552bf4f8c"}, - {file = "google_ai_generativelanguage-0.4.0-py3-none-any.whl", hash = "sha256:e4c425376c1ee26c78acbc49a24f735f90ebfa81bf1a06495fae509a2433232c"}, + {file = "google-ai-generativelanguage-0.6.5.tar.gz", hash = "sha256:c4089c277fa4e26722f76ab03ee3039f28be8bf1c9be282948b9583a154c6d79"}, + {file = "google_ai_generativelanguage-0.6.5-py3-none-any.whl", hash = "sha256:236875bb4a6d6ebdba2f12bd9d5e776100fd913402157a47b5e9fb80a13f25a7"}, ] [package.dependencies] -google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" @@ -1588,6 +1589,24 @@ grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] +[[package]] +name = "google-api-python-client" +version = "2.134.0" +description = "Google API Client Library for Python" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google-api-python-client-2.134.0.tar.gz", hash = "sha256:4a8f0bea651a212997cc83c0f271fc86f80ef93d1cee9d84de7dfaeef2a858b6"}, + {file = "google_api_python_client-2.134.0-py2.py3-none-any.whl", hash = "sha256:ba05d60f6239990b7994f6328f17bb154c602d31860fb553016dc9f8ce886945"}, +] + +[package.dependencies] +google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0.dev0" +google-auth = ">=1.32.0,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0" +google-auth-httplib2 = ">=0.2.0,<1.0.0" +httplib2 = ">=0.19.0,<1.dev0" +uritemplate = ">=3.0.1,<5" + [[package]] name = "google-auth" version = "2.29.0" @@ -1611,19 +1630,35 @@ pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] reauth = ["pyu2f (>=0.1.5)"] requests = ["requests (>=2.20.0,<3.0.0.dev0)"] +[[package]] +name = "google-auth-httplib2" +version = "0.2.0" +description = "Google Authentication Library: httplib2 transport" +optional = true +python-versions = "*" +files = [ + {file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"}, + {file = "google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d"}, +] + +[package.dependencies] +google-auth = "*" +httplib2 = ">=0.19.0" + [[package]] name = "google-generativeai" -version = "0.4.1" +version = "0.7.0" description = "Google Generative AI High level API client library and tools." optional = true python-versions = ">=3.9" files = [ - {file = "google_generativeai-0.4.1-py3-none-any.whl", hash = "sha256:89be3c00c2e688108fccefc50f47f45fc9d37ecd53c1ade9d86b5d982919c24a"}, + {file = "google_generativeai-0.7.0-py3-none-any.whl", hash = "sha256:7be4b634afeb8b6bebde1af7271e94d2af84d2d28b5988c7ed9921733c40fe63"}, ] [package.dependencies] -google-ai-generativelanguage = "0.4.0" +google-ai-generativelanguage = "0.6.5" google-api-core = "*" +google-api-python-client = "*" google-auth = ">=2.15.0" protobuf = "*" pydantic = "*" @@ -1863,6 +1898,20 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] trio = ["trio (>=0.22.0,<0.26.0)"] +[[package]] +name = "httplib2" +version = "0.22.0" +description = "A comprehensive HTTP client library." +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"}, + {file = "httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81"}, +] + +[package.dependencies] +pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""} + [[package]] name = "httpx" version = "0.27.0" @@ -4242,6 +4291,20 @@ cryptography = ">=41.0.5,<43" docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx-rtd-theme"] test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"] +[[package]] +name = "pyparsing" +version = "3.1.2" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = true +python-versions = ">=3.6.8" +files = [ + {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, + {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + [[package]] name = "pypdf" version = "3.17.4" @@ -5944,6 +6007,17 @@ tzdata = {version = "*", markers = "platform_system == \"Windows\""} [package.extras] devenv = ["check-manifest", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"] +[[package]] +name = "uritemplate" +version = "4.1.1" +description = "Implementation of RFC 6570 URI Templates" +optional = true +python-versions = ">=3.6" +files = [ + {file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"}, + {file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"}, +] + [[package]] name = "urllib3" version = "1.26.18" @@ -6321,4 +6395,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "a5e1a9aaf0fa253d904eee8803dab5f943be59c872c2449b73aea917ecb1c543" +content-hash = "ce26764ee2c4a9a99d24ef4afc7efa6aa894a7560a725388ff24db15f6014e9a" diff --git a/pyproject.toml b/pyproject.toml index 972f0fe81..29d4fda44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ redis = { version = "^4.6.0", optional = true } opensearch-py = { version = "^2.3.1", optional = true } pgvector = { version = "^0.2.3", optional = true } psycopg2-binary = { version = "^2.9.9", optional = true } -google-generativeai = { version = "^0.4.1", optional = true } +google-generativeai = { version = "^0.7.0", optional = true } trafilatura = {version = "^1.6", optional = true} playwright = {version = "^1.42", optional = true} beautifulsoup4 = {version = "^4.12.3", optional = true} From 80c27cb1fff80a4d3c9200f6a95db28c59de71ca Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Tue, 25 Jun 2024 15:06:43 -0600 Subject: [PATCH 4/6] VectorStoreClient improvements (#899) --- CHANGELOG.md | 4 +++ .../official-tools/vector-store-client.md | 2 +- griptape/tools/vector_store_client/tool.py | 30 +++++++++---------- tests/unit/tools/test_vector_store_client.py | 25 +++++++++++++++- 4 files changed, 44 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 435b62505..818211d8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `ProxyWebScraperDriver` to web scrape using proxies. - Parameter `session` on `AmazonBedrockStructureConfig`. - Parameter `meta` on `TextArtifact`. +- `VectorStoreClient` improvements: + - `VectorStoreClient.query_params` dict for custom query params. + - `VectorStoreClient.process_query_output_fn` for custom query output processing logic. ### Changed - **BREAKING**: `BaseVectorStoreDriver.upsert_text_artifact()` and `BaseVectorStoreDriver.upsert_text()` use artifact/string values to generate `vector_id` if it wasn't implicitly passed. This change ensures that we don't generate embeddings for the same content every time. @@ -46,6 +49,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Merged `BaseVectorStoreDriver.QueryResult` into `BaseVectorStoreDriver.Entry`. - **BREAKING**: Replaced `query_engine` with `vector_store_driver` in `VectorStoreClient`. - **BREAKING**: removed parameters `google_api_lang`, `google_api_key`, `google_api_search_id`, `google_api_country` on `WebSearch` in favor of `web_search_driver`. +- **BREAKING**: removed `VectorStoreClient.top_n` and `VectorStoreClient.namespace` in favor of `VectorStoreClient.query_params`. - `GriptapeCloudKnowledgeBaseClient` migrated to `/search` api. ## [0.27.1] - 2024-06-20 diff --git a/docs/griptape-tools/official-tools/vector-store-client.md b/docs/griptape-tools/official-tools/vector-store-client.md index ffdbbba91..f3cab2065 100644 --- a/docs/griptape-tools/official-tools/vector-store-client.md +++ b/docs/griptape-tools/official-tools/vector-store-client.md @@ -21,7 +21,7 @@ vector_store_driver.upsert_text_artifacts( vector_db = VectorStoreClient( description="This DB has information about the Griptape Python framework", vector_store_driver=vector_store_driver, - namespace="griptape", + query_params={"namespace": "griptape"}, off_prompt=True ) diff --git a/griptape/tools/vector_store_client/tool.py b/griptape/tools/vector_store_client/tool.py index 38d7784c2..8d4e73022 100644 --- a/griptape/tools/vector_store_client/tool.py +++ b/griptape/tools/vector_store_client/tool.py @@ -1,34 +1,36 @@ from __future__ import annotations -from typing import Optional -from attrs import define, field +from typing import Callable, Any +from attrs import define, field, Factory from schema import Schema, Literal -from griptape.artifacts import ErrorArtifact +from griptape.artifacts import ErrorArtifact, BaseArtifact from griptape.artifacts import ListArtifact from griptape.drivers import BaseVectorStoreDriver from griptape.tools import BaseTool from griptape.utils.decorators import activity -@define +@define(kw_only=True) class VectorStoreClient(BaseTool): """ Attributes: description: LLM-friendly vector DB description. - namespace: Vector storage namespace. vector_store_driver: `BaseVectorStoreDriver`. - top_n: Max number of results returned for the query engine query. + query_params: Optional dictionary of vector store driver query parameters. + process_query_output_fn: Optional lambda for processing vector store driver query output `Entry`s. """ DEFAULT_TOP_N = 5 - description: str = field(kw_only=True) - vector_store_driver: BaseVectorStoreDriver = field(kw_only=True) - top_n: int = field(default=DEFAULT_TOP_N, kw_only=True) - namespace: Optional[str] = field(default=None, kw_only=True) + description: str = field() + vector_store_driver: BaseVectorStoreDriver = field() + query_params: dict[str, Any] = field(factory=dict) + process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], BaseArtifact] = field( + default=Factory(lambda: lambda es: ListArtifact([e.to_artifact() for e in es])) + ) @activity( config={ - "description": "Can be used to search a vector database with the following description: {{ _self.description }}", + "description": "Can be used to search a database with the following description: {{ _self.description }}", "schema": Schema( { Literal( @@ -38,12 +40,10 @@ class VectorStoreClient(BaseTool): ), } ) - def search(self, params: dict) -> ListArtifact | ErrorArtifact: + def search(self, params: dict) -> BaseArtifact: query = params["values"]["query"] try: - entries = self.vector_store_driver.query(query, namespace=self.namespace, count=self.top_n) - - return ListArtifact([e.to_artifact() for e in entries]) + return self.process_query_output_fn(self.vector_store_driver.query(query, **self.query_params)) except Exception as e: return ErrorArtifact(f"error querying vector store: {e}") diff --git a/tests/unit/tools/test_vector_store_client.py b/tests/unit/tools/test_vector_store_client.py index 9503501b5..45018b847 100644 --- a/tests/unit/tools/test_vector_store_client.py +++ b/tests/unit/tools/test_vector_store_client.py @@ -1,5 +1,5 @@ import pytest -from griptape.artifacts import TextArtifact +from griptape.artifacts import TextArtifact, ListArtifact from griptape.drivers import LocalVectorStoreDriver from griptape.tools import VectorStoreClient from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -17,3 +17,26 @@ def test_search(self): driver.upsert_text_artifacts({"test": [TextArtifact("foo"), TextArtifact("bar")]}) assert set([a.value for a in tool.search({"values": {"query": "test"}})]) == {"foo", "bar"} + + def test_search_with_namespace(self): + driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) + tool1 = VectorStoreClient(description="Test", vector_store_driver=driver, query_params={"namespace": "test"}) + tool2 = VectorStoreClient(description="Test", vector_store_driver=driver, query_params={"namespace": "test2"}) + + driver.upsert_text_artifacts({"test": [TextArtifact("foo"), TextArtifact("bar")]}) + + assert len(tool1.search({"values": {"query": "test"}})) == 2 + assert len(tool2.search({"values": {"query": "test"}})) == 0 + + def test_custom_process_query_output_fn(self): + driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) + tool1 = VectorStoreClient( + description="Test", + vector_store_driver=driver, + process_query_output_fn=lambda es: ListArtifact([e.vector for e in es]), + query_params={"include_vectors": True}, + ) + + driver.upsert_text_artifacts({"test": [TextArtifact("foo"), TextArtifact("bar")]}) + + assert tool1.search({"values": {"query": "test"}}).value == [[0, 1], [0, 1]] From 22981b12529eadb96842d601e7eedf99d9d1bf6d Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Tue, 25 Jun 2024 15:19:37 -0600 Subject: [PATCH 5/6] Wrap future execution with context managers (#898) --- CHANGELOG.md | 1 + .../drivers/vector/base_vector_store_driver.py | 15 ++++++++------- .../related_query_generation_rag_module.py | 17 +++++++++-------- .../retrieval/text_retrieval_rag_module.py | 13 +++++++------ griptape/engines/rag/stages/query_rag_stage.py | 9 +++++---- .../engines/rag/stages/retrieval_rag_stage.py | 5 ++--- griptape/loaders/base_loader.py | 11 +++++------ griptape/tasks/actions_subtask.py | 5 ++--- griptape/utils/file_utils.py | 7 ++++--- 9 files changed, 43 insertions(+), 40 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 818211d8f..d1e581b4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: removed parameters `google_api_lang`, `google_api_key`, `google_api_search_id`, `google_api_country` on `WebSearch` in favor of `web_search_driver`. - **BREAKING**: removed `VectorStoreClient.top_n` and `VectorStoreClient.namespace` in favor of `VectorStoreClient.query_params`. - `GriptapeCloudKnowledgeBaseClient` migrated to `/search` api. +- Wrapped all future `submit` calls with the `with` block to address future executor shutdown issues. ## [0.27.1] - 2024-06-20 diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index b1d9ed6d0..8002101b7 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -36,13 +36,14 @@ def to_artifact(self) -> BaseArtifact: def upsert_text_artifacts( self, artifacts: dict[str, list[TextArtifact]], meta: Optional[dict] = None, **kwargs ) -> None: - utils.execute_futures_dict( - { - namespace: self.futures_executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs) - for namespace, artifact_list in artifacts.items() - for a in artifact_list - } - ) + with self.futures_executor as executor: + utils.execute_futures_dict( + { + namespace: executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs) + for namespace, artifact_list in artifacts.items() + for a in artifact_list + } + ) def upsert_text_artifact( self, diff --git a/griptape/engines/rag/modules/query/related_query_generation_rag_module.py b/griptape/engines/rag/modules/query/related_query_generation_rag_module.py index 4661e24a2..9f610c0e5 100644 --- a/griptape/engines/rag/modules/query/related_query_generation_rag_module.py +++ b/griptape/engines/rag/modules/query/related_query_generation_rag_module.py @@ -18,14 +18,15 @@ class RelatedQueryGenerationRagModule(BaseQueryRagModule): def run(self, context: RagContext) -> list[str]: system_prompt = self.generate_system_template(context.initial_query) - results = utils.execute_futures_list( - [ - self.futures_executor.submit( - self.prompt_driver.run, self.generate_query_prompt_stack(system_prompt, "Alternative query: ") - ) - for _ in range(self.query_count) - ] - ) + with self.futures_executor as executor: + results = utils.execute_futures_list( + [ + executor.submit( + self.prompt_driver.run, self.generate_query_prompt_stack(system_prompt, "Alternative query: ") + ) + for _ in range(self.query_count) + ] + ) return [r.value for r in results] diff --git a/griptape/engines/rag/modules/retrieval/text_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/text_retrieval_rag_module.py index 5c6bedac4..1434b9880 100644 --- a/griptape/engines/rag/modules/retrieval/text_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/text_retrieval_rag_module.py @@ -21,12 +21,13 @@ def run(self, context: RagContext) -> Sequence[TextArtifact]: all_queries = [context.initial_query] + context.alternative_queries namespace = self.namespace or context.namespace - results = utils.execute_futures_list( - [ - self.futures_executor.submit(self.vector_store_driver.query, query, self.top_n, namespace, False) - for query in all_queries - ] - ) + with self.futures_executor as executor: + results = utils.execute_futures_list( + [ + executor.submit(self.vector_store_driver.query, query, self.top_n, namespace, False) + for query in all_queries + ] + ) return [ artifact diff --git a/griptape/engines/rag/stages/query_rag_stage.py b/griptape/engines/rag/stages/query_rag_stage.py index b122fa7a6..8c93ad763 100644 --- a/griptape/engines/rag/stages/query_rag_stage.py +++ b/griptape/engines/rag/stages/query_rag_stage.py @@ -14,10 +14,11 @@ class QueryRagStage(BaseRagStage): def run(self, context: RagContext) -> RagContext: logging.info(f"QueryStage: running {len(self.query_generation_modules)} query generation modules in parallel") - results = utils.execute_futures_list( - [self.futures_executor.submit(r.run, context) for r in self.query_generation_modules] - ) + with self.futures_executor as executor: + results = utils.execute_futures_list( + [executor.submit(r.run, context) for r in self.query_generation_modules] + ) - context.alternative_queries = list(itertools.chain.from_iterable(results)) + context.alternative_queries = list(itertools.chain.from_iterable(results)) return context diff --git a/griptape/engines/rag/stages/retrieval_rag_stage.py b/griptape/engines/rag/stages/retrieval_rag_stage.py index 8a0ceca44..77443fb2f 100644 --- a/griptape/engines/rag/stages/retrieval_rag_stage.py +++ b/griptape/engines/rag/stages/retrieval_rag_stage.py @@ -19,9 +19,8 @@ class RetrievalRagStage(BaseRagStage): def run(self, context: RagContext) -> RagContext: logging.info(f"RetrievalStage: running {len(self.retrieval_modules)} retrieval modules in parallel") - results = utils.execute_futures_list( - [self.futures_executor.submit(r.run, context) for r in self.retrieval_modules] - ) + with self.futures_executor as executor: + results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.retrieval_modules]) # flatten the list of lists results = list(itertools.chain.from_iterable(results)) diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index 1648b8f26..40121067c 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -26,12 +26,11 @@ def load_collection( # Create a dictionary before actually submitting the jobs to the executor # to avoid duplicate work. sources_by_key = {self.to_key(source): source for source in sources} - return execute_futures_dict( - { - key: self.futures_executor.submit(self.load, source, *args, **kwargs) - for key, source in sources_by_key.items() - } - ) + + with self.futures_executor as executor: + return execute_futures_dict( + {key: executor.submit(self.load, source, *args, **kwargs) for key, source in sources_by_key.items()} + ) def to_key(self, source: Any, *args, **kwargs) -> str: if isinstance(source, bytes): diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 1546a825d..ae3893abb 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -119,9 +119,8 @@ def run(self) -> BaseArtifact: return ErrorArtifact("no tool output") def execute_actions(self, actions: list[Action]) -> list[tuple[str, BaseArtifact]]: - results = utils.execute_futures_dict( - {a.tag: self.futures_executor.submit(self.execute_action, a) for a in actions} - ) + with self.futures_executor as executor: + results = utils.execute_futures_dict({a.tag: executor.submit(self.execute_action, a) for a in actions}) return [r for r in results.values()] diff --git a/griptape/utils/file_utils.py b/griptape/utils/file_utils.py index 402436a2f..ebe5ba456 100644 --- a/griptape/utils/file_utils.py +++ b/griptape/utils/file_utils.py @@ -30,6 +30,7 @@ def load_files(paths: list[str], futures_executor: Optional[futures.ThreadPoolEx if futures_executor is None: futures_executor = futures.ThreadPoolExecutor() - return utils.execute_futures_dict( - {utils.str_to_hash(str(path)): futures_executor.submit(load_file, path) for path in paths} - ) + with futures_executor as executor: + return utils.execute_futures_dict( + {utils.str_to_hash(str(path)): executor.submit(load_file, path) for path in paths} + ) From 2a347f32ed1ce948ce4e1bf8654c13640ee342a5 Mon Sep 17 00:00:00 2001 From: Emily Danielson <2302515+emjay07@users.noreply.github.com> Date: Tue, 25 Jun 2024 15:39:12 -0700 Subject: [PATCH 6/6] Bug Fix: Cohere prompts with no history (#900) --- CHANGELOG.md | 1 + .../drivers/prompt-drivers.md | 2 +- docs/griptape-framework/structures/config.md | 13 ++++++ .../drivers/prompt/cohere_prompt_driver.py | 12 ++++-- .../prompt/test_cohere_prompt_driver.py | 41 +++++++++++++++++-- 5 files changed, 61 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d1e581b4b..183c3a34a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: removed `VectorStoreClient.top_n` and `VectorStoreClient.namespace` in favor of `VectorStoreClient.query_params`. - `GriptapeCloudKnowledgeBaseClient` migrated to `/search` api. - Wrapped all future `submit` calls with the `with` block to address future executor shutdown issues. +- Fixed bug in `CoherePromptDriver` to properly handle empty history ## [0.27.1] - 2024-06-20 diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 0100ccbac..96a1be4e1 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -141,7 +141,7 @@ from griptape.config import StructureConfig agent = Agent( config=StructureConfig( prompt_driver=CoherePromptDriver( - model="command", + model="command-r", api_key=os.environ['COHERE_API_KEY'], ) ) diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index 1a69d70c0..969163bf6 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -95,6 +95,19 @@ agent = Agent( ) ``` +#### Cohere + +The [Cohere Structure Config](../../reference/griptape/config/cohere_structure_config.md) provides default Drivers for Cohere's APIs. + + +```python +import os +from griptape.config import CohereStructureConfig +from griptape.structures import Agent + +agent = Agent(config=CohereStructureConfig(api_key=os.environ["COHERE_API_KEY"])) +``` + ### Custom Configs You can create your own [StructureConfig](../../reference/griptape/config/structure_config.md) by overriding relevant Drivers. diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index ca199011a..f4a306ebe 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -52,12 +52,18 @@ def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dic def _base_params(self, prompt_stack: PromptStack) -> dict: user_message = prompt_stack.inputs[-1].content - history_messages = [self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs[:-1]] + history_messages = [ + self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs[:-1] if input.content + ] - return { + params = { "message": user_message, - "chat_history": history_messages, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, "max_tokens": self.max_tokens, } + + if history_messages: + params["chat_history"] = history_messages + + return params diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index b3ceb11a4..6e5063b26 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -7,16 +7,16 @@ class TestCoherePromptDriver: @pytest.fixture def mock_client(self, mocker): - mock_client = mocker.patch("cohere.Client").return_value - mock_client.chat.return_value = Mock(text="model-output") + mock_client = mocker.patch("cohere.Client") + mock_client.return_value.chat.return_value = Mock(text="model-output") return mock_client @pytest.fixture def mock_stream_client(self, mocker): - mock_client = mocker.patch("cohere.Client").return_value + mock_client = mocker.patch("cohere.Client") mock_chunk = Mock(text="model-output", event_type="text-generation") - mock_client.chat_stream.return_value = iter([mock_chunk]) + mock_client.return_value.chat_stream.return_value = iter([mock_chunk]) return mock_client @@ -42,8 +42,41 @@ def test_try_run(self, mock_client, prompt_stack): # pyright: ignore # When text_artifact = driver.try_run(prompt_stack) + print(f"Called methods: {mock_client}") # Then + expected_message = "assistant-input" + expected_history = [ + {"role": "ASSISTANT", "text": "generic-input"}, + {"role": "SYSTEM", "text": "system-input"}, + {"role": "USER", "text": "user-input"}, + ] + mock_client.return_value.chat.assert_called_once_with( + message=expected_message, + temperature=driver.temperature, + stop_sequences=driver.tokenizer.stop_sequences, + max_tokens=driver.max_tokens, + chat_history=expected_history, + ) + assert text_artifact.value == "model-output" + + def test_try_run_no_history(self, mock_client, prompt_stack): + # Given + prompt_stack_no_history = PromptStack() + prompt_stack_no_history.add_user_input("user-input") + driver = CoherePromptDriver(model="command", api_key="api-key") + + # When + text_artifact = driver.try_run(prompt_stack_no_history) + + # Then + expected_message = "user-input" + mock_client.return_value.chat.assert_called_once_with( + message=expected_message, + temperature=driver.temperature, + stop_sequences=driver.tokenizer.stop_sequences, + max_tokens=driver.max_tokens, + ) assert text_artifact.value == "model-output" def test_try_stream_run(self, mock_stream_client, prompt_stack): # pyright: ignore