diff --git a/inbuilt_cfgs/example_GeoCLR_config.yml b/inbuilt_cfgs/example_GeoCLR_config.yml index 3f4c6c257..8f38d2a89 100644 --- a/inbuilt_cfgs/example_GeoCLR_config.yml +++ b/inbuilt_cfgs/example_GeoCLR_config.yml @@ -93,11 +93,18 @@ dataset_params: # Training Dataset train: image: - module: minerva.datasets - name: TstImgDataset - root: test_images - params: - res: 1.0 + image1: + module: minerva.datasets + name: TstImgDataset + root: test_images + params: + res: 1.0 + image2: + module: minerva.datasets + name: TstImgDataset + root: test_images + params: + res: 1.0 # Validation Dataset val: diff --git a/minerva/datasets.py b/minerva/datasets.py index 8645bfbd3..dac774b49 100644 --- a/minerva/datasets.py +++ b/minerva/datasets.py @@ -38,6 +38,7 @@ __copyright__ = "Copyright (C) 2023 Harry Baker" __all__ = [ "PairedDataset", + "PairedUnionDataset", "construct_dataloader", "get_collator", "get_manifest", @@ -166,6 +167,51 @@ def __getitem__( # type: ignore[override] queries[1] ) + def __and__(self, other: "PairedDataset") -> IntersectionDataset: + """Take the intersection of two :class:`PairedDataset`. + + Args: + other (PairedDataset): Another dataset. + + Returns: + IntersectionDataset: A single dataset. + + Raises: + ValueError: If other is not a :class:`PairedDataset` + + .. versionadded:: 0.24 + """ + print("paired intersect") + if not isinstance(other, PairedDataset): + raise ValueError( + f"Intersecting a dataset of {type(other)} and a PairedDataset is not supported!" + ) + + return IntersectionDataset( + self, other, collate_fn=utils.pair_collate(concat_samples) + ) + + def __or__(self, other: "PairedDataset") -> "PairedUnionDataset": + """Take the union of two :class:`PairedDataset`. + + Args: + other (PairedDataset): Another dataset. + + Returns: + PairedUnionDataset: A single dataset. + + Raises: + ValueError: If ``other`` is not a :class:`PairedDataset` + + .. versionadded:: 0.24 + """ + if not isinstance(other, PairedDataset): + raise ValueError( + f"Unionising a dataset of {type(other)} and a PairedDataset is not supported!" + ) + + return PairedUnionDataset(self, other) + def __getattr__(self, item): if item in self.dataset.__dict__: return getattr(self.dataset, item) # pragma: no cover @@ -242,6 +288,43 @@ def plot_random_sample( return self.plot(sample, show_titles, suptitle) +class PairedUnionDataset(UnionDataset): + """Adapted form of :class:`~torchgeo.datasets.UnionDataset` to handle paired samples. + + ..warning:: + + Do not use with :class:`PairedDataset` as this will essentially account for paired sampling twice + and cause a :class:`TypeError`. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + new_datasets = [] + for _dataset in self.datasets: + if isinstance(_dataset, PairedDataset): + new_datasets.append(_dataset.dataset) + + self.datasets = new_datasets + + def __getitem__( + self, query: Tuple[BoundingBox, BoundingBox] + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Retrieve image and metadata indexed by query. + + Uses :meth:`torchgeo.datasets.UnionDataset.__getitem__` to send each query of the pair off to get a + sample for each and returns as a tuple. + + Args: + query (tuple[~torchgeo.datasets.utils.BoundingBox, ~torchgeo.datasets.utils.BoundingBox]): Coordinates + to index in the form (minx, maxx, miny, maxy, mint, maxt). + + Returns: + tuple[dict[str, ~typing.Any], dict[str, ~typing.Any]]: Sample of data/labels and metadata at that index. + """ + return super().__getitem__(query[0]), super().__getitem__(query[1]) + + # ===================================================================================================================== # METHODS # ===================================================================================================================== @@ -287,64 +370,40 @@ def stack_sample_pairs( return stack_samples(a), stack_samples(b) -def intersect_datasets( - datasets: Sequence[GeoDataset], sample_pairs: bool = False -) -> IntersectionDataset: +def intersect_datasets(datasets: Sequence[GeoDataset]) -> IntersectionDataset: r""" Intersects a list of :class:`~torchgeo.datasets.GeoDataset` together to return a single dataset object. Args: datasets (list[~torchgeo.datasets.GeoDataset]): List of datasets to intersect together. Should have some geospatial overlap. - sample_pairs (bool): Optional; True if paired sampling. This will wrap the collation function - for paired samples. Returns: ~torchgeo.datasets.IntersectionDataset: Final dataset object representing an intersection of all the parsed datasets. """ - - def intersect_pair_datasets(a: GeoDataset, b: GeoDataset) -> IntersectionDataset: - if sample_pairs: - return IntersectionDataset( - a, b, collate_fn=utils.pair_collate(concat_samples) - ) - else: - return a & b - master_dataset: Union[GeoDataset, IntersectionDataset] = datasets[0] for i in range(len(datasets) - 1): - master_dataset = intersect_pair_datasets(master_dataset, datasets[i + 1]) + master_dataset = master_dataset & datasets[i + 1] assert isinstance(master_dataset, IntersectionDataset) return master_dataset -def unionise_datasets( - datasets: Sequence[GeoDataset], sample_pairs: bool = False -) -> UnionDataset: +def unionise_datasets(datasets: Sequence[GeoDataset]) -> UnionDataset: """Unionises a list of :class:`~torchgeo.datasets.GeoDataset` together to return a single dataset object. Args: datasets (list[~torchgeo.datasets.GeoDataset]): List of datasets to unionise together. - sample_pairs (bool): Optional; Activates paired sampling. - This will wrap the collation function for paired samples. Returns: ~torchgeo.datasets.UnionDataset: Final dataset object representing an union of all the parsed datasets. """ - - def unionise_pair_datasets(a: GeoDataset, b: GeoDataset) -> UnionDataset: - if sample_pairs: - return UnionDataset(a, b, collate_fn=utils.pair_collate(concat_samples)) - else: - return a | b - master_dataset: Union[GeoDataset, UnionDataset] = datasets[0] for i in range(len(datasets) - 1): - master_dataset = unionise_pair_datasets(master_dataset, datasets[i + 1]) + master_dataset = master_dataset | datasets[i + 1] assert isinstance(master_dataset, UnionDataset) return master_dataset @@ -486,7 +545,7 @@ def create_subdataset( ) if multi_datasets_exist: - sub_datasets.append(unionise_datasets(type_subdatasets, sample_pairs)) + sub_datasets.append(unionise_datasets(type_subdatasets)) else: sub_datasets.append( create_subdataset( @@ -499,7 +558,7 @@ def create_subdataset( # Intersect sub-datasets to form single dataset if more than one sub-dataset exists. Else, just set that to dataset. dataset = sub_datasets[0] if len(sub_datasets) > 1: - dataset = intersect_datasets(sub_datasets, sample_pairs=sample_pairs) + dataset = intersect_datasets(sub_datasets) return dataset, sub_datasets diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 510b84d47..266664ef6 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -47,7 +47,12 @@ from torchgeo.samplers.utils import get_random_bounding_box from minerva import datasets as mdt -from minerva.datasets import PairedDataset, TstImgDataset, TstMaskDataset +from minerva.datasets import ( + PairedDataset, + PairedUnionDataset, + TstImgDataset, + TstMaskDataset, +) from minerva.utils.utils import CONFIG @@ -80,6 +85,13 @@ def test_tinydataset(img_root: Path, lc_root: Path) -> None: def test_paired_datasets(img_root: Path) -> None: dataset = PairedDataset(TstImgDataset, img_root) + dataset2 = TstImgDataset(img_root) + + with pytest.raises( + ValueError, + match=f"Intersecting a dataset of {type(dataset2)} and a PairedDataset is not supported!", + ): + _ = dataset & dataset2 bounds = BoundingBox(411248.0, 412484.0, 4058102.0, 4059399.0, 0, 1e12) query_1 = get_random_bounding_box(bounds, (32, 32), 10.0) @@ -105,6 +117,35 @@ def test_paired_datasets(img_root: Path) -> None: ) +def test_paired_union_datasets(img_root: Path) -> None: + def dataset_test(_dataset) -> None: + query_1 = get_random_bounding_box(bounds, (32, 32), 10.0) + query_2 = get_random_bounding_box(bounds, (32, 32), 10.0) + sample_1, sample_2 = _dataset[(query_1, query_2)] + + assert type(sample_1) == dict + assert type(sample_2) == dict + + bounds = BoundingBox(411248.0, 412484.0, 4058102.0, 4059399.0, 0, 1e12) + + dataset1 = TstImgDataset(img_root) + dataset2 = TstImgDataset(img_root) + dataset3 = PairedDataset(TstImgDataset, img_root) + dataset4 = PairedDataset(TstImgDataset, img_root) + + with pytest.raises( + ValueError, + match=f"Unionising a dataset of {type(dataset2)} and a PairedDataset is not supported!", + ): + _ = dataset3 | dataset2 + + union_dataset1 = PairedUnionDataset(dataset1, dataset2) + union_dataset2 = dataset3 | dataset4 + + for dataset in (union_dataset1, union_dataset2): + dataset_test(dataset) + + def test_get_collator() -> None: collator_params_1 = {"module": "torchgeo.datasets.utils", "name": "stack_samples"} collator_params_2 = {"name": "stack_sample_pairs"} @@ -154,10 +195,7 @@ def test_intersect_datasets(img_root: Path, lc_root: Path) -> None: imagery = PairedDataset(TstImgDataset, img_root) labels = PairedDataset(TstMaskDataset, lc_root) - assert isinstance( - mdt.intersect_datasets([imagery, labels], sample_pairs=True), - IntersectionDataset, - ) + assert isinstance(mdt.intersect_datasets([imagery, labels]), IntersectionDataset) def test_make_dataset() -> None: