Skip to content

Commit

Permalink
refactor diff retrieval to use more smaller queries
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtmccarty committed Jan 27, 2025
1 parent 7e98238 commit aa379be
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 88 deletions.
39 changes: 17 additions & 22 deletions backend/infrahub/core/diff/query/diff_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
AND ($to_time IS NULL OR diff_root.to_time <= $to_time)
AND ($tracking_id IS NULL OR diff_root.tracking_id = $tracking_id)
AND ($diff_ids IS NULL OR diff_root.uuid IN $diff_ids)
WITH diff_root
ORDER BY diff_root.base_branch, diff_root.diff_branch, diff_root.from_time, diff_root.to_time
// get all the nodes attached to the diffs
OPTIONAL MATCH (diff_root)-[:DIFF_HAS_NODE]->(diff_node:DiffNode)
"""
Expand Down Expand Up @@ -79,26 +77,21 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
self.add_to_query(query=query_filters)

query_2 = """
// group by diff node uuid for pagination
WITH diff_node.uuid AS diff_node_uuid, diff_node.kind AS diff_node_kind, collect([diff_root, diff_node]) AS node_root_tuples
// order by kind and latest label for each diff_node uuid
CALL {
WITH node_root_tuples
UNWIND node_root_tuples AS nrt
WITH nrt[0] AS diff_root, nrt[1] AS diff_node
ORDER BY diff_root.from_time DESC
RETURN diff_node.label AS latest_node_label
LIMIT 1
}
WITH diff_node_kind, node_root_tuples, latest_node_label
ORDER BY diff_node_kind, latest_node_label
WITH diff_root, diff_node
ORDER BY diff_root.base_branch, diff_root.diff_branch, diff_root.from_time, diff_root.to_time, diff_node.uuid
// -------------------------------------
// Limit number of results
// -------------------------------------
SKIP COALESCE($offset, 0)
LIMIT $limit
UNWIND node_root_tuples AS nrt
WITH nrt[0] AS diff_root, nrt[1] AS diff_node
WITH diff_root, diff_node
// -------------------------------------
// Check if more data after this limited group
// -------------------------------------
WITH collect([diff_root, diff_node]) AS limited_results
WITH limited_results, size(limited_results) = $limit AS has_more_data
UNWIND limited_results AS one_result
WITH one_result[0] AS diff_root, one_result[1] AS diff_node, has_more_data
// if depth limit, make sure not to exceed it when traversing linked nodes
WITH diff_root, diff_node
// -------------------------------------
// Retrieve Parents
// -------------------------------------
Expand All @@ -109,12 +102,12 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
ORDER BY size(nodes(parents_path)) DESC
LIMIT 1
}
WITH diff_root, diff_node, parents_path
WITH diff_root, diff_node, has_more_data, parents_path
// -------------------------------------
// Retrieve conflicts
// -------------------------------------
OPTIONAL MATCH (diff_node)-[:DIFF_HAS_CONFLICT]->(diff_node_conflict:DiffConflict)
WITH diff_root, diff_node, parents_path, diff_node_conflict
WITH diff_root, diff_node, has_more_data, parents_path, diff_node_conflict
// -------------------------------------
// Retrieve Attributes
// -------------------------------------
Expand All @@ -128,7 +121,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
RETURN diff_attribute, diff_attr_property, diff_attr_property_conflict
ORDER BY diff_attribute.name, diff_attr_property.property_type
}
WITH diff_root, diff_node, parents_path, diff_node_conflict, collect([diff_attribute, diff_attr_property, diff_attr_property_conflict]) as diff_attributes
WITH diff_root, diff_node, has_more_data, parents_path, diff_node_conflict, collect([diff_attribute, diff_attr_property, diff_attr_property_conflict]) as diff_attributes
// -------------------------------------
// Retrieve Relationships
// -------------------------------------
Expand All @@ -150,6 +143,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
WITH
diff_root,
diff_node,
has_more_data,
parents_path,
diff_node_conflict,
diff_attributes,
Expand All @@ -161,6 +155,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
self.return_labels = [
"diff_root",
"diff_node",
"has_more_data",
"parents_path",
"diff_node_conflict",
"diff_attributes",
Expand Down
98 changes: 51 additions & 47 deletions backend/infrahub/core/diff/repository/deserializer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Iterable

from neo4j.graph import Node as Neo4jNode
from neo4j.graph import Path as Neo4jPath

Expand Down Expand Up @@ -30,37 +28,47 @@ def __init__(self) -> None:
self._diff_node_rel_group_map: dict[tuple[str, str, str], EnrichedDiffRelationship] = {}
self._diff_node_rel_element_map: dict[tuple[str, str, str, str], EnrichedDiffSingleRelationship] = {}
self._diff_prop_map: dict[tuple[str, str, str, str] | tuple[str, str, str, str, str], EnrichedDiffProperty] = {}
# {EnrichedDiffRoot: [(node_uuid, parents_path: Neo4jPath), ...]}
self._parents_path_map: dict[EnrichedDiffRoot, list[tuple[str, Neo4jPath]]] = {}

def _initialize(self) -> None:
def initialize(self) -> None:
self._diff_root_map = {}
self._diff_node_map = {}
self._diff_node_attr_map = {}
self._diff_node_rel_group_map = {}
self._diff_node_rel_element_map = {}
self._diff_prop_map = {}
self._parents_path_map = {}

async def deserialize(
self, database_results: Iterable[QueryResult], include_parents: bool
) -> list[EnrichedDiffRoot]:
self._initialize()
results = list(database_results)
for result in results:
enriched_root = self._deserialize_diff_root(root_node=result.get_node("diff_root"))
node_node = result.get(label="diff_node")
if not isinstance(node_node, Neo4jNode):
continue
enriched_node = self._deserialize_diff_node(node_node=node_node, enriched_root=enriched_root)
node_conflict_node = result.get(label="diff_node_conflict")
if isinstance(node_conflict_node, Neo4jNode) and not enriched_node.conflict:
conflict = self.deserialize_conflict(diff_conflict_node=node_conflict_node)
enriched_node.conflict = conflict
self._deserialize_attributes(result=result, enriched_root=enriched_root, enriched_node=enriched_node)
self._deserialize_relationships(result=result, enriched_root=enriched_root, enriched_node=enriched_node)
def _track_parents_path(self, enriched_root: EnrichedDiffRoot, node_uuid: str, parents_path: Neo4jPath) -> None:
if enriched_root not in self._parents_path_map:
self._parents_path_map[enriched_root] = []
self._parents_path_map[enriched_root].append((node_uuid, parents_path))

async def read_result(self, result: QueryResult, include_parents: bool) -> None:
enriched_root = self._deserialize_diff_root(root_node=result.get_node("diff_root"))
node_node = result.get(label="diff_node")
if not isinstance(node_node, Neo4jNode):
return
enriched_node = self._deserialize_diff_node(node_node=node_node, enriched_root=enriched_root)

if include_parents:
for result in results:
enriched_root = self._deserialize_diff_root(root_node=result.get_node("diff_root"))
self._deserialize_parents(result=result, enriched_root=enriched_root)
parents_path = result.get("parents_path")
if parents_path and isinstance(parents_path, Neo4jPath):
self._track_parents_path(
enriched_root=enriched_root, node_uuid=enriched_node.uuid, parents_path=parents_path
)

node_conflict_node = result.get(label="diff_node_conflict")
if isinstance(node_conflict_node, Neo4jNode) and not enriched_node.conflict:
conflict = self.deserialize_conflict(diff_conflict_node=node_conflict_node)
enriched_node.conflict = conflict
self._deserialize_attributes(result=result, enriched_root=enriched_root, enriched_node=enriched_node)
self._deserialize_relationships(result=result, enriched_root=enriched_root, enriched_node=enriched_node)

async def deserialize(self, include_parents: bool = True) -> list[EnrichedDiffRoot]:
if include_parents:
self._deserialize_parents()

return list(self._diff_root_map.values())

Expand Down Expand Up @@ -117,30 +125,26 @@ def _deserialize_relationships(
conflict = self.deserialize_conflict(diff_conflict_node=property_conflict)
element_property.conflict = conflict

def _deserialize_parents(self, result: QueryResult, enriched_root: EnrichedDiffRoot) -> None:
parents_path = result.get("parents_path")
if not parents_path or not isinstance(parents_path, Neo4jPath):
return

node_uuid = result.get(label="diff_node").get("uuid")

# Remove the node itself from the path
parents_path = parents_path.nodes[1:] # type: ignore[union-attr]

# TODO Ensure the list is even
current_node_uuid = node_uuid
for rel, parent in zip(parents_path[::2], parents_path[1::2]):
enriched_root.add_parent(
node_id=current_node_uuid,
parent_id=parent.get("uuid"),
parent_kind=parent.get("kind"),
parent_label=parent.get("label"),
parent_rel_name=rel.get("name"),
parent_rel_identifier=rel.get("identifier"),
parent_rel_cardinality=RelationshipCardinality(rel.get("cardinality")),
parent_rel_label=rel.get("label"),
)
current_node_uuid = parent.get("uuid")
def _deserialize_parents(self) -> None:
for enriched_root, node_path_tuples in self._parents_path_map.items():
for node_uuid, parents_path in node_path_tuples:
# Remove the node itself from the path
parents_path_slice = parents_path.nodes[1:]

# TODO Ensure the list is even
current_node_uuid = node_uuid
for rel, parent in zip(parents_path_slice[::2], parents_path_slice[1::2]):
enriched_root.add_parent(
node_id=current_node_uuid,
parent_id=parent.get("uuid"),
parent_kind=parent.get("kind"),
parent_label=parent.get("label"),
parent_rel_name=rel.get("name"),
parent_rel_identifier=rel.get("identifier"),
parent_rel_cardinality=RelationshipCardinality(rel.get("cardinality")),
parent_rel_label=rel.get("label"),
)
current_node_uuid = parent.get("uuid")

@classmethod
def _get_str_or_none_property_value(cls, node: Neo4jNode, property_name: str) -> str | None:
Expand Down
71 changes: 52 additions & 19 deletions backend/infrahub/core/diff/repository/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,49 @@ def __init__(self, db: InfrahubDatabase, deserializer: EnrichedDiffDeserializer)
self.db = db
self.deserializer = deserializer

async def _run_get_diff_query(
self,
base_branch_name: str,
diff_branch_names: list[str],
limit: int,
from_time: Timestamp | None = None,
to_time: Timestamp | None = None,
filters: EnrichedDiffQueryFilters | None = None,
offset: int = 0,
include_parents: bool = True,
max_depth: int | None = None,
tracking_id: TrackingId | None = None,
diff_ids: list[str] | None = None,
) -> list[EnrichedDiffRoot]:
self.deserializer.initialize()
has_more_data = True
while has_more_data:
get_query = await EnrichedDiffGetQuery.init(
db=self.db,
base_branch_name=base_branch_name,
diff_branch_names=diff_branch_names,
from_time=from_time,
to_time=to_time,
filters=filters,
max_depth=max_depth,
limit=limit,
offset=offset,
tracking_id=tracking_id,
diff_ids=diff_ids,
)
log.info(f"Beginning enriched diff get query {limit=}, {offset=}")
await get_query.execute(db=self.db)
log.info("Enriched diff get query complete")
last_result = None
for query_result in get_query.get_results():
await self.deserializer.read_result(result=query_result, include_parents=include_parents)
last_result = query_result
has_more_data = False
if last_result:
has_more_data = last_result.get_as_type("has_more_data", bool)
offset += limit
return await self.deserializer.deserialize()

async def get(
self,
base_branch_name: str,
Expand All @@ -62,23 +105,20 @@ async def get(
include_empty: bool = False,
) -> list[EnrichedDiffRoot]:
final_max_depth = config.SETTINGS.database.max_depth_search_hierarchy
query = await EnrichedDiffGetQuery.init(
db=self.db,
limit = limit or int(config.SETTINGS.database.query_size_limit / 10)
diff_roots = await self._run_get_diff_query(
base_branch_name=base_branch_name,
diff_branch_names=diff_branch_names,
from_time=from_time,
to_time=to_time,
filters=EnrichedDiffQueryFilters(**dict(filters or {})),
include_parents=include_parents,
max_depth=final_max_depth,
limit=limit,
offset=offset,
offset=offset or 0,
tracking_id=tracking_id,
diff_ids=diff_ids,
)
await query.execute(db=self.db)
diff_roots = await self.deserializer.deserialize(
database_results=query.get_results(), include_parents=include_parents
)
if not include_empty:
diff_roots = [dr for dr in diff_roots if len(dr.nodes) > 0]
return diff_roots
Expand All @@ -91,30 +131,23 @@ async def get_pairs(
to_time: Timestamp,
) -> list[EnrichedDiffs]:
max_depth = config.SETTINGS.database.max_depth_search_hierarchy
query = await EnrichedDiffGetQuery.init(
db=self.db,
limit = int(config.SETTINGS.database.query_size_limit / 10)
diff_branch_roots = await self._run_get_diff_query(
base_branch_name=base_branch_name,
diff_branch_names=[diff_branch_name],
from_time=from_time,
to_time=to_time,
max_depth=max_depth,
)
await query.execute(db=self.db)
diff_branch_roots = await self.deserializer.deserialize(
database_results=query.get_results(), include_parents=True
limit=limit,
)
diffs_by_uuid = {dbr.uuid: dbr for dbr in diff_branch_roots}
base_partner_query = await EnrichedDiffGetQuery.init(
db=self.db,
base_branch_roots = await self._run_get_diff_query(
base_branch_name=base_branch_name,
diff_branch_names=[base_branch_name],
max_depth=max_depth,
limit=limit,
diff_ids=[d.partner_uuid for d in diffs_by_uuid.values()],
)
await base_partner_query.execute(db=self.db)
base_branch_roots = await self.deserializer.deserialize(
database_results=base_partner_query.get_results(), include_parents=True
)
diffs_by_uuid.update({bbr.uuid: bbr for bbr in base_branch_roots})
return [
EnrichedDiffs(
Expand Down
1 change: 1 addition & 0 deletions backend/tests/unit/core/diff/test_diff_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ async def reset_database(self, db: InfrahubDatabase, default_branch):
@pytest.fixture
def diff_repository(self, db: InfrahubDatabase) -> DiffRepository:
config.SETTINGS.database.max_depth_search_hierarchy = 10
config.SETTINGS.database.query_size_limit = 50
return DiffRepository(db=db, deserializer=EnrichedDiffDeserializer())

def build_diff_node(self, num_sub_fields=2, no_recurse=False) -> EnrichedDiffNode:
Expand Down

0 comments on commit aa379be

Please sign in to comment.