Skip to content

Commit

Permalink
Merge pull request #876 from lsst/tickets/DM-38498
Browse files Browse the repository at this point in the history
DM-38498: improve follow-up query support for QG generation use cases
  • Loading branch information
TallJimbo committed Aug 24, 2023
2 parents 7c9a229 + fd2474a commit 1585a23
Show file tree
Hide file tree
Showing 11 changed files with 531 additions and 56 deletions.
1 change: 1 addition & 0 deletions doc/changes/DM-38498.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve support for finding calibrations and spatially-joined datasets as follow-ups to data ID queries.
37 changes: 32 additions & 5 deletions python/lsst/daf/butler/core/dimensions/_coordinate.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,12 @@ def fromFullValues(graph: DimensionGraph, values: tuple[DataIdValue, ...]) -> Da
return _BasicTupleDataCoordinate(graph, values)

def __hash__(self) -> int:
return hash((self.graph,) + tuple(self[d.name] for d in self.graph.required))
return hash((self.graph,) + self.values_tuple())

def __eq__(self, other: Any) -> bool:
if not isinstance(other, DataCoordinate):
other = DataCoordinate.standardize(other, universe=self.universe)
return self.graph == other.graph and all(self[d.name] == other[d.name] for d in self.graph.required)
return self.graph == other.graph and self.values_tuple() == other.values_tuple()

def __repr__(self) -> str:
# We can't make repr yield something that could be exec'd here without
Expand Down Expand Up @@ -513,6 +513,7 @@ def hasFull(self) -> bool:
raise NotImplementedError()

@property
@abstractmethod
def full(self) -> NamedKeyMapping[Dimension, DataIdValue]:
"""Return mapping for all dimensions in ``self.graph``.
Expand All @@ -524,8 +525,17 @@ def full(self) -> NamedKeyMapping[Dimension, DataIdValue]:
when implied keys are accessed via the returned mapping, depending on
the implementation and whether assertions are enabled.
"""
assert self.hasFull(), "full may only be accessed if hasFull() returns True."
return _DataCoordinateFullView(self)
raise NotImplementedError()

@abstractmethod
def values_tuple(self) -> tuple[DataIdValue, ...]:
"""Return the required values (only) of this data ID as a tuple.
In contexts where all data IDs have the same dimensions, comparing and
hashing these tuples can be *much* faster than comparing the original
`DataCoordinate` instances.
"""
raise NotImplementedError()

@abstractmethod
def hasRecords(self) -> bool:
Expand Down Expand Up @@ -779,7 +789,7 @@ class _DataCoordinateFullView(NamedKeyMapping[Dimension, DataIdValue]):
The `DataCoordinate` instance this object provides a view of.
"""

def __init__(self, target: DataCoordinate):
def __init__(self, target: _BasicTupleDataCoordinate):
self._target = target

__slots__ = ("_target",)
Expand Down Expand Up @@ -892,6 +902,13 @@ def __getitem__(self, key: DataIdKey) -> DataIdValue:
# values for the required ones.
raise KeyError(key) from None

def byName(self) -> dict[str, DataIdValue]:
# Docstring inheritance.
# Reimplementation is for optimization; `values_tuple()` is much faster
# to iterate over than values() because it doesn't go through
# `__getitem__`.
return dict(zip(self.names, self.values_tuple(), strict=True))

def subset(self, graph: DimensionGraph) -> DataCoordinate:
# Docstring inherited from DataCoordinate.
if self._graph == graph:
Expand Down Expand Up @@ -933,6 +950,12 @@ def union(self, other: DataCoordinate) -> DataCoordinate:
values.update(other.full.byName() if other.hasFull() else other.byName())
return DataCoordinate.standardize(values, graph=graph)

@property
def full(self) -> NamedKeyMapping[Dimension, DataIdValue]:
# Docstring inherited.
assert self.hasFull(), "full may only be accessed if hasFull() returns True."
return _DataCoordinateFullView(self)

def expanded(
self, records: NameLookupMapping[DimensionElement, DimensionRecord | None]
) -> DataCoordinate:
Expand All @@ -954,6 +977,10 @@ def hasRecords(self) -> bool:
# Docstring inherited from DataCoordinate.
return False

def values_tuple(self) -> tuple[DataIdValue, ...]:
# Docstring inherited from DataCoordinate.
return self._values[: len(self._graph.required)]

def _record(self, name: str) -> DimensionRecord | None:
# Docstring inherited from DataCoordinate.
raise AssertionError()
Expand Down
9 changes: 8 additions & 1 deletion python/lsst/daf/butler/core/dimensions/_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ def to_simple(self, minimal: bool = False) -> SerializedDimensionRecord:
# query. This may not be overly useful since to reconstruct
# a collection of records will require repeated registry queries.
# For now do not implement minimal form.
key = (id(self.definition), self.dataId)
cache = PersistenceContextVars.serializedDimensionRecordMapping.get()
if cache is not None and (result := cache.get(key)) is not None:
return result

mapping = {name: getattr(self, name) for name in self.__slots__}
# If the item in mapping supports simplification update it
Expand All @@ -360,7 +364,10 @@ def to_simple(self, minimal: bool = False) -> SerializedDimensionRecord:
# hash objects, encode it here to a hex string
mapping[k] = v.hex()
definition = self.definition.to_simple(minimal=minimal)
return SerializedDimensionRecord(definition=definition, record=mapping)
dimRec = SerializedDimensionRecord(definition=definition, record=mapping)
if cache is not None:
cache[key] = dimRec
return dimRec

@classmethod
def from_simple(
Expand Down
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/core/persistenceContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class PersistenceContextVars:
"""

