Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BaseVectorStoreDriver.upsert_text_artifacts concurrency bugfix #1074

Merged
merged 18 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `JsonArtifact` for handling de/seralization of values.
- `Chat.logger_level` for setting what the `Chat` utility sets the logger level to.
- `FuturesExecutorMixin` to DRY up and optimize concurrent code across multiple classes.
- `utils.execute_futures_list_dict` for executing a dict of lists of futures.

### Changed
- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`.
Expand Down Expand Up @@ -58,6 +59,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- `JsonExtractionEngine` failing to parse json when the LLM outputs more than just the json.
- Exception when adding `ErrorArtifact`'s to the Prompt Stack.
- Concurrency bug in `BaseVectorStoreDriver.upsert_text_artifacts`.

## [0.29.2] - 2024-08-16

Expand Down
21 changes: 13 additions & 8 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,20 @@ def upsert_text_artifacts(
],
)
else:
utils.execute_futures_dict(
{
namespace: self.futures_executor.submit(
self.upsert_text_artifact, a, namespace=namespace, meta=meta, **kwargs
futures_dict = {}

for namespace, artifact_list in artifacts.items():
for a in artifact_list:
if not futures_dict.get(namespace):
futures_dict[namespace] = []

futures_dict[namespace].append(
self.futures_executor.submit(
self.upsert_text_artifact, a, namespace=namespace, meta=meta, **kwargs
)
)
for namespace, artifact_list in artifacts.items()
for a in artifact_list
},
)

utils.execute_futures_list_dict(futures_dict)

def upsert_text_artifact(
self,
Expand Down
32 changes: 17 additions & 15 deletions griptape/drivers/vector/local_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,19 @@

if not os.path.isfile(self.persist_file):
with open(self.persist_file, "w") as file:
self.save_entries_to_file(file)
self.__save_entries_to_file(file)

with open(self.persist_file, "r+") as file:
if os.path.getsize(self.persist_file) > 0:
self.entries = self.load_entries_from_file(file)
else:
self.save_entries_to_file(file)

def save_entries_to_file(self, json_file: TextIO) -> None:
with self.thread_lock:
serialized_data = {k: asdict(v) for k, v in self.entries.items()}

json.dump(serialized_data, json_file)
self.__save_entries_to_file(file)

Check warning on line 40 in griptape/drivers/vector/local_vector_store_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/vector/local_vector_store_driver.py#L40

Added line #L40 was not covered by tests

def load_entries_from_file(self, json_file: TextIO) -> dict[str, BaseVectorStoreDriver.Entry]:
data = json.load(json_file)
with self.thread_lock:
data = json.load(json_file)

return {k: BaseVectorStoreDriver.Entry.from_dict(v) for k, v in data.items()}
return {k: BaseVectorStoreDriver.Entry.from_dict(v) for k, v in data.items()}

def upsert_vector(
self,
Expand All @@ -62,7 +57,7 @@
vector_id = vector_id or utils.str_to_hash(str(vector))

with self.thread_lock:
self.entries[self._namespaced_vector_id(vector_id, namespace=namespace)] = self.Entry(
self.entries[self.__namespaced_vector_id(vector_id, namespace=namespace)] = self.Entry(
id=vector_id,
vector=vector,
meta=meta,
Expand All @@ -73,12 +68,12 @@
# TODO: optimize later since it reserializes all entries from memory and stores them in the JSON file
# every time a new vector is inserted
with open(self.persist_file, "w") as file:
self.save_entries_to_file(file)
self.__save_entries_to_file(file)

return vector_id

def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
return self.entries.get(self._namespaced_vector_id(vector_id, namespace=namespace), None)
return self.entries.get(self.__namespaced_vector_id(vector_id, namespace=namespace), None)

def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace]
Expand All @@ -100,8 +95,9 @@
entries = self.entries

entries_and_relatednesses = [
(entry, self.relatedness_fn(query_embedding, entry.vector)) for entry in entries.values()
(entry, self.relatedness_fn(query_embedding, entry.vector)) for entry in list(entries.values())
]

entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True)

result = [
Expand All @@ -120,5 +116,11 @@
def delete_vector(self, vector_id: str) -> NoReturn:
raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

def _namespaced_vector_id(self, vector_id: str, *, namespace: Optional[str]) -> str:
def __save_entries_to_file(self, json_file: TextIO) -> None:
with self.thread_lock:
serialized_data = {k: asdict(v) for k, v in self.entries.items()}

json.dump(serialized_data, json_file)

def __namespaced_vector_id(self, vector_id: str, *, namespace: Optional[str]) -> str:
return vector_id if namespace is None else f"{namespace}-{vector_id}"
22 changes: 8 additions & 14 deletions griptape/mixins/futures_executor_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import threading
from abc import ABC
from concurrent import futures
from typing import Callable, Optional
Expand All @@ -13,20 +12,15 @@ class FuturesExecutorMixin(ABC):
futures_executor_fn: Callable[[], futures.Executor] = field(
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()),
)
thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock()))

_futures_executor: Optional[futures.Executor] = field(init=False, default=None)
futures_executor: Optional[futures.Executor] = field(
default=Factory(lambda self: self.futures_executor_fn(), takes_self=True)
)

@property
def futures_executor(self) -> futures.Executor:
with self.thread_lock:
if self._futures_executor is None:
self._futures_executor = self.futures_executor_fn()
def __del__(self) -> None:
executor = self.futures_executor

return self._futures_executor
if executor:
self.futures_executor = None

def __del__(self) -> None:
with self.thread_lock:
if self._futures_executor:
self._futures_executor.shutdown(wait=True)
self._futures_executor = None
executor.shutdown(wait=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None check

if executor := self.futures_executor is not None:
    executor.shutdown(wait=True)
    self.futures_executor = None

6 changes: 1 addition & 5 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,7 @@ def run(self) -> BaseArtifact:
return ErrorArtifact("no tool output")

def execute_actions(self, actions: list[ToolAction]) -> list[tuple[str, BaseArtifact]]:
results = utils.execute_futures_dict(
{a.tag: self.futures_executor.submit(self.execute_action, a) for a in actions}
)

return list(results.values())
return utils.execute_futures_list([self.futures_executor.submit(self.execute_action, a) for a in actions])

def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]:
if action.tool is not None:
Expand Down
4 changes: 2 additions & 2 deletions griptape/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from .python_runner import PythonRunner
from .command_runner import CommandRunner
from .chat import Chat
from .futures import execute_futures_dict
from .futures import execute_futures_list
from .futures import execute_futures_dict, execute_futures_list, execute_futures_list_dict
from .token_counter import TokenCounter
from .dict_utils import remove_null_values_in_dict_recursively, dict_merge, remove_key_in_dict_recursively
from .file_utils import load_file, load_files
Expand Down Expand Up @@ -37,6 +36,7 @@ def minify_json(value: str) -> str:
"is_dependency_installed",
"execute_futures_dict",
"execute_futures_list",
"execute_futures_list_dict",
"TokenCounter",
"remove_null_values_in_dict_recursively",
"dict_merge",
Expand Down
6 changes: 6 additions & 0 deletions griptape/utils/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,9 @@ def execute_futures_list(fs_list: list[futures.Future[T]]) -> list[T]:
futures.wait(fs_list, timeout=None, return_when=futures.ALL_COMPLETED)

return [future.result() for future in fs_list]


def execute_futures_list_dict(fs_dict: dict[str, list[futures.Future[T]]]) -> dict[str, list[T]]:
execute_futures_list([item for sublist in fs_dict.values() for item in sublist])

return {key: [f.result() for f in fs] for key, fs in fs_dict.items()}
13 changes: 13 additions & 0 deletions tests/unit/drivers/vector/test_local_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,16 @@ def test_upsert_text_artifacts_list(self, driver):

assert len(driver.load_artifacts(namespace="foo")) == 0
assert len(driver.load_artifacts()) == 2

def test_upsert_text_artifacts_stress_test(self, driver):
driver.upsert_text_artifacts(
{
"test1": [TextArtifact(f"foo-{i}") for i in range(0, 1000)],
"test2": [TextArtifact(f"foo-{i}") for i in range(0, 1000)],
"test3": [TextArtifact(f"foo-{i}") for i in range(0, 1000)],
}
)

assert len(driver.query("foo", namespace="test1")) == 1000
assert len(driver.query("foo", namespace="test2")) == 1000
assert len(driver.query("foo", namespace="test3")) == 1000
6 changes: 0 additions & 6 deletions tests/unit/tools/test_vector_store_tool.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import pytest

from griptape.artifacts import ListArtifact, TextArtifact
from griptape.drivers import LocalVectorStoreDriver
from griptape.tools import VectorStoreTool
from tests.mocks.mock_embedding_driver import MockEmbeddingDriver


class TestVectorStoreTool:
@pytest.fixture(autouse=True)
def _mock_try_run(self, mocker):
mocker.patch("griptape.drivers.OpenAiEmbeddingDriver.try_embed_chunk", return_value=[0, 1])

def test_search(self):
driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver())
tool = VectorStoreTool(description="Test", vector_store_driver=driver)
Expand Down
17 changes: 15 additions & 2 deletions tests/unit/utils/test_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,21 @@ def test_execute_futures_list(self):
[executor.submit(self.foobar, "foo"), executor.submit(self.foobar, "baz")]
)

assert result[0] == "foo-bar"
assert result[1] == "baz-bar"
assert set(result) == {"foo-bar", "baz-bar"}

def test_execute_futures_list_dict(self):
with futures.ThreadPoolExecutor() as executor:
result = utils.execute_futures_list_dict(
{
"test1": [executor.submit(self.foobar, f"foo-{i}") for i in range(0, 1000)],
"test2": [executor.submit(self.foobar, f"foo-{i}") for i in range(0, 1000)],
"test3": [executor.submit(self.foobar, f"foo-{i}") for i in range(0, 1000)],
}
)

assert len(result["test1"]) == 1000
assert len(result["test2"]) == 1000
assert len(result["test3"]) == 1000

def foobar(self, foo):
return f"{foo}-bar"
Loading