From 52abca0f2f403a7439523f6453ef4bad400e657c Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Tue, 6 Aug 2024 15:28:10 -0700 Subject: [PATCH] Switch hybrid registry to use remote registry for queryDatasetAssociations. This enables testing of the remote registry implementation, which uncovered a problem with deserialization of general result pages. Had to add additional code to serialize/deserialize that thing correctly. --- python/lsst/daf/butler/column_spec.py | 49 +++++++++++++++++++ .../daf/butler/remote_butler/_query_driver.py | 20 ++++++++ .../server/handlers/_query_serialization.py | 21 ++++++++ .../daf/butler/remote_butler/server_models.py | 15 +++++- .../butler/tests/hybrid_butler_registry.py | 2 +- 5 files changed, 104 insertions(+), 3 deletions(-) diff --git a/python/lsst/daf/butler/column_spec.py b/python/lsst/daf/butler/column_spec.py index 28f93d55f2..0a53031fbb 100644 --- a/python/lsst/daf/butler/column_spec.py +++ b/python/lsst/daf/butler/column_spec.py @@ -53,6 +53,7 @@ from . import arrow_utils, ddl from ._timespan import Timespan +from .pydantic_utils import SerializableRegion, SerializableTime if TYPE_CHECKING: from .name_shrinker import NameShrinker @@ -134,6 +135,18 @@ def to_arrow(self) -> arrow_utils.ToArrow: """ raise NotImplementedError() + @abstractmethod + def type_adapter(self) -> pydantic.TypeAdapter: + """Return pydantic type adapter that converts values of this column to + or from serializable format. + + Returns + ------- + type_adapter : `pydantic.TypeAdapter` + A converter instance. + """ + raise NotImplementedError() + def display(self, level: int = 0, tab: str = " ") -> list[str]: """Return a human-reader-focused string description of this column as a list of lines. @@ -178,6 +191,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: # Docstring inherited. return arrow_utils.ToArrow.for_primitive(self.name, pa.uint64(), nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(self.pytype) + @final class StringColumnSpec(_BaseColumnSpec): @@ -198,6 +215,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: # Docstring inherited. return arrow_utils.ToArrow.for_primitive(self.name, pa.string(), nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(self.pytype) + @final class HashColumnSpec(_BaseColumnSpec): @@ -224,6 +245,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: nullable=self.nullable, ) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(self.pytype) + @final class FloatColumnSpec(_BaseColumnSpec): @@ -238,6 +263,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: assert self.nullable is not None, "nullable=None should be resolved by validators" return arrow_utils.ToArrow.for_primitive(self.name, pa.float64(), nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(self.pytype) + @final class BoolColumnSpec(_BaseColumnSpec): @@ -251,6 +280,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: # Docstring inherited. return arrow_utils.ToArrow.for_primitive(self.name, pa.bool_(), nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(self.pytype) + @final class UUIDColumnSpec(_BaseColumnSpec): @@ -265,6 +298,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: assert self.nullable is not None, "nullable=None should be resolved by validators" return arrow_utils.ToArrow.for_uuid(self.name, nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(self.pytype) + @final class RegionColumnSpec(_BaseColumnSpec): @@ -284,6 +321,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: assert self.nullable is not None, "nullable=None should be resolved by validators" return arrow_utils.ToArrow.for_region(self.name, nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(SerializableRegion) + @final class TimespanColumnSpec(_BaseColumnSpec): @@ -299,6 +340,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: # Docstring inherited. return arrow_utils.ToArrow.for_timespan(self.name, nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(self.pytype) + @final class DateTimeColumnSpec(_BaseColumnSpec): @@ -315,6 +360,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: assert self.nullable is not None, "nullable=None should be resolved by validators" return arrow_utils.ToArrow.for_datetime(self.name, nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(SerializableTime) + ColumnSpec = Annotated[ Union[ diff --git a/python/lsst/daf/butler/remote_butler/_query_driver.py b/python/lsst/daf/butler/remote_butler/_query_driver.py index 9106fa86c9..6151e06c51 100644 --- a/python/lsst/daf/butler/remote_butler/_query_driver.py +++ b/python/lsst/daf/butler/remote_butler/_query_driver.py @@ -66,6 +66,7 @@ from .server_models import ( AdditionalQueryInput, DataCoordinateUpload, + GeneralResultModel, MaterializedQuery, QueryAnyRequestModel, QueryAnyResponseModel, @@ -260,5 +261,24 @@ def _convert_query_result_page( spec=result_spec, rows=[DatasetRef.from_simple(r, universe) for r in result.rows], ) + elif result_spec.result_type == "general": + assert result.type == "general" + return _convert_general_result(result_spec, result) else: raise NotImplementedError(f"Unhandled result type {result_spec.result_type}") + + +def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) -> GeneralResultPage: + """Convert GeneralResultModel to a general result page.""" + columns = spec.get_result_columns() + type_adapters = [ + columns.get_column_spec(column.logical_table, column.field).type_adapter() for column in columns + ] + rows = [ + tuple( + value if value is None else type_adapter.validate_python(value) + for value, type_adapter in zip(row, type_adapters) + ) + for row in model.rows + ] + return GeneralResultPage(spec=spec, rows=rows) diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py index 6511abec24..0c8f4e1dc1 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py @@ -34,6 +34,7 @@ DataCoordinateResultPage, DatasetRefResultPage, DimensionRecordResultPage, + GeneralResultPage, ResultPage, ResultSpec, ) @@ -42,6 +43,7 @@ DataCoordinateResultModel, DatasetRefResultModel, DimensionRecordsResultModel, + GeneralResultModel, QueryErrorResultModel, QueryExecuteResultData, ) @@ -86,5 +88,24 @@ def _convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResul case "dataset_ref": assert isinstance(page, DatasetRefResultPage) return DatasetRefResultModel(rows=[ref.to_simple() for ref in page.rows]) + case "general": + assert isinstance(page, GeneralResultPage) + return _convert_general_result(page) case _: raise NotImplementedError(f"Unhandled query result type {spec.result_type}") + + +def _convert_general_result(page: GeneralResultPage) -> GeneralResultModel: + """Convert GeneralResultPage to a serializable model.""" + columns = page.spec.get_result_columns() + type_adapters = [ + columns.get_column_spec(column.logical_table, column.field).type_adapter() for column in columns + ] + rows = [ + tuple( + value if value is None else type_adapter.dump_python(value) + for value, type_adapter in zip(row, type_adapters) + ) + for row in page.rows + ] + return GeneralResultModel(rows=rows) diff --git a/python/lsst/daf/butler/remote_butler/server_models.py b/python/lsst/daf/butler/remote_butler/server_models.py index f8b9bf261c..876bfe9839 100644 --- a/python/lsst/daf/butler/remote_butler/server_models.py +++ b/python/lsst/daf/butler/remote_butler/server_models.py @@ -37,7 +37,7 @@ "GetCollectionSummaryResponseModel", ] -from typing import Annotated, Literal, NewType, TypeAlias +from typing import Annotated, Any, Literal, NewType, TypeAlias from uuid import UUID import pydantic @@ -276,6 +276,13 @@ class DatasetRefResultModel(pydantic.BaseModel): rows: list[SerializedDatasetRef] +class GeneralResultModel(pydantic.BaseModel): + """Result model for /query/execute/ when user requested general results.""" + + type: Literal["general"] = "general" + rows: list[tuple[Any, ...]] + + class QueryErrorResultModel(pydantic.BaseModel): """Result model for /query/execute when an error occurs part-way through returning rows. @@ -293,7 +300,11 @@ class QueryErrorResultModel(pydantic.BaseModel): QueryExecuteResultData: TypeAlias = Annotated[ - DataCoordinateResultModel | DimensionRecordsResultModel | DatasetRefResultModel | QueryErrorResultModel, + DataCoordinateResultModel + | DimensionRecordsResultModel + | DatasetRefResultModel + | GeneralResultModel + | QueryErrorResultModel, pydantic.Field(discriminator="type"), ] diff --git a/python/lsst/daf/butler/tests/hybrid_butler_registry.py b/python/lsst/daf/butler/tests/hybrid_butler_registry.py index 83effc6eca..919792a920 100644 --- a/python/lsst/daf/butler/tests/hybrid_butler_registry.py +++ b/python/lsst/daf/butler/tests/hybrid_butler_registry.py @@ -376,7 +376,7 @@ def queryDatasetAssociations( collectionTypes: Iterable[CollectionType] = CollectionType.all(), flattenChains: bool = False, ) -> Iterator[DatasetAssociation]: - return self._direct.queryDatasetAssociations( + return self._remote.queryDatasetAssociations( datasetType, collections, collectionTypes=collectionTypes, flattenChains=flattenChains )