diff --git a/CHANGELOG.md b/CHANGELOG.md index e7d2f757a..365cfb1d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased +### Added +- `BaseGraphStoreDriver` as the base class for all graph store drivers. +- `FalkorDBGraphStoreDriver` to support the FalkorDB graph database in Griptape. +- Unit tests for `BaseGraphStoreDriver` and `FalkorDBGraphStoreDriver` to ensure their functionality. +- Documentation for `BaseGraphStoreDriver` and `FalkorDBGraphStoreDriver` including examples on how to use them. +- `falkordb` extra and dependency for the `FalkorDBGraphStoreDriver`. + ### Added - Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, and `CoherePromptDriver`. - `OllamaEmbeddingDriver` for generating embeddings with Ollama. diff --git a/docs/griptape-framework/drivers/graph-store-drivers.md b/docs/griptape-framework/drivers/graph-store-drivers.md new file mode 100644 index 000000000..1240bcc92 --- /dev/null +++ b/docs/griptape-framework/drivers/graph-store-drivers.md @@ -0,0 +1,53 @@ +## Overview + +Griptape provides a way to build drivers for graph databases where graph data can be stored and queried. Every graph store driver implements the following methods: + +- `upsert_node()` for updating or inserting a new node into graph databases. +- `upsert_node_artifacts()` for updating or inserting multiple nodes into graph databases. +- `upsert_edge()` for updating and inserting new edges between nodes in graph databases. +- `query_nodes()` for querying nodes in graph databases. +- `query_edges()` for querying edges in graph databases. + +Each graph driver extends the `BaseGraphStoreDriver`, which provides additional utility methods to manage graph structures. + +!!! info + When working with graph database indexes with Griptape drivers, ensure that your schema supports the required node and edge properties. Check the documentation for your graph database on how to create and manage graph schemas. + +## FalkorDB + +!!! info + This driver requires the `drivers-graph-falkordb` [extra](../index.md#extras). + +The [FalkorDBGraphStoreDriver](../../reference/griptape/drivers/graph/falkordb_graph_store_driver.md) supports the [FalkorDB graph database](https://www.falkordb.com/). + +Here is an example of how the driver can be used to load and query information in a FalkorDB cluster: + +```python +import os +from griptape.drivers.graph import FalkorDBGraphStoreDriver + +# Initialize the FalkorDB driver +falkordb_driver = FalkorDBGraphStoreDriver( + url=os.environ["FALKORDB_URL"], + api_key=os.environ["FALKORDB_API_KEY"] +) + +# Example node data +node_data = { + "id": "1", + "label": "Person", + "properties": { + "name": "Alice", + "age": 30 + } +} + +# Upsert a node +falkordb_driver.upsert_node(node_data) + +# Query nodes +results = falkordb_driver.query_nodes("MATCH (n:Person) RETURN n") + +for result in results: + print(result) +``` \ No newline at end of file diff --git a/griptape/drivers/graph/base_graph_store_driver.py b/griptape/drivers/graph/base_graph_store_driver.py index 129a7a032..a76d5b27a 100644 --- a/griptape/drivers/graph/base_graph_store_driver.py +++ b/griptape/drivers/graph/base_graph_store_driver.py @@ -1,14 +1,18 @@ from __future__ import annotations + import uuid from abc import ABC, abstractmethod from concurrent import futures from dataclasses import dataclass from typing import Any, Callable, Optional -from attrs import define, field, Factory + +from attrs import Factory, define, field + from griptape import utils -from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact +from griptape.artifacts import BaseArtifact, ListArtifact from griptape.mixins import SerializableMixin + @define class BaseGraphStoreDriver(SerializableMixin, ABC): DEFAULT_QUERY_COUNT = 5 @@ -31,9 +35,7 @@ def to_artifact(self) -> BaseArtifact: default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), kw_only=True ) - def upsert_artifacts( - self, artifacts: dict[str, list[BaseArtifact]], meta: Optional[dict] = None, **kwargs - ) -> None: + def upsert_artifacts(self, artifacts: dict[str, list[BaseArtifact]], meta: Optional[dict] = None, **kwargs) -> None: with self.futures_executor_fn() as executor: utils.execute_futures_dict( { diff --git a/griptape/drivers/graph/falkordb_graph_store_driver.py b/griptape/drivers/graph/falkordb_graph_store_driver.py index 7205bee84..fed00f622 100644 --- a/griptape/drivers/graph/falkordb_graph_store_driver.py +++ b/griptape/drivers/graph/falkordb_graph_store_driver.py @@ -1,11 +1,16 @@ +from __future__ import annotations + import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional + import redis from falkordb import FalkorDB -from base_graph_store_driver import BaseGraphStoreDriver + +from .base_graph_store_driver import BaseGraphStoreDriver logger = logging.getLogger(__name__) + class FalkorDBGraphStoreDriver(BaseGraphStoreDriver): """FalkorDB Graph Store Driver with triplet handling, schema management, and relationship mapping.""" @@ -18,7 +23,7 @@ def __init__(self, url: str, database: str = "falkor", node_label: str = "Entity if not self.index_exists("id"): self._driver.query(f"CREATE INDEX FOR (n:`{self._node_label}`) ON (n.id)") except redis.exceptions.ResponseError as e: - if 'already indexed' in str(e): + if "already indexed" in str(e): logger.warning("Index on 'id' already exists: %s", e) else: raise e @@ -40,14 +45,16 @@ def index_exists(self, attribute: str) -> bool: result = self._driver.query(query) return result.result_set[0][0] > 0 - def get(self, subj: str) -> List[List[str]]: + def get(self, subj: str) -> list[list[str]]: """Get triplets for a given subject.""" result = self._driver.query(self.get_query, params={"subj": subj}) return result.result_set - def get_rel_map(self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30) -> Dict[str, List[List[str]]]: + def get_rel_map( + self, subjs: Optional[list[str]] = None, depth: int = 2, limit: int = 30 + ) -> dict[str, list[list[str]]]: """Get flat relationship map.""" - rel_map: Dict[Any, List[Any]] = {} + rel_map: dict[Any, list[Any]] = {} if subjs is None or len(subjs) == 0: return rel_map @@ -75,7 +82,7 @@ def get_rel_map(self, subjs: Optional[List[str]] = None, depth: int = 2, limit: path.append(edge.relation) path.append(dest_id) - paths = rel_map[subj_id] if subj_id in rel_map else [] + paths = rel_map.get(subj_id, []) paths.append(path) rel_map[subj_id] = paths @@ -135,7 +142,7 @@ def refresh_schema(self) -> None: Relationships: {relationships} """ - def get_schema(self, refresh: bool = False) -> str: + def get_schema(self, *, refresh: bool = False) -> str: """Get the schema of the FalkorDBGraph store.""" if self.schema and not refresh: return self.schema @@ -143,18 +150,18 @@ def get_schema(self, refresh: bool = False) -> str: logger.debug(f"get_schema() schema:\n{self.schema}") return self.schema - def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> Any: + def query(self, query: str, params: Optional[dict[str, Any]] = None) -> Any: """Execute a query on the database.""" result = self._driver.query(query, params=params) return result.result_set - def create_connection(self, connection_params): + def create_connection(self, connection_params: dict) -> Optional[FalkorDB]: """Create a connection to FalkorDB.""" try: connection = FalkorDB(**connection_params) return connection except Exception as e: - print(f"Error connecting to FalkorDB: {e}") + logger.error(f"Error connecting to FalkorDB: {e}") return None # Implement abstract methods @@ -176,6 +183,8 @@ def load_entries(self, namespace: Optional[str] = None) -> list[BaseGraphStoreDr entries = [BaseGraphStoreDriver.Entry(id=node.id, properties=node.properties) for node in result] return entries - def upsert_node(self, node_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs) -> str: + def upsert_node( + self, node_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs + ) -> str: self.upsert_triplet(node_id, "", "") return node_id diff --git a/griptape/drivers/graph/testgraphstore.py b/griptape/drivers/graph/testgraphstore.py deleted file mode 100644 index ec2657428..000000000 --- a/griptape/drivers/graph/testgraphstore.py +++ /dev/null @@ -1,75 +0,0 @@ -import unittest -from falkordb_graph_store_driver import FalkorDBGraphStoreDriver - -class TestFalkorDBGraphStoreDriver(unittest.TestCase): - def setUp(self): - self.url = "redis://localhost:6379" - self.database = "falkor" - self.node_label = "Entity" - self.driver = FalkorDBGraphStoreDriver(url=self.url, database=self.database, node_label=self.node_label) - - def test_connect(self): - self.assertIsNotNone(self.driver.client) - print("test_connect passed") - - def test_upsert_triplet(self): - subj = "test_subject" - rel = "test_relation" - obj = "test_object" - self.driver.upsert_triplet(subj, rel, obj) - # Verify triplet is upserted - triplets = self.driver.get(subj) - self.assertTrue(any(obj in triplet for triplet in triplets)) - print("test_upsert_triplet passed") - - def test_get_triplets(self): - subj = "test_subject" - rel = "test_relation" - obj = "test_object" - self.driver.upsert_triplet(subj, rel, obj) - triplets = self.driver.get(subj) - self.assertEqual(len(triplets), 1) - self.assertEqual(triplets[0][1], obj) - print("test_get_triplets passed") - - def test_get_rel_map(self): - subj = "test_subject" - rel = "test_relation" - obj = "test_object" - self.driver.upsert_triplet(subj, rel, obj) - subjs = [subj] - rel_map = self.driver.get_rel_map(subjs=subjs) - self.assertIn(subj, rel_map) - print("test_get_rel_map passed") - - def test_delete_triplet(self): - subj = "test_subject" - rel = "test_relation" - obj = "test_object" - self.driver.upsert_triplet(subj, rel, obj) - self.driver.delete(subj, rel, obj) - # Verify triplet is deleted - triplets = self.driver.get(subj) - self.assertFalse(any(obj in triplet for triplet in triplets)) - print("test_delete_triplet passed") - - def test_refresh_schema(self): - self.driver.refresh_schema() - self.assertIn("Properties", self.driver.schema) - self.assertIn("Relationships", self.driver.schema) - print("test_refresh_schema passed") - - def test_get_schema(self): - schema = self.driver.get_schema(refresh=True) - self.assertIn("Properties", schema) - self.assertIn("Relationships", schema) - print("test_get_schema passed") - - def test_query(self): - query = "MATCH (n) RETURN n LIMIT 1" - result = self.driver.query(query) - self.assertIsNotNone(result) - print("test_query passed") - -if __name__ == '__main__': - unittest.main() diff --git a/mkdocs.yml b/mkdocs.yml index 5c603e4d6..4b4099005 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -113,6 +113,7 @@ nav: - Text to Speech Drivers: "griptape-framework/drivers/text-to-speech-drivers.md" - Audio Transcription Drivers: "griptape-framework/drivers/audio-transcription-drivers.md" - Web Search Drivers: "griptape-framework/drivers/web-search-drivers.md" + - Graph Store Drivers: "griptape-framework/drivers/graph-store-drivers.md" - Data: - Overview: "griptape-framework/data/index.md" - Artifacts: "griptape-framework/data/artifacts.md" diff --git a/pyproject.toml b/pyproject.toml index 12fc49199..6ddb64166 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ pypdf = {version = "^3.9", optional = true} pillow = {version = "^10.2.0", optional = true} mail-parser = {version = "^3.15.0", optional = true} filetype = {version = "^1.2", optional = true} +falkordb = { version = "*", optional = true } [tool.poetry.extras] drivers-prompt-cohere = ["cohere"] @@ -141,6 +142,8 @@ loaders-email = ["mail-parser"] loaders-audio = ["filetype"] loaders-sql = ["sqlalchemy"] +drivers-graph-falkordb = ["falkordb"] + all = [ # drivers "cohere", diff --git a/tests/unit/drivers/graph/__init__.py b/tests/unit/drivers/graph/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/drivers/graph/test_falkordb_graph_store_driver.py b/tests/unit/drivers/graph/test_falkordb_graph_store_driver.py new file mode 100644 index 000000000..e692be6f4 --- /dev/null +++ b/tests/unit/drivers/graph/test_falkordb_graph_store_driver.py @@ -0,0 +1,110 @@ +import logging +from unittest.mock import MagicMock + +import pytest + +from griptape.drivers.graph.falkordb_graph_store_driver import FalkorDBGraphStoreDriver + +logger = logging.getLogger(__name__) + + +class TestFalkorDBGraphStoreDriver: + @pytest.fixture(autouse=True) + def mock_client(self, mocker): + mock_client = mocker.patch("falkordb.FalkorDB.from_url").return_value.select_graph.return_value + mock_client.query.return_value = MagicMock(result_set=[[1]]) + return mock_client + + @pytest.fixture() + def driver(self, mock_client): + url = "redis://localhost:6379" + database = "falkor" + node_label = "Entity" + return FalkorDBGraphStoreDriver(url=url, database=database, node_label=node_label) + + def test_connect(self, driver): + assert driver.client is not None + logger.info("test_connect passed") + + def test_upsert_triplet(self, driver, mock_client): + subj = "test_subject" + rel = "test_relation" + obj = "test_object" + driver.upsert_triplet(subj, rel, obj) + mock_client.query.assert_called() + logger.info("test_upsert_triplet passed") + + def test_get_triplets(self, driver, mock_client): + subj = "test_subject" + rel = "test_relation" + obj = "test_object" + mock_client.query.return_value.result_set = [[rel, obj]] + triplets = driver.get(subj) + assert len(triplets) == 1 + assert triplets[0][1] == obj + mock_client.query.assert_called() + logger.info("test_get_triplets passed") + + def test_get_rel_map(self, driver, mock_client): + subj = "test_subject" + rel = "test_relation" + obj = "test_object" + mock_client.query.return_value.result_set = [ + [ + MagicMock( + nodes=lambda: [MagicMock(properties={"id": subj}), MagicMock(properties={"id": obj})], + edges=lambda: [MagicMock(relation=rel)], + ) + ] + ] + subjs = [subj] + rel_map = driver.get_rel_map(subjs=subjs) + logger.debug(f"rel_map: {rel_map}") # Debugging line to print the rel_map + for k, v in rel_map.items(): + logger.debug(f"Key: {k}, Value: {v}") # Debugging line to print keys and values in rel_map + assert subj in rel_map # Ensure the key exists + assert rel_map[subj] == [[rel, obj]] + mock_client.query.assert_called() + logger.info("test_get_rel_map passed") + + def test_delete_triplet(self, driver, mock_client): + subj = "test_subject" + rel = "test_relation" + obj = "test_object" + driver.upsert_triplet(subj, rel, obj) + driver.delete(subj, rel, obj) + mock_client.query.return_value.result_set = [] + triplets = driver.get(subj) + assert not any(obj in triplet for triplet in triplets) + mock_client.query.assert_called() + logger.info("test_delete_triplet passed") + + def test_refresh_schema(self, driver, mock_client): + mock_client.query.side_effect = [ + MagicMock(result_set=[["Property1"], ["Property2"]]), + MagicMock(result_set=[["Relationship1"], ["Relationship2"]]), + ] + driver.refresh_schema() + assert "Properties" in driver.schema + assert "Relationships" in driver.schema + mock_client.query.assert_called() + logger.info("test_refresh_schema passed") + + def test_get_schema(self, driver, mock_client): + mock_client.query.side_effect = [ + MagicMock(result_set=[["Property1"], ["Property2"]]), + MagicMock(result_set=[["Relationship1"], ["Relationship2"]]), + ] + schema = driver.get_schema(refresh=True) + assert "Properties" in schema + assert "Relationships" in schema + mock_client.query.assert_called() + logger.info("test_get_schema passed") + + def test_query(self, driver, mock_client): + query = "MATCH (n) RETURN n LIMIT 1" + mock_client.query.return_value.result_set = [["n"]] + result = driver.query(query) + assert result is not None + mock_client.query.assert_any_call(query, params=None) + logger.info("test_query passed")