Skip to content

Commit

Permalink
Allow boolean columns to be used in ExpressionFactory
Browse files Browse the repository at this point in the history
Add an as_boolean() method to the expression proxies used to reference dimension fields, so boolean-valued fields can be converted to Predicate instances that can be used to constrain a query.
  • Loading branch information
dhirving committed Aug 9, 2024
1 parent 94456a4 commit 81f50ca
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 8 deletions.
58 changes: 57 additions & 1 deletion python/lsst/daf/butler/queries/expression_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,27 @@ def desc(self) -> tree.Reversed:
"""
return tree.Reversed(operand=self._expression)

def as_boolean(self) -> tree.Predicate:
"""If this scalar expression is a boolean, convert it to a `Predicate`
so it can be used as a boolean expression.
Raises
------
InvalidQueryError
If this expression is not a boolean.
Returns
-------
predicate : `Predicate`
This expression converted to a `Predicate`.
"""
expr = self._expression
raise InvalidQueryError(
f"Expression '{expr}' with type"
f" '{expr.column_type}' can't be used directly as a boolean value."
" Use a comparison operator like '>' or '==' instead."
)

def __eq__(self, other: object) -> tree.Predicate: # type: ignore[override]
return self._make_comparison(other, "==")

Expand Down Expand Up @@ -233,6 +254,38 @@ def _expression(self) -> tree.ColumnExpression:
return self._expr


class BooleanScalarExpressionProxy(ScalarExpressionProxy):
"""A `ScalarExpressionProxy` representing a boolean column. You should
call `as_boolean()` on this object to convert it to an instance of
`Predicate` before attempting to use it.
Parameters
----------
expression : `.tree.ColumnReference`
Boolean column reference that backs this proxy.
"""

# This is a hack/work-around to make static typing work when referencing
# dimension record metadata boolean columns. From the perspective of
# typing, anything boolean should be a `Predicate`, but the type system has
# no way of knowing whether a given column is a bool or some other type.

def __init__(self, expression: tree.ColumnReference) -> None:
if expression.column_type != "bool":
raise ValueError(f"Expression is a {expression.column_type}, not a 'bool': {expression}")

Check warning on line 275 in python/lsst/daf/butler/queries/expression_factory.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/queries/expression_factory.py#L275

Added line #L275 was not covered by tests
self._boolean_expression = expression

def as_boolean(self) -> tree.Predicate:
return tree.Predicate.from_bool_expression(self._boolean_expression)

@property
def _expression(self) -> tree.ColumnExpression:
raise InvalidQueryError(
f"Boolean expression '{self._boolean_expression}' can't be used directly in other expressions."
" Call the 'as_boolean()' method to convert it to a Predicate instead."
)


class TimespanProxy(ExpressionProxy):
"""An `ExpressionProxy` specialized for timespan columns and literals.
Expand Down Expand Up @@ -350,7 +403,10 @@ def __getattr__(self, field: str) -> ScalarExpressionProxy:
expression = tree.DimensionFieldReference(element=self._element, field=field)
except InvalidQueryError:
raise AttributeError(field)
return ResolvedScalarExpressionProxy(expression)
if expression.column_type == "bool":
return BooleanScalarExpressionProxy(expression)
else:
return ResolvedScalarExpressionProxy(expression)

@property
def region(self) -> RegionProxy:
Expand Down
52 changes: 45 additions & 7 deletions python/lsst/daf/butler/tests/butler_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ..dimensions import DataCoordinate, DimensionRecord
from ..direct_query_driver import DirectQueryDriver
from ..queries import DimensionRecordQueryResults
from ..queries.tree import Predicate
from ..registry import CollectionType, NoDefaultCollectionError, RegistryDefaults
from ..registry.sql_registry import SqlRegistry
from ..transfers import YamlRepoImportBackend
Expand Down Expand Up @@ -943,15 +944,15 @@ def test_boolean_columns(self) -> None:

base_data = {"instrument": "HSC", "physical_filter": "HSC-R", "group": "903342", "day_obs": 20130617}

TRUE_ID_1 = 1001
TRUE_ID_2 = 2001
FALSE_ID_1 = 1002
TRUE_ID = 1000
FALSE_ID_1 = 2001
FALSE_ID_2 = 2002
NULL_ID = 3000
records = [
{"id": TRUE_ID_1, "obs_id": "true-1", "can_see_sky": True},
{"id": TRUE_ID_2, "obs_id": "true-2", "can_see_sky": True},
{"id": TRUE_ID, "obs_id": "true-1", "can_see_sky": True},
{"id": FALSE_ID_1, "obs_id": "false-1", "can_see_sky": False},
{"id": FALSE_ID_2, "obs_id": "false-2", "can_see_sky": False},
{"id": NULL_ID, "obs_id": "null-1", "can_see_sky": None},
# There is also a record ID 903342 from the YAML file with a NULL
# value for can_see_sky.
]
Expand All @@ -971,10 +972,11 @@ def _run_query(where: str) -> list[int]:
query.dimension_records("exposure").where(where, instrument="HSC")
)

# Test boolean columns in the `where` string syntax.
for test, query_func in [("registry", _run_registry_query), ("new-query", _run_query)]:
with self.subTest(test):
# Boolean columns should be usable standalone as an expression.
self.assertCountEqual(query_func("exposure.can_see_sky"), [TRUE_ID_1, TRUE_ID_2])
self.assertCountEqual(query_func("exposure.can_see_sky"), [TRUE_ID])

# You can find false values in the column with NOT. The NOT of
# NULL is NULL, consistent with SQL semantics -- so records
Expand All @@ -984,9 +986,45 @@ def _run_query(where: str) -> list[int]:
# Make sure the bare column composes with other expressions
# correctly.
self.assertCountEqual(
query_func("exposure.can_see_sky OR exposure = 1002"), [TRUE_ID_1, TRUE_ID_2, FALSE_ID_1]
query_func("exposure.can_see_sky OR exposure = 2001"), [TRUE_ID, FALSE_ID_1]
)

# Test boolean columns in ExpressionFactory.
with butler._query() as query:
x = query.expression_factory

def do_query(constraint: Predicate) -> list[int]:
return _get_exposure_ids_from_dimension_records(
query.dimension_records("exposure").where(constraint, instrument="HSC")
)

# Boolean columns should be usable standalone as a Predicate.
self.assertCountEqual(do_query(x.exposure.can_see_sky.as_boolean()), [TRUE_ID])

# You can find false values in the column with NOT. The NOT of
# NULL is NULL, consistent with SQL semantics -- so records
# with NULL can_see_sky are not included here.
self.assertCountEqual(
do_query(x.exposure.can_see_sky.as_boolean().logical_not()), [FALSE_ID_1, FALSE_ID_2]
)

# Attempting to use operators that only apply to non-boolean types
# is an error.
with self.assertRaisesRegex(
InvalidQueryError,
r"Boolean expression 'exposure.can_see_sky' can't be used directly in other expressions."
r" Call the 'as_boolean\(\)' method to convert it to a Predicate instead.",
):
x.exposure.can_see_sky == 1

# Non-boolean types can't be converted directly to Predicate.
with self.assertRaisesRegex(
InvalidQueryError,
r"Expression 'exposure.observation_type' with type 'string' can't be used directly"
r" as a boolean value.",
):
x.exposure.observation_type.as_boolean()


def _get_exposure_ids_from_dimension_records(dimension_records: Iterable[DimensionRecord]) -> list[int]:
output = []
Expand Down

0 comments on commit 81f50ca

Please sign in to comment.