Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add class_names/class_property_name to dataset config file #139

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions rslearn/config/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
format: dict[str, Any] | None = None,
zoom_offset: int = 0,
remap: dict[str, Any] | None = None,
class_names: list[list[str]] | None = None,
) -> None:
"""Creates a new BandSetConfig instance.

Expand All @@ -133,12 +134,22 @@ def __init__(
negative, store data at the window resolution multiplied by
2^(-zoom_offset) (lower resolution).
remap: config dict for Remapper to remap pixel values
class_names: optional list of names for the different possible values of
each band. The length of this list must equal the number of bands. For
example, [["forest", "desert"]] means that it is a single-band raster
where values can be 0 (forest) or 1 (desert).
"""
if class_names is not None and len(bands) != len(class_names):
raise ValueError(
f"the number of class lists ({len(class_names)}) does not match the number of bands ({len(bands)})"
)

self.config_dict = config_dict
self.bands = bands
self.dtype = dtype
self.zoom_offset = zoom_offset
self.remap = remap
self.class_names = class_names

if format is None:
self.format = {"name": "geotiff"}
Expand All @@ -161,7 +172,7 @@ def from_config(config: dict[str, Any]) -> "BandSetConfig":
dtype=DType(config["dtype"]),
bands=config["bands"],
)
for k in ["format", "zoom_offset", "remap"]:
for k in ["format", "zoom_offset", "remap", "class_names"]:
if k in config:
kwargs[k] = config[k]
return BandSetConfig(**kwargs) # type: ignore
Expand Down Expand Up @@ -447,6 +458,8 @@ def __init__(
zoom_offset: int = 0,
format: VectorFormatConfig = VectorFormatConfig("geojson"),
alias: str | None = None,
class_property_name: str | None = None,
class_names: list[str] | None = None,
):
"""Initialize a new VectorLayerConfig.

Expand All @@ -456,10 +469,17 @@ def __init__(
zoom_offset: zoom offset at which to store the vector data
format: the VectorFormatConfig, default storing as GeoJSON
alias: alias for this layer to use in the tile store
class_property_name: optional metadata field indicating that the GeoJSON
features contain a property that corresponds to a class label, and this
is the name of that property.
class_names: the list of classes that the class_property_name property
could be set to.
"""
super().__init__(layer_type, data_source, alias)
self.zoom_offset = zoom_offset
self.format = format
self.class_property_name = class_property_name
self.class_names = class_names

@staticmethod
def from_config(config: dict[str, Any]) -> "VectorLayerConfig":
Expand All @@ -471,12 +491,18 @@ def from_config(config: dict[str, Any]) -> "VectorLayerConfig":
kwargs: dict[str, Any] = {"layer_type": LayerType(config["type"])}
if "data_source" in config:
kwargs["data_source"] = DataSourceConfig.from_config(config["data_source"])
if "zoom_offset" in config:
kwargs["zoom_offset"] = config["zoom_offset"]
if "format" in config:
kwargs["format"] = VectorFormatConfig.from_config(config["format"])
if "alias" in config:
kwargs["alias"] = config["alias"]

simple_optionals = [
"zoom_offset",
"alias",
"class_property_name",
"class_names",
]
for k in simple_optionals:
if k in config:
kwargs[k] = config[k]
return VectorLayerConfig(**kwargs) # type: ignore

def get_final_projection_and_bounds(
Expand Down
Empty file added tests/unit/config/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions tests/unit/config/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Test the dataset configuration file.

Mostly just makes sure there aren't runtime errors with the parsing.
"""

from rslearn.config import RasterLayerConfig, VectorLayerConfig


class TestBandSetConfig:
"""Test BandSetConfig."""

def test_class_names_option(self) -> None:
"""Verify that config parsing works when class_names option is set."""
class_names = ["class0", "class1", "class2"]
layer_cfg_dict = {
"type": "raster",
"band_sets": [
{
"dtype": "uint8",
"bands": ["class"],
"class_names": [class_names],
}
],
}
layer_cfg = RasterLayerConfig.from_config(layer_cfg_dict)
assert len(layer_cfg.band_sets) == 1
band_set = layer_cfg.band_sets[0]
assert len(band_set.bands) == 1
assert band_set.class_names is not None
assert band_set.class_names[0] == class_names


class TestVectorLayerConfig:
"""Test VectorLayerConfig."""

def test_class_names_option(self) -> None:
"""Verify that config parsing works when property_name/class_names are set."""
property_name = "my_class_prop"
class_names = ["class0", "class1", "class2"]
layer_cfg_dict = {
"type": "vector",
"class_property_name": property_name,
"class_names": class_names,
}
layer_cfg = VectorLayerConfig.from_config(layer_cfg_dict)
assert layer_cfg.class_property_name == property_name
assert layer_cfg.class_names == class_names
Loading