Skip to content

Commit

Permalink
Add BaseGraphStoreDriver and FalkorDBGraphStoreDriver with tests and …
Browse files Browse the repository at this point in the history
…documentation, along with changed to CHANGELOG
  • Loading branch information
Kornspan committed Jul 22, 2024
1 parent 0df5450 commit 5ede77b
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 92 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
### Added
- `BaseGraphStoreDriver` as the base class for all graph store drivers.
- `FalkorDBGraphStoreDriver` to support the FalkorDB graph database in Griptape.
- Unit tests for `BaseGraphStoreDriver` and `FalkorDBGraphStoreDriver` to ensure their functionality.
- Documentation for `BaseGraphStoreDriver` and `FalkorDBGraphStoreDriver` including examples on how to use them.
- `falkordb` extra and dependency for the `FalkorDBGraphStoreDriver`.

### Added
- Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, and `CoherePromptDriver`.
- `OllamaEmbeddingDriver` for generating embeddings with Ollama.
Expand Down
53 changes: 53 additions & 0 deletions docs/griptape-framework/drivers/graph-store-drivers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
## Overview

Griptape provides a way to build drivers for graph databases where graph data can be stored and queried. Every graph store driver implements the following methods:

- `upsert_node()` for updating or inserting a new node into graph databases.
- `upsert_node_artifacts()` for updating or inserting multiple nodes into graph databases.
- `upsert_edge()` for updating and inserting new edges between nodes in graph databases.
- `query_nodes()` for querying nodes in graph databases.
- `query_edges()` for querying edges in graph databases.

Each graph driver extends the `BaseGraphStoreDriver`, which provides additional utility methods to manage graph structures.

!!! info
When working with graph database indexes with Griptape drivers, ensure that your schema supports the required node and edge properties. Check the documentation for your graph database on how to create and manage graph schemas.

## FalkorDB

!!! info
This driver requires the `drivers-graph-falkordb` [extra](../index.md#extras).

The [FalkorDBGraphStoreDriver](../../reference/griptape/drivers/graph/falkordb_graph_store_driver.md) supports the [FalkorDB graph database](https://www.falkordb.com/).

Here is an example of how the driver can be used to load and query information in a FalkorDB cluster:

```python
import os
from griptape.drivers.graph import FalkorDBGraphStoreDriver

# Initialize the FalkorDB driver
falkordb_driver = FalkorDBGraphStoreDriver(
url=os.environ["FALKORDB_URL"],
api_key=os.environ["FALKORDB_API_KEY"]
)

# Example node data
node_data = {
"id": "1",
"label": "Person",
"properties": {
"name": "Alice",
"age": 30
}
}

# Upsert a node
falkordb_driver.upsert_node(node_data)

# Query nodes
results = falkordb_driver.query_nodes("MATCH (n:Person) RETURN n")

for result in results:
print(result)
```
12 changes: 7 additions & 5 deletions griptape/drivers/graph/base_graph_store_driver.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations

import uuid
from abc import ABC, abstractmethod
from concurrent import futures
from dataclasses import dataclass
from typing import Any, Callable, Optional
from attrs import define, field, Factory

from attrs import Factory, define, field

from griptape import utils
from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact
from griptape.artifacts import BaseArtifact, ListArtifact
from griptape.mixins import SerializableMixin


@define
class BaseGraphStoreDriver(SerializableMixin, ABC):
DEFAULT_QUERY_COUNT = 5
Expand All @@ -31,9 +35,7 @@ def to_artifact(self) -> BaseArtifact:
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), kw_only=True
)

def upsert_artifacts(
self, artifacts: dict[str, list[BaseArtifact]], meta: Optional[dict] = None, **kwargs
) -> None:
def upsert_artifacts(self, artifacts: dict[str, list[BaseArtifact]], meta: Optional[dict] = None, **kwargs) -> None:
with self.futures_executor_fn() as executor:
utils.execute_futures_dict(
{
Expand Down
33 changes: 21 additions & 12 deletions griptape/drivers/graph/falkordb_graph_store_driver.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional
from typing import Any, Optional

import redis
from falkordb import FalkorDB
from base_graph_store_driver import BaseGraphStoreDriver

from .base_graph_store_driver import BaseGraphStoreDriver

logger = logging.getLogger(__name__)


class FalkorDBGraphStoreDriver(BaseGraphStoreDriver):
"""FalkorDB Graph Store Driver with triplet handling, schema management, and relationship mapping."""

Expand All @@ -18,7 +23,7 @@ def __init__(self, url: str, database: str = "falkor", node_label: str = "Entity
if not self.index_exists("id"):
self._driver.query(f"CREATE INDEX FOR (n:`{self._node_label}`) ON (n.id)")
except redis.exceptions.ResponseError as e:
if 'already indexed' in str(e):
if "already indexed" in str(e):
logger.warning("Index on 'id' already exists: %s", e)
else:
raise e
Expand All @@ -40,14 +45,16 @@ def index_exists(self, attribute: str) -> bool:
result = self._driver.query(query)
return result.result_set[0][0] > 0

def get(self, subj: str) -> List[List[str]]:
def get(self, subj: str) -> list[list[str]]:
"""Get triplets for a given subject."""
result = self._driver.query(self.get_query, params={"subj": subj})
return result.result_set

def get_rel_map(self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30) -> Dict[str, List[List[str]]]:
def get_rel_map(
self, subjs: Optional[list[str]] = None, depth: int = 2, limit: int = 30
) -> dict[str, list[list[str]]]:
"""Get flat relationship map."""
rel_map: Dict[Any, List[Any]] = {}
rel_map: dict[Any, list[Any]] = {}
if subjs is None or len(subjs) == 0:
return rel_map

Expand Down Expand Up @@ -75,7 +82,7 @@ def get_rel_map(self, subjs: Optional[List[str]] = None, depth: int = 2, limit:
path.append(edge.relation)
path.append(dest_id)

paths = rel_map[subj_id] if subj_id in rel_map else []
paths = rel_map.get(subj_id, [])
paths.append(path)
rel_map[subj_id] = paths

Expand Down Expand Up @@ -135,26 +142,26 @@ def refresh_schema(self) -> None:
Relationships: {relationships}
"""

def get_schema(self, refresh: bool = False) -> str:
def get_schema(self, *, refresh: bool = False) -> str:
"""Get the schema of the FalkorDBGraph store."""
if self.schema and not refresh:
return self.schema
self.refresh_schema()
logger.debug(f"get_schema() schema:\n{self.schema}")
return self.schema

def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> Any:
def query(self, query: str, params: Optional[dict[str, Any]] = None) -> Any:
"""Execute a query on the database."""
result = self._driver.query(query, params=params)
return result.result_set

def create_connection(self, connection_params):
def create_connection(self, connection_params: dict) -> Optional[FalkorDB]:
"""Create a connection to FalkorDB."""
try:
connection = FalkorDB(**connection_params)
return connection
except Exception as e:
print(f"Error connecting to FalkorDB: {e}")
logger.error(f"Error connecting to FalkorDB: {e}")
return None

# Implement abstract methods
Expand All @@ -176,6 +183,8 @@ def load_entries(self, namespace: Optional[str] = None) -> list[BaseGraphStoreDr
entries = [BaseGraphStoreDriver.Entry(id=node.id, properties=node.properties) for node in result]
return entries

def upsert_node(self, node_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs) -> str:
def upsert_node(
self, node_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs
) -> str:
self.upsert_triplet(node_id, "", "")
return node_id
75 changes: 0 additions & 75 deletions griptape/drivers/graph/testgraphstore.py

This file was deleted.

1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ nav:
- Text to Speech Drivers: "griptape-framework/drivers/text-to-speech-drivers.md"
- Audio Transcription Drivers: "griptape-framework/drivers/audio-transcription-drivers.md"
- Web Search Drivers: "griptape-framework/drivers/web-search-drivers.md"
- Graph Store Drivers: "griptape-framework/drivers/graph-store-drivers.md"
- Data:
- Overview: "griptape-framework/data/index.md"
- Artifacts: "griptape-framework/data/artifacts.md"
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ pypdf = {version = "^3.9", optional = true}
pillow = {version = "^10.2.0", optional = true}
mail-parser = {version = "^3.15.0", optional = true}
filetype = {version = "^1.2", optional = true}
falkordb = { version = "*", optional = true }

[tool.poetry.extras]
drivers-prompt-cohere = ["cohere"]
Expand Down Expand Up @@ -141,6 +142,8 @@ loaders-email = ["mail-parser"]
loaders-audio = ["filetype"]
loaders-sql = ["sqlalchemy"]

drivers-graph-falkordb = ["falkordb"]

all = [
# drivers
"cohere",
Expand Down
Empty file.
Loading

0 comments on commit 5ede77b

Please sign in to comment.