Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-45680: Allow boolean columns to be used in query 'where' #1051

Merged
merged 6 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/changes/DM-45680.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fix an issue where boolean metadata columns (like `exposure.can_see_sky` and
`exposure.has_simulated`) were not usable in `where` clauses for Registry query
functions. These column names can now be used as a boolean expression, for
example `where="exposure.can_see_sky` or `where="NOT exposure.can_see_sky"`.
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def visit_reversed(self, expression: qt.Reversed) -> sqlalchemy.ColumnElement[An
# Docstring inherited.
return self.expect_scalar(expression.operand).desc()

def visit_boolean_wrapper(
self, value: qt.ColumnExpression, flags: PredicateVisitFlags
) -> sqlalchemy.ColumnElement[bool]:
return self.expect_scalar(value)

def visit_comparison(
self,
a: qt.ColumnExpression,
Expand Down
45 changes: 44 additions & 1 deletion python/lsst/daf/butler/queries/_expression_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .tree import (
BinaryExpression,
ColumnExpression,
ColumnReference,
ComparisonOperator,
LiteralValue,
Predicate,
Expand Down Expand Up @@ -158,6 +159,16 @@ def visitBinaryOp(
return Predicate.is_null(rhs.value)
case ["!=", _Null(), _ColExpr() as rhs]:
return Predicate.is_null(rhs.value).logical_not()
# Boolean columns can be null, but will have been converted to
# Predicate, so we need additional cases.
case ["=" | "!=", Predicate() as pred, _Null()] | ["=" | "!=", _Null(), Predicate() as pred]:
column_ref = _get_boolean_column_reference(pred)
if column_ref is not None:
match operator:
case "=":
return Predicate.is_null(column_ref)
case "!=":
return Predicate.is_null(column_ref).logical_not()

# Handle arithmetic operations
case [("+" | "-" | "*" | "/" | "%") as op, _ColExpr() as lhs, _ColExpr() as rhs]:
Expand Down Expand Up @@ -198,7 +209,23 @@ def visitIdentifier(self, name: str, node: Node) -> _VisitorResult:
if categorizeConstant(name) == ExpressionConstant.NULL:
return _Null()

return _ColExpr(interpret_identifier(self.context, name))
column_expression = interpret_identifier(self.context, name)
if column_expression.column_type == "bool":
# Expression-handling code (in this file and elsewhere) expects
# boolean-valued expressions to be represented as Predicate, not a
# ColumnExpression.

# We should only be getting direct references to a column, not a
# more complicated expression.
# (Anything more complicated should be a Predicate already.)
assert (
column_expression.expression_type == "dataset_field"
or column_expression.expression_type == "dimension_field"
or column_expression.expression_type == "dimension_key"
)
return Predicate.from_bool_expression(column_expression)
else:
return _ColExpr(column_expression)

def visitNumericLiteral(self, value: str, node: Node) -> _VisitorResult:
numeric: int | float
Expand Down Expand Up @@ -303,3 +330,19 @@ def _convert_in_clause_to_predicate(lhs: ColumnExpression, rhs: _VisitorResult,
return Predicate.is_null(lhs)
case _:
raise InvalidQueryError(f"Invalid IN expression: '{node!s}")


def _get_boolean_column_reference(predicate: Predicate) -> ColumnReference | None:
"""Unwrap a predicate to recover the boolean ColumnReference it contains.
Returns `None` if this Predicate contains anything other than a single
boolean ColumnReference operand.

This undoes the ColumnReference to Predicate conversion that occurs in
visitIdentifier for boolean columns.
"""
if len(predicate.operands) == 1 and len(predicate.operands[0]) == 1:
predicate_leaf = predicate.operands[0][0]
if predicate_leaf.predicate_type == "boolean_wrapper":
return predicate_leaf.operand

return None
62 changes: 61 additions & 1 deletion python/lsst/daf/butler/queries/expression_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,27 @@
"""
return tree.Reversed(operand=self._expression)

def as_boolean(self) -> tree.Predicate:
"""If this scalar expression is a boolean, convert it to a `Predicate`
so it can be used as a boolean expression.

Raises
------
InvalidQueryError
If this expression is not a boolean.

Returns
-------
predicate : `Predicate`
This expression converted to a `Predicate`.
"""
expr = self._expression
raise InvalidQueryError(
f"Expression '{expr}' with type"
f" '{expr.column_type}' can't be used directly as a boolean value."
" Use a comparison operator like '>' or '==' instead."
)

def __eq__(self, other: object) -> tree.Predicate: # type: ignore[override]
return self._make_comparison(other, "==")

Expand Down Expand Up @@ -233,6 +254,42 @@
return self._expr


class BooleanScalarExpressionProxy(ScalarExpressionProxy):
"""A `ScalarExpressionProxy` representing a boolean column. You should
call `as_boolean()` on this object to convert it to an instance of
`Predicate` before attempting to use it.

Parameters
----------
expression : `.tree.ColumnReference`
Boolean column reference that backs this proxy.
"""

# This is a hack/work-around to make static typing work when referencing
# dimension record metadata boolean columns. From the perspective of
# typing, anything boolean should be a `Predicate`, but the type system has
# no way of knowing whether a given column is a bool or some other type.

def __init__(self, expression: tree.ColumnReference) -> None:
if expression.column_type != "bool":
raise ValueError(f"Expression is a {expression.column_type}, not a 'bool': {expression}")

Check warning on line 275 in python/lsst/daf/butler/queries/expression_factory.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/queries/expression_factory.py#L275

Added line #L275 was not covered by tests
self._boolean_expression = expression

@property
def is_null(self) -> tree.Predicate:
return ResolvedScalarExpressionProxy(self._boolean_expression).is_null

def as_boolean(self) -> tree.Predicate:
Copy link
Contributor Author

@dhirving dhirving Aug 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like forcing users to call as_boolean() to get a Predicate out of this, but I didn't really have a better idea.

We can't just return a Predicate directly from DimensionElementProxy.__getattr__ because it would completely destroy the static typing. The return type would become ScalarExpressionProxy | Predicate and those types have almost no methods in common.

This at least throws a semi-helpful runtime error if you fail to call as_boolean() or attempt to do a comparison like boolean_column == 0.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really wish Python typing had something like Any[A | B], where it'd happily do an implicit cast to one of the options you give it without losing the typing entirely (the way Any does).

I'm not convinced the static typing is so important here that it's worth doing even slight harm to the runtime interface - I'm imagining users in Jupyter notebooks doing tab completion as the most important use case for ExpressionFactory. But I do expect boolean columns to be very rare, and this is a small change that's easy to revisit later, and the same cannot be said of trying to put the entire Predicate interface into ScalarExpressionProxy.

return tree.Predicate.from_bool_expression(self._boolean_expression)

@property
def _expression(self) -> tree.ColumnExpression:
raise InvalidQueryError(
f"Boolean expression '{self._boolean_expression}' can't be used directly in other expressions."
" Call the 'as_boolean()' method to convert it to a Predicate instead."
)


class TimespanProxy(ExpressionProxy):
"""An `ExpressionProxy` specialized for timespan columns and literals.

Expand Down Expand Up @@ -350,7 +407,10 @@
expression = tree.DimensionFieldReference(element=self._element, field=field)
except InvalidQueryError:
raise AttributeError(field)
return ResolvedScalarExpressionProxy(expression)
if expression.column_type == "bool":
return BooleanScalarExpressionProxy(expression)
else:
return ResolvedScalarExpressionProxy(expression)

@property
def region(self) -> RegionProxy:
Expand Down
43 changes: 42 additions & 1 deletion python/lsst/daf/butler/queries/tree/_predicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from ._base import QueryTreeBase
from ._column_expression import (
ColumnExpression,
ColumnReference,
is_one_datetime_and_one_ingest_date,
is_one_timespan_and_one_datetime,
)
Expand Down Expand Up @@ -155,6 +156,26 @@
#
return cls.model_construct(operands=() if value else ((),))

@classmethod
def from_bool_expression(cls, value: ColumnReference) -> Predicate:
"""Construct a predicate that wraps a boolean ColumnReference, taking
on the value of the underlying ColumnReference.

Parameters
----------
value : `ColumnExpression`
Boolean-valued expression to convert to Predicate.

Returns
-------
predicate : `Predicate`
Predicate representing the expression.
"""
if value.column_type != "bool":
raise ValueError(f"ColumnExpression must have column type 'bool', not '{value.column_type}'")

Check warning on line 175 in python/lsst/daf/butler/queries/tree/_predicate.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/queries/tree/_predicate.py#L175

Added line #L175 was not covered by tests

return cls._from_leaf(BooleanWrapper(operand=value))

@classmethod
def compare(cls, a: ColumnExpression, operator: ComparisonOperator, b: ColumnExpression) -> Predicate:
"""Construct a predicate representing a binary comparison between
Expand Down Expand Up @@ -412,6 +433,26 @@
return visitor._visit_logical_not(self.operand, flags)


class BooleanWrapper(PredicateLeafBase):
"""Pass-through to a pre-existing boolean column expression."""

predicate_type: Literal["boolean_wrapper"] = "boolean_wrapper"

operand: ColumnReference
"""Wrapped expression that will be used as the value for this predicate."""

def gather_required_columns(self, columns: ColumnSet) -> None:
# Docstring inherited.
self.operand.gather_required_columns(columns)

def __str__(self) -> str:
return f"{self.operand}"

Check warning on line 449 in python/lsst/daf/butler/queries/tree/_predicate.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/queries/tree/_predicate.py#L449

Added line #L449 was not covered by tests

def visit(self, visitor: PredicateVisitor[_A, _O, _L], flags: PredicateVisitFlags) -> _L:
# Docstring inherited.
return visitor.visit_boolean_wrapper(self.operand, flags)


@final
class IsNull(PredicateLeafBase):
"""A boolean column expression that tests whether its operand is NULL."""
Expand Down Expand Up @@ -639,7 +680,7 @@
return self


LogicalNotOperand: TypeAlias = IsNull | Comparison | InContainer | InRange | InQuery
LogicalNotOperand: TypeAlias = IsNull | Comparison | InContainer | InRange | InQuery | BooleanWrapper
PredicateLeaf: TypeAlias = Annotated[
LogicalNotOperand | LogicalNot, pydantic.Field(discriminator="predicate_type")
]
Expand Down
24 changes: 24 additions & 0 deletions python/lsst/daf/butler/queries/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,25 @@ class PredicateVisitor(Generic[_A, _O, _L]):
visit method arguments.
"""

@abstractmethod
def visit_boolean_wrapper(self, value: tree.ColumnExpression, flags: PredicateVisitFlags) -> _L:
"""Visit a boolean-valued column expression.

Parameters
----------
value : `tree.ColumnExpression`
Column expression, guaranteed to have `column_type == "bool"`.
flags : `PredicateVisitFlags`
Information about where this leaf appears in the larger predicate
tree.

Returns
-------
result : `object`
Implementation-defined.
"""
raise NotImplementedError()

@abstractmethod
def visit_comparison(
self,
Expand Down Expand Up @@ -448,6 +467,11 @@ class SimplePredicateVisitor(
return a replacement `Predicate` to construct a new tree.
"""

def visit_boolean_wrapper(
self, value: tree.ColumnExpression, flags: PredicateVisitFlags
) -> tree.Predicate | None:
return None

def visit_comparison(
self,
a: tree.ColumnExpression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,12 @@ def visitIdentifier(self, name: str, node: Node) -> VisitorResult:
if column == timespan_database_representation.TimespanDatabaseRepresentation.NAME
else element.RecordClass.fields.standard[column].getPythonType()
)
return ColumnExpression.reference(tag, dtype)
if dtype is bool:
# ColumnExpression is for non-boolean columns only. Booleans
# are represented as Predicate.
return Predicate.reference(tag)
else:
return ColumnExpression.reference(tag, dtype)
else:
tag = DimensionKeyColumnTag(element.name)
assert isinstance(element, Dimension)
Expand Down
Loading
Loading