Skip to content

Commit

Permalink
Implement overlap trimming for blockwise IO (#105)
Browse files Browse the repository at this point in the history
- double the half window size for overlap
- change the `single.py` outputs to use a trimmed slice
- fix tests accordingly

This should close #73
  • Loading branch information
scottstanie authored Jul 25, 2023
1 parent 97bc534 commit 10ee9f5
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 26 deletions.
20 changes: 10 additions & 10 deletions src/dolphin/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,12 +823,12 @@ def _slice_iterator(
arr_shape : tuple[int, int]
(num_rows, num_cols), full size of array to access
block_shape : tuple[int, int]
(height, width), size of accessing blocks
overlaps : tuple[int, int]
(row_overlap, col_overlap), number of pixels to re-include
after sliding the block (default (0, 0))
start_offsets : tuple[int, int]
Offsets to start reading from (default (0, 0))
(height, width), size of blocks to load
overlaps : tuple[int, int], default = (0, 0)
(row_overlap, col_overlap), number of pixels to re-include from
the previous block after sliding
start_offsets : tuple[int, int], default = (0, 0)
Offsets from top left to start reading from
Yields
------
Expand Down Expand Up @@ -859,10 +859,10 @@ def _slice_iterator(
width = cols

# Check we're not moving backwards with the overlap:
if row_overlap >= height:
raise ValueError(f"row_overlap {row_overlap} must be less than {height}")
if col_overlap >= width:
raise ValueError(f"col_overlap {col_overlap} must be less than {width}")
if row_overlap >= height and height != rows:
raise ValueError(f"{row_overlap = } must be less than block height {height}")
if col_overlap >= width and width != cols:
raise ValueError(f"{col_overlap = } must be less than block width {width}")
while row_off < rows:
while col_off < cols:
row_end = min(row_off + height, rows) # Dont yield something OOB
Expand Down
49 changes: 39 additions & 10 deletions src/dolphin/workflows/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def run_wrapped_phase_single(
# Note: dividing by len(stack) since cov is shape (rows, cols, nslc, nslc)
# so we need to load less to not overflow memory
stack_max_bytes = max_bytes / len(vrt)
overlaps = (yhalf, xhalf)
overlaps = (2 * yhalf, 2 * xhalf)
block_gen = vrt.iter_blocks(
overlaps=overlaps,
max_bytes=stack_max_bytes,
Expand Down Expand Up @@ -250,22 +250,46 @@ def run_wrapped_phase_single(

# Save each of the MLE estimates (ignoring the compressed SLCs)
assert len(cur_mle_stack[first_non_comp_idx:]) == len(output_slc_files)

# Get the location within the output file, shrinking down the slices
out_row_start = rows.start // ys
out_col_start = cols.start // xs
for img, f in zip(cur_mle_stack[first_non_comp_idx:], output_slc_files):
# Move the starts forward by half the overlap to trim the incomplete
# data sections for each output
out_row_start = (rows.start + yhalf) // ys
out_col_start = (cols.start + xhalf) // xs
# Also need to trim the data blocks themselves
trim_row_slice = slice(yhalf // ys, -yhalf // ys)
trim_col_slice = slice(xhalf // xs, -xhalf // xs)

for img, f in zip(
cur_mle_stack[first_non_comp_idx:, trim_row_slice, trim_col_slice],
output_slc_files,
):
writer.queue_write(img, f, out_row_start, out_col_start)

# Save the temporal coherence blocks
writer.queue_write(tcorr, tcorr_file, out_row_start, out_col_start)
writer.queue_write(
tcorr[trim_row_slice, trim_col_slice],
tcorr_file,
out_row_start,
out_col_start,
)

# Save avg coh index
if avg_coh is not None:
writer.queue_write(avg_coh, avg_coh_file, out_row_start, out_col_start)
writer.queue_write(avg_coh, avg_coh_file, out_row_start, out_col_start)
writer.queue_write(
avg_coh[trim_row_slice, trim_col_slice],
avg_coh_file,
out_row_start,
out_col_start,
)
# Save the SHP counts for each pixel (if not using Rect window)
shp_counts = np.sum(neighbor_arrays[rows, cols], axis=(-2, -1))
writer.queue_write(shp_counts, shp_counts_file, out_row_start, out_col_start)
shp_counts = np.sum(neighbor_arrays, axis=(-2, -1))
writer.queue_write(
shp_counts[trim_row_slice, trim_col_slice],
shp_counts_file,
out_row_start,
out_col_start,
)

# Compress the ministack using only the non-compressed SLCs
cur_comp_slc = compress(
Expand All @@ -274,7 +298,12 @@ def run_wrapped_phase_single(
)
# Save the compressed SLC block
# TODO: make a flag? We don't always need to save the compressed SLCs
writer.queue_write(cur_comp_slc, comp_slc_file, rows.start, cols.start)
writer.queue_write(
cur_comp_slc[yhalf:-yhalf, xhalf:-xhalf],
comp_slc_file,
rows.start + yhalf,
cols.start + xhalf,
)
# logger.debug(f"Saved compressed block SLC to {cur_comp_slc_file}")

# Block until all the writers for this ministack have finished
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def raster_100_by_200(tmp_path):
d = tmp_path / "raster_100_by_200"
d.mkdir()
filename = str(d / "test.bin")
ds = driver.Create(filename, xsize, ysize, 1, gdal.GDT_CFloat32) # noqa
ds = driver.Create(filename, xsize, ysize, 1, gdal.GDT_CFloat32)
data = np.random.randn(ysize, xsize) + 1j * np.random.randn(ysize, xsize)
ds.WriteArray(data)
ds.FlushCache()
ds = None
return filename
Expand All @@ -180,6 +182,8 @@ def tiled_raster_100_by_200(tmp_path):
ds = driver.Create(
str(filename), xsize, ysize, 1, gdal.GDT_CFloat32, options=creation_options
)
data = np.random.randn(ysize, xsize) + 1j * np.random.randn(ysize, xsize)
ds.WriteArray(data)
ds.FlushCache()
ds = None
return filename
Expand Down
58 changes: 55 additions & 3 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def test_iter_blocks_rowcols(tiled_raster_100_by_200):
for rs, cs in slices:
assert rs.stop - rs.start == 10
assert cs.stop - cs.start == 20
loader.notify_finished()

# Non-multiple block size
loader = io.EagerLoader(filename=tiled_raster_100_by_200, block_shape=(32, 32))
Expand All @@ -368,31 +369,82 @@ def test_iter_nodata(
bs = io.get_max_block_shape(tiled_raster_100_by_200, 1, max_bytes=max_bytes)
loader = io.EagerLoader(filename=tiled_raster_100_by_200, block_shape=bs)
blocks, slices = zip(*list(loader.iter_blocks()))
loader.notify_finished()

row_blocks = 100 // 32 + 1
col_blocks = 200 // 32 + 1
expected_num_blocks = row_blocks * col_blocks
assert len(blocks) == expected_num_blocks
assert blocks[0].shape == (32, 32)
loader.notify_finished()

# One nan should be fine, will get loaded
loader = io.EagerLoader(filename=raster_with_nan, block_shape=bs)
blocks, slices = zip(*list(loader.iter_blocks()))
assert len(blocks) == expected_num_blocks
loader.notify_finished()
assert len(blocks) == expected_num_blocks

# Now check entire block for a skipped block
loader = io.EagerLoader(filename=raster_with_nan_block, block_shape=bs)
blocks, slices = zip(*list(loader.iter_blocks()))
assert len(blocks) == expected_num_blocks - 1
loader.notify_finished()
assert len(blocks) == expected_num_blocks - 1

# Now check entire block for a skipped block
loader = io.EagerLoader(filename=raster_with_zero_block, block_shape=bs)
blocks, slices = zip(*list(loader.iter_blocks()))
loader.notify_finished()
assert len(blocks) == expected_num_blocks - 1


def test_iter_blocks_overlap(tiled_raster_100_by_200):
# Block size that is a multiple of the raster size
xhalf, yhalf = 4, 5
check_out = np.zeros((100, 200))
slices = list(
io._slice_iterator((100, 200), (30, 30), overlaps=(2 * yhalf, 2 * xhalf))
)

for rs, cs in slices:
trim_row = slice(rs.start + yhalf, rs.stop - yhalf)
trim_col = slice(cs.start + xhalf, cs.stop - xhalf)
check_out[trim_row, trim_col] += 1

# Everywhere in the middle should have been touched onces by the iteration
assert np.all(check_out[yhalf:-yhalf, xhalf:-xhalf] == 1)
# the outside is still 0
assert np.all(check_out[:yhalf] == 0)
assert np.all(check_out[-yhalf:] == 0)
assert np.all(check_out[:xhalf] == 0)
assert np.all(check_out[-xhalf:] == 0)

loader = io.EagerLoader(
filename=tiled_raster_100_by_200,
block_shape=(32, 32),
overlaps=(2 * yhalf, 2 * xhalf),
)
assert hasattr(loader, "_finished_event")
blocks, slices = zip(*list(loader.iter_blocks()))
loader.notify_finished()
check_out = np.zeros((100, 200), dtype="complex")
xs, ys = 1, 1 # 1-by-1 strides
for b, (rows, cols) in zip(blocks, slices):
# Use the logic in `single.py`
# TODO: figure out how to encapsulate so we test a function
out_row_start = (rows.start + yhalf) // ys
out_col_start = (cols.start + xhalf) // xs
# Also need to trim the data blocks themselves
trim_row_slice = slice(yhalf // ys, -yhalf // ys)
trim_col_slice = slice(xhalf // xs, -xhalf // xs)
b_trimmed = b[trim_row_slice, trim_col_slice]
check_out[
out_row_start : out_row_start + b_trimmed.shape[0],
out_col_start : out_col_start + b_trimmed.shape[1],
] += b_trimmed

expected = io.load_gdal(tiled_raster_100_by_200)
npt.assert_allclose(
check_out[yhalf:-yhalf, xhalf:-xhalf], expected[yhalf:-yhalf, xhalf:-xhalf]
)


@pytest.mark.skip
Expand Down
4 changes: 2 additions & 2 deletions tests/test_workflows_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def test_sequential_gtiff(tmp_path, slc_file_list, gpu_enabled):

# Input is only (5, 10) so we can't use a larger window.
@pytest.mark.parametrize(
"half_window", [{"x": 1, "y": 1}, {"x": 2, "y": 3}, {"x": 4, "y": 3}]
"half_window", [{"x": 1, "y": 1}, {"x": 2, "y": 2}, {"x": 4, "y": 2}]
)
@pytest.mark.parametrize(
"strides", [{"x": 1, "y": 1}, {"x": 1, "y": 2}, {"x": 2, "y": 3}, {"x": 4, "y": 2}]
"strides", [{"x": 1, "y": 1}, {"x": 1, "y": 2}, {"x": 2, "y": 2}, {"x": 4, "y": 2}]
)
def test_sequential_nc(tmp_path, slc_file_list_nc, half_window, strides):
"""Check various strides/windows/ministacks with a NetCDF input stack."""
Expand Down

0 comments on commit 10ee9f5

Please sign in to comment.