Skip to content

Commit

Permalink
Merge pull request #1075 from lsst/tickets/DM-45993
Browse files Browse the repository at this point in the history
DM-45993: Optimize DirectButlerCollections.query_info to avoid too many queries
  • Loading branch information
andy-slac authored Sep 10, 2024
2 parents 1d1bf7b + 76d34d6 commit 5c4a71f
Show file tree
Hide file tree
Showing 10 changed files with 179 additions and 52 deletions.
45 changes: 43 additions & 2 deletions python/lsst/daf/butler/_butler_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@
__all__ = ("ButlerCollections", "CollectionInfo")

from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence, Set
from typing import Any, overload
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence, Set
from typing import TYPE_CHECKING, Any, overload

from pydantic import BaseModel

from ._collection_type import CollectionType

if TYPE_CHECKING:
from ._dataset_type import DatasetType


class CollectionInfo(BaseModel):
"""Information about a single Butler collection."""
Expand Down Expand Up @@ -275,6 +279,8 @@ def query_info(
include_chains: bool | None = None,
include_parents: bool = False,
include_summary: bool = False,
include_doc: bool = False,
summary_datasets: Iterable[DatasetType] | None = None,
) -> Sequence[CollectionInfo]:
"""Query the butler for collections matching an expression and
return detailed information about those collections.
Expand All @@ -298,6 +304,14 @@ def query_info(
include_summary : `bool`, optional
Whether the returned information includes dataset type and
governor information for the collections.
include_doc : `bool`, optional
Whether the returned information includes collection documentation
string.
summary_datasets : `~collections.abc.Iterable` [ `DatasetType` ], \
optional
Dataset types to include in returned summaries. Only used if
``include_summary`` is `True`. If not specified then all dataset
types will be included.
Returns
-------
Expand Down Expand Up @@ -411,3 +425,30 @@ def _filter_dataset_types(
collection_dataset_types.update(info.dataset_types)
dataset_types_set = dataset_types_set.intersection(collection_dataset_types)
return dataset_types_set

def _group_by_dataset_type(
self, dataset_types: Set[str], collection_infos: Iterable[CollectionInfo]
) -> Mapping[str, list[str]]:
"""Filter dataset types and collections names based on summary in
collecion infos.
Parameters
----------
dataset_types : `~collections.abc.Set` [`str`]
Set of dataset type names to extract.
collection_infos : `~collections.abc.Iterable` [`CollectionInfo`]
Collection infos, must contain dataset type summary.
Returns
-------
filtered : `~collections.abc.Mapping` [`str`, `list`[`str`]]
Mapping of the dataset type name to its corresponding list of
collection names.
"""
dataset_type_collections: dict[str, list[str]] = defaultdict(list)
for info in collection_infos:
if info.dataset_types is None:
raise RuntimeError("Can only filter by collections if include_summary was True")
for dataset_type in info.dataset_types & dataset_types:
dataset_type_collections[dataset_type].append(info.name)
return dataset_type_collections
59 changes: 51 additions & 8 deletions python/lsst/daf/butler/direct_butler/_direct_butler_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

__all__ = ("DirectButlerCollections",)

from collections.abc import Iterable, Sequence, Set
from collections.abc import Iterable, Mapping, Sequence, Set
from typing import TYPE_CHECKING, Any

import sqlalchemy
from lsst.utils.iteration import ensure_iterable
Expand All @@ -39,6 +40,11 @@
from ..registry._exceptions import OrphanedRecordError
from ..registry.interfaces import ChainedCollectionRecord
from ..registry.sql_registry import SqlRegistry
from ..registry.wildcards import CollectionWildcard

if TYPE_CHECKING:
from .._dataset_type import DatasetType
from ..registry._collection_summary import CollectionSummary


class DirectButlerCollections(ButlerCollections):
Expand Down Expand Up @@ -107,20 +113,57 @@ def query_info(
include_chains: bool | None = None,
include_parents: bool = False,
include_summary: bool = False,
include_doc: bool = False,
summary_datasets: Iterable[DatasetType] | None = None,
) -> Sequence[CollectionInfo]:
info = []
with self._registry.caching_context():
if collection_types is None:
collection_types = CollectionType.all()
for name in self._registry.queryCollections(
expression,
collectionTypes=collection_types,
flattenChains=flatten_chains,
includeChains=include_chains,
):
elif isinstance(collection_types, CollectionType):
collection_types = {collection_types}

records = self._registry._managers.collections.resolve_wildcard(
CollectionWildcard.from_expression(expression),
collection_types=collection_types,
flatten_chains=flatten_chains,
include_chains=include_chains,
)

summaries: Mapping[Any, CollectionSummary] = {}
if include_summary:
summaries = self._registry._managers.datasets.fetch_summaries(records, summary_datasets)

docs: Mapping[Any, str] = {}
if include_doc:
docs = self._registry._managers.collections.get_docs(record.key for record in records)

for record in records:
doc = docs.get(record.key, "")
children: tuple[str, ...] = tuple()
if record.type == CollectionType.CHAINED:
assert isinstance(record, ChainedCollectionRecord)
children = tuple(record.children)
parents: frozenset[str] | None = None
if include_parents:
# TODO: This is non-vectorized, so expensive to do in a
# loop.
parents = frozenset(self._registry.getCollectionParentChains(record.name))
dataset_types: Set[str] | None = None
if summary := summaries.get(record.key):
dataset_types = frozenset([dt.name for dt in summary.dataset_types])

info.append(
self.get_info(name, include_parents=include_parents, include_summary=include_summary)
CollectionInfo(
name=record.name,
type=record.type,
doc=doc,
parents=parents,
children=children,
dataset_types=dataset_types,
)
)

return info

def get_info(
Expand Down
26 changes: 19 additions & 7 deletions python/lsst/daf/butler/registry/collections/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from typing import TYPE_CHECKING, Any, Generic, Literal, NamedTuple, TypeVar, cast

import sqlalchemy
from lsst.utils.iteration import chunk_iterable

from ..._collection_type import CollectionType
from ..._exceptions import CollectionCycleError, CollectionTypeError, MissingCollectionError
Expand Down Expand Up @@ -450,13 +451,24 @@ def filter_types(records: Iterable[CollectionRecord[K]]) -> Iterator[CollectionR

def getDocumentation(self, key: K) -> str | None:
# Docstring inherited from CollectionManager.
sql = (
sqlalchemy.sql.select(self._tables.collection.columns.doc)
.select_from(self._tables.collection)
.where(self._tables.collection.columns[self._collectionIdName] == key)
)
with self._db.query(sql) as sql_result:
return sql_result.scalar()
docs = self.get_docs([key])
return docs.get(key)

def get_docs(self, keys: Iterable[K]) -> Mapping[K, str]:
# Docstring inherited from CollectionManager.
docs: dict[K, str] = {}
id_column = self._tables.collection.columns[self._collectionIdName]
doc_column = self._tables.collection.columns.doc
for chunk in chunk_iterable(keys):
sql = (
sqlalchemy.sql.select(id_column, doc_column)
.select_from(self._tables.collection)
.where(sqlalchemy.sql.and_(id_column.in_(chunk), doc_column != sqlalchemy.literal("")))
)
with self._db.query(sql) as sql_result:
for row in sql_result:
docs[row[0]] = row[1]
return docs

def setDocumentation(self, key: K, doc: str | None) -> None:
# Docstring inherited from CollectionManager.
Expand Down
19 changes: 18 additions & 1 deletion python/lsst/daf/butler/registry/interfaces/_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
]

from abc import abstractmethod
from collections.abc import Iterable, Set
from collections.abc import Iterable, Mapping, Set
from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar

import sqlalchemy
Expand Down Expand Up @@ -570,6 +570,23 @@ def getDocumentation(self, key: _Key) -> str | None:
"""
raise NotImplementedError()

