From aa379be532190de9544d09f2a781be5685c60590 Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Fri, 24 Jan 2025 16:32:14 -0800 Subject: [PATCH] refactor diff retrieval to use more smaller queries --- backend/infrahub/core/diff/query/diff_get.py | 39 ++++---- .../core/diff/repository/deserializer.py | 98 ++++++++++--------- .../core/diff/repository/repository.py | 71 ++++++++++---- .../unit/core/diff/test_diff_repository.py | 1 + 4 files changed, 121 insertions(+), 88 deletions(-) diff --git a/backend/infrahub/core/diff/query/diff_get.py b/backend/infrahub/core/diff/query/diff_get.py index 86cfa0e18e..0fd0f25f56 100644 --- a/backend/infrahub/core/diff/query/diff_get.py +++ b/backend/infrahub/core/diff/query/diff_get.py @@ -17,8 +17,6 @@ AND ($to_time IS NULL OR diff_root.to_time <= $to_time) AND ($tracking_id IS NULL OR diff_root.tracking_id = $tracking_id) AND ($diff_ids IS NULL OR diff_root.uuid IN $diff_ids) - WITH diff_root - ORDER BY diff_root.base_branch, diff_root.diff_branch, diff_root.from_time, diff_root.to_time // get all the nodes attached to the diffs OPTIONAL MATCH (diff_root)-[:DIFF_HAS_NODE]->(diff_node:DiffNode) """ @@ -79,26 +77,21 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: self.add_to_query(query=query_filters) query_2 = """ - // group by diff node uuid for pagination - WITH diff_node.uuid AS diff_node_uuid, diff_node.kind AS diff_node_kind, collect([diff_root, diff_node]) AS node_root_tuples - // order by kind and latest label for each diff_node uuid - CALL { - WITH node_root_tuples - UNWIND node_root_tuples AS nrt - WITH nrt[0] AS diff_root, nrt[1] AS diff_node - ORDER BY diff_root.from_time DESC - RETURN diff_node.label AS latest_node_label - LIMIT 1 - } - WITH diff_node_kind, node_root_tuples, latest_node_label - ORDER BY diff_node_kind, latest_node_label + WITH diff_root, diff_node + ORDER BY diff_root.base_branch, diff_root.diff_branch, diff_root.from_time, diff_root.to_time, diff_node.uuid + // ------------------------------------- + // Limit number of results + // ------------------------------------- SKIP COALESCE($offset, 0) LIMIT $limit - UNWIND node_root_tuples AS nrt - WITH nrt[0] AS diff_root, nrt[1] AS diff_node - WITH diff_root, diff_node + // ------------------------------------- + // Check if more data after this limited group + // ------------------------------------- + WITH collect([diff_root, diff_node]) AS limited_results + WITH limited_results, size(limited_results) = $limit AS has_more_data + UNWIND limited_results AS one_result + WITH one_result[0] AS diff_root, one_result[1] AS diff_node, has_more_data // if depth limit, make sure not to exceed it when traversing linked nodes - WITH diff_root, diff_node // ------------------------------------- // Retrieve Parents // ------------------------------------- @@ -109,12 +102,12 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: ORDER BY size(nodes(parents_path)) DESC LIMIT 1 } - WITH diff_root, diff_node, parents_path + WITH diff_root, diff_node, has_more_data, parents_path // ------------------------------------- // Retrieve conflicts // ------------------------------------- OPTIONAL MATCH (diff_node)-[:DIFF_HAS_CONFLICT]->(diff_node_conflict:DiffConflict) - WITH diff_root, diff_node, parents_path, diff_node_conflict + WITH diff_root, diff_node, has_more_data, parents_path, diff_node_conflict // ------------------------------------- // Retrieve Attributes // ------------------------------------- @@ -128,7 +121,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: RETURN diff_attribute, diff_attr_property, diff_attr_property_conflict ORDER BY diff_attribute.name, diff_attr_property.property_type } - WITH diff_root, diff_node, parents_path, diff_node_conflict, collect([diff_attribute, diff_attr_property, diff_attr_property_conflict]) as diff_attributes + WITH diff_root, diff_node, has_more_data, parents_path, diff_node_conflict, collect([diff_attribute, diff_attr_property, diff_attr_property_conflict]) as diff_attributes // ------------------------------------- // Retrieve Relationships // ------------------------------------- @@ -150,6 +143,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: WITH diff_root, diff_node, + has_more_data, parents_path, diff_node_conflict, diff_attributes, @@ -161,6 +155,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: self.return_labels = [ "diff_root", "diff_node", + "has_more_data", "parents_path", "diff_node_conflict", "diff_attributes", diff --git a/backend/infrahub/core/diff/repository/deserializer.py b/backend/infrahub/core/diff/repository/deserializer.py index 246375dec4..05ef2ced1d 100644 --- a/backend/infrahub/core/diff/repository/deserializer.py +++ b/backend/infrahub/core/diff/repository/deserializer.py @@ -1,5 +1,3 @@ -from typing import Iterable - from neo4j.graph import Node as Neo4jNode from neo4j.graph import Path as Neo4jPath @@ -30,37 +28,47 @@ def __init__(self) -> None: self._diff_node_rel_group_map: dict[tuple[str, str, str], EnrichedDiffRelationship] = {} self._diff_node_rel_element_map: dict[tuple[str, str, str, str], EnrichedDiffSingleRelationship] = {} self._diff_prop_map: dict[tuple[str, str, str, str] | tuple[str, str, str, str, str], EnrichedDiffProperty] = {} + # {EnrichedDiffRoot: [(node_uuid, parents_path: Neo4jPath), ...]} + self._parents_path_map: dict[EnrichedDiffRoot, list[tuple[str, Neo4jPath]]] = {} - def _initialize(self) -> None: + def initialize(self) -> None: self._diff_root_map = {} self._diff_node_map = {} self._diff_node_attr_map = {} self._diff_node_rel_group_map = {} self._diff_node_rel_element_map = {} self._diff_prop_map = {} + self._parents_path_map = {} - async def deserialize( - self, database_results: Iterable[QueryResult], include_parents: bool - ) -> list[EnrichedDiffRoot]: - self._initialize() - results = list(database_results) - for result in results: - enriched_root = self._deserialize_diff_root(root_node=result.get_node("diff_root")) - node_node = result.get(label="diff_node") - if not isinstance(node_node, Neo4jNode): - continue - enriched_node = self._deserialize_diff_node(node_node=node_node, enriched_root=enriched_root) - node_conflict_node = result.get(label="diff_node_conflict") - if isinstance(node_conflict_node, Neo4jNode) and not enriched_node.conflict: - conflict = self.deserialize_conflict(diff_conflict_node=node_conflict_node) - enriched_node.conflict = conflict - self._deserialize_attributes(result=result, enriched_root=enriched_root, enriched_node=enriched_node) - self._deserialize_relationships(result=result, enriched_root=enriched_root, enriched_node=enriched_node) + def _track_parents_path(self, enriched_root: EnrichedDiffRoot, node_uuid: str, parents_path: Neo4jPath) -> None: + if enriched_root not in self._parents_path_map: + self._parents_path_map[enriched_root] = [] + self._parents_path_map[enriched_root].append((node_uuid, parents_path)) + + async def read_result(self, result: QueryResult, include_parents: bool) -> None: + enriched_root = self._deserialize_diff_root(root_node=result.get_node("diff_root")) + node_node = result.get(label="diff_node") + if not isinstance(node_node, Neo4jNode): + return + enriched_node = self._deserialize_diff_node(node_node=node_node, enriched_root=enriched_root) if include_parents: - for result in results: - enriched_root = self._deserialize_diff_root(root_node=result.get_node("diff_root")) - self._deserialize_parents(result=result, enriched_root=enriched_root) + parents_path = result.get("parents_path") + if parents_path and isinstance(parents_path, Neo4jPath): + self._track_parents_path( + enriched_root=enriched_root, node_uuid=enriched_node.uuid, parents_path=parents_path + ) + + node_conflict_node = result.get(label="diff_node_conflict") + if isinstance(node_conflict_node, Neo4jNode) and not enriched_node.conflict: + conflict = self.deserialize_conflict(diff_conflict_node=node_conflict_node) + enriched_node.conflict = conflict + self._deserialize_attributes(result=result, enriched_root=enriched_root, enriched_node=enriched_node) + self._deserialize_relationships(result=result, enriched_root=enriched_root, enriched_node=enriched_node) + + async def deserialize(self, include_parents: bool = True) -> list[EnrichedDiffRoot]: + if include_parents: + self._deserialize_parents() return list(self._diff_root_map.values()) @@ -117,30 +125,26 @@ def _deserialize_relationships( conflict = self.deserialize_conflict(diff_conflict_node=property_conflict) element_property.conflict = conflict - def _deserialize_parents(self, result: QueryResult, enriched_root: EnrichedDiffRoot) -> None: - parents_path = result.get("parents_path") - if not parents_path or not isinstance(parents_path, Neo4jPath): - return - - node_uuid = result.get(label="diff_node").get("uuid") - - # Remove the node itself from the path - parents_path = parents_path.nodes[1:] # type: ignore[union-attr] - - # TODO Ensure the list is even - current_node_uuid = node_uuid - for rel, parent in zip(parents_path[::2], parents_path[1::2]): - enriched_root.add_parent( - node_id=current_node_uuid, - parent_id=parent.get("uuid"), - parent_kind=parent.get("kind"), - parent_label=parent.get("label"), - parent_rel_name=rel.get("name"), - parent_rel_identifier=rel.get("identifier"), - parent_rel_cardinality=RelationshipCardinality(rel.get("cardinality")), - parent_rel_label=rel.get("label"), - ) - current_node_uuid = parent.get("uuid") + def _deserialize_parents(self) -> None: + for enriched_root, node_path_tuples in self._parents_path_map.items(): + for node_uuid, parents_path in node_path_tuples: + # Remove the node itself from the path + parents_path_slice = parents_path.nodes[1:] + + # TODO Ensure the list is even + current_node_uuid = node_uuid + for rel, parent in zip(parents_path_slice[::2], parents_path_slice[1::2]): + enriched_root.add_parent( + node_id=current_node_uuid, + parent_id=parent.get("uuid"), + parent_kind=parent.get("kind"), + parent_label=parent.get("label"), + parent_rel_name=rel.get("name"), + parent_rel_identifier=rel.get("identifier"), + parent_rel_cardinality=RelationshipCardinality(rel.get("cardinality")), + parent_rel_label=rel.get("label"), + ) + current_node_uuid = parent.get("uuid") @classmethod def _get_str_or_none_property_value(cls, node: Neo4jNode, property_name: str) -> str | None: diff --git a/backend/infrahub/core/diff/repository/repository.py b/backend/infrahub/core/diff/repository/repository.py index aa61449181..4273490755 100644 --- a/backend/infrahub/core/diff/repository/repository.py +++ b/backend/infrahub/core/diff/repository/repository.py @@ -47,6 +47,49 @@ def __init__(self, db: InfrahubDatabase, deserializer: EnrichedDiffDeserializer) self.db = db self.deserializer = deserializer + async def _run_get_diff_query( + self, + base_branch_name: str, + diff_branch_names: list[str], + limit: int, + from_time: Timestamp | None = None, + to_time: Timestamp | None = None, + filters: EnrichedDiffQueryFilters | None = None, + offset: int = 0, + include_parents: bool = True, + max_depth: int | None = None, + tracking_id: TrackingId | None = None, + diff_ids: list[str] | None = None, + ) -> list[EnrichedDiffRoot]: + self.deserializer.initialize() + has_more_data = True + while has_more_data: + get_query = await EnrichedDiffGetQuery.init( + db=self.db, + base_branch_name=base_branch_name, + diff_branch_names=diff_branch_names, + from_time=from_time, + to_time=to_time, + filters=filters, + max_depth=max_depth, + limit=limit, + offset=offset, + tracking_id=tracking_id, + diff_ids=diff_ids, + ) + log.info(f"Beginning enriched diff get query {limit=}, {offset=}") + await get_query.execute(db=self.db) + log.info("Enriched diff get query complete") + last_result = None + for query_result in get_query.get_results(): + await self.deserializer.read_result(result=query_result, include_parents=include_parents) + last_result = query_result + has_more_data = False + if last_result: + has_more_data = last_result.get_as_type("has_more_data", bool) + offset += limit + return await self.deserializer.deserialize() + async def get( self, base_branch_name: str, @@ -62,23 +105,20 @@ async def get( include_empty: bool = False, ) -> list[EnrichedDiffRoot]: final_max_depth = config.SETTINGS.database.max_depth_search_hierarchy - query = await EnrichedDiffGetQuery.init( - db=self.db, + limit = limit or int(config.SETTINGS.database.query_size_limit / 10) + diff_roots = await self._run_get_diff_query( base_branch_name=base_branch_name, diff_branch_names=diff_branch_names, from_time=from_time, to_time=to_time, filters=EnrichedDiffQueryFilters(**dict(filters or {})), + include_parents=include_parents, max_depth=final_max_depth, limit=limit, - offset=offset, + offset=offset or 0, tracking_id=tracking_id, diff_ids=diff_ids, ) - await query.execute(db=self.db) - diff_roots = await self.deserializer.deserialize( - database_results=query.get_results(), include_parents=include_parents - ) if not include_empty: diff_roots = [dr for dr in diff_roots if len(dr.nodes) > 0] return diff_roots @@ -91,30 +131,23 @@ async def get_pairs( to_time: Timestamp, ) -> list[EnrichedDiffs]: max_depth = config.SETTINGS.database.max_depth_search_hierarchy - query = await EnrichedDiffGetQuery.init( - db=self.db, + limit = int(config.SETTINGS.database.query_size_limit / 10) + diff_branch_roots = await self._run_get_diff_query( base_branch_name=base_branch_name, diff_branch_names=[diff_branch_name], from_time=from_time, to_time=to_time, max_depth=max_depth, - ) - await query.execute(db=self.db) - diff_branch_roots = await self.deserializer.deserialize( - database_results=query.get_results(), include_parents=True + limit=limit, ) diffs_by_uuid = {dbr.uuid: dbr for dbr in diff_branch_roots} - base_partner_query = await EnrichedDiffGetQuery.init( - db=self.db, + base_branch_roots = await self._run_get_diff_query( base_branch_name=base_branch_name, diff_branch_names=[base_branch_name], max_depth=max_depth, + limit=limit, diff_ids=[d.partner_uuid for d in diffs_by_uuid.values()], ) - await base_partner_query.execute(db=self.db) - base_branch_roots = await self.deserializer.deserialize( - database_results=base_partner_query.get_results(), include_parents=True - ) diffs_by_uuid.update({bbr.uuid: bbr for bbr in base_branch_roots}) return [ EnrichedDiffs( diff --git a/backend/tests/unit/core/diff/test_diff_repository.py b/backend/tests/unit/core/diff/test_diff_repository.py index d5eafe40bc..4a12b822fc 100644 --- a/backend/tests/unit/core/diff/test_diff_repository.py +++ b/backend/tests/unit/core/diff/test_diff_repository.py @@ -42,6 +42,7 @@ async def reset_database(self, db: InfrahubDatabase, default_branch): @pytest.fixture def diff_repository(self, db: InfrahubDatabase) -> DiffRepository: config.SETTINGS.database.max_depth_search_hierarchy = 10 + config.SETTINGS.database.query_size_limit = 50 return DiffRepository(db=db, deserializer=EnrichedDiffDeserializer()) def build_diff_node(self, num_sub_fields=2, no_recurse=False) -> EnrichedDiffNode: