Skip to content

Commit

Permalink
upgrade to Pydantic v2 (#118)
Browse files Browse the repository at this point in the history
* bump pydantic reqs to 2.1

* update interferogram.py models to V2

* update config.py for V2

* fix new test errors, redo yaml dumping

* fix failing cli test

* add failing print_yaml_schema test

* filter new pydantic userwarning for `print_yaml_schema`

* ignore pge warnings

* bump changelog version for impending release

* remove pipes due to type subscriptable error
  • Loading branch information
scottstanie authored Aug 24, 2023
1 parent f9f2f87 commit c465604
Show file tree
Hide file tree
Showing 14 changed files with 199 additions and 177 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-build-push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
numpy=1.20
numba=0.54
pillow==7.0
pydantic=1.10
pydantic=2.1
pymp-pypi=0.4.5
pyproj=3.3
rich=12.0
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ repos:
additional_dependencies:
- types-pkg_resources
- types-requests
- "pydantic>=1.10.0,<2"
- "pydantic>=2.1"

- repo: https://github.com/PyCQA/pydocstyle
rev: "6.3.0"
Expand Down
9 changes: 8 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
# Unreleased

# [0.3.0](https://github.com/opera-adt/dolphin/compare/v0.2.0...v0.3.0) - 2023-08-23

**Added**

- Save a multilooked version of the PS mask for output inspection

**Changed**

- Refectored the blockwise IO into `_blocks.py`.
- Pydantic models were upgraded to V2
- Refactored the blockwise IO into `_blocks.py`.
- The iteration now happens over the output grid for easier dilating/padding when using `strides`
- New classes with `BlockIndices` and `BlockManager` for easier mangement of the different slices

**Dependencies**

- pydantic >= 2.1

# [0.2.0](https://github.com/opera-adt/dolphin/compare/v0.1.0...v0.2.0) - 2023-07-25

**Added**
Expand Down
2 changes: 1 addition & 1 deletion conda-env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:
- numba>=0.54
- numpy>=1.20
- pillow>=7.0
- pydantic>=1.10,<2
- pydantic>=2.1
- pymp-pypi>=0.4.5
- pyproj>=3.3
- rich>=12.0
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ h5py>=3.6
numba>=0.54
numpy>=1.20
pillow>=7.0
pydantic>=1.10
pydantic>=2.1
pymp-pypi>=0.4.5
pyproj>=3.3
rich>=12.0
Expand Down
141 changes: 63 additions & 78 deletions src/dolphin/interferogram.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
"""Combine estimated DS phases with PS phases to form interferograms."""
from __future__ import annotations

import datetime
import itertools
from os import fspath
from pathlib import Path
from typing import Iterable, Optional, Sequence, Tuple, Union
from typing import Iterable, Literal, Optional, Sequence, Union

import numpy as np
from numpy.typing import ArrayLike
from osgeo import gdal
from pydantic import BaseModel, Extra, Field, root_validator, validator
from pydantic import (
BaseModel,
Field,
FieldValidationInfo,
field_validator,
model_validator,
)

from dolphin import io, utils
from dolphin._log import get_log
Expand All @@ -21,14 +26,14 @@
logger = get_log(__name__)


class VRTInterferogram(BaseModel):
class VRTInterferogram(BaseModel, extra="allow"):
"""Create an interferogram using a VRTDerivedRasterBand.
Attributes
----------
ref_slc : Union[str, Path]
ref_slc : Union[Path, str]
Path to reference SLC file
sec_slc : Union[str, Path]
sec_slc : Union[Path, str]
Path to secondary SLC file
path : Optional[Path], optional
Path to output interferogram. Defaults to Path('<date1>_<date2>.vrt'),
Expand Down Expand Up @@ -60,14 +65,15 @@ class VRTInterferogram(BaseModel):
)
ref_slc: Union[Path, str] = Field(..., description="Path to reference SLC file")
sec_slc: Union[Path, str] = Field(..., description="Path to secondary SLC file")
outdir: Optional[Union[str, Path]] = Field(
outdir: Optional[Path] = Field(
None,
description=(
"Directory to place output interferogram. Defaults to the same directory as"
" `ref_slc`. If only `outdir` is specified, the output interferogram will"
" be named '<date1>_<date2>.vrt', where the dates are parsed from the"
" inputs. If `path` is specified, this is ignored."
"Directory to place output interferogram. Defaults to the same"
" directory as `ref_slc`. If only `outdir` is specified, the output"
" interferogram will be named '<date1>_<date2>.vrt', where the dates"
" are parsed from the inputs. If `path` is specified, this is ignored."
),
validate_default=True,
)
path: Optional[Path] = Field(
None,
Expand All @@ -76,11 +82,12 @@ class VRTInterferogram(BaseModel):
" dates are parsed from the input files, placed in the same directory as"
" `ref_slc`."
),
validate_default=True,
)
date_format: str = "%Y%m%d"
write: bool = Field(True, description="Write the VRT file to disk")

pixel_function: str = "cmul"
pixel_function: Literal["cmul", "mul"] = "cmul"
_template = """\
<VRTDataset rasterXSize="{xsize}" rasterYSize="{ysize}">
<VRTRasterBand dataType="CFloat32" band="1" subClass="VRTDerivedRasterBand">
Expand All @@ -94,14 +101,11 @@ class VRTInterferogram(BaseModel):
</VRTRasterBand>
</VRTDataset>
"""
dates: Optional[Tuple[datetime.date, datetime.date]] = None

class Config:
extra = Extra.forbid # raise error if extra fields passed in

@validator("ref_slc", "sec_slc")
def _check_gdal_string(cls, v, values):
subdataset = values.get("subdataset")
@field_validator("ref_slc", "sec_slc")
@classmethod
def _check_gdal_string(cls, v: Union[Path, str], info: FieldValidationInfo):
subdataset = info.data.get("subdataset")
# If we're using a subdataset, create a the GDAL-readable string
gdal_str = io.format_nc_filename(v, subdataset)
try:
Expand All @@ -113,69 +117,53 @@ def _check_gdal_string(cls, v, values):
# the file is absolute
if ":" in str(gdal_str):
try:
gdal_str = utils._resolve_gdal_path(gdal_str)
gdal_str = str(utils._resolve_gdal_path(gdal_str))
except Exception:
# if the file had colons for some reason but
# it didn't match, just ignore
pass
return gdal_str

@validator("pixel_function")
def _validate_pixel_func(cls, v):
if v not in ["mul", "cmul"]:
raise ValueError("pixel function must be 'mul' or 'cmul'")
return v.lower()

@validator("outdir", always=True)
def _check_output_dir(cls, v, values):
@field_validator("outdir")
@classmethod
def _check_output_dir(cls, v, info: FieldValidationInfo):
if v is not None:
return Path(v)
# If outdir is not set, use the directory of the reference SLC
ref_slc = values.get("ref_slc")
ref_slc = str(info.data.get("ref_slc"))
return utils._get_path_from_gdal_str(ref_slc).parent

@validator("path", always=True)
def _remove_existing_file(cls, v, values):
if not v:
# No path was passed: try and make one.
# Form the output file name from the dates within input files
ref_slc, sec_slc = values.get("ref_slc"), values.get("sec_slc")
if not ref_slc or not sec_slc:
return v

fmt = values.get("date_format", "%Y%m%d")
date1 = utils.get_dates(ref_slc, fmt=fmt)[0]
date2 = utils.get_dates(sec_slc, fmt=fmt)[0]

outdir = values.get("outdir")
v = outdir / (io._format_date_pair(date1, date2, fmt) + ".vrt")

if Path(v).exists():
v.unlink()
return v

@validator("dates")
def _check_dates_match(cls, v, values):
"""Ensure passed dates match those parsed from the input files."""
fmt = values.get("date_format", "%Y%m%d")
ref_slc, sec_slc = values.get("ref_slc"), values.get("sec_slc")
@model_validator(mode="after")
def _form_path(self) -> "VRTInterferogram":
"""Create the filename (if not provided) from the provided SLCs."""
ref_slc, sec_slc = self.ref_slc, self.sec_slc

fmt = self.date_format
date1 = utils.get_dates(ref_slc, fmt=fmt)[0]
date2 = utils.get_dates(sec_slc, fmt=fmt)[0]
if v is not None:
if v != (date1, date2):
raise ValueError(
f"Dates {v} do not match dates parsed from input files: {date1},"
f" {date2}"
)

@root_validator
def _validate_files(cls, values):
if self.path is not None:
return self

if self.outdir is None:
# If outdir is not set, use the directory of the reference SLC
self.outdir = utils._get_path_from_gdal_str(ref_slc).parent

path = self.outdir / (io._format_date_pair(date1, date2, fmt) + ".vrt")
if Path(path).exists():
logger.info(f"Removing {path}")
path.unlink()
self.path = path
return self

@model_validator(mode="after")
def _validate_files(self) -> "VRTInterferogram":
"""Check that the inputs are the same size and geotransform."""
ref_slc = values.get("ref_slc")
sec_slc = values.get("sec_slc")
ref_slc = self.ref_slc
sec_slc = self.sec_slc
if not ref_slc or not sec_slc:
# Skip validation if files are not set
return values
return self
ds1 = gdal.Open(fspath(ref_slc))
ds2 = gdal.Open(fspath(sec_slc))
xsize, ysize = ds1.RasterXSize, ds1.RasterYSize
Expand All @@ -191,7 +179,7 @@ def _validate_files(cls, values):
f"Input files {ref_slc} and {sec_slc} have different GeoTransforms"
)

return values
return self

def __init__(self, **data):
"""Create a VRTInterferogram object and write the VRT file."""
Expand Down Expand Up @@ -248,16 +236,13 @@ def from_vrt_file(cls, path: Filename) -> "VRTInterferogram":
if subdataset is not None:
ref_slc = io.format_nc_filename(ref_slc, subdataset)
sec_slc = io.format_nc_filename(sec_slc, subdataset)
# TODO: any good way/reason to store the date fmt?
date1 = utils.get_dates(ref_slc, fmt="%Y%m%d")[0]
date2 = utils.get_dates(sec_slc, fmt="%Y%m%d")[0]

return cls.construct(
ref_slc=ref_slc,
sec_slc=sec_slc,
path=Path(path).resolve(),
subdataset=subdataset,
dates=(date1, date2),
date_format="%Y%m%d",
)


Expand All @@ -272,13 +257,13 @@ class Network:
list of dates corresponding to the SLCs.
ifg_list : list[tuple[Filename, Filename]]
list of `VRTInterferogram`s created from the SLCs.
max_bandwidth : Optional[int], optional
max_bandwidth : int | None, optional
Maximum number of SLCs to include in an interferogram, by index distance.
Defaults to None.
max_temporal_baseline : Optional[float], optional
Maximum temporal baseline to include in an interferogram, in days.
Defaults to None.
reference_idx : Optional[int], optional
reference_idx : int | None, optional
Index of the SLC to use as the reference for all interferograms.
Defaults to None.
"""
Expand All @@ -287,9 +272,9 @@ def __init__(
self,
slc_list: Sequence[Filename],
outdir: Optional[Filename] = None,
max_bandwidth: Optional[int] = None,
max_bandwidth: int | None = None,
max_temporal_baseline: Optional[float] = None,
reference_idx: Optional[int] = None,
reference_idx: int | None = None,
indexes: Optional[Sequence[tuple[int, int]]] = None,
subdataset: Optional[Union[str, Sequence[str]]] = None,
write: bool = True,
Expand All @@ -303,13 +288,13 @@ def __init__(
outdir : Optional[Filename], optional
Directory to write the VRT files to.
If not set, defaults to the directory of the reference SLC.
max_bandwidth : Optional[int], optional
max_bandwidth : int | None, optional
Maximum number of SLCs to include in an interferogram, by index distance.
Defaults to None.
max_temporal_baseline : Optional[float]
Maximum temporal baseline to include in an interferogram, in days.
Defaults to None.
reference_idx : Optional[int]
reference_idx : int | None
Index of the SLC to use as the reference for all interferograms.
Defaults to None.
indexes : Optional[Sequence[tuple[int, int]]]
Expand Down Expand Up @@ -400,9 +385,9 @@ def __str__(self):
@staticmethod
def _make_ifg_pairs(
slc_list: Sequence[Filename],
max_bandwidth: Optional[int] = None,
max_bandwidth: int | None = None,
max_temporal_baseline: Optional[float] = None,
reference_idx: Optional[int] = None,
reference_idx: int | None = None,
indexes: Optional[Sequence[tuple[int, int]]] = None,
) -> list[tuple]:
"""Form interferogram pairs from a list of SLC files sorted by date."""
Expand Down
Loading

0 comments on commit c465604

Please sign in to comment.