@abstractmethod
def get_docs(self, key: Iterable[_Key]) -> Mapping[_Key, str]:
"""Retrieve the documentation string for multiple collections.
Parameters
----------
key : `~collections.abc.Iterable` [ _Key ]
Internal primary key value for the collection.
Returns
-------
docs : `~collections.abc.Mapping` [ _Key, `str`]
Documentation strings indexed by collection key. Only collections
with non-empty documentation strings are returned.
"""
raise NotImplementedError()

@abstractmethod
def setDocumentation(self, key: _Key, doc: str | None) -> None:
"""Set the documentation string for a collection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .._collection_type import CollectionType

if TYPE_CHECKING:
from .._dataset_type import DatasetType
from ._registry import RemoteButlerRegistry


Expand Down Expand Up @@ -79,6 +80,8 @@ def query_info(
include_chains: bool | None = None,
include_parents: bool = False,
include_summary: bool = False,
include_doc: bool = False,
summary_datasets: Iterable[DatasetType] | None = None,
) -> Sequence[CollectionInfo]:
# This should become a single call on the server in the future.
if collection_types is None:
Expand Down
1 change: 1 addition & 0 deletions python/lsst/daf/butler/script/exportCalibs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def exportCalibs(
collections_query,
flatten_chains=True,
include_chains=True,
include_doc=True,
collection_types={CollectionType.CALIBRATION, CollectionType.CHAINED},
):
log.info("Checking collection: %s", collection.name)
Expand Down
21 changes: 9 additions & 12 deletions python/lsst/daf/butler/script/queryDataIds.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,27 +177,24 @@ def queryDataIds(
if datasets:
# Need to constrain results based on dataset type and collection.
query_collections = collections or "*"
collections_info = butler.collections.query_info(query_collections, include_summary=True)
collections_info = butler.collections.query_info(
query_collections, include_summary=True, summary_datasets=dataset_types
)
expanded_collections = [info.name for info in collections_info]
filtered_dataset_types = list(
butler.collections._filter_dataset_types([dt.name for dt in dataset_types], collections_info)
dataset_type_collections = butler.collections._group_by_dataset_type(
{dt.name for dt in dataset_types}, collections_info
)
if not filtered_dataset_types:
if not dataset_type_collections:
return (
None,
f"No datasets of type {datasets!r} existed in the specified "
f"collections {','.join(expanded_collections)}.",
)

sub_query = query.join_dataset_search(
filtered_dataset_types.pop(0), collections=expanded_collections
)
for dt in filtered_dataset_types:
sub_query = sub_query.join_dataset_search(dt, collections=expanded_collections)
for dt, dt_collections in dataset_type_collections.items():
query = query.join_dataset_search(dt, collections=dt_collections)

results = sub_query.data_ids(dimensions)
else:
results = query.data_ids(dimensions)
results = query.data_ids(dimensions)

if where:
results = results.where(where)
Expand Down
28 changes: 18 additions & 10 deletions python/lsst/daf/butler/script/queryDatasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,31 +240,39 @@ def getDatasets(self) -> Iterator[list[DatasetRef]]:
Dataset references matching the given query criteria grouped
by dataset type.
"""
datasetTypes = self._dataset_type_glob or ...
datasetTypes = self._dataset_type_glob
query_collections: Iterable[str] = self._collections_wildcard or ["*"]

