Skip to content

Commit

Permalink
Support pydantic v1 and v2
Browse files Browse the repository at this point in the history
  • Loading branch information
timj committed Jul 19, 2023
1 parent 2210bd4 commit 06c375e
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions python/lsst/ci/middleware/repo_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@
from collections.abc import Iterable
from typing import Any, cast

try:
import pydantic.v1 as pydantic
except ModuleNotFoundError:
import pydantic # type: ignore

from lsst.daf.butler import (
Butler,
CollectionType,
Expand All @@ -45,6 +40,7 @@
SerializedDatasetType,
SerializedDimensionRecord,
)
from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat
from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.skymap import BaseSkyMap, DiscreteSkyMap
from lsst.sphgeom import ConvexPolygon
Expand Down Expand Up @@ -491,7 +487,23 @@ def make_skymap_instance(
return DiscreteSkyMap(config)


class InputDatasetTypes(pydantic.BaseModel):
if PYDANTIC_V2:
from pydantic import RootModel # type: ignore

class _InputDatasetTypes(RootModel):
root: dict[str, list[SerializedDatasetType]]

else:

class _InputDatasetTypes(_BaseModelCompat): # type: ignore
__root__: dict[str, list[SerializedDatasetType]]

@property
def root(self) -> dict[str, list[SerializedDatasetType]]:
return self.__root__


class InputDatasetTypes(_InputDatasetTypes):
"""Datasets types used as overall inputs by most mocked pipelines.
This is not expected to be exhaustive for all pipelines; it's a common
Expand All @@ -502,12 +514,10 @@ class InputDatasetTypes(pydantic.BaseModel):
datasets should be inserted into.
"""

__root__: dict[str, list[SerializedDatasetType]]

@property
def runs(self) -> Iterable[str]:
"""The RUN collections datasets should be written to."""
return self.__root__.keys() - {"REGISTER_ONLY"}
return self.root.keys() - {"REGISTER_ONLY"}

@classmethod
def read(
Expand All @@ -530,7 +540,7 @@ def read(
uri = ResourcePath(uri)
with uri.open() as stream:
data = json.load(stream)
return cls.parse_obj(data)
return cls.model_validate(data)

def resolve(self, universe: DimensionUniverse) -> dict[str, list[DatasetType]]:
"""Return dataset type objects with resolved dimensions.
Expand All @@ -548,7 +558,7 @@ def resolve(self, universe: DimensionUniverse) -> dict[str, list[DatasetType]]:
"""
return {
run: [DatasetType.from_simple(s, universe=universe) for s in serialized_dataset_types]
for run, serialized_dataset_types in self.__root__.items()
for run, serialized_dataset_types in self.root.items()
}


Expand Down

0 comments on commit 06c375e

Please sign in to comment.