Skip to content

Commit

Permalink
feat: Use unwind for batch edge save and add unit tests for get_graph…
Browse files Browse the repository at this point in the history
…_from_model

* feat: adds some unit tests for get_graph_from_model

* feat: updates neo4j add_edges cypher and deletes shallow get_graph_from_model

* fix: fixing merge conflict false resolve

* chore: deletes old only_root unit test
  • Loading branch information
hajdul88 authored Jan 31, 2025
1 parent a79f713 commit f843c25
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 11 deletions.
2 changes: 1 addition & 1 deletion cognee/api/v1/cognify/cognify_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ async def get_default_tasks(
summarization_model=cognee_config.summarization_model,
task_config={"batch_size": 10},
),
Task(add_data_points, only_root=True, task_config={"batch_size": 10}),
Task(add_data_points, task_config={"batch_size": 10}),
Task(store_descriptive_metrics),
]
except Exception as error:
Expand Down
20 changes: 14 additions & 6 deletions cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,20 @@ async def add_edge(

async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
query = """
UNWIND $edges AS edge
MATCH (from_node {id: edge.from_node})
MATCH (to_node {id: edge.to_node})
CALL apoc.create.relationship(from_node, edge.relationship_name, edge.properties, to_node) YIELD rel
RETURN rel
"""
UNWIND $edges AS edge
MATCH (from_node {id: edge.from_node})
MATCH (to_node {id: edge.to_node})
CALL apoc.merge.relationship(
from_node,
edge.relationship_name,
{
source_node_id: edge.from_node,
target_node_id: edge.to_node
},
edge.properties,
to_node
) YIELD rel
RETURN rel"""

edges = [
{
Expand Down
3 changes: 1 addition & 2 deletions cognee/modules/graph/utils/get_graph_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ async def get_graph_from_model(
added_nodes: dict,
added_edges: dict,
visited_properties: dict = None,
only_root=False,
include_root=True,
):
if str(data_point.id) in added_nodes:
Expand Down Expand Up @@ -98,7 +97,7 @@ async def get_graph_from_model(
)
added_edges[str(edge_key)] = True

if str(field_value.id) in added_nodes or only_root:
if str(field_value.id) in added_nodes:
continue

property_nodes, property_edges = await get_graph_from_model(
Expand Down
3 changes: 1 addition & 2 deletions cognee/tasks/storage/add_data_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .index_data_points import index_data_points


async def add_data_points(data_points: list[DataPoint], only_root=False):
async def add_data_points(data_points: list[DataPoint]):
nodes = []
edges = []

Expand All @@ -20,7 +20,6 @@ async def add_data_points(data_points: list[DataPoint], only_root=False):
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
only_root=only_root,
)
for data_point in data_points
]
Expand Down
120 changes: 120 additions & 0 deletions cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import pytest
import asyncio
import random
from typing import List
from uuid import NAMESPACE_OID, uuid5
from uuid import uuid4

from IPython.utils.wildcard import is_type

from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.models.Entity import Entity, EntityType
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.data.processing.document_types import Document
from cognee.modules.graph.utils import get_graph_from_model


@pytest.mark.asyncio
async def test_get_graph_from_model_basic_initialization():
"""Test the basic behavior of get_graph_from_model with a simple data point - without connection."""
data_point = DataPoint(id=uuid4(), attributes={"name": "Node1"})
added_nodes = {}
added_edges = {}
visited_properties = {}

nodes, edges = await get_graph_from_model(
data_point, added_nodes, added_edges, visited_properties
)

assert len(nodes) == 1
assert len(edges) == 0
assert str(data_point.id) in added_nodes


@pytest.mark.asyncio
async def test_get_graph_from_model_with_single_neighbor():
"""Test the behavior of get_graph_from_model when a data point has a single DataPoint property."""
type_node = EntityType(
id=uuid4(),
name="Vehicle",
description="This is a Vehicle node",
)

entity_node = Entity(
id=uuid4(),
name="Car",
is_a=type_node,
description="This is a car node",
)
added_nodes = {}
added_edges = {}
visited_properties = {}

nodes, edges = await get_graph_from_model(
entity_node, added_nodes, added_edges, visited_properties
)

assert len(nodes) == 2
assert len(edges) == 1
assert str(entity_node.id) in added_nodes
assert str(type_node.id) in added_nodes
assert (str(entity_node.id) + str(type_node.id) + "is_a") in added_edges


@pytest.mark.asyncio
async def test_get_graph_from_model_with_multiple_nested_connections():
"""Test the behavior of get_graph_from_model when a data point has multiple nested DataPoint property."""
type_node = EntityType(
id=uuid4(),
name="Transportation tool",
description="This is a Vehicle node",
)

entity_node_1 = Entity(
id=uuid4(),
name="Car",
is_a=type_node,
description="This is a car node",
)

entity_node_2 = Entity(
id=uuid4(),
name="Bus",
is_a=type_node,
description="This is a bus node",
)

document = Document(
name="main_document", raw_data_location="home/", metadata_id=uuid4(), mime_type="test"
)

chunk = DocumentChunk(
id=uuid4(),
word_count=8,
chunk_index=0,
cut_type="test",
text="The car and the bus are transportation tools",
is_part_of=document,
contains=[entity_node_1, entity_node_2],
)

added_nodes = {}
added_edges = {}
visited_properties = {}

nodes, edges = await get_graph_from_model(chunk, added_nodes, added_edges, visited_properties)

assert len(nodes) == 5
assert len(edges) == 5

assert str(entity_node_1.id) in added_nodes
assert str(entity_node_2.id) in added_nodes
assert str(type_node.id) in added_nodes
assert str(document.id) in added_nodes
assert str(chunk.id) in added_nodes

assert (str(entity_node_1.id) + str(type_node.id) + "is_a") in added_edges
assert (str(entity_node_2.id) + str(type_node.id) + "is_a") in added_edges
assert (str(chunk.id) + str(document.id) + "is_part_of") in added_edges
assert (str(chunk.id) + str(entity_node_1.id) + "contains") in added_edges
assert (str(chunk.id) + str(entity_node_2.id) + "contains") in added_edges

0 comments on commit f843c25

Please sign in to comment.