Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to mask a subset of input SLCs #386

Merged
merged 6 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion src/dolphin/masking.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from __future__ import annotations

import logging
import tempfile
from enum import IntEnum
from os import fspath
from pathlib import Path
from typing import Optional, Sequence

import numpy as np
from osgeo import gdal
from pyproj import CRS, Transformer
from shapely import to_geojson
from shapely.geometry import box

from dolphin import io
from dolphin._types import PathOrStr
from dolphin._types import Bbox, PathOrStr

gdal.UseExceptions()

Expand Down Expand Up @@ -158,3 +163,72 @@ def load_mask_as_numpy(mask_file: PathOrStr) -> np.ndarray:
# invert the mask so Trues are the missing data pixels
nodata_mask = ~nodata_mask
return nodata_mask


def create_bounds_mask(
bounds: Bbox | tuple[float, float, float, float],
output_filename: PathOrStr,
like_filename: PathOrStr,
bounds_epsg: int = 4326,
overwrite: bool = False,
) -> None:
"""Create a boolean raster mask where 1 is inside the given bounds and 0 is outside.

Parameters
----------
bounds : tuple
(min x, min y, max x, max y) of the area of interest
like_filename : Filename
Reference file to copy the shape, extent, and projection.
output_filename : Filename
Output filename for the mask
bounds_epsg : int, optional
EPSG code of the coordinate system of the bounds.
Default is 4326 (lat/lon coordinates for the bounds).
overwrite : bool, optional
Overwrite the output file if it already exists, by default False

"""
if Path(output_filename).exists():
if not overwrite:
logger.info(f"Skipping {output_filename} since it already exists.")
return
else:
logger.info(f"Overwriting {output_filename} since overwrite=True.")
Path(output_filename).unlink()

# Transform bounds if necessary
# Geojson default is 4326, and GDAL handles the conversion to, e.g., UTM
if bounds_epsg != 4326:
transformer = Transformer.from_crs(
CRS.from_epsg(bounds_epsg), 4326, always_xy=True
)
bounds = transformer.transform_bounds(*bounds)

logger.info(f"Creating mask for bounds {bounds}")

# Create a polygon from the bounds
bounds_poly = box(*bounds)

# Create the output raster
io.write_arr(
arr=None,
output_name=output_filename,
dtype=bool,
nbands=1,
like_filename=like_filename,
)

with tempfile.TemporaryDirectory() as tmpdir:
temp_vector_file = Path(tmpdir) / "temp.geojson"
with open(temp_vector_file, "w") as f:
f.write(to_geojson(bounds_poly))

# Open the input vector file
src_ds = gdal.OpenEx(fspath(temp_vector_file), gdal.OF_VECTOR)
dst_ds = gdal.Open(fspath(output_filename), gdal.GA_Update)

# Now burn in the union of all polygons
gdal.Rasterize(dst_ds, src_ds, burnValues=[1])

