Skip to content

Commit

Permalink
Implement deterministic GeoDataset (#1908)
Browse files Browse the repository at this point in the history
  • Loading branch information
DimitrisMantas authored and isaaccorley committed Mar 2, 2024
1 parent 342f662 commit eed722d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
15 changes: 15 additions & 0 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,21 @@ def test_files_property_for_virtual_files(self) -> None:
]
assert len(CustomGeoDataset(paths=paths).files) == len(paths)

def test_files_property_ordered(self) -> None:
"""Ensure that the list of files is ordered."""
paths = ["file://file3.tif", "file://file1.tif", "file://file2.tif"]
assert CustomGeoDataset(paths=paths).files == sorted(paths)

def test_files_property_deterministic(self) -> None:
"""Ensure that the list of files is consistent regardless of their original
order.
"""
paths1 = ["file://file3.tif", "file://file1.tif", "file://file2.tif"]
paths2 = ["file://file2.tif", "file://file3.tif", "file://file1.tif"]
assert (
CustomGeoDataset(paths=paths1).files == CustomGeoDataset(paths=paths2).files
)


class TestRasterDataset:
@pytest.fixture(params=zip([["R", "G", "B"], None], [True, False]))
Expand Down
5 changes: 3 additions & 2 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def res(self, new_res: float) -> None:
self._res = new_res

@property
def files(self) -> set[str]:
def files(self) -> list[str]:
"""A list of all files in the dataset.
Returns:
Expand Down Expand Up @@ -314,7 +314,8 @@ def files(self) -> set[str]:
UserWarning,
)

return files
# Sort the output to enforce deterministic behavior.
return sorted(files)


class RasterDataset(GeoDataset):
Expand Down

0 comments on commit eed722d

Please sign in to comment.