Skip to content

perf: optimize sql for or queries #948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 44 additions & 50 deletions src/tagstudio/core/library/alchemy/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tagstudio.core.library.alchemy.models import Entry, Tag, TagAlias
from tagstudio.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories
from tagstudio.core.query_lang.ast import (
AST,
ANDList,
BaseVisitor,
Constraint,
Expand Down Expand Up @@ -58,42 +59,15 @@ def __init__(self, lib: Library) -> None:
self.lib = lib

def visit_or_list(self, node: ORList) -> ColumnElement[bool]:
return or_(*[self.visit(element) for element in node.elements])
tag_ids, bool_expressions = self.__separate_tags(node.elements, only_single=False)
if len(tag_ids) > 0:
bool_expressions.append(self.__entry_has_any_tags(tag_ids))
return or_(*bool_expressions)

def visit_and_list(self, node: ANDList) -> ColumnElement[bool]:
tag_ids: list[int] = []
bool_expressions: list[ColumnElement[bool]] = []

# Search for TagID / unambiguous Tag Constraints and store the respective tag ids separately
for term in node.terms:
if isinstance(term, Constraint) and len(term.properties) == 0:
match term.type:
case ConstraintType.TagID:
try:
tag_ids.append(int(term.value))
except ValueError:
logger.error(
"[SQLBoolExpressionBuilder] Could not cast value to an int Tag ID",
value=term.value,
)
continue
case ConstraintType.Tag:
if len(ids := self.__get_tag_ids(term.value)) == 1:
tag_ids.append(ids[0])
continue

bool_expressions.append(self.visit(term))

# If there are at least two tag ids use a relational division query
# to efficiently check all of them
if len(tag_ids) > 1:
tag_ids, bool_expressions = self.__separate_tags(node.terms, only_single=True)
if len(tag_ids) > 0:
bool_expressions.append(self.__entry_has_all_tags(tag_ids))
# If there is just one tag id, check the normal way
elif len(tag_ids) == 1:
bool_expressions.append(
self.__entry_satisfies_expression(TagEntry.tag_id == tag_ids[0])
)

Comment on lines -91 to -96
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my testing using __entry_has_all_tags was slightly faster with one tag id. Sqlite will optimize out the group_by when just one tag_id is present. this can be checked with the query plan

rows = session.execute(text(f"EXPLAIN QUERY PLAN {query_full}")).fetchall()
for row in rows:
    print(row)

return and_(*bool_expressions)

def visit_constraint(self, node: Constraint) -> ColumnElement[bool]:
Expand All @@ -102,9 +76,9 @@ def visit_constraint(self, node: Constraint) -> ColumnElement[bool]:
raise NotImplementedError("Properties are not implemented yet") # TODO TSQLANG

if node.type == ConstraintType.Tag:
return self.__entry_matches_tag_ids(self.__get_tag_ids(node.value))
return self.__entry_has_any_tags(self.__get_tag_ids(node.value))
elif node.type == ConstraintType.TagID:
return self.__entry_matches_tag_ids([int(node.value)])
return self.__entry_has_any_tags([int(node.value)])
elif node.type == ConstraintType.Path:
ilike = False
glob = False
Expand Down Expand Up @@ -153,15 +127,6 @@ def visit_property(self, node: Property) -> ColumnElement[bool]:
def visit_not(self, node: Not) -> ColumnElement[bool]:
return ~self.visit(node.child)

def __entry_matches_tag_ids(self, tag_ids: list[int]) -> ColumnElement[bool]:
"""Returns a boolean expression that is true if the entry has at least one of the supplied tags.""" # noqa: E501
return (
select(1)
.correlate(Entry)
.where(and_(TagEntry.entry_id == Entry.id, TagEntry.tag_id.in_(tag_ids)))
.exists()
)

def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]:
"""Given a tag name find the ids of all tags that this name could refer to."""
with Session(self.lib.engine) as session:
Expand All @@ -185,6 +150,36 @@ def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[in
outp.extend(list(session.scalars(TAG_CHILDREN_ID_QUERY, {"tag_id": tag_id})))
return outp

def __separate_tags(
self, terms: list[AST], only_single: bool = True
) -> tuple[list[int], list[ColumnElement[bool]]]:
tag_ids: list[int] = []
bool_expressions: list[ColumnElement[bool]] = []

for term in terms:
if isinstance(term, Constraint) and len(term.properties) == 0:
match term.type:
case ConstraintType.TagID:
try:
tag_ids.append(int(term.value))
except ValueError:
logger.error(
"[SQLBoolExpressionBuilder] Could not cast value to an int Tag ID",
value=term.value,
)
continue
case ConstraintType.Tag:
ids = self.__get_tag_ids(term.value)
if not only_single:
tag_ids.extend(ids)
continue
elif len(ids) == 1:
tag_ids.append(ids[0])
continue

bool_expressions.append(self.visit(term))
return tag_ids, bool_expressions

def __entry_has_all_tags(self, tag_ids: list[int]) -> ColumnElement[bool]:
"""Returns Binary Expression that is true if the Entry has all provided tag ids."""
# Relational Division Query
Expand All @@ -195,9 +190,8 @@ def __entry_has_all_tags(self, tag_ids: list[int]) -> ColumnElement[bool]:
.having(func.count(distinct(TagEntry.tag_id)) == len(tag_ids))
)

def __entry_satisfies_expression(self, expr: ColumnElement[bool]) -> ColumnElement[bool]:
"""Returns Binary Expression that is true if the Entry satisfies the column expression.

Executed on: Entry ⟕ TagEntry (Entry LEFT OUTER JOIN TagEntry).
"""
return Entry.id.in_(select(Entry.id).outerjoin(TagEntry).where(expr))
def __entry_has_any_tags(self, tag_ids: list[int]) -> ColumnElement[bool]:
"""Returns Binary Expression that is true if the Entry has any of the provided tag ids."""
return Entry.id.in_(
select(TagEntry.entry_id).where(TagEntry.tag_id.in_(tag_ids)).distinct()
)