logger.info(f"Created {output_filename}")
9 changes: 1 addition & 8 deletions src/dolphin/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,9 @@ def create_ps(
# Initialize the intermediate arrays for the calculation
magnitude = np.zeros((reader.shape[0], *block_shape), dtype=np.float32)

skip_empty = nodata_mask is None

writer = io.BackgroundBlockWriter()
# Make the generator for the blocks
block_gen = EagerLoader(
reader,
block_shape=block_shape,
nodata_mask=nodata_mask,
skip_empty=skip_empty,
)
block_gen = EagerLoader(reader, block_shape=block_shape, nodata_mask=nodata_mask)
for cur_data, (rows, cols) in block_gen.iter_blocks(**tqdm_kwargs):
cur_rows, cur_cols = cur_data.shape[-2:]

Expand Down
5 changes: 5 additions & 0 deletions src/dolphin/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,12 +511,14 @@ def create_velocity(
msg += f"{len(cor_file_list) = }, but {len(unw_file_list) = }"
raise ValueError(msg)

logger.info("Using correlation to weight velocity fit")
cor_reader = io.VRTStack(
file_list=cor_file_list,
outfile=out_dir / "cor_inputs.vrt",
skip_size_check=True,
)
else:
logger.info("Using unweighted fit for velocity.")
cor_reader = None

# Read in the reference point
Expand Down Expand Up @@ -568,6 +570,7 @@ def read_and_fit(
if add_overviews:
logger.info("Creating overviews for velocity image")
create_overviews([output_file])
logger.info("Completed create_velocity")


class AverageFunc(Protocol):
Expand Down Expand Up @@ -763,6 +766,8 @@ def read_and_solve(
if add_overviews:
logger.info("Creating overviews for unwrapped images")
create_overviews(out_paths, image_type=ImageType.UNWRAPPED)

logger.info("Completed invert_unw_network")
return out_paths


Expand Down
9 changes: 9 additions & 0 deletions src/dolphin/workflows/_cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def create_config(
baseline_lag: Optional[int] = None,
amp_dispersion_threshold: float = 0.25,
strides: tuple[int, int],
output_bounds: tuple[int, int, int, int],
block_shape: tuple[int, int] = (512, 512),
threads_per_worker: int = 4,
n_parallel_bursts: int = 1,
Expand Down Expand Up @@ -91,6 +92,7 @@ def create_config(
interferogram_network=interferogram_network,
output_options={
"strides": {"x": strides[0], "y": strides[1]},
"bounds": output_bounds,
},
phase_linking={
"ministack_size": ministack_size,
Expand Down Expand Up @@ -367,6 +369,13 @@ def get_parser(subparser=None, subcommand_name="run"):
" output shape."
),
)
out_group.add_argument(
"--output-bounds",
nargs=4,
type=float,
metavar=("left", "bottom", "right", "top"),
help="Requested bounding box (in lat/lon) for final output.",
)

worker_group = parser.add_argument_group("Worker options")
worker_group.add_argument(
Expand Down
96 changes: 72 additions & 24 deletions src/dolphin/workflows/wrapped_phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from opera_utils import get_dates, make_nodata_mask

from dolphin import interferogram, ps, stack
from dolphin import Bbox, Filename, interferogram, masking, ps, stack
from dolphin._log import log_runtime, setup_logging
from dolphin.io import VRTStack

Expand Down Expand Up @@ -73,6 +73,28 @@ def run(
outfile=cfg.work_directory / "slc_stack.vrt",
)

# Mark any files beginning with "compressed" as compressed
is_compressed = ["compressed" in str(f).lower() for f in input_file_list]
input_dates = _get_input_dates(
input_file_list, is_compressed, cfg.input_options.cslc_date_fmt
)
reference_date, reference_idx = _get_reference_date_idx(
input_file_list, is_compressed, input_dates
)

non_compressed_slcs = [
f for f, is_comp in zip(input_file_list, is_compressed) if not is_comp
]

# Create a mask file from input bounding polygons and/or specified output bounds
mask_filename = _get_mask(
output_dir=cfg.work_directory,
output_bounds=cfg.output_options.bounds,
like_filename=vrt_stack.outfile,
cslc_file_list=non_compressed_slcs,
)

nodata_mask = masking.load_mask_as_numpy(mask_filename) if mask_filename else None
# ###############
# PS selection
# ###############
Expand All @@ -97,6 +119,7 @@ def run(
output_amp_dispersion_file=cfg.ps_options._amp_dispersion_file,
amp_dispersion_threshold=cfg.ps_options.amp_dispersion_threshold,
existing_amp_dispersion_file=existing_disp,
nodata_mask=nodata_mask,
existing_amp_mean_file=existing_amp,
block_shape=cfg.worker_settings.block_shape,
**kwargs,
Expand All @@ -116,15 +139,6 @@ def run(
pl_path = cfg.phase_linking._directory
pl_path.mkdir(parents=True, exist_ok=True)

# Mark any files beginning with "compressed" as compressed
is_compressed = ["compressed" in str(f).lower() for f in input_file_list]
input_dates = _get_input_dates(
input_file_list, is_compressed, cfg.input_options.cslc_date_fmt
)
reference_date, reference_idx = _get_reference_date_idx(
input_file_list, is_compressed, input_dates
)

ministack_planner = stack.MiniStackPlanner(
file_list=input_file_list,
dates=input_dates,
Expand All @@ -135,19 +149,6 @@ def run(
reference_idx=reference_idx,
)

# Make the nodata mask from the polygons, if we're using OPERA CSLCs
non_compressed_slcs = [
f for f, is_comp in zip(input_file_list, is_compressed) if not is_comp
]
try:
nodata_mask_file = cfg.work_directory / "nodata_mask.tif"
make_nodata_mask(
non_compressed_slcs, out_file=nodata_mask_file, buffer_pixels=200
)
except Exception as e:
logger.warning(f"Could not make nodata mask: {e}")
nodata_mask_file = None

phase_linked_slcs = sorted(pl_path.glob("2*.tif"))
if len(phase_linked_slcs) > 0:
logger.info(f"Skipping EVD step, {len(phase_linked_slcs)} files already exist")
Expand All @@ -171,7 +172,7 @@ def run(
strides=strides,
use_evd=cfg.phase_linking.use_evd,
beta=cfg.phase_linking.beta,
mask_file=nodata_mask_file,
mask_file=mask_filename,
ps_mask_file=ps_output,
amp_mean_file=cfg.ps_options._amp_mean_file,
amp_dispersion_file=cfg.ps_options._amp_dispersion_file,
Expand Down Expand Up @@ -394,3 +395,50 @@ def _get_input_dates(
dates[:1] if not is_comp else dates
for dates, is_comp in zip(input_dates, is_compressed)
]


def _get_mask(
output_dir: Path,
output_bounds: Bbox | tuple[float, float, float, float] | None,
like_filename: Filename,
cslc_file_list: Sequence[Filename],
) -> Path | None:
# Make the nodata mask from the polygons, if we're using OPERA CSLCs

try:
nodata_mask_file = output_dir / "nodata_mask.tif"
make_nodata_mask(
opera_file_list=cslc_file_list,
out_file=nodata_mask_file,
buffer_pixels=200,
)
except Exception as e:
logger.warning(f"Could not make nodata mask: {e}")
nodata_mask_file = None

mask_filename: Path | None = None
# Also mask outside the area of interest if we've specified a small bounds
if output_bounds is not None:
# Make a mask just from the bounds
bounds_mask_filename = output_dir / "bounds_mask.tif"
masking.create_bounds_mask(
bounds=output_bounds,
output_filename=bounds_mask_filename,
like_filename=like_filename,
)

# Then combine with the nodata mask
if nodata_mask_file is not None:
combined_mask_filename = output_dir / "combined_mask.tif"
masking.combine_mask_files(
mask_files=[bounds_mask_filename, nodata_mask_file],
output_file=combined_mask_filename,
output_convention=masking.MaskConvention.ZERO_IS_NODATA,
)
mask_filename = combined_mask_filename
else:
mask_filename = bounds_mask_filename
else:
mask_filename = nodata_mask_file

return mask_filename
Binary file added tests/data/dummy_like.tif.zip
Binary file not shown.
36 changes: 35 additions & 1 deletion tests/test_masking.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import zipfile
from pathlib import Path

import numpy as np
import pytest

from dolphin import io, masking
from dolphin import Bbox, io, masking


@pytest.fixture()
Expand Down Expand Up @@ -42,3 +43,36 @@ def test_load_mask_as_numpy(mask_files):
expected = np.ones((9, 9), dtype=bool)
expected[:3] = False
np.testing.assert_array_equal(arr, expected)


@pytest.fixture
def like_filename_zipped():
return Path(__file__).parent / "data/dummy_like.tif.zip"


def test_bounds(tmp_path, like_filename_zipped):
# Unzip to tmp_path
with zipfile.ZipFile(like_filename_zipped, "r") as zip_ref:
zip_ref.extractall(tmp_path)

# Get the path of the extracted TIF file
extracted_tif = tmp_path / "dummy_like.tif"

output_filename = tmp_path / "mask_bounds.tif"
bounds = Bbox(
left=-122.90334860812246,
bottom=51.7323987260125,
right=-122.68416491724179,
top=51.95333755674119,
)
masking.create_bounds_mask(
bounds, like_filename=extracted_tif, output_filename=output_filename
)
# Check result
mask = io.load_gdal(output_filename)
assert (mask[1405:3856, 9681:12685] == 1).all()
# WGS84 box is not a box in UTM
assert (mask[:1400, :] == 0).all()
assert (mask[4000:, :] == 0).all()
assert (mask[:, :9500] == 0).all()
assert (mask[:, 13000:] == 0).all()