Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RedisVectorStoreDriver bugs #782

Merged
merged 3 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Default behavior of OpenAiStructureConfig to utilize `gpt-4o` for prompt_driver.

### Fixed
- Honor `namespace` in `RedisVectorStoreDriver.query()`.
- Correctly set the `meta`, `score`, and `vector` fields of query result returned from `RedisVectorStoreDriver.query()`.

## [0.25.1] - 2024-05-09
### Added
- Optional event batching on Event Listener Drivers.
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/drivers/vector-store-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ print(result)

The format for creating a vector index should be similar to the following:
```
FT.CREATE idx:griptape ON hash PREFIX 1 "griptape:" SCHEMA tag TAG vector VECTOR FLAT 6 TYPE FLOAT32 DIM 1536 DISTANCE_METRIC COSINE
FT.CREATE idx:griptape ON hash PREFIX 1 "griptape:" SCHEMA namespace TAG vector VECTOR FLAT 6 TYPE FLOAT32 DIM 1536 DISTANCE_METRIC COSINE
```

## OpenSearch Vector Store Driver
Expand Down
12 changes: 8 additions & 4 deletions griptape/drivers/vector/redis_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def upsert_vector(
mapping["vector"] = np.array(vector, dtype=np.float32).tobytes()
mapping["vec_string"] = bytes_vector

if namespace:
mapping["namespace"] = namespace

if meta:
mapping["metadata"] = json.dumps(meta)

Expand Down Expand Up @@ -120,8 +123,9 @@ def query(

vector = self.embedding_driver.embed_string(query)

filter_expression = f"(@namespace:{{{namespace}}})" if namespace else "*"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the namespace field doesn't exist?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the namespace does not exist in the index schema, then passing namespace to query will never find any results, including when it belongs to a tool.

So in my example above, driver.query("What is griptape?", namespace="griptape") would produce 0 results instead of 1 and the VectorStoreClient tool (configured with a namespace) will never find any results.

I think we might be able to inspect the schema to either change this behavior or provide a warning to prevent people from shooting their feet. Not every query, on initialization. I'll see what I can do. Let me know if you have other ideas.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. So, when no namespace parameter is specified and we add the @namespace:* filter, will it return empty results as well? If so, we should probably not include that filter, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no namespace is provided, filter_expression will evaluate to '*' rather than @namespace:* (there's a ternary). Providing no namespace will search through all vectors in the index just as before.

This is illustrated in the example program via driver.query("What is griptape?"), which returns a result.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I misread the code, sorry! Let's ship it!

query_expression = (
Query(f"*=>[KNN {count or 10} @vector $vector as score]")
Query(f"{filter_expression}=>[KNN {count or 10} @vector $vector as score]")
.sort_by("score")
.return_fields("id", "score", "metadata", "vec_string")
.paging(0, count or 10)
Expand All @@ -134,15 +138,15 @@ def query(

query_results = []
for document in results:
metadata = getattr(document, "metadata", None)
metadata = json.loads(document.metadata) if hasattr(document, "metadata") else None
namespace = document.id.split(":")[0] if ":" in document.id else None
vector_id = document.id.split(":")[1] if ":" in document.id else document.id
vector_float_list = json.loads(document["vec_string"]) if include_vectors else None
vector_float_list = json.loads(document.vec_string) if include_vectors else None
query_results.append(
BaseVectorStoreDriver.QueryResult(
id=vector_id,
vector=vector_float_list,
score=float(document["score"]),
score=float(document.score),
meta=metadata,
namespace=namespace,
)
Expand Down
80 changes: 65 additions & 15 deletions tests/unit/drivers/vector/test_redis_vector_store_driver.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from unittest.mock import MagicMock
import pytest
import redis
from tests.mocks.mock_embedding_driver import MockEmbeddingDriver
Expand All @@ -6,43 +7,92 @@

class TestRedisVectorStorageDriver:
@pytest.fixture(autouse=True)
def mock_redis(self, mocker):
fake_hgetall_response = {b"vector": b"\x00\x00\x80?\x00\x00\x00@\x00\x00@@", b"metadata": b'{"foo": "bar"}'}
def mock_client(self, mocker):
return mocker.patch("redis.Redis").return_value

mocker.patch.object(redis.StrictRedis, "hset", return_value=None)
mocker.patch.object(redis.StrictRedis, "hgetall", return_value=fake_hgetall_response)
mocker.patch.object(redis.StrictRedis, "keys", return_value=[b"some_namespace:some_vector_id"])

fake_redisearch = mocker.MagicMock()
fake_redisearch.search = mocker.MagicMock(return_value=mocker.MagicMock(docs=[]))
fake_redisearch.info = mocker.MagicMock(side_effect=Exception("Index not found"))
fake_redisearch.create_index = mocker.MagicMock(return_value=None)
@pytest.fixture
def mock_keys(self, mock_client):
mock_client.keys.return_value = [b"some_vector_id"]
return mock_client.keys

mocker.patch.object(redis.StrictRedis, "ft", return_value=fake_redisearch)
@pytest.fixture
def mock_hgetall(self, mock_client):
mock_client.hgetall.return_value = {
b"vector": b"\x00\x00\x80?\x00\x00\x00@\x00\x00@@",
b"metadata": b'{"foo": "bar"}',
}
return mock_client.hgetall

@pytest.fixture
def driver(self):
return RedisVectorStoreDriver(
host="localhost", port=6379, index="test_index", db=0, embedding_driver=MockEmbeddingDriver()
)

@pytest.fixture
def mock_search(self, mock_client):
mock_client.ft.return_value.search.return_value.docs = [
MagicMock(
id="some_namespace:some_vector_id",
score="0.456198036671",
metadata='{"foo": "bar"}',
vec_string="[1.0, 2.0, 3.0]",
)
]
return mock_client.ft.return_value.search

def test_upsert_vector(self, driver):
assert (
driver.upsert_vector([1.0, 2.0, 3.0], vector_id="some_vector_id", namespace="some_namespace")
== "some_vector_id"
)

def test_load_entry(self, driver):
def test_load_entry(self, driver, mock_hgetall):
entry = driver.load_entry("some_vector_id")
mock_hgetall.assert_called_once_with("some_vector_id")
assert entry.id == "some_vector_id"
assert entry.vector == [1.0, 2.0, 3.0]
assert entry.meta == {"foo": "bar"}

def test_load_entry_with_namespace(self, driver, mock_hgetall):
entry = driver.load_entry("some_vector_id", namespace="some_namespace")
mock_hgetall.assert_called_once_with("some_namespace:some_vector_id")
assert entry.id == "some_vector_id"
assert entry.vector == [1.0, 2.0, 3.0]
assert entry.meta == {"foo": "bar"}

def test_load_entries(self, driver):
def test_load_entries(self, driver, mock_keys, mock_hgetall):
entries = driver.load_entries()
mock_keys.assert_called_once_with("*")
mock_hgetall.assert_called_once_with("some_vector_id")
assert len(entries) == 1
assert entries[0].vector == [1.0, 2.0, 3.0]
assert entries[0].meta == {"foo": "bar"}

def test_load_entries_with_namespace(self, driver, mock_keys, mock_hgetall):
entries = driver.load_entries(namespace="some_namespace")
mock_keys.assert_called_once_with("some_namespace:*")
mock_hgetall.assert_called_once_with("some_namespace:some_vector_id")
assert len(entries) == 1
assert entries[0].vector == [1.0, 2.0, 3.0]
assert entries[0].meta == {"foo": "bar"}

def test_query(self, driver):
assert driver.query("some_vector_id") == []
def test_query(self, driver, mock_search):
results = driver.query("Some query")
mock_search.assert_called_once()
assert len(results) == 1
assert results[0].namespace == "some_namespace"
assert results[0].id == "some_vector_id"
assert results[0].score == 0.456198036671
assert results[0].meta == {"foo": "bar"}
assert results[0].vector is None

def test_query_with_include_vectors(self, driver, mock_search):
results = driver.query("Some query", include_vectors=True)
mock_search.assert_called_once()
assert len(results) == 1
assert results[0].namespace == "some_namespace"
assert results[0].id == "some_vector_id"
assert results[0].score == 0.456198036671
assert results[0].meta == {"foo": "bar"}
assert results[0].vector == [1.0, 2.0, 3.0]
Loading