Skip to content

Commit

Permalink
BaseVectorStoreDriver.upsert_text_artifacts concurrency bugfix (#1074)
Browse files Browse the repository at this point in the history
  • Loading branch information
vasinov authored Aug 19, 2024
1 parent a931176 commit 3aaeb6e
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 52 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `JsonArtifact` for handling de/seralization of values.
- `Chat.logger_level` for setting what the `Chat` utility sets the logger level to.
- `FuturesExecutorMixin` to DRY up and optimize concurrent code across multiple classes.
- `utils.execute_futures_list_dict` for executing a dict of lists of futures.

### Changed
- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`.
Expand Down Expand Up @@ -58,6 +59,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- `JsonExtractionEngine` failing to parse json when the LLM outputs more than just the json.
- Exception when adding `ErrorArtifact`'s to the Prompt Stack.
- Concurrency bug in `BaseVectorStoreDriver.upsert_text_artifacts`.

## [0.29.2] - 2024-08-16

Expand Down
21 changes: 13 additions & 8 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,20 @@ def upsert_text_artifacts(
],
)
else:
utils.execute_futures_dict(
{
namespace: self.futures_executor.submit(
self.upsert_text_artifact, a, namespace=namespace, meta=meta, **kwargs
futures_dict = {}

for namespace, artifact_list in artifacts.items():
for a in artifact_list:
if not futures_dict.get(namespace):
futures_dict[namespace] = []

futures_dict[namespace].append(
self.futures_executor.submit(
self.upsert_text_artifact, a, namespace=namespace, meta=meta, **kwargs
)
)
for namespace, artifact_list in artifacts.items()
for a in artifact_list
},
)

utils.execute_futures_list_dict(futures_dict)

def upsert_text_artifact(
self,
Expand Down
32 changes: 17 additions & 15 deletions griptape/drivers/vector/local_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,19 @@ def __attrs_post_init__(self) -> None:

if not os.path.isfile(self.persist_file):
with open(self.persist_file, "w") as file:
self.save_entries_to_file(file)
self.__save_entries_to_file(file)

with open(self.persist_file, "r+") as file:
if os.path.getsize(self.persist_file) > 0:
self.entries = self.load_entries_from_file(file)
else:
self.save_entries_to_file(file)

def save_entries_to_file(self, json_file: TextIO) -> None:
with self.thread_lock:
serialized_data = {k: asdict(v) for k, v in self.entries.items()}

json.dump(serialized_data, json_file)
self.__save_entries_to_file(file)

def load_entries_from_file(self, json_file: TextIO) -> dict[str, BaseVectorStoreDriver.Entry]:
data = json.load(json_file)
with self.thread_lock:
data = json.load(json_file)

return {k: BaseVectorStoreDriver.Entry.from_dict(v) for k, v in data.items()}
return {k: BaseVectorStoreDriver.Entry.from_dict(v) for k, v in data.items()}

def upsert_vector(
self,
Expand All @@ -62,7 +57,7 @@ def upsert_vector(
vector_id = vector_id or utils.str_to_hash(str(vector))

with self.thread_lock:
self.entries[self._namespaced_vector_id(vector_id, namespace=namespace)] = self.Entry(
self.entries[self.__namespaced_vector_id(vector_id, namespace=namespace)] = self.Entry(
id=vector_id,
vector=vector,
meta=meta,
Expand All @@ -73,12 +68,12 @@ def upsert_vector(
# TODO: optimize later since it reserializes all entries from memory and stores them in the JSON file
# every time a new vector is inserted
with open(self.persist_file, "w") as file:
self.save_entries_to_file(file)
self.__save_entries_to_file(file)

return vector_id

def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
return self.entries.get(self._namespaced_vector_id(vector_id, namespace=namespace), None)
return self.entries.get(self.__namespaced_vector_id(vector_id, namespace=namespace), None)

def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace]
Expand All @@ -100,8 +95,9 @@ def query(
entries = self.entries

entries_and_relatednesses = [
(entry, self.relatedness_fn(query_embedding, entry.vector)) for entry in entries.values()
(entry, self.relatedness_fn(query_embedding, entry.vector)) for entry in list(entries.values())
]

entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True)

result = [
Expand All @@ -120,5 +116,11 @@ def query(
def delete_vector(self, vector_id: str) -> NoReturn:
raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

def _namespaced_vector_id(self, vector_id: str, *, namespace: Optional[str]) -> str:
def __save_entries_to_file(self, json_file: TextIO) -> None:
with self.thread_lock:
serialized_data = {k: asdict(v) for k, v in self.entries.items()}

json.dump(serialized_data, json_file)

def __namespaced_vector_id(self, vector_id: str, *, namespace: Optional[str]) -> str:
return vector_id if namespace is None else f"{namespace}-{vector_id}"
22 changes: 8 additions & 14 deletions griptape/mixins/futures_executor_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import threading
from abc import ABC
from concurrent import futures
from typing import Callable, Optional
Expand All @@ -13,20 +12,15 @@ class FuturesExecutorMixin(ABC):
futures_executor_fn: Callable[[], futures.Executor] = field(
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()),
)
thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock()))

_futures_executor: Optional[futures.Executor] = field(init=False, default=None)
futures_executor: Optional[futures.Executor] = field(
default=Factory(lambda self: self.futures_executor_fn(), takes_self=True)
)

@property
def futures_executor(self) -> futures.Executor:
with self.thread_lock:
if self._futures_executor is None:
self._futures_executor = self.futures_executor_fn()
def __del__(self) -> None:
executor = self.futures_executor

return self._futures_executor
if executor is not None:
self.futures_executor = None

def __del__(self) -> None:
with self.thread_lock:
if self._futures_executor:
self._futures_executor.shutdown(wait=True)
self._futures_executor = None
executor.shutdown(wait=True)
6 changes: 1 addition & 5 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,7 @@ def run(self) -> BaseArtifact:
return ErrorArtifact("no tool output")

def execute_actions(self, actions: list[ToolAction]) -> list[tuple[str, BaseArtifact]]:
results = utils.execute_futures_dict(
{a.tag: self.futures_executor.submit(self.execute_action, a) for a in actions}
)

return list(results.values())
return utils.execute_futures_list([self.futures_executor.submit(self.execute_action, a) for a in actions])

def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]:
if action.tool is not None:
Expand Down
4 changes: 2 additions & 2 deletions griptape/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from .python_runner import PythonRunner
from .command_runner import CommandRunner
from .chat import Chat
from .futures import execute_futures_dict
from .futures import execute_futures_list
from .futures import execute_futures_dict, execute_futures_list, execute_futures_list_dict
from .token_counter import TokenCounter
from .dict_utils import remove_null_values_in_dict_recursively, dict_merge, remove_key_in_dict_recursively
from .file_utils import load_file, load_files
Expand Down Expand Up @@ -37,6 +36,7 @@ def minify_json(value: str) -> str:
"is_dependency_installed",
"execute_futures_dict",
"execute_futures_list",
"execute_futures_list_dict",
"TokenCounter",
"remove_null_values_in_dict_recursively",
"dict_merge",
Expand Down
6 changes: 6 additions & 0 deletions griptape/utils/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,9 @@ def execute_futures_list(fs_list: list[futures.Future[T]]) -> list[T]:
futures.wait(fs_list, timeout=None, return_when=futures.ALL_COMPLETED)

return [future.result() for future in fs_list]


def execute_futures_list_dict(fs_dict: dict[str, list[futures.Future[T]]]) -> dict[str, list[T]]:
execute_futures_list([item for sublist in fs_dict.values() for item in sublist])

return {key: [f.result() for f in fs] for key, fs in fs_dict.items()}
13 changes: 13 additions & 0 deletions tests/unit/drivers/vector/test_local_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,16 @@ def test_upsert_text_artifacts_list(self, driver):

assert len(driver.load_artifacts(namespace="foo")) == 0
assert len(driver.load_artifacts()) == 2

def test_upsert_text_artifacts_stress_test(self, driver):
driver.upsert_text_artifacts(
{
"test1": [TextArtifact(f"foo-{i}") for i in range(0, 1000)],
"test2": [TextArtifact(f"foo-{i}") for i in range(0, 1000)],
"test3": [TextArtifact(f"foo-{i}") for i in range(0, 1000)],
}
)

assert len(driver.query("foo", namespace="test1")) == 1000
assert len(driver.query("foo", namespace="test2")) == 1000
assert len(driver.query("foo", namespace="test3")) == 1000
6 changes: 0 additions & 6 deletions tests/unit/tools/test_vector_store_tool.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import pytest

from griptape.artifacts import ListArtifact, TextArtifact
from griptape.drivers import LocalVectorStoreDriver
from griptape.tools import VectorStoreTool
from tests.mocks.mock_embedding_driver import MockEmbeddingDriver


class TestVectorStoreTool:
@pytest.fixture(autouse=True)
def _mock_try_run(self, mocker):
mocker.patch("griptape.drivers.OpenAiEmbeddingDriver.try_embed_chunk", return_value=[0, 1])

def test_search(self):
driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver())
tool = VectorStoreTool(description="Test", vector_store_driver=driver)
Expand Down
17 changes: 15 additions & 2 deletions tests/unit/utils/test_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,21 @@ def test_execute_futures_list(self):
[executor.submit(self.foobar, "foo"), executor.submit(self.foobar, "baz")]
)

assert result[0] == "foo-bar"
assert result[1] == "baz-bar"
assert set(result) == {"foo-bar", "baz-bar"}

def test_execute_futures_list_dict(self):
with futures.ThreadPoolExecutor() as executor:
result = utils.execute_futures_list_dict(
{
"test1": [executor.submit(self.foobar, f"foo-{i}") for i in range(0, 1000)],
"test2": [executor.submit(self.foobar, f"foo-{i}") for i in range(0, 1000)],
"test3": [executor.submit(self.foobar, f"foo-{i}") for i in range(0, 1000)],
}
)

assert len(result["test1"]) == 1000
assert len(result["test2"]) == 1000
assert len(result["test3"]) == 1000

def foobar(self, foo):
return f"{foo}-bar"

0 comments on commit 3aaeb6e

Please sign in to comment.