Skip to content

Commit

Permalink
fix(backend): do not cleanup labels for each db result row
Browse files Browse the repository at this point in the history
There is a few microseconds overhead for this function...
Also add a span for Query.execute in order to know the overhead of
fetching the results.

Signed-off-by: Fatih Acar <[email protected]>
  • Loading branch information
fatih-acar committed Jan 9, 2025
1 parent 6b21034 commit a07bb7f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
7 changes: 5 additions & 2 deletions backend/infrahub/core/query/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from neo4j.graph import Node as Neo4jNode
from neo4j.graph import Path as Neo4jPath
from neo4j.graph import Relationship as Neo4jRelationship
from opentelemetry import trace

from infrahub import config
from infrahub.core.constants import PermissionLevel
Expand Down Expand Up @@ -161,7 +162,7 @@ def cleanup_return_labels(labels: list[str]) -> list[str]:
class QueryResult:
def __init__(self, data: list[Union[Neo4jNode, Neo4jRelationship, list[Neo4jNode]]], labels: list[str]):
self.data = data
self.labels = cleanup_return_labels(labels)
self.labels = labels
self.branch_score: int = 0
self.time_score: int = 0
self.permission_score = PermissionLevel.DEFAULT
Expand Down Expand Up @@ -523,6 +524,7 @@ def _get_params_for_neo4j_shell(self) -> str:

return ":params { " + ", ".join(params) + " }"

@trace.get_tracer(__name__).start_as_current_span("Query.execute")
async def execute(self, db: InfrahubDatabase) -> Self:
# Ensure all mandatory params have been provided
# Ensure at least 1 return obj has been defined
Expand Down Expand Up @@ -552,7 +554,8 @@ async def execute(self, db: InfrahubDatabase) -> Self:
if not results and self.raise_error_if_empty:
raise QueryError(query=query_str, params=self.params)

self.results = [QueryResult(data=result, labels=self.return_labels) for result in results]
clean_labels = cleanup_return_labels(self.return_labels)
self.results = [QueryResult(data=result, labels=clean_labels) for result in results]
self.has_been_executed = True

return self
Expand Down
20 changes: 11 additions & 9 deletions backend/tests/unit/core/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,14 @@ async def test_query_result_getters(neo4j_factory):

qr = QueryResult(
data=[n1, r1, r2, n2],
labels=[
"n1",
"r1",
"r2",
"n2",
],
labels=cleanup_return_labels(
[
"n1",
"r1",
"r2",
"n2",
]
),
)
assert list(qr.get_rels()) == [r1, r2]
assert list(qr.get_nodes()) == [n1, n2]
Expand Down Expand Up @@ -228,9 +230,9 @@ async def test_sort_results_by_time(neo4j_factory):
},
)

qr1 = QueryResult(data=[n1, n2, r1], labels=["n1", "n2", "r"])
qr2 = QueryResult(data=[n1, n2, r2], labels=["n1", "n2", "r"])
qr3 = QueryResult(data=[n1, n2, r3], labels=["n1", "n2", "r"])
qr1 = QueryResult(data=[n1, n2, r1], labels=cleanup_return_labels(["n1", "n2", "r"]))
qr2 = QueryResult(data=[n1, n2, r2], labels=cleanup_return_labels(["n1", "n2", "r"]))
qr3 = QueryResult(data=[n1, n2, r3], labels=cleanup_return_labels(["n1", "n2", "r"]))

results = sort_results_by_time(results=[qr1, qr2, qr3], rel_label="r")
assert list(results) == [qr3, qr1, qr2]
Expand Down

0 comments on commit a07bb7f

Please sign in to comment.