Skip to content

Commit

Permalink
Merge branch 'dev' into feature/qdrant_vector_store_driver
Browse files Browse the repository at this point in the history
  • Loading branch information
hkhajgiwale authored Jun 26, 2024
2 parents 55e8cfa + 2a347f3 commit 940ece5
Show file tree
Hide file tree
Showing 23 changed files with 256 additions and 81 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `DuckDuckGoWebSearchDriver` to web search with the DuckDuckGo search SDK.
- `ProxyWebScraperDriver` to web scrape using proxies.
- Parameter `session` on `AmazonBedrockStructureConfig`.
- Parameter `meta` on `TextArtifact`.
- `VectorStoreClient` improvements:
- `VectorStoreClient.query_params` dict for custom query params.
- `VectorStoreClient.process_query_output_fn` for custom query output processing logic.

### Changed
- **BREAKING**: `BaseVectorStoreDriver.upsert_text_artifact()` and `BaseVectorStoreDriver.upsert_text()` use artifact/string values to generate `vector_id` if it wasn't implicitly passed. This change ensures that we don't generate embeddings for the same content every time.
Expand All @@ -45,7 +49,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Merged `BaseVectorStoreDriver.QueryResult` into `BaseVectorStoreDriver.Entry`.
- **BREAKING**: Replaced `query_engine` with `vector_store_driver` in `VectorStoreClient`.
- **BREAKING**: removed parameters `google_api_lang`, `google_api_key`, `google_api_search_id`, `google_api_country` on `WebSearch` in favor of `web_search_driver`.
- **BREAKING**: removed `VectorStoreClient.top_n` and `VectorStoreClient.namespace` in favor of `VectorStoreClient.query_params`.
- `GriptapeCloudKnowledgeBaseClient` migrated to `/search` api.
- Wrapped all future `submit` calls with the `with` block to address future executor shutdown issues.
- Fixed bug in `CoherePromptDriver` to properly handle empty history

## [0.27.1] - 2024-06-20

Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ from griptape.config import StructureConfig
agent = Agent(
config=StructureConfig(
prompt_driver=CoherePromptDriver(
model="command",
model="command-r",
api_key=os.environ['COHERE_API_KEY'],
)
)
Expand Down
13 changes: 13 additions & 0 deletions docs/griptape-framework/structures/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,19 @@ agent = Agent(
)
```

#### Cohere

The [Cohere Structure Config](../../reference/griptape/config/cohere_structure_config.md) provides default Drivers for Cohere's APIs.


```python
import os
from griptape.config import CohereStructureConfig
from griptape.structures import Agent

agent = Agent(config=CohereStructureConfig(api_key=os.environ["COHERE_API_KEY"]))
```

### Custom Configs

You can create your own [StructureConfig](../../reference/griptape/config/structure_config.md) by overriding relevant Drivers.
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-tools/official-tools/vector-store-client.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ vector_store_driver.upsert_text_artifacts(
vector_db = VectorStoreClient(
description="This DB has information about the Griptape Python framework",
vector_store_driver=vector_store_driver,
namespace="griptape",
query_params={"namespace": "griptape"},
off_prompt=True
)

Expand Down
3 changes: 2 additions & 1 deletion griptape/artifacts/text_artifact.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional
from attrs import define, field
from griptape.artifacts import BaseArtifact

Expand All @@ -13,6 +13,7 @@ class TextArtifact(BaseArtifact):
value: str = field(converter=str, metadata={"serializable": True})
encoding: str = field(default="utf-8", kw_only=True)
encoding_error_handler: str = field(default="strict", kw_only=True)
meta: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})
_embedding: list[float] = field(factory=list, kw_only=True)

@property
Expand Down
12 changes: 9 additions & 3 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,18 @@ def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dic
def _base_params(self, prompt_stack: PromptStack) -> dict:
user_message = prompt_stack.inputs[-1].content

history_messages = [self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs[:-1]]
history_messages = [
self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs[:-1] if input.content
]

return {
params = {
"message": user_message,
"chat_history": history_messages,
"temperature": self.temperature,
"stop_sequences": self.tokenizer.stop_sequences,
"max_tokens": self.max_tokens,
}

if history_messages:
params["chat_history"] = history_messages

return params
20 changes: 12 additions & 8 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ def to_artifact(self) -> BaseArtifact:
def upsert_text_artifacts(
self, artifacts: dict[str, list[TextArtifact]], meta: Optional[dict] = None, **kwargs
) -> None:
utils.execute_futures_dict(
{
namespace: self.futures_executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs)
for namespace, artifact_list in artifacts.items()
for a in artifact_list
}
)
with self.futures_executor as executor:
utils.execute_futures_dict(
{
namespace: executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs)
for namespace, artifact_list in artifacts.items()
for a in artifact_list
}
)

def upsert_text_artifact(
self,
Expand Down Expand Up @@ -92,7 +93,10 @@ def upsert_text(
)

def does_entry_exist(self, vector_id: str, namespace: Optional[str] = None) -> bool:
return self.load_entry(vector_id, namespace) is not None
try:
return self.load_entry(vector_id, namespace) is not None
except Exception:
return False

def load_artifacts(self, namespace: Optional[str] = None) -> ListArtifact:
result = self.load_entries(namespace)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ class RelatedQueryGenerationRagModule(BaseQueryRagModule):
def run(self, context: RagContext) -> list[str]:
system_prompt = self.generate_system_template(context.initial_query)

results = utils.execute_futures_list(
[
self.futures_executor.submit(
self.prompt_driver.run, self.generate_query_prompt_stack(system_prompt, "Alternative query: ")
)
for _ in range(self.query_count)
]
)
with self.futures_executor as executor:
results = utils.execute_futures_list(
[
executor.submit(
self.prompt_driver.run, self.generate_query_prompt_stack(system_prompt, "Alternative query: ")
)
for _ in range(self.query_count)
]
)

return [r.value for r in results]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ def run(self, context: RagContext) -> Sequence[TextArtifact]:
all_queries = [context.initial_query] + context.alternative_queries
namespace = self.namespace or context.namespace

results = utils.execute_futures_list(
[
self.futures_executor.submit(self.vector_store_driver.query, query, self.top_n, namespace, False)
for query in all_queries
]
)
with self.futures_executor as executor:
results = utils.execute_futures_list(
[
executor.submit(self.vector_store_driver.query, query, self.top_n, namespace, False)
for query in all_queries
]
)

return [
artifact
Expand Down
9 changes: 5 additions & 4 deletions griptape/engines/rag/stages/query_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ class QueryRagStage(BaseRagStage):
def run(self, context: RagContext) -> RagContext:
logging.info(f"QueryStage: running {len(self.query_generation_modules)} query generation modules in parallel")

results = utils.execute_futures_list(
[self.futures_executor.submit(r.run, context) for r in self.query_generation_modules]
)
with self.futures_executor as executor:
results = utils.execute_futures_list(
[executor.submit(r.run, context) for r in self.query_generation_modules]
)

context.alternative_queries = list(itertools.chain.from_iterable(results))
context.alternative_queries = list(itertools.chain.from_iterable(results))

return context
5 changes: 2 additions & 3 deletions griptape/engines/rag/stages/retrieval_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ class RetrievalRagStage(BaseRagStage):
def run(self, context: RagContext) -> RagContext:
logging.info(f"RetrievalStage: running {len(self.retrieval_modules)} retrieval modules in parallel")

results = utils.execute_futures_list(
[self.futures_executor.submit(r.run, context) for r in self.retrieval_modules]
)
with self.futures_executor as executor:
results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.retrieval_modules])

# flatten the list of lists
results = list(itertools.chain.from_iterable(results))
Expand Down
11 changes: 5 additions & 6 deletions griptape/loaders/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ def load_collection(
# Create a dictionary before actually submitting the jobs to the executor
# to avoid duplicate work.
sources_by_key = {self.to_key(source): source for source in sources}
return execute_futures_dict(
{
key: self.futures_executor.submit(self.load, source, *args, **kwargs)
for key, source in sources_by_key.items()
}
)

with self.futures_executor as executor:
return execute_futures_dict(
{key: executor.submit(self.load, source, *args, **kwargs) for key, source in sources_by_key.items()}
)

def to_key(self, source: Any, *args, **kwargs) -> str:
if isinstance(source, bytes):
Expand Down
5 changes: 2 additions & 3 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,8 @@ def run(self) -> BaseArtifact:
return ErrorArtifact("no tool output")

def execute_actions(self, actions: list[Action]) -> list[tuple[str, BaseArtifact]]:
results = utils.execute_futures_dict(
{a.tag: self.futures_executor.submit(self.execute_action, a) for a in actions}
)
with self.futures_executor as executor:
results = utils.execute_futures_dict({a.tag: executor.submit(self.execute_action, a) for a in actions})

return [r for r in results.values()]

Expand Down
30 changes: 15 additions & 15 deletions griptape/tools/vector_store_client/tool.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,36 @@
from __future__ import annotations
from typing import Optional
from attrs import define, field
from typing import Callable, Any
from attrs import define, field, Factory
from schema import Schema, Literal
from griptape.artifacts import ErrorArtifact
from griptape.artifacts import ErrorArtifact, BaseArtifact
from griptape.artifacts import ListArtifact
from griptape.drivers import BaseVectorStoreDriver
from griptape.tools import BaseTool
from griptape.utils.decorators import activity


@define
@define(kw_only=True)
class VectorStoreClient(BaseTool):
"""
Attributes:
description: LLM-friendly vector DB description.
namespace: Vector storage namespace.
vector_store_driver: `BaseVectorStoreDriver`.
top_n: Max number of results returned for the query engine query.
query_params: Optional dictionary of vector store driver query parameters.
process_query_output_fn: Optional lambda for processing vector store driver query output `Entry`s.
"""

DEFAULT_TOP_N = 5

description: str = field(kw_only=True)
vector_store_driver: BaseVectorStoreDriver = field(kw_only=True)
top_n: int = field(default=DEFAULT_TOP_N, kw_only=True)
namespace: Optional[str] = field(default=None, kw_only=True)
description: str = field()
vector_store_driver: BaseVectorStoreDriver = field()
query_params: dict[str, Any] = field(factory=dict)
process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], BaseArtifact] = field(
default=Factory(lambda: lambda es: ListArtifact([e.to_artifact() for e in es]))
)

@activity(
config={
"description": "Can be used to search a vector database with the following description: {{ _self.description }}",
"description": "Can be used to search a database with the following description: {{ _self.description }}",
"schema": Schema(
{
Literal(
Expand All @@ -38,12 +40,10 @@ class VectorStoreClient(BaseTool):
),
}
)
def search(self, params: dict) -> ListArtifact | ErrorArtifact:
def search(self, params: dict) -> BaseArtifact:
query = params["values"]["query"]

try:
entries = self.vector_store_driver.query(query, namespace=self.namespace, count=self.top_n)

return ListArtifact([e.to_artifact() for e in entries])
return self.process_query_output_fn(self.vector_store_driver.query(query, **self.query_params))
except Exception as e:
return ErrorArtifact(f"error querying vector store: {e}")
7 changes: 4 additions & 3 deletions griptape/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def load_files(paths: list[str], futures_executor: Optional[futures.ThreadPoolEx
if futures_executor is None:
futures_executor = futures.ThreadPoolExecutor()

return utils.execute_futures_dict(
{utils.str_to_hash(str(path)): futures_executor.submit(load_file, path) for path in paths}
)
with futures_executor as executor:
return utils.execute_futures_dict(
{utils.str_to_hash(str(path)): executor.submit(load_file, path) for path in paths}
)
Loading

0 comments on commit 940ece5

Please sign in to comment.