Skip to content

Commit

Permalink
Merge pull request #170 from Pale-Blue-Dot-97/mcgonagall
Browse files Browse the repository at this point in the history
McGonagall Phase 1
  • Loading branch information
Pale-Blue-Dot-97 authored May 11, 2023
2 parents 4e07006 + 10c6dcb commit e48d1ca
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 40 deletions.
17 changes: 12 additions & 5 deletions inbuilt_cfgs/example_GeoCLR_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,18 @@ dataset_params:
# Training Dataset
train:
image:
module: minerva.datasets
name: TstImgDataset
root: test_images
params:
res: 1.0
image1:
module: minerva.datasets
name: TstImgDataset
root: test_images
params:
res: 1.0
image2:
module: minerva.datasets
name: TstImgDataset
root: test_images
params:
res: 1.0

# Validation Dataset
val:
Expand Down
119 changes: 89 additions & 30 deletions minerva/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
__copyright__ = "Copyright (C) 2023 Harry Baker"
__all__ = [
"PairedDataset",
"PairedUnionDataset",
"construct_dataloader",
"get_collator",
"get_manifest",
Expand Down Expand Up @@ -166,6 +167,51 @@ def __getitem__( # type: ignore[override]
queries[1]
)

def __and__(self, other: "PairedDataset") -> IntersectionDataset:
"""Take the intersection of two :class:`PairedDataset`.
Args:
other (PairedDataset): Another dataset.
Returns:
IntersectionDataset: A single dataset.
Raises:
ValueError: If other is not a :class:`PairedDataset`
.. versionadded:: 0.24
"""
print("paired intersect")
if not isinstance(other, PairedDataset):
raise ValueError(
f"Intersecting a dataset of {type(other)} and a PairedDataset is not supported!"
)

return IntersectionDataset(
self, other, collate_fn=utils.pair_collate(concat_samples)
)

def __or__(self, other: "PairedDataset") -> "PairedUnionDataset":
"""Take the union of two :class:`PairedDataset`.
Args:
other (PairedDataset): Another dataset.
Returns:
PairedUnionDataset: A single dataset.
Raises:
ValueError: If ``other`` is not a :class:`PairedDataset`
.. versionadded:: 0.24
"""
if not isinstance(other, PairedDataset):
raise ValueError(
f"Unionising a dataset of {type(other)} and a PairedDataset is not supported!"
)

return PairedUnionDataset(self, other)

def __getattr__(self, item):
if item in self.dataset.__dict__:
return getattr(self.dataset, item) # pragma: no cover
Expand Down Expand Up @@ -242,6 +288,43 @@ def plot_random_sample(
return self.plot(sample, show_titles, suptitle)


class PairedUnionDataset(UnionDataset):
"""Adapted form of :class:`~torchgeo.datasets.UnionDataset` to handle paired samples.
..warning::
Do not use with :class:`PairedDataset` as this will essentially account for paired sampling twice
and cause a :class:`TypeError`.
"""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

new_datasets = []
for _dataset in self.datasets:
if isinstance(_dataset, PairedDataset):
new_datasets.append(_dataset.dataset)

self.datasets = new_datasets

def __getitem__(
self, query: Tuple[BoundingBox, BoundingBox]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Retrieve image and metadata indexed by query.
Uses :meth:`torchgeo.datasets.UnionDataset.__getitem__` to send each query of the pair off to get a
sample for each and returns as a tuple.
Args:
query (tuple[~torchgeo.datasets.utils.BoundingBox, ~torchgeo.datasets.utils.BoundingBox]): Coordinates
to index in the form (minx, maxx, miny, maxy, mint, maxt).
Returns:
tuple[dict[str, ~typing.Any], dict[str, ~typing.Any]]: Sample of data/labels and metadata at that index.
"""
return super().__getitem__(query[0]), super().__getitem__(query[1])


# =====================================================================================================================
# METHODS
# =====================================================================================================================
Expand Down Expand Up @@ -287,64 +370,40 @@ def stack_sample_pairs(
return stack_samples(a), stack_samples(b)


def intersect_datasets(
datasets: Sequence[GeoDataset], sample_pairs: bool = False
) -> IntersectionDataset:
def intersect_datasets(datasets: Sequence[GeoDataset]) -> IntersectionDataset:
r"""
Intersects a list of :class:`~torchgeo.datasets.GeoDataset` together to return a single dataset object.
Args:
datasets (list[~torchgeo.datasets.GeoDataset]): List of datasets to intersect together.
Should have some geospatial overlap.
sample_pairs (bool): Optional; True if paired sampling. This will wrap the collation function
for paired samples.
Returns:
~torchgeo.datasets.IntersectionDataset: Final dataset object representing an intersection
of all the parsed datasets.
"""

def intersect_pair_datasets(a: GeoDataset, b: GeoDataset) -> IntersectionDataset:
if sample_pairs:
return IntersectionDataset(
a, b, collate_fn=utils.pair_collate(concat_samples)
)
else:
return a & b

master_dataset: Union[GeoDataset, IntersectionDataset] = datasets[0]

for i in range(len(datasets) - 1):
master_dataset = intersect_pair_datasets(master_dataset, datasets[i + 1])
master_dataset = master_dataset & datasets[i + 1]

assert isinstance(master_dataset, IntersectionDataset)
return master_dataset


def unionise_datasets(
datasets: Sequence[GeoDataset], sample_pairs: bool = False
) -> UnionDataset:
def unionise_datasets(datasets: Sequence[GeoDataset]) -> UnionDataset:
"""Unionises a list of :class:`~torchgeo.datasets.GeoDataset` together to return a single dataset object.
Args:
datasets (list[~torchgeo.datasets.GeoDataset]): List of datasets to unionise together.
sample_pairs (bool): Optional; Activates paired sampling.
This will wrap the collation function for paired samples.
Returns:
~torchgeo.datasets.UnionDataset: Final dataset object representing an union of all the parsed datasets.
"""

def unionise_pair_datasets(a: GeoDataset, b: GeoDataset) -> UnionDataset:
if sample_pairs:
return UnionDataset(a, b, collate_fn=utils.pair_collate(concat_samples))
else:
return a | b

master_dataset: Union[GeoDataset, UnionDataset] = datasets[0]

for i in range(len(datasets) - 1):
master_dataset = unionise_pair_datasets(master_dataset, datasets[i + 1])
master_dataset = master_dataset | datasets[i + 1]

assert isinstance(master_dataset, UnionDataset)
return master_dataset
Expand Down Expand Up @@ -486,7 +545,7 @@ def create_subdataset(
)

if multi_datasets_exist:
sub_datasets.append(unionise_datasets(type_subdatasets, sample_pairs))
sub_datasets.append(unionise_datasets(type_subdatasets))
else:
sub_datasets.append(
create_subdataset(
Expand All @@ -499,7 +558,7 @@ def create_subdataset(
# Intersect sub-datasets to form single dataset if more than one sub-dataset exists. Else, just set that to dataset.
dataset = sub_datasets[0]
if len(sub_datasets) > 1:
dataset = intersect_datasets(sub_datasets, sample_pairs=sample_pairs)
dataset = intersect_datasets(sub_datasets)

return dataset, sub_datasets

Expand Down
48 changes: 43 additions & 5 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@
from torchgeo.samplers.utils import get_random_bounding_box

from minerva import datasets as mdt
from minerva.datasets import PairedDataset, TstImgDataset, TstMaskDataset
from minerva.datasets import (
PairedDataset,
PairedUnionDataset,
TstImgDataset,
TstMaskDataset,
)
from minerva.utils.utils import CONFIG


Expand Down Expand Up @@ -80,6 +85,13 @@ def test_tinydataset(img_root: Path, lc_root: Path) -> None:

def test_paired_datasets(img_root: Path) -> None:
dataset = PairedDataset(TstImgDataset, img_root)
dataset2 = TstImgDataset(img_root)

with pytest.raises(
ValueError,
match=f"Intersecting a dataset of {type(dataset2)} and a PairedDataset is not supported!",
):
_ = dataset & dataset2

bounds = BoundingBox(411248.0, 412484.0, 4058102.0, 4059399.0, 0, 1e12)
query_1 = get_random_bounding_box(bounds, (32, 32), 10.0)
Expand All @@ -105,6 +117,35 @@ def test_paired_datasets(img_root: Path) -> None:
)


def test_paired_union_datasets(img_root: Path) -> None:
def dataset_test(_dataset) -> None:
query_1 = get_random_bounding_box(bounds, (32, 32), 10.0)
query_2 = get_random_bounding_box(bounds, (32, 32), 10.0)
sample_1, sample_2 = _dataset[(query_1, query_2)]

assert type(sample_1) == dict
assert type(sample_2) == dict

bounds = BoundingBox(411248.0, 412484.0, 4058102.0, 4059399.0, 0, 1e12)

dataset1 = TstImgDataset(img_root)
dataset2 = TstImgDataset(img_root)
dataset3 = PairedDataset(TstImgDataset, img_root)
dataset4 = PairedDataset(TstImgDataset, img_root)

with pytest.raises(
ValueError,
match=f"Unionising a dataset of {type(dataset2)} and a PairedDataset is not supported!",
):
_ = dataset3 | dataset2

union_dataset1 = PairedUnionDataset(dataset1, dataset2)
union_dataset2 = dataset3 | dataset4

for dataset in (union_dataset1, union_dataset2):
dataset_test(dataset)


def test_get_collator() -> None:
collator_params_1 = {"module": "torchgeo.datasets.utils", "name": "stack_samples"}
collator_params_2 = {"name": "stack_sample_pairs"}
Expand Down Expand Up @@ -154,10 +195,7 @@ def test_intersect_datasets(img_root: Path, lc_root: Path) -> None:
imagery = PairedDataset(TstImgDataset, img_root)
labels = PairedDataset(TstMaskDataset, lc_root)

assert isinstance(
mdt.intersect_datasets([imagery, labels], sample_pairs=True),
IntersectionDataset,
)
assert isinstance(mdt.intersect_datasets([imagery, labels]), IntersectionDataset)


def test_make_dataset() -> None:
Expand Down

0 comments on commit e48d1ca

Please sign in to comment.