Skip to content

Commit

Permalink
Fix error in LocalFiles dataset when providing item spec for raster.
Browse files Browse the repository at this point in the history
Resolves #128
  • Loading branch information
favyen2 committed Jan 31, 2025
1 parent cb7698b commit fcd109e
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 5 deletions.
2 changes: 1 addition & 1 deletion rslearn/data_sources/local_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def list_items(self, config: LayerConfig, src_dir: UPath) -> list[RasterItem]:
item_name = spec.fnames[0].name.split(".")[0]

logger.debug(
"RasterImporter.list_items: got bounds of %s: %s", path, geometry
"RasterImporter.list_items: got bounds of %s: %s", item_name, geometry
)
items.append(RasterItem(item_name, geometry, spec))

Expand Down
19 changes: 15 additions & 4 deletions rslearn/utils/raster_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def encode_raster(
projection: Projection,
bounds: PixelBounds,
array: npt.NDArray[Any],
fname: str | None = None,
) -> None:
"""Encodes raster data.
Expand All @@ -357,7 +358,11 @@ def encode_raster(
projection: the projection of the raster data
bounds: the bounds of the raster data in the projection
array: the raster data
fname: override the filename to save as
"""
if fname is None:
fname = self.fname

crs = projection.crs
transform = affine.Affine(
projection.x_resolution,
Expand Down Expand Up @@ -393,21 +398,27 @@ def encode_raster(
profile.update(self.geotiff_options)

path.mkdir(parents=True, exist_ok=True)
logger.info(f"Writing geotiff to {path / self.fname}")
with open_rasterio_upath_writer(path / self.fname, **profile) as dst:
logger.info(f"Writing geotiff to {path / fname}")
with open_rasterio_upath_writer(path / fname, **profile) as dst:
dst.write(array)

def decode_raster(self, path: UPath, bounds: PixelBounds) -> npt.NDArray[Any]:
def decode_raster(
self, path: UPath, bounds: PixelBounds, fname: str | None = None
) -> npt.NDArray[Any]:
"""Decodes raster data.
Args:
path: the directory to read from
bounds: the bounds of the raster to read
fname: override the filename to read from
Returns:
the raster data, or None if no image content is found
"""
with open_rasterio_upath_reader(path / self.fname) as src:
if fname is None:
fname = self.fname

with open_rasterio_upath_reader(path / fname) as src:
transform = src.transform
x_resolution = transform.a
y_resolution = transform.e
Expand Down
79 changes: 79 additions & 0 deletions tests/integration/data_sources/test_local_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pathlib

import numpy as np
import pytest
import shapely
from rasterio.crs import CRS
Expand All @@ -18,6 +19,7 @@
from rslearn.utils.feature import Feature
from rslearn.utils.geometry import Projection, STGeometry
from rslearn.utils.get_utm_ups_crs import get_utm_ups_projection
from rslearn.utils.raster_format import GeotiffRasterFormat
from rslearn.utils.vector_format import (
GeojsonCoordinateMode,
GeojsonVectorFormat,
Expand Down Expand Up @@ -132,6 +134,83 @@ def test_large_dataset(self, tmp_path: pathlib.Path) -> None:
assert features[0].properties is not None
assert features[0].properties["check"]

def test_raster_dataset_with_item_spec(self, tmp_path: pathlib.Path) -> None:
"""Test LocalFiles with directly provided item specs."""
ds_path = UPath(tmp_path)

# Create two source GeoTIFFs to read from.
source_dir_name = "source_data"
src_path = UPath(tmp_path / source_dir_name)
projection = Projection(CRS.from_epsg(3857), 1, -1)
bounds = (0, 0, 8, 8)
b1 = np.zeros((1, 8, 8), dtype=np.uint8)
b2 = np.ones((1, 8, 8), dtype=np.uint8)
GeotiffRasterFormat().encode_raster(
src_path, projection, bounds, b1, fname="b1.tif"
)
GeotiffRasterFormat().encode_raster(
src_path, projection, bounds, b2, fname="b2.tif"
)

# Make an rslearn dataset that uses LocalFiles to ingest the source data.
# We need to pass item specs because we have bands in two separate files.
layer_name = "local_file"
dataset_config = {
"layers": {
layer_name: {
"type": "raster",
"band_sets": [
{
"bands": ["b1", "b2"],
"dtype": "uint8",
}
],
"data_source": {
"name": "rslearn.data_sources.local_files.LocalFiles",
"src_dir": source_dir_name,
"item_specs": [
{
"fnames": ["b1.tif", "b2.tif"],
"bands": [["b1"], ["b2"]],
}
],
},
},
},
}
with (ds_path / "config.json").open("w") as f:
json.dump(dataset_config, f)

# Create a window and materialize it.
Window(
path=Window.get_window_root(ds_path, "default", "default"),
group="default",
name="default",
projection=projection,
bounds=bounds,
time_range=None,
).save()
dataset = Dataset(ds_path)
windows = dataset.load_windows()
prepare_dataset_windows(dataset, windows)
ingest_dataset_windows(dataset, windows)
materialize_dataset_windows(dataset, windows)

# Verify that b1 is 0s and b2 is 1s.
window = windows[0]
raster_dir = window.get_raster_dir(layer_name, ["b1", "b2"])
materialized_image = GeotiffRasterFormat().decode_raster(
raster_dir, window.bounds
)
assert (
materialized_image[0, :, :].min() == 0
and materialized_image[0, :, :].max() == 0
)
assert (
materialized_image[1, :, :].min() == 1
and materialized_image[1, :, :].max() == 1
)


class TestCoordinateModes:
"""Test LocalFiles again, focusing on using different coordinate modes.
Expand Down

0 comments on commit fcd109e

Please sign in to comment.