# Currently need to use old interface to get all the matching
# dataset types and loop over the dataset types executing a new
# query each time.
dataset_types: set[str] = {d.name for d in self.butler.registry.queryDatasetTypes(datasetTypes)}
dataset_types = set(self.butler.registry.queryDatasetTypes(datasetTypes or ...))
n_dataset_types = len(dataset_types)
if n_dataset_types == 0:
_LOG.info("The given dataset type, %s, is not known to this butler.", datasetTypes)
return

# Expand the collections query and include summary information.
query_collections_info = self.butler.collections.query_info(query_collections, include_summary=True)
query_collections_info = self.butler.collections.query_info(
query_collections,
include_summary=True,
flatten_chains=True,
include_chains=False,
summary_datasets=dataset_types,
)
expanded_query_collections = [c.name for c in query_collections_info]
if self._find_first and set(query_collections) != set(expanded_query_collections):
raise RuntimeError("Can not use wildcards in collections when find_first=True")
query_collections = expanded_query_collections

# Only iterate over dataset types that are relevant for the query.
dataset_types = set(
self.butler.collections._filter_dataset_types(dataset_types, query_collections_info)
dataset_type_names = {dataset_type.name for dataset_type in dataset_types}
dataset_type_collections = self.butler.collections._group_by_dataset_type(
dataset_type_names, query_collections_info
)

if (n_filtered := len(dataset_types)) != n_dataset_types:
if (n_filtered := len(dataset_type_collections)) != n_dataset_types:
_LOG.info("Filtered %d dataset types down to %d", n_dataset_types, n_filtered)
elif n_dataset_types == 0:
_LOG.info("The given dataset type, %s, is not known to this butler.", datasetTypes)
else:
_LOG.info("Processing %d dataset type%s", n_dataset_types, "" if n_dataset_types == 1 else "s")

Expand All @@ -278,7 +286,7 @@ def getDatasets(self) -> Iterator[list[DatasetRef]]:
# possible dataset types to query.
warn_limit = True
limit = abs(limit) + 1 # +1 to tell us we hit the limit.
for dt in sorted(dataset_types):
for dt, collections in sorted(dataset_type_collections.items()):
kwargs: dict[str, Any] = {}
if self._where:
kwargs["where"] = self._where
Expand All @@ -288,7 +296,7 @@ def getDatasets(self) -> Iterator[list[DatasetRef]]:
_LOG.debug("Querying dataset type %s with %s", dt, kwargs)
results = self.butler.query_datasets(
dt,
collections=query_collections,
collections=collections,
find_first=self._find_first,
with_dimension_records=True,
order_by=self._order_by,
Expand Down
24 changes: 12 additions & 12 deletions python/lsst/daf/butler/script/queryDimensionRecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,21 @@ def queryDimensionRecords(

if datasets:
query_collections = collections or "*"
collections_info = butler.collections.query_info(query_collections, include_summary=True)
expanded_collections = [info.name for info in collections_info]
dataset_types = [dt.name for dt in butler.registry.queryDatasetTypes(datasets)]
dataset_types = list(butler.collections._filter_dataset_types(dataset_types, collections_info))

if not dataset_types:
dataset_types = butler.registry.queryDatasetTypes(datasets)
collections_info = butler.collections.query_info(
query_collections, include_summary=True, summary_datasets=dataset_types
)
dataset_type_collections = butler.collections._group_by_dataset_type(
{dt.name for dt in dataset_types}, collections_info
)

if not dataset_type_collections:
return None

sub_query = query.join_dataset_search(dataset_types.pop(0), collections=expanded_collections)
for dt in dataset_types:
sub_query = sub_query.join_dataset_search(dt, collections=expanded_collections)
for dt, dt_collections in dataset_type_collections.items():
query = query.join_dataset_search(dt, collections=dt_collections)

query_results = sub_query.dimension_records(element)
else:
query_results = query.dimension_records(element)
query_results = query.dimension_records(element)

if where:
query_results = query_results.where(where)
Expand Down
Loading

0 comments on commit 5c4a71f

Please sign in to comment.