From c46560474ede48b5563edea4593ef6d82c7a4778 Mon Sep 17 00:00:00 2001 From: Scott Staniewicz Date: Wed, 23 Aug 2023 23:28:30 -0400 Subject: [PATCH] upgrade to Pydantic v2 (#118) * bump pydantic reqs to 2.1 * update interferogram.py models to V2 * update config.py for V2 * fix new test errors, redo yaml dumping * fix failing cli test * add failing print_yaml_schema test * filter new pydantic userwarning for `print_yaml_schema` * ignore pge warnings * bump changelog version for impending release * remove pipes due to type subscriptable error --- .github/workflows/test-build-push.yml | 2 +- .pre-commit-config.yaml | 2 +- CHANGELOG.md | 9 +- conda-env.yml | 2 +- requirements.txt | 2 +- src/dolphin/interferogram.py | 141 +++++++++++------------- src/dolphin/workflows/_pge_runconfig.py | 32 +++--- src/dolphin/workflows/_yaml_model.py | 36 ++++-- src/dolphin/workflows/config.py | 108 +++++++++--------- src/dolphin/workflows/s1_disp.py | 6 +- src/dolphin/workflows/wrapped_phase.py | 4 +- tests/test_cli.py | 2 +- tests/test_workflows_config.py | 21 +++- tests/test_workflows_pge_runconfig.py | 9 +- 14 files changed, 199 insertions(+), 177 deletions(-) diff --git a/.github/workflows/test-build-push.yml b/.github/workflows/test-build-push.yml index 03f475b0..ec7463c2 100644 --- a/.github/workflows/test-build-push.yml +++ b/.github/workflows/test-build-push.yml @@ -32,7 +32,7 @@ jobs: numpy=1.20 numba=0.54 pillow==7.0 - pydantic=1.10 + pydantic=2.1 pymp-pypi=0.4.5 pyproj=3.3 rich=12.0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 56a96385..dd1189d1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -42,7 +42,7 @@ repos: additional_dependencies: - types-pkg_resources - types-requests - - "pydantic>=1.10.0,<2" + - "pydantic>=2.1" - repo: https://github.com/PyCQA/pydocstyle rev: "6.3.0" diff --git a/CHANGELOG.md b/CHANGELOG.md index 228cbe16..c49a750e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,15 +1,22 @@ # Unreleased +# [0.3.0](https://github.com/opera-adt/dolphin/compare/v0.2.0...v0.3.0) - 2023-08-23 + **Added** - Save a multilooked version of the PS mask for output inspection **Changed** -- Refectored the blockwise IO into `_blocks.py`. +- Pydantic models were upgraded to V2 +- Refactored the blockwise IO into `_blocks.py`. - The iteration now happens over the output grid for easier dilating/padding when using `strides` - New classes with `BlockIndices` and `BlockManager` for easier mangement of the different slices +**Dependencies** + +- pydantic >= 2.1 + # [0.2.0](https://github.com/opera-adt/dolphin/compare/v0.1.0...v0.2.0) - 2023-07-25 **Added** diff --git a/conda-env.yml b/conda-env.yml index 97c88929..78c27981 100644 --- a/conda-env.yml +++ b/conda-env.yml @@ -13,7 +13,7 @@ dependencies: - numba>=0.54 - numpy>=1.20 - pillow>=7.0 - - pydantic>=1.10,<2 + - pydantic>=2.1 - pymp-pypi>=0.4.5 - pyproj>=3.3 - rich>=12.0 diff --git a/requirements.txt b/requirements.txt index 8ef973e2..70e56090 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ h5py>=3.6 numba>=0.54 numpy>=1.20 pillow>=7.0 -pydantic>=1.10 +pydantic>=2.1 pymp-pypi>=0.4.5 pyproj>=3.3 rich>=12.0 diff --git a/src/dolphin/interferogram.py b/src/dolphin/interferogram.py index 6c2bbd4e..c5c87709 100644 --- a/src/dolphin/interferogram.py +++ b/src/dolphin/interferogram.py @@ -1,16 +1,21 @@ """Combine estimated DS phases with PS phases to form interferograms.""" from __future__ import annotations -import datetime import itertools from os import fspath from pathlib import Path -from typing import Iterable, Optional, Sequence, Tuple, Union +from typing import Iterable, Literal, Optional, Sequence, Union import numpy as np from numpy.typing import ArrayLike from osgeo import gdal -from pydantic import BaseModel, Extra, Field, root_validator, validator +from pydantic import ( + BaseModel, + Field, + FieldValidationInfo, + field_validator, + model_validator, +) from dolphin import io, utils from dolphin._log import get_log @@ -21,14 +26,14 @@ logger = get_log(__name__) -class VRTInterferogram(BaseModel): +class VRTInterferogram(BaseModel, extra="allow"): """Create an interferogram using a VRTDerivedRasterBand. Attributes ---------- - ref_slc : Union[str, Path] + ref_slc : Union[Path, str] Path to reference SLC file - sec_slc : Union[str, Path] + sec_slc : Union[Path, str] Path to secondary SLC file path : Optional[Path], optional Path to output interferogram. Defaults to Path('_.vrt'), @@ -60,14 +65,15 @@ class VRTInterferogram(BaseModel): ) ref_slc: Union[Path, str] = Field(..., description="Path to reference SLC file") sec_slc: Union[Path, str] = Field(..., description="Path to secondary SLC file") - outdir: Optional[Union[str, Path]] = Field( + outdir: Optional[Path] = Field( None, description=( - "Directory to place output interferogram. Defaults to the same directory as" - " `ref_slc`. If only `outdir` is specified, the output interferogram will" - " be named '_.vrt', where the dates are parsed from the" - " inputs. If `path` is specified, this is ignored." + "Directory to place output interferogram. Defaults to the same" + " directory as `ref_slc`. If only `outdir` is specified, the output" + " interferogram will be named '_.vrt', where the dates" + " are parsed from the inputs. If `path` is specified, this is ignored." ), + validate_default=True, ) path: Optional[Path] = Field( None, @@ -76,11 +82,12 @@ class VRTInterferogram(BaseModel): " dates are parsed from the input files, placed in the same directory as" " `ref_slc`." ), + validate_default=True, ) date_format: str = "%Y%m%d" write: bool = Field(True, description="Write the VRT file to disk") - pixel_function: str = "cmul" + pixel_function: Literal["cmul", "mul"] = "cmul" _template = """\ @@ -94,14 +101,11 @@ class VRTInterferogram(BaseModel): """ - dates: Optional[Tuple[datetime.date, datetime.date]] = None - class Config: - extra = Extra.forbid # raise error if extra fields passed in - - @validator("ref_slc", "sec_slc") - def _check_gdal_string(cls, v, values): - subdataset = values.get("subdataset") + @field_validator("ref_slc", "sec_slc") + @classmethod + def _check_gdal_string(cls, v: Union[Path, str], info: FieldValidationInfo): + subdataset = info.data.get("subdataset") # If we're using a subdataset, create a the GDAL-readable string gdal_str = io.format_nc_filename(v, subdataset) try: @@ -113,69 +117,53 @@ def _check_gdal_string(cls, v, values): # the file is absolute if ":" in str(gdal_str): try: - gdal_str = utils._resolve_gdal_path(gdal_str) + gdal_str = str(utils._resolve_gdal_path(gdal_str)) except Exception: # if the file had colons for some reason but # it didn't match, just ignore pass return gdal_str - @validator("pixel_function") - def _validate_pixel_func(cls, v): - if v not in ["mul", "cmul"]: - raise ValueError("pixel function must be 'mul' or 'cmul'") - return v.lower() - - @validator("outdir", always=True) - def _check_output_dir(cls, v, values): + @field_validator("outdir") + @classmethod + def _check_output_dir(cls, v, info: FieldValidationInfo): if v is not None: return Path(v) # If outdir is not set, use the directory of the reference SLC - ref_slc = values.get("ref_slc") + ref_slc = str(info.data.get("ref_slc")) return utils._get_path_from_gdal_str(ref_slc).parent - @validator("path", always=True) - def _remove_existing_file(cls, v, values): - if not v: - # No path was passed: try and make one. - # Form the output file name from the dates within input files - ref_slc, sec_slc = values.get("ref_slc"), values.get("sec_slc") - if not ref_slc or not sec_slc: - return v - - fmt = values.get("date_format", "%Y%m%d") - date1 = utils.get_dates(ref_slc, fmt=fmt)[0] - date2 = utils.get_dates(sec_slc, fmt=fmt)[0] - - outdir = values.get("outdir") - v = outdir / (io._format_date_pair(date1, date2, fmt) + ".vrt") - - if Path(v).exists(): - v.unlink() - return v - - @validator("dates") - def _check_dates_match(cls, v, values): - """Ensure passed dates match those parsed from the input files.""" - fmt = values.get("date_format", "%Y%m%d") - ref_slc, sec_slc = values.get("ref_slc"), values.get("sec_slc") + @model_validator(mode="after") + def _form_path(self) -> "VRTInterferogram": + """Create the filename (if not provided) from the provided SLCs.""" + ref_slc, sec_slc = self.ref_slc, self.sec_slc + + fmt = self.date_format date1 = utils.get_dates(ref_slc, fmt=fmt)[0] date2 = utils.get_dates(sec_slc, fmt=fmt)[0] - if v is not None: - if v != (date1, date2): - raise ValueError( - f"Dates {v} do not match dates parsed from input files: {date1}," - f" {date2}" - ) - @root_validator - def _validate_files(cls, values): + if self.path is not None: + return self + + if self.outdir is None: + # If outdir is not set, use the directory of the reference SLC + self.outdir = utils._get_path_from_gdal_str(ref_slc).parent + + path = self.outdir / (io._format_date_pair(date1, date2, fmt) + ".vrt") + if Path(path).exists(): + logger.info(f"Removing {path}") + path.unlink() + self.path = path + return self + + @model_validator(mode="after") + def _validate_files(self) -> "VRTInterferogram": """Check that the inputs are the same size and geotransform.""" - ref_slc = values.get("ref_slc") - sec_slc = values.get("sec_slc") + ref_slc = self.ref_slc + sec_slc = self.sec_slc if not ref_slc or not sec_slc: # Skip validation if files are not set - return values + return self ds1 = gdal.Open(fspath(ref_slc)) ds2 = gdal.Open(fspath(sec_slc)) xsize, ysize = ds1.RasterXSize, ds1.RasterYSize @@ -191,7 +179,7 @@ def _validate_files(cls, values): f"Input files {ref_slc} and {sec_slc} have different GeoTransforms" ) - return values + return self def __init__(self, **data): """Create a VRTInterferogram object and write the VRT file.""" @@ -248,16 +236,13 @@ def from_vrt_file(cls, path: Filename) -> "VRTInterferogram": if subdataset is not None: ref_slc = io.format_nc_filename(ref_slc, subdataset) sec_slc = io.format_nc_filename(sec_slc, subdataset) - # TODO: any good way/reason to store the date fmt? - date1 = utils.get_dates(ref_slc, fmt="%Y%m%d")[0] - date2 = utils.get_dates(sec_slc, fmt="%Y%m%d")[0] return cls.construct( ref_slc=ref_slc, sec_slc=sec_slc, path=Path(path).resolve(), subdataset=subdataset, - dates=(date1, date2), + date_format="%Y%m%d", ) @@ -272,13 +257,13 @@ class Network: list of dates corresponding to the SLCs. ifg_list : list[tuple[Filename, Filename]] list of `VRTInterferogram`s created from the SLCs. - max_bandwidth : Optional[int], optional + max_bandwidth : int | None, optional Maximum number of SLCs to include in an interferogram, by index distance. Defaults to None. max_temporal_baseline : Optional[float], optional Maximum temporal baseline to include in an interferogram, in days. Defaults to None. - reference_idx : Optional[int], optional + reference_idx : int | None, optional Index of the SLC to use as the reference for all interferograms. Defaults to None. """ @@ -287,9 +272,9 @@ def __init__( self, slc_list: Sequence[Filename], outdir: Optional[Filename] = None, - max_bandwidth: Optional[int] = None, + max_bandwidth: int | None = None, max_temporal_baseline: Optional[float] = None, - reference_idx: Optional[int] = None, + reference_idx: int | None = None, indexes: Optional[Sequence[tuple[int, int]]] = None, subdataset: Optional[Union[str, Sequence[str]]] = None, write: bool = True, @@ -303,13 +288,13 @@ def __init__( outdir : Optional[Filename], optional Directory to write the VRT files to. If not set, defaults to the directory of the reference SLC. - max_bandwidth : Optional[int], optional + max_bandwidth : int | None, optional Maximum number of SLCs to include in an interferogram, by index distance. Defaults to None. max_temporal_baseline : Optional[float] Maximum temporal baseline to include in an interferogram, in days. Defaults to None. - reference_idx : Optional[int] + reference_idx : int | None Index of the SLC to use as the reference for all interferograms. Defaults to None. indexes : Optional[Sequence[tuple[int, int]]] @@ -400,9 +385,9 @@ def __str__(self): @staticmethod def _make_ifg_pairs( slc_list: Sequence[Filename], - max_bandwidth: Optional[int] = None, + max_bandwidth: int | None = None, max_temporal_baseline: Optional[float] = None, - reference_idx: Optional[int] = None, + reference_idx: int | None = None, indexes: Optional[Sequence[tuple[int, int]]] = None, ) -> list[tuple]: """Form interferogram pairs from a list of SLC files sorted by date.""" diff --git a/src/dolphin/workflows/_pge_runconfig.py b/src/dolphin/workflows/_pge_runconfig.py index 68a0975b..73bd851e 100644 --- a/src/dolphin/workflows/_pge_runconfig.py +++ b/src/dolphin/workflows/_pge_runconfig.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import ClassVar, List, Optional -from pydantic import Extra, Field +from pydantic import ConfigDict, Field from ._yaml_model import YamlModel from .config import ( @@ -31,15 +31,12 @@ class InputFileGroup(YamlModel): ..., description="Frame ID of the bursts contained in `cslc_file_list`.", ) - - class Config: - """Pydantic config class.""" - - extra = Extra.forbid - schema_extra = {"required": ["cslc_file_list", "frame_id"]} + model_config = ConfigDict( + extra="forbid", json_schema_extra={"required": ["cslc_file_list", "frame_id"]} + ) -class DynamicAncillaryFileGroup(YamlModel, extra=Extra.forbid): +class DynamicAncillaryFileGroup(YamlModel, extra="forbid"): """A group of dynamic ancillary files.""" algorithm_parameters_file: Path = Field( # type: ignore @@ -101,7 +98,7 @@ class DynamicAncillaryFileGroup(YamlModel, extra=Extra.forbid): ) -class PrimaryExecutable(YamlModel, extra=Extra.forbid): +class PrimaryExecutable(YamlModel, extra="forbid"): """Group describing the primary executable.""" product_type: str = Field( @@ -110,7 +107,7 @@ class PrimaryExecutable(YamlModel, extra=Extra.forbid): ) -class ProductPathGroup(YamlModel, extra=Extra.forbid): +class ProductPathGroup(YamlModel, extra="forbid"): """Group describing the product paths.""" product_path: Path = Field( # type: ignore @@ -142,7 +139,7 @@ class ProductPathGroup(YamlModel, extra=Extra.forbid): ) -class AlgorithmParameters(YamlModel, extra=Extra.forbid): +class AlgorithmParameters(YamlModel, extra="forbid"): """Class containing all the other [`Workflow`][dolphin.workflows.config] classes.""" # Options for each step in the workflow @@ -159,7 +156,7 @@ class AlgorithmParameters(YamlModel, extra=Extra.forbid): ) -class RunConfig(YamlModel, extra=Extra.forbid): +class RunConfig(YamlModel, extra="forbid"): """A PGE run configuration.""" # Used for the top-level key @@ -178,9 +175,9 @@ class RunConfig(YamlModel, extra=Extra.forbid): description="Path to the output log file in addition to logging to stderr.", ) - # Override the constructor to allow recursively construct without validation + # Override the constructor to allow recursively model_construct without validation @classmethod - def construct(cls, **kwargs): + def model_construct(cls, **kwargs): if "input_file_group" not in kwargs: kwargs["input_file_group"] = InputFileGroup._construct_empty() if "dynamic_ancillary_file_group" not in kwargs: @@ -189,7 +186,7 @@ def construct(cls, **kwargs): ) if "product_path_group" not in kwargs: kwargs["product_path_group"] = ProductPathGroup._construct_empty() - return super().construct( + return super().model_construct( **kwargs, ) @@ -220,7 +217,7 @@ def to_workflow(self): algorithm_parameters = AlgorithmParameters.from_yaml( self.dynamic_ancillary_file_group.algorithm_parameters_file ) - param_dict = algorithm_parameters.dict() + param_dict = algorithm_parameters.model_dump() input_options = dict(subdataset=param_dict.pop("subdataset")) # This get's unpacked to load the rest of the parameters for the Workflow @@ -257,7 +254,8 @@ def from_workflow( This is mostly used as preliminary setup to further edit the fields. """ # Load the algorithm parameters from the file - alg_param_dict = workflow.dict(include=AlgorithmParameters.__fields__.keys()) + algo_keys = set(AlgorithmParameters.model_fields.keys()) + alg_param_dict = workflow.model_dump(include=algo_keys) AlgorithmParameters(**alg_param_dict).to_yaml(algorithm_parameters_file) # This get's unpacked to load the rest of the parameters for the Workflow diff --git a/src/dolphin/workflows/_yaml_model.py b/src/dolphin/workflows/_yaml_model.py index 785737aa..d4d2976f 100644 --- a/src/dolphin/workflows/_yaml_model.py +++ b/src/dolphin/workflows/_yaml_model.py @@ -1,6 +1,7 @@ import json import sys import textwrap +import warnings from io import StringIO from itertools import repeat from typing import Optional, TextIO, Union @@ -44,8 +45,9 @@ def to_yaml( if with_comments: _add_comments( yaml_obj, - self.schema(by_alias=by_alias), + self.model_json_schema(by_alias=by_alias), indent_per_level=indent_per_level, + pydantic_class=self.__class__, ) y = YAML() @@ -103,9 +105,13 @@ def print_yaml_schema( Number of spaces to indent per level. """ full_dict = cls._construct_empty() - cls.construct(**full_dict).to_yaml( - output_path, with_comments=True, indent_per_level=indent_per_level - ) + # UserWarning: Pydantic serializer warnings: + # New V2 warning, but seems harmless for just printing the schema + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + cls.model_construct(**full_dict).to_yaml( + output_path, with_comments=True, indent_per_level=indent_per_level + ) @classmethod def _construct_empty(cls): @@ -119,8 +125,10 @@ def _construct_empty(cls): so first, we manually make a dict of all the fields with None values then we update it with the default-filled values """ - all_none_vals = dict(zip(cls.schema()["properties"].keys(), repeat(None))) - all_none_vals.update(cls.construct().dict()) + all_none_vals = dict( + zip(cls.model_json_schema()["properties"].keys(), repeat(None)) + ) + all_none_vals.update(cls.model_construct().model_dump()) return all_none_vals def _to_yaml_obj(self, by_alias: bool = True) -> CommentedMap: @@ -128,7 +136,7 @@ def _to_yaml_obj(self, by_alias: bool = True) -> CommentedMap: # We can't just do `dumps` for some reason, need a stream y = YAML() ss = StringIO() - y.dump(json.loads(self.json(by_alias=by_alias)), ss) + y.dump(json.loads(self.model_dump_json(by_alias=by_alias)), ss) yaml_obj = y.load(ss.getvalue()) return yaml_obj @@ -140,11 +148,12 @@ def _add_comments( definitions: Optional[dict] = None, # variable specifying how much to indent per level indent_per_level: int = 2, + pydantic_class=None, ): """Add comments above each YAML field using the pydantic model schema.""" # Definitions are in schemas that contain nested pydantic Models if definitions is None: - definitions = schema.get("definitions") + definitions = schema.get("$defs") for key, val in schema["properties"].items(): reference = "" @@ -186,11 +195,18 @@ def _add_comments( subsequent_indent=" " * indent_per_level, ) ) - type_str = f"\n Type: {val['type']}." + if "anyOf" in val.keys(): + # 'anyOf': [{'type': 'string'}, {'type': 'null'}], + # Join the options with a pipe, like Python types + type_str = " | ".join(d["type"] for d in val["anyOf"]) + type_str.replace("null", "None") + else: + type_str = val["type"] + type_line = f"\n Type: {type_str}." choices = f"\n Options: {val['enum']}." if "enum" in val.keys() else "" # Combine the description/type/choices as the YAML comment - comment = f"{desc}{type_str}{choices}" + comment = f"{desc}{type_line}{choices}" comment = comment.replace("..", ".") # Remove double periods # Prepend the required label for fields that are required diff --git a/src/dolphin/workflows/config.py b/src/dolphin/workflows/config.py index 58e3eb98..ca4e02c2 100644 --- a/src/dolphin/workflows/config.py +++ b/src/dolphin/workflows/config.py @@ -6,7 +6,14 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple -from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator, validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + field_validator, + model_validator, +) from dolphin import __version__ as _dolphin_version from dolphin._log import get_log @@ -24,10 +31,7 @@ logger = get_log(__name__) # Specific to OPERA CSLC products: -# TODO: this will become f"/" in upcoming OPERA release -# We may want to keep ths old for compatibility for awhile? OPERA_DATASET_ROOT = "/" -# TODO: this will become f"{OPERA_DATASET_ROOT}/data/VV" OPERA_DATASET_NAME = f"{OPERA_DATASET_ROOT}/data/VV" OPERA_IDENTIFICATION = f"{OPERA_DATASET_ROOT}/identification" @@ -37,7 +41,7 @@ ) -class PsOptions(BaseModel, extra=Extra.forbid): +class PsOptions(BaseModel, extra="forbid"): """Options for the PS pixel selection portion of the workflow.""" _directory: Path = PrivateAttr(Path("PS")) @@ -52,7 +56,7 @@ class PsOptions(BaseModel, extra=Extra.forbid): ) -class HalfWindow(BaseModel, extra=Extra.forbid): +class HalfWindow(BaseModel, extra="forbid"): """Class to hold half-window size for multi-looking during phase linking.""" x: int = Field(11, description="Half window size (in pixels) for x direction", gt=0) @@ -68,7 +72,7 @@ def from_looks(cls, row_looks: int, col_looks: int): return cls(x=col_looks // 2, y=row_looks // 2) -class PhaseLinkingOptions(BaseModel, extra=Extra.forbid): +class PhaseLinkingOptions(BaseModel, extra="forbid"): """Configurable options for wrapped phase estimation.""" _directory: Path = PrivateAttr(Path("linked_phase")) @@ -94,7 +98,7 @@ class PhaseLinkingOptions(BaseModel, extra=Extra.forbid): ) -class InterferogramNetwork(BaseModel, extra=Extra.forbid): +class InterferogramNetwork(BaseModel, extra="forbid"): """Options to determine the type of network for interferogram formation.""" _directory: Path = PrivateAttr(Path("interferograms")) @@ -125,11 +129,11 @@ class InterferogramNetwork(BaseModel, extra=Extra.forbid): network_type: InterferogramNetworkType = InterferogramNetworkType.SINGLE_REFERENCE # validation - @root_validator - def _check_network_type(cls, values): - ref_idx = values.get("reference_idx") - max_bw = values.get("max_bandwidth") - max_tb = values.get("max_temporal_baseline") + @model_validator(mode="after") + def _check_network_type(self) -> "InterferogramNetwork": + ref_idx = self.reference_idx + max_bw = self.max_bandwidth + max_tb = self.max_temporal_baseline # Check if more than one has been set: if sum([ref_idx is not None, max_bw is not None, max_tb is not None]) > 1: raise ValueError( @@ -137,22 +141,22 @@ def _check_network_type(cls, values): " `max_temporal_baseline` can be set." ) if max_tb is not None: - values["network_type"] = InterferogramNetworkType.MAX_TEMPORAL_BASELINE - return values + self.network_type = InterferogramNetworkType.MAX_TEMPORAL_BASELINE + return self if max_bw is not None: - values["network_type"] = InterferogramNetworkType.MAX_BANDWIDTH - return values + self.network_type = InterferogramNetworkType.MAX_BANDWIDTH + return self # If nothing else specified, set to a single reference network - values["network_type"] = InterferogramNetworkType.SINGLE_REFERENCE + self.network_type = InterferogramNetworkType.SINGLE_REFERENCE # and make sure the reference index is set if ref_idx is None: - values["reference_idx"] = 0 - return values + self.reference_idx = 0 + return self -class UnwrapOptions(BaseModel, extra=Extra.forbid): +class UnwrapOptions(BaseModel, extra="forbid"): """Options for unwrapping after wrapped phase estimation.""" run_unwrap: bool = Field( @@ -163,7 +167,7 @@ class UnwrapOptions(BaseModel, extra=Extra.forbid): ) _directory: Path = PrivateAttr(Path("unwrapped")) unwrap_method: UnwrapMethod = UnwrapMethod.SNAPHU - tiles: List[int] = Field( + tiles: list[int] = Field( [1, 1], description=( "Number of tiles to split the unwrapping into (for multi-scale unwrapping)." @@ -175,7 +179,7 @@ class UnwrapOptions(BaseModel, extra=Extra.forbid): ) -class WorkerSettings(BaseModel, extra=Extra.forbid): +class WorkerSettings(BaseModel, extra="forbid"): """Settings for controlling CPU/GPU settings and parallelism.""" gpu_enabled: bool = Field( @@ -212,7 +216,7 @@ class WorkerSettings(BaseModel, extra=Extra.forbid): ) -class InputOptions(BaseModel, extra=Extra.forbid): +class InputOptions(BaseModel, extra="forbid"): """Options specifying input datasets for workflow.""" subdataset: Optional[str] = Field( @@ -229,7 +233,7 @@ class InputOptions(BaseModel, extra=Extra.forbid): ) -class OutputOptions(BaseModel, extra=Extra.forbid): +class OutputOptions(BaseModel, extra="forbid"): """Options for the output size/format/compressions.""" output_resolution: Optional[Dict[str, int]] = Field( @@ -246,6 +250,7 @@ class OutputOptions(BaseModel, extra=Extra.forbid): " strides of [4, 2] would turn an input resolution of [5, 10] into an" " output resolution of [20, 20]." ), + validate_default=True, ) bounds: Optional[Bbox] = Field( None, @@ -268,17 +273,12 @@ class OutputOptions(BaseModel, extra=Extra.forbid): ) # validators - @validator("output_resolution", "strides", pre=True, always=True) - def _check_resolution(cls, v): - """Allow the user to specify just one float, applying to both dimensions.""" - if isinstance(v, (int, float)): - return {"x": v, "y": v} - return v - - @validator("strides", always=True) - def _check_strides_against_res(cls, strides, values): + + @field_validator("strides") + @classmethod + def _check_strides_against_res(cls, strides, info): """Compute the output resolution from the strides.""" - resolution = values.get("output_resolution") + resolution = info.data.get("output_resolution") if strides is not None and resolution is not None: raise ValueError("Cannot specify both strides and output_resolution.") elif strides is None and resolution is None: @@ -329,10 +329,12 @@ class Workflow(YamlModel): scratch_directory: Path = Field( Path("scratch"), description="Name of sub-directory to use for scratch files", + validate_default=True, ) output_directory: Path = Field( Path("output"), description="Name of sub-directory to use for output files", + validate_default=True, ) # Options for each step in the workflow @@ -376,24 +378,23 @@ class Workflow(YamlModel): creation_time_utc: datetime = Field( default_factory=datetime.utcnow, description="Time the config file was created" ) - dolphin_version: str = Field( - _dolphin_version, description="Version of Dolphin used." - ) + _dolphin_version: str = PrivateAttr(_dolphin_version) # internal helpers # Stores the list of directories to be created by the workflow _directory_list: List[Path] = PrivateAttr(default_factory=list) + model_config = ConfigDict( + extra="forbid", json_schema_extra={"required": ["cslc_file_list"]} + ) - class Config: - extra = Extra.forbid - schema_extra = {"required": ["cslc_file_list"]} - - @validator("output_directory", "scratch_directory", always=True) - def _dir_is_absolute(cls, v): + @field_validator("output_directory", "scratch_directory") + @classmethod + def _make_dir_absolute(cls, v: Path): return v.resolve() # validators - @validator("cslc_file_list", pre=True) + @field_validator("cslc_file_list", mode="before") + @classmethod def _check_input_file_list(cls, v): if v is None: return [] @@ -422,13 +423,13 @@ def _is_opera_file_list(cslc_file_list): re.search(OPERA_BURST_RE, str(f)) is not None for f in cslc_file_list ) - @root_validator - def _check_slc_files_exist(cls, values): - file_list = values.get("cslc_file_list") + @model_validator(mode="after") + def _check_slc_files_exist(self) -> "Workflow": + file_list = self.cslc_file_list if not file_list: raise ValueError("Must specify list of input SLC files.") - input_options = values.get("input_options") + input_options = self.input_options date_fmt = input_options.cslc_date_fmt # Filter out files that don't have dates in the filename file_matching_date = [Path(f) for f in file_list if get_dates(f, fmt=date_fmt)] @@ -443,7 +444,7 @@ def _check_slc_files_exist(cls, values): if ext in [".h5", ".nc"]: subdataset = input_options.subdataset if subdataset is None: - if cls._is_opera_file_list(file_list): + if self._is_opera_file_list(file_list): # Assume that the user forgot to set the subdataset, and set it to the # default OPERA dataset name logger.info( @@ -457,9 +458,10 @@ def _check_slc_files_exist(cls, values): ) # Coerce the file_list to a sorted list of Path objects - file_list, _ = sort_files_by_date(file_list, file_date_fmt=date_fmt) - values["cslc_file_list"] = [Path(f) for f in file_list] - return values + self.cslc_file_list = [ + Path(f) for f in sort_files_by_date(file_list, file_date_fmt=date_fmt)[0] + ] + return self def __init__(self, *args: Any, **kwargs: Any) -> None: """After validation, set up properties for use during workflow run.""" diff --git a/src/dolphin/workflows/s1_disp.py b/src/dolphin/workflows/s1_disp.py index 88502580..a7e37942 100755 --- a/src/dolphin/workflows/s1_disp.py +++ b/src/dolphin/workflows/s1_disp.py @@ -39,7 +39,7 @@ def run( """ # Set the logging level for all `dolphin.` modules logger = get_log(name="dolphin", debug=debug, filename=cfg.log_file) - logger.debug(pformat(cfg.dict())) + logger.debug(pformat(cfg.model_dump())) cfg.create_dir_tree(debug=debug) set_num_threads(cfg.worker_settings.threads_per_worker) @@ -171,7 +171,7 @@ def run( # Print the maximum memory usage for each worker max_mem = get_max_memory_usage(units="GB") logger.info(f"Maximum memory usage: {max_mem:.2f} GB") - logger.info(f"Config file dolphin version: {cfg.dolphin_version}") + logger.info(f"Config file dolphin version: {cfg._dolphin_version}") logger.info(f"Current running dolphin version: {__version__}") @@ -182,7 +182,7 @@ def _create_burst_cfg( grouped_amp_mean_files: dict[str, list[Path]], grouped_amp_dispersion_files: dict[str, list[Path]], ) -> Workflow: - cfg_temp_dict = cfg.copy(deep=True, exclude={"cslc_file_list"}).dict() + cfg_temp_dict = cfg.model_dump(exclude={"cslc_file_list"}) # Just update the inputs and the scratch directory top_level_scratch = cfg_temp_dict["scratch_directory"] diff --git a/src/dolphin/workflows/wrapped_phase.py b/src/dolphin/workflows/wrapped_phase.py index b763bdf5..28979658 100644 --- a/src/dolphin/workflows/wrapped_phase.py +++ b/src/dolphin/workflows/wrapped_phase.py @@ -118,7 +118,7 @@ def run(cfg: Workflow, debug: bool = False) -> tuple[list[Path], Path, Path, Pat single.run_wrapped_phase_single( slc_vrt_file=vrt_stack.outfile, output_folder=pl_path, - half_window=cfg.phase_linking.half_window.dict(), + half_window=cfg.phase_linking.half_window.model_dump(), strides=strides, reference_idx=0, beta=cfg.phase_linking.beta, @@ -139,7 +139,7 @@ def run(cfg: Workflow, debug: bool = False) -> tuple[list[Path], Path, Path, Pat sequential.run_wrapped_phase_sequential( slc_vrt_file=vrt_stack.outfile, output_folder=pl_path, - half_window=cfg.phase_linking.half_window.dict(), + half_window=cfg.phase_linking.half_window.model_dump(), strides=strides, beta=cfg.phase_linking.beta, ministack_size=cfg.phase_linking.ministack_size, diff --git a/tests/test_cli.py b/tests/test_cli.py index a59c159e..6f0c317f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -33,7 +33,7 @@ def test_cli_config_basic(tmpdir, slc_file_list): [ "config", "--n-workers", - 1, + "1", "--slc-files", *list(map(str, slc_file_list)), ] diff --git a/tests/test_workflows_config.py b/tests/test_workflows_config.py index 9a01d063..83abfdde 100644 --- a/tests/test_workflows_config.py +++ b/tests/test_workflows_config.py @@ -15,7 +15,7 @@ def test_half_window_defaults(): hw = config.HalfWindow() assert hw.x == 11 assert hw.y == 5 - assert hw.dict() == dict(x=11, y=5) + assert hw.model_dump() == dict(x=11, y=5) def test_half_window_to_looks(): @@ -166,7 +166,7 @@ def test_input_find_slcs(slc_file_list_nc): opts2 = config.Workflow( cslc_file_list=cslc_dir / "slclist.txt", input_options={"subdataset": "data"} ) - opts2.dict() == opts.dict() + opts2.model_dump() == opts.model_dump() def test_input_glob_pattern(slc_file_list_nc): @@ -338,8 +338,8 @@ def test_config_roundtrip_dict(dir_with_1_slc): cslc_file_list=dir_with_1_slc / "slclist.txt", input_options={"subdataset": "data"}, ) - c_dict = c.dict() - c2 = config.Workflow.parse_obj(c_dict) + c_dict = c.model_dump() + c2 = config.Workflow(**c_dict) assert c == c2 @@ -348,8 +348,8 @@ def test_config_roundtrip_json(dir_with_1_slc): cslc_file_list=dir_with_1_slc / "slclist.txt", input_options={"subdataset": "data"}, ) - c_json = c.json() - c2 = config.Workflow.parse_raw(c_json) + c_json = c.model_dump_json() + c2 = config.Workflow.model_validate_json(c_json) assert c == c2 @@ -373,3 +373,12 @@ def test_config_roundtrip_yaml_with_comments(tmp_path, dir_with_1_slc): c.to_yaml(outfile, with_comments=True) c2 = config.Workflow.from_yaml(outfile) assert c == c2 + + +def test_config_print_yaml_schema(tmp_path, dir_with_1_slc): + outfile = tmp_path / "empty_schema.yaml" + c = config.Workflow( + cslc_file_list=dir_with_1_slc / "slclist.txt", + input_options={"subdataset": "data"}, + ) + c.print_yaml_schema(outfile) diff --git a/tests/test_workflows_pge_runconfig.py b/tests/test_workflows_pge_runconfig.py index 36512fc3..9a6b62b7 100644 --- a/tests/test_workflows_pge_runconfig.py +++ b/tests/test_workflows_pge_runconfig.py @@ -1,4 +1,5 @@ import sys +import warnings import pytest @@ -13,11 +14,15 @@ def test_algorithm_parameters_schema(): - AlgorithmParameters.print_yaml_schema() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + AlgorithmParameters.print_yaml_schema() def test_run_config_schema(): - RunConfig.print_yaml_schema() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + RunConfig.print_yaml_schema() @pytest.fixture