Skip to content

Commit

Permalink
Merge pull request #139 from allenai/favyen/add-classes-to-dataset-co…
Browse files Browse the repository at this point in the history
…nfig

Add class_names/class_property_name to dataset config file
  • Loading branch information
favyen2 authored Feb 19, 2025
2 parents 70de2f0 + 375650b commit 755094c
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 5 deletions.
36 changes: 31 additions & 5 deletions rslearn/config/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
format: dict[str, Any] | None = None,
zoom_offset: int = 0,
remap: dict[str, Any] | None = None,
class_names: list[list[str]] | None = None,
) -> None:
"""Creates a new BandSetConfig instance.
Expand All @@ -133,12 +134,22 @@ def __init__(
negative, store data at the window resolution multiplied by
2^(-zoom_offset) (lower resolution).
remap: config dict for Remapper to remap pixel values
class_names: optional list of names for the different possible values of
each band. The length of this list must equal the number of bands. For
example, [["forest", "desert"]] means that it is a single-band raster
where values can be 0 (forest) or 1 (desert).
"""
if class_names is not None and len(bands) != len(class_names):
raise ValueError(
f"the number of class lists ({len(class_names)}) does not match the number of bands ({len(bands)})"
)

self.config_dict = config_dict
self.bands = bands
self.dtype = dtype
self.zoom_offset = zoom_offset
self.remap = remap
self.class_names = class_names

if format is None:
self.format = {"name": "geotiff"}
Expand All @@ -161,7 +172,7 @@ def from_config(config: dict[str, Any]) -> "BandSetConfig":
dtype=DType(config["dtype"]),
bands=config["bands"],
)
for k in ["format", "zoom_offset", "remap"]:
for k in ["format", "zoom_offset", "remap", "class_names"]:
if k in config:
kwargs[k] = config[k]
return BandSetConfig(**kwargs) # type: ignore
Expand Down Expand Up @@ -447,6 +458,8 @@ def __init__(
zoom_offset: int = 0,
format: VectorFormatConfig = VectorFormatConfig("geojson"),
alias: str | None = None,
class_property_name: str | None = None,
class_names: list[str] | None = None,
):
"""Initialize a new VectorLayerConfig.
Expand All @@ -456,10 +469,17 @@ def __init__(
zoom_offset: zoom offset at which to store the vector data
format: the VectorFormatConfig, default storing as GeoJSON
alias: alias for this layer to use in the tile store
class_property_name: optional metadata field indicating that the GeoJSON
features contain a property that corresponds to a class label, and this
is the name of that property.
class_names: the list of classes that the class_property_name property
could be set to.
"""
super().__init__(layer_type, data_source, alias)
self.zoom_offset = zoom_offset
self.format = format
self.class_property_name = class_property_name
self.class_names = class_names

@staticmethod
def from_config(config: dict[str, Any]) -> "VectorLayerConfig":
Expand All @@ -471,12 +491,18 @@ def from_config(config: dict[str, Any]) -> "VectorLayerConfig":
kwargs: dict[str, Any] = {"layer_type": LayerType(config["type"])}
if "data_source" in config:
kwargs["data_source"] = DataSourceConfig.from_config(config["data_source"])
if "zoom_offset" in config:
kwargs["zoom_offset"] = config["zoom_offset"]
if "format" in config:
kwargs["format"] = VectorFormatConfig.from_config(config["format"])
if "alias" in config:
kwargs["alias"] = config["alias"]

simple_optionals = [
"zoom_offset",
"alias",
"class_property_name",
"class_names",
]
for k in simple_optionals:
if k in config:
kwargs[k] = config[k]
return VectorLayerConfig(**kwargs) # type: ignore

def get_final_projection_and_bounds(
Expand Down
Empty file added tests/unit/config/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions tests/unit/config/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Test the dataset configuration file.
Mostly just makes sure there aren't runtime errors with the parsing.
"""

from rslearn.config import RasterLayerConfig, VectorLayerConfig


class TestBandSetConfig:
"""Test BandSetConfig."""

def test_class_names_option(self) -> None:
"""Verify that config parsing works when class_names option is set."""
class_names = ["class0", "class1", "class2"]
layer_cfg_dict = {
"type": "raster",
"band_sets": [
{
"dtype": "uint8",
"bands": ["class"],
"class_names": [class_names],
}
],
}
layer_cfg = RasterLayerConfig.from_config(layer_cfg_dict)
assert len(layer_cfg.band_sets) == 1
band_set = layer_cfg.band_sets[0]
assert len(band_set.bands) == 1
assert band_set.class_names is not None
assert band_set.class_names[0] == class_names


class TestVectorLayerConfig:
"""Test VectorLayerConfig."""

def test_class_names_option(self) -> None:
"""Verify that config parsing works when property_name/class_names are set."""
property_name = "my_class_prop"
class_names = ["class0", "class1", "class2"]
layer_cfg_dict = {
"type": "vector",
"class_property_name": property_name,
"class_names": class_names,
}
layer_cfg = VectorLayerConfig.from_config(layer_cfg_dict)
assert layer_cfg.class_property_name == property_name
assert layer_cfg.class_names == class_names

0 comments on commit 755094c

Please sign in to comment.