Skip to content

Commit

Permalink
Avoid using DimensionGraph->DimensionGroup transition helpers.
Browse files Browse the repository at this point in the history
  • Loading branch information
TallJimbo committed Jul 25, 2024
1 parent 1c47f99 commit 6c802e8
Show file tree
Hide file tree
Showing 12 changed files with 31 additions and 46 deletions.
4 changes: 2 additions & 2 deletions python/lsst/daf/butler/datastore/file_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ def validateTemplate(self, entity: DatasetRef | DatasetType | StorageClass | Non
# Fall back to dataId keys if we have them but no links.
# dataId keys must still be present in the template
try:
minimal = set(entity.dimensions.required.names)
minimal = set(entity.dimensions.required)
maximal = set(entity.dimensions.names)
except AttributeError:
try:
Expand Down Expand Up @@ -888,5 +888,5 @@ def _determine_skypix_alias(self, entity: DatasetRef | DatasetType) -> str | Non
# update to be more like the real world while still providing our
# only tests of important behavior.
if len(entity.dimensions.skypix) == 1:
(alias,) = entity.dimensions.skypix.names
(alias,) = entity.dimensions.skypix
return alias
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/dimensions/_coordinate.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def __init__(self, target: DataCoordinate):
__slots__ = ("_target",)

def __repr__(self) -> str:
terms = [f"{d}: {self[d]!r}" for d in self._target.dimensions.elements.names]
terms = [f"{d}: {self[d]!r}" for d in self._target.dimensions.elements]
return "{{{}}}".format(", ".join(terms))

def __str__(self) -> str:
Expand Down
31 changes: 13 additions & 18 deletions python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ def _buildCalibOverlapQuery(
)
if data_ids is not None:
relation = relation.join(
context.make_data_id_relation(
data_ids, self.datasetType.dimensions.required.names
).transferred_to(context.sql_engine),
context.make_data_id_relation(data_ids, self.datasetType.dimensions.required).transferred_to(
context.sql_engine
),
)
return relation

Expand Down Expand Up @@ -314,9 +314,7 @@ def decertify(
calib_pkey_tag = DatasetColumnTag(self.datasetType.name, "calib_pkey")
dataset_id_tag = DatasetColumnTag(self.datasetType.name, "dataset_id")
timespan_tag = DatasetColumnTag(self.datasetType.name, "timespan")
data_id_tags = [
(name, DimensionKeyColumnTag(name)) for name in self.datasetType.dimensions.required.names
]
data_id_tags = [(name, DimensionKeyColumnTag(name)) for name in self.datasetType.dimensions.required]
# Set up collections to populate with the rows we'll want to modify.
# The insert rows will have the same values for collection and
# dataset type.
Expand Down Expand Up @@ -509,7 +507,7 @@ def _finish_single_relation(
value=collection_col,
)
# Add more column definitions, starting with the data ID.
for dimension_name in self.datasetType.dimensions.required.names:
for dimension_name in self.datasetType.dimensions.required:
payload.columns_available[DimensionKeyColumnTag(dimension_name)] = payload.from_clause.columns[
dimension_name
]
Expand Down Expand Up @@ -692,7 +690,7 @@ def _finish_query_builder(
collections, collection_col
)
# Add more column definitions, starting with the data ID.
sql_projection.joiner.extract_dimensions(self.datasetType.dimensions.required.names)
sql_projection.joiner.extract_dimensions(self.datasetType.dimensions.required)
# We can always get the dataset_id from the tags/calibs table, even if
# could also get it from the 'static' dataset table.
if "dataset_id" in fields:
Expand Down Expand Up @@ -794,7 +792,7 @@ def getDataId(self, id: DatasetId) -> DataCoordinate:
assert row is not None, "Should be guaranteed by caller and foreign key constraints."
return DataCoordinate.from_required_values(
self.datasetType.dimensions,
tuple(row[dimension] for dimension in self.datasetType.dimensions.required.names),
tuple(row[dimension] for dimension in self.datasetType.dimensions.required),
)

def refresh_collection_summaries(self) -> None:
Expand Down Expand Up @@ -1062,19 +1060,16 @@ def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) ->
tags.columns.dataset_id,
tags.columns.dataset_type_id.label("type_id"),
tmp_tags.columns.dataset_type_id.label("new_type_id"),
*[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
*[
tmp_tags.columns[dim].label(f"new_{dim}")
for dim in self.datasetType.dimensions.required.names
],
*[tags.columns[dim] for dim in self.datasetType.dimensions.required],
*[tmp_tags.columns[dim].label(f"new_{dim}") for dim in self.datasetType.dimensions.required],
)
.select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id))
.where(
sqlalchemy.sql.or_(
tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
*[
tags.columns[dim] != tmp_tags.columns[dim]
for dim in self.datasetType.dimensions.required.names
for dim in self.datasetType.dimensions.required
],
)
)
Expand All @@ -1091,7 +1086,7 @@ def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) ->
# Check that matching run+dataId have the same dataset ID.
query = (
sqlalchemy.sql.select(
*[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
*[tags.columns[dim] for dim in self.datasetType.dimensions.required],
tags.columns.dataset_id,
tmp_tags.columns.dataset_id.label("new_dataset_id"),
tags.columns[collFkName],
Expand All @@ -1105,7 +1100,7 @@ def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) ->
tags.columns[collFkName] == tmp_tags.columns[collFkName],
*[
tags.columns[dim] == tmp_tags.columns[dim]
for dim in self.datasetType.dimensions.required.names
for dim in self.datasetType.dimensions.required
],
),
)
Expand All @@ -1116,7 +1111,7 @@ def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) ->
with self._db.query(query) as result:
# only include the first one in the exception message
if (row := result.first()) is not None:
data_id = {dim: getattr(row, dim) for dim in self.datasetType.dimensions.required.names}
data_id = {dim: getattr(row, dim) for dim in self.datasetType.dimensions.required}
existing_collection = self._collections[getattr(row, collFkName)].name
new_collection = self._collections[getattr(row, f"new_{collFkName}")].name
raise ConflictingDefinitionError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def makeTagTableSpec(
target=(collectionFieldSpec.name, "dataset_type_id"),
)
)
for dimension_name in datasetType.dimensions.required.names:
for dimension_name in datasetType.dimensions.required:
dimension = datasetType.dimensions.universe.dimensions[dimension_name]
fieldSpec = addDimensionForeignKey(
tableSpec, dimension=dimension, nullable=False, primaryKey=False, constraint=constraints
Expand Down Expand Up @@ -439,7 +439,7 @@ def makeCalibTableSpec(
)
)
# Add dimension fields (part of the temporal lookup index.constraint).
for dimension_name in datasetType.dimensions.required.names:
for dimension_name in datasetType.dimensions.required:
dimension = datasetType.dimensions.universe.dimensions[dimension_name]
fieldSpec = addDimensionForeignKey(tableSpec, dimension=dimension, nullable=False, primaryKey=False)
index.append(fieldSpec.name)
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/daf/butler/registry/dimensions/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def _make_common_skypix_join_relation(
payload.columns_available[DimensionKeyColumnTag(self.universe.commonSkyPix.name)] = (
payload.from_clause.columns.skypix_index
)
for dimension_name in element.minimal_group.required.names:
for dimension_name in element.minimal_group.required:
payload.columns_available[DimensionKeyColumnTag(dimension_name)] = payload.from_clause.columns[
dimension_name
]
Expand Down Expand Up @@ -596,7 +596,7 @@ def _make_skypix_overlap_tables(
"skypix_system",
"skypix_level",
"skypix_index",
*element.minimal_group.required.names,
*element.minimal_group.required,
),
},
foreignKeys=[
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/daf/butler/registry/queries/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def finish(self, joinMissing: bool = True) -> Query:
for family1, family2 in itertools.combinations(self.summary.dimensions.spatial, 2):
spatial_joins.append(
(
family1.choose(self.summary.dimensions.elements.names, self.summary.universe).name,
family2.choose(self.summary.dimensions.elements.names, self.summary.universe).name,
family1.choose(self.summary.dimensions.elements, self.summary.universe).name,
family2.choose(self.summary.dimensions.elements, self.summary.universe).name,
)
)
self.relation = self._backend.make_dimension_relation(
Expand Down
10 changes: 3 additions & 7 deletions python/lsst/daf/butler/registry/queries/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,20 +746,16 @@ def find_datasets(
# present in each family (e.g. patch beats tract).
spatial_joins.append(
(
lhs_spatial_family.choose(
full_dimensions.elements.names, self.dimensions.universe
).name,
rhs_spatial_family.choose(
full_dimensions.elements.names, self.dimensions.universe
).name,
lhs_spatial_family.choose(full_dimensions.elements, self.dimensions.universe).name,
rhs_spatial_family.choose(full_dimensions.elements, self.dimensions.universe).name,
)
)
# Set up any temporal join between the query dimensions and CALIBRATION
# collection's validity ranges.
temporal_join_on: set[ColumnTag] = set()
if any(r.type is CollectionType.CALIBRATION for r in collection_records):
for family in self._dimensions.temporal:
endpoint = family.choose(self._dimensions.elements.names, self.dimensions.universe)
endpoint = family.choose(self._dimensions.elements, self.dimensions.universe)
temporal_join_on.add(DimensionRecordColumnTag(endpoint.name, "timespan"))
base_columns_required.update(temporal_join_on)
# Note which of the many kinds of potentially-missing columns we have
Expand Down
4 changes: 1 addition & 3 deletions python/lsst/daf/butler/registry/queries/_query_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,7 @@ def make_doomed_dataset_relation(
relation : `lsst.daf.relation.Relation`
Relation with the requested columns and no rows.
"""
column_tags: set[ColumnTag] = set(
DimensionKeyColumnTag.generate(dataset_type.dimensions.required.names)
)
column_tags: set[ColumnTag] = set(DimensionKeyColumnTag.generate(dataset_type.dimensions.required))
column_tags.update(DatasetColumnTag.generate(dataset_type.name, columns))
return context.preferred_engine.make_doomed_relation(columns=column_tags, messages=list(messages))

Expand Down
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/registry/queries/_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class _BasicDataCoordinateReader(DataCoordinateReader):

def __init__(self, dimensions: DimensionGroup):
self._dimensions = dimensions
self._tags = tuple(DimensionKeyColumnTag(name) for name in self._dimensions.required.names)
self._tags = tuple(DimensionKeyColumnTag(name) for name in self._dimensions.required)

__slots__ = ("_dimensions", "_tags")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def make_dimension_relation(
"out if they have already been added or will be added later."
)
for element_name in missing_columns.dimension_records:
if element_name not in dimensions.elements.names:
if element_name not in dimensions.elements:
raise ColumnError(
f"Cannot join dimension element {element_name} whose dimensions are not a "
f"subset of {dimensions}."
Expand Down
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/registry/queries/_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def _compute_columns_required(
missing_common_skypix = False
if region is not None:
for family in dimensions.spatial:
element = family.choose(dimensions.elements.names, self.universe)
element = family.choose(dimensions.elements, self.universe)
tags.add(DimensionRecordColumnTag(element.name, "region"))
if (
not isinstance(element, SkyPixDimension)
Expand Down
8 changes: 2 additions & 6 deletions python/lsst/daf/butler/registry/tests/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,9 +1211,7 @@ def testInstrumentDimensions(self):
(ref,) = registry.insertDatasets(rawType, dataIds=[dataId], run=run2)
registry.associate(tagged2, [ref])

dimensions = registry.dimensions.conform(
rawType.dimensions.required.names | calexpType.dimensions.required.names
)
dimensions = registry.dimensions.conform(rawType.dimensions.required | calexpType.dimensions.required)
# Test that single dim string works as well as list of str
rows = registry.queryDataIds("visit", datasets=rawType, collections=run1).expanded().toSet()
rowsI = registry.queryDataIds(["visit"], datasets=rawType, collections=run1).expanded().toSet()
Expand Down Expand Up @@ -1337,9 +1335,7 @@ def testSkyMapDimensions(self):
registry.registerDatasetType(measType)

dimensions = registry.dimensions.conform(
calexpType.dimensions.required.names
| mergeType.dimensions.required.names
| measType.dimensions.required.names
calexpType.dimensions.required | mergeType.dimensions.required | measType.dimensions.required
)

# add pre-existing datasets
Expand Down

0 comments on commit 6c802e8

Please sign in to comment.