serializedDimensionRecordMapping: ContextVar[
dict[tuple[str, frozenset], SerializedDimensionRecord] | None
dict[tuple[str, frozenset] | tuple[int, DataCoordinate], SerializedDimensionRecord] | None
] = ContextVar("serializedDimensionRecordMapping", default=None)
r"""A cache of `SerializedDimensionRecord`\ s.
"""
Expand Down
47 changes: 34 additions & 13 deletions python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,9 +728,9 @@ def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) ->
sqlalchemy.sql.select(
dataset.columns.id.label("dataset_id"),
dataset.columns.dataset_type_id.label("dataset_type_id"),
tmp_tags.columns.dataset_type_id.label("new dataset_type_id"),
tmp_tags.columns.dataset_type_id.label("new_dataset_type_id"),
dataset.columns[self._runKeyColumn].label("run"),
tmp_tags.columns[collFkName].label("new run"),
tmp_tags.columns[collFkName].label("new_run"),
)
.select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id))
.where(
Expand All @@ -742,21 +742,38 @@ def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) ->
.limit(1)
)
with self._db.query(query) as result:
# Only include the first one in the exception message
if (row := result.first()) is not None:
# Only include the first one in the exception message
raise ConflictingDefinitionError(
f"Existing dataset type or run do not match new dataset: {row._asdict()}"
)
existing_run = self._collections[row.run].name
new_run = self._collections[row.new_run].name
if row.dataset_type_id == self._dataset_type_id:
if row.new_dataset_type_id == self._dataset_type_id:
raise ConflictingDefinitionError(
f"Current run {existing_run!r} and new run {new_run!r} do not agree for "
f"dataset {row.dataset_id}."
)
else:
raise ConflictingDefinitionError(
f"Dataset {row.dataset_id} was provided with type {self.datasetType.name!r} "
f"in run {new_run!r}, but was already defined with type ID {row.dataset_type_id} "
f"in run {run!r}."
)
else:
raise ConflictingDefinitionError(
f"Dataset {row.dataset_id} was provided with type ID {row.new_dataset_type_id} "
f"in run {new_run!r}, but was already defined with type {self.datasetType.name!r} "
f"in run {run!r}."
)

# Check that matching dataset in tags table has the same DataId.
query = (
sqlalchemy.sql.select(
tags.columns.dataset_id,
tags.columns.dataset_type_id.label("type_id"),
tmp_tags.columns.dataset_type_id.label("new type_id"),
tmp_tags.columns.dataset_type_id.label("new_type_id"),
*[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
*[
tmp_tags.columns[dim].label(f"new {dim}")
tmp_tags.columns[dim].label(f"new_{dim}")
for dim in self.datasetType.dimensions.required.names
],
)
Expand All @@ -783,12 +800,11 @@ def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) ->
# Check that matching run+dataId have the same dataset ID.
query = (
sqlalchemy.sql.select(
tags.columns.dataset_type_id.label("dataset_type_id"),
*[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
tags.columns.dataset_id,
tmp_tags.columns.dataset_id.label("new dataset_id"),
tmp_tags.columns.dataset_id.label("new_dataset_id"),
tags.columns[collFkName],
tmp_tags.columns[collFkName].label(f"new {collFkName}"),
tmp_tags.columns[collFkName].label(f"new_{collFkName}"),
)
.select_from(
tags.join(
Expand All @@ -807,8 +823,13 @@ def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) ->
.limit(1)
)
with self._db.query(query) as result:
# only include the first one in the exception message
if (row := result.first()) is not None:
# only include the first one in the exception message
data_id = {dim: getattr(row, dim) for dim in self.datasetType.dimensions.required.names}
existing_collection = self._collections[getattr(row, collFkName)].name
new_collection = self._collections[getattr(row, f"new_{collFkName}")].name
raise ConflictingDefinitionError(
f"Existing dataset type and dataId does not match new dataset: {row._asdict()}"
f"Dataset with type {self.datasetType.name!r} and data ID {data_id} "
f"has ID {row.dataset_id} in existing collection {existing_collection!r} "
f"but ID {row.new_dataset_id} in new collection {new_collection!r}."
)
Loading

0 comments on commit 1585a23

Please sign in to comment.