diff --git a/dataset/create_dataset.py b/dataset/create_dataset.py index 062141f5..3007de5f 100644 --- a/dataset/create_dataset.py +++ b/dataset/create_dataset.py @@ -1,119 +1,169 @@ -import numpy as np -from pathlib import Path -from typing import Any, Dict, cast import sys +from pathlib import Path +from typing import Any, Dict, List, cast -from rasterio.windows import from_bounds +import kornia as K +import numpy as np +import pandas as pd import rasterio +import torch +from affine import Affine +from osgeo import ogr +# These two import statements prevent exception when using eval(metadata) in SegmentationDataset()'s __init__() +from rasterio.crs import CRS from rasterio.io import DatasetReader from rasterio.plot import reshape_as_image +from rasterio.vrt import WarpedVRT +from rasterio.windows import from_bounds from torch.utils.data import Dataset from torchgeo.datasets import GeoDataset -from rasterio.vrt import WarpedVRT from torchgeo.datasets.utils import BoundingBox -import torch -from osgeo import ogr from utils.logger import get_logger -# These two import statements prevent exception when using eval(metadata) in SegmentationDataset()'s __init__() -from rasterio.crs import CRS -from affine import Affine - # Set the logging file logging = get_logger(__name__) # import logging -def append_to_dataset(dataset, sample): - """ - Append a new sample to a provided dataset. The dataset has to be expanded before we can add value to it. - :param dataset: - :param sample: data to append - :return: Index of the newly added sample. - """ - old_size = dataset.shape[0] # this function always appends samples on the first axis - dataset.resize(old_size + 1, axis=0) - dataset[old_size, ...] = sample - return old_size - - class SegmentationDataset(Dataset): - """Semantic segmentation dataset based on input csvs listing pairs of imagery and ground truth patches as .tif.""" + """Semantic segmentation dataset based on input csvs listing pairs of imagery and ground truth patches as .tif. + + Args: + dataset_list_path (str): The path to the dataset list file. + num_bands (int): The number of bands in the imagery. + dontcare (Optional[int]): The value to be ignored in the label. + max_sample_count (Optional[int]): The maximum number of samples to load from the dataset. + radiom_transform (Optional[Callable]): The radiometric transform function to be applied to the samples. + geom_transform (Optional[Callable]): The geometric transform function to be applied to the samples. + totensor_transform (Optional[Callable]): The transform function to convert samples to tensors. + debug (bool): Whether to enable debug mode. + + Attributes: + max_sample_count (int): The maximum number of samples to load from the dataset. + num_bands (int): The number of bands in the imagery. + radiom_transform (Optional[Callable]): The radiometric transform function to be applied to the samples. + geom_transform (Optional[Callable]): The geometric transform function to be applied to the samples. + totensor_transform (Optional[Callable]): The transform function to convert samples to tensors. + debug (bool): Whether debug mode is enabled. + dontcare (Optional[int]): The value to be ignored in the label. + list_path (str): The path to the dataset list file. + assets (List[Dict[str, str]]): The list of filepaths to images and labels. + + """ def __init__(self, dataset_list_path, - dataset_type, num_bands, + dontcare=None, max_sample_count=None, radiom_transform=None, geom_transform=None, totensor_transform=None, debug=False): - # note: if 'max_sample_count' is None, then it will be read from the dataset at runtime self.max_sample_count = max_sample_count - self.dataset_type = dataset_type self.num_bands = num_bands self.radiom_transform = radiom_transform self.geom_transform = geom_transform self.totensor_transform = totensor_transform self.debug = debug + self.dontcare = dontcare self.list_path = dataset_list_path + if not Path(self.list_path).is_file(): logging.error(f"Couldn't locate dataset list file: {self.list_path}.\n" f"If purposely omitting test set, this error can be ignored") - self.max_sample_count = 0 - else: - with open(self.list_path, 'r') as datafile: - datalist = datafile.readlines() - if self.max_sample_count is None: - self.max_sample_count = len(datalist) - + self.assets = self._load_data() + def __len__(self): - return self.max_sample_count - + return len(self.assets) + def __getitem__(self, index): - with open(self.list_path, 'r') as datafile: - datalist = datafile.readlines() - data_line = datalist[index] - with rasterio.open(data_line.split(';')[0], 'r') as sat_handle: - sat_img = reshape_as_image(sat_handle.read()) - metadata = sat_handle.meta - with rasterio.open(data_line.split(';')[1].rstrip('\n'), 'r') as label_handle: - map_img = reshape_as_image(label_handle.read()) - map_img = map_img[..., 0] - - assert self.num_bands <= sat_img.shape[-1] - - if isinstance(metadata, np.ndarray) and len(metadata) == 1: - metadata = metadata[0] - elif isinstance(metadata, bytes): - metadata = metadata.decode('UTF-8') - try: - metadata = eval(metadata) - except TypeError: - pass - + + sat_img, metadata = self._load_image(index) + map_img = self._load_label(index) + + if isinstance(metadata, np.ndarray) and len(metadata) == 1: + metadata = metadata[0] + elif isinstance(metadata, bytes): + metadata = metadata.decode('UTF-8') + try: + metadata = eval(metadata) + except TypeError: + pass + sample = {"image": sat_img, "mask": map_img, "metadata": metadata, "list_path": self.list_path} - - if self.radiom_transform: # radiometric transforms should always precede geometric ones + # radiometric transforms should always precede geometric ones + if self.radiom_transform: sample = self.radiom_transform(sample) - if self.geom_transform: # rotation, geometric scaling, flip and crop. Will also put channels first and convert to torch tensor from numpy. + # rotation, geometric scaling, flip and crop. + # Will also put channels first and convert to torch tensor from numpy. + if self.geom_transform: sample = self.geom_transform(sample) - - sample = self.totensor_transform(sample) + if self.totensor_transform: + sample = self.totensor_transform(sample) if self.debug: # assert no new class values in map_img initial_class_ids = set(np.unique(map_img)) - final_class_ids = set(np.unique(sample["mask"].numpy())) + if self.dontcare is not None: + initial_class_ids.add(self.dontcare) + final_class_ids = set(np.unique(sample['mask'].numpy())) if not final_class_ids.issubset(initial_class_ids): + logging.debug(f"WARNING: Class ids for label before and after augmentations don't match. " + f"Ignore if overwritting ignore_index in ToTensorTarget") logging.warning(f"\nWARNING: Class values for label before and after augmentations don't match." f"\nUnique values before: {initial_class_ids}" f"\nUnique values after: {final_class_ids}" f"\nIgnore if some augmentations have padded with dontcare value.") sample['index'] = index + return sample + + def _load_data(self) -> List[str]: + """Load the filepaths to images and labels + + Returns: + List[str]: a list of filepaths to train/test data + """ + df = pd.read_csv(self.list_path, sep=';', header=None, usecols=[i for i in range(2)]) + assets = [{"image": x, "label": y} for x, y in zip(df[0], df[1])] + + return assets + + def _load_image(self, index: int): + """ Load image + + Args: + index: poosition of image + Returns: + image array and metadata + """ + image_path = self.assets[index]["image"] + with rasterio.open(image_path, 'r') as image_handle: + image = reshape_as_image(image_handle.read()) + metadata = image_handle.meta + assert self.num_bands <= image.shape[-1] + + return image, metadata + + def _load_label(self, index: int): + """ Load label + + Args: + index: poosition of label + + Returns: + label array and metadata + """ + label_path = self.assets[index]["label"] + + with rasterio.open(label_path, 'r') as label_handle: + label = reshape_as_image(label_handle.read()) + label = label[..., 0] + + return label + class DRDataset(GeoDataset): def __init__(self, dr_ds: DatasetReader) -> None: diff --git a/inference_segmentation.py b/inference_segmentation.py index e7b3a746..9e75fd71 100644 --- a/inference_segmentation.py +++ b/inference_segmentation.py @@ -158,12 +158,12 @@ def segmentation(param, """ sample = {"image": None, "mask": None, 'metadata': None} - start_seg = time.time() print_log = True if logging.level == 20 else False # 20 is INFO model.eval() # switch to evaluate mode # initialize test time augmentation - transforms = tta.aliases.d4_transform() + # transforms = tta.aliases.d4_transform() + transforms = tta.Compose([]) tf_len = len(transforms) h_padded, w_padded = input_image.height + chunk_size, input_image.width + chunk_size patch_list = generate_patch_list(w_padded, h_padded, chunk_size, use_hanning) @@ -171,6 +171,7 @@ def segmentation(param, fp = np.memmap(tp_mem, dtype='float16', mode='w+', shape=(tf_len, h_padded, w_padded, num_classes)) img_gen = gen_img_samples(src=input_image, patch_list=patch_list, chunk_size=chunk_size) single_class_mode = False if num_classes > 1 else True + start_time = time.time() for sub_image, h_idxs, w_idxs, hann_win in tqdm( img_gen, position=0, leave=True, desc='Inferring on patches', total=len(patch_list) @@ -230,9 +231,9 @@ def segmentation(param, arr1 = stretch_heatmap(heatmap_arr=arr1, out_max=heatmap_max) pred_heatmap[row:row + chunk_size, col:col + chunk_size, :] = arr1.astype(heatmap_dtype) - - end_seg = time.time() - start_seg - logging.info('Segmentation operation completed in {:.0f}m {:.0f}s'.format(end_seg // 60, end_seg % 60)) + end_time = time.time() - start_time + # logging.info('Segmentation operation completed in {:.0f}m {:.0f}s'.format(end_seg // 60, end_seg % 60)) + logging.info('Segmentation Completed in {:.0f}m {:.0f}s'.format(end_time // 60, end_time % 60)) if debug: logging.debug(f'Bin count of final output: {np.unique(pred_heatmap, return_counts=True)}') diff --git a/tests/data/tiles/tiled_image_1.tif b/tests/data/tiles/tiled_image_1.tif new file mode 100644 index 00000000..18a1ee93 Binary files /dev/null and b/tests/data/tiles/tiled_image_1.tif differ diff --git a/tests/data/tiles/tiled_image_2.tif b/tests/data/tiles/tiled_image_2.tif new file mode 100644 index 00000000..a8d7c2d0 Binary files /dev/null and b/tests/data/tiles/tiled_image_2.tif differ diff --git a/tests/data/tiles/tiled_label_1.tif b/tests/data/tiles/tiled_label_1.tif new file mode 100644 index 00000000..ec99b983 Binary files /dev/null and b/tests/data/tiles/tiled_label_1.tif differ diff --git a/tests/data/tiles/tiled_label_2.tif b/tests/data/tiles/tiled_label_2.tif new file mode 100644 index 00000000..be081f08 Binary files /dev/null and b/tests/data/tiles/tiled_label_2.tif differ diff --git a/tests/data/tiles/tiles.csv b/tests/data/tiles/tiles.csv new file mode 100644 index 00000000..54cac1d5 --- /dev/null +++ b/tests/data/tiles/tiles.csv @@ -0,0 +1,2 @@ +tests/data/tiles/tiled_image_1.tif;tests/data/tiles/tiled_label_1.tif +tests/data/tiles/tiled_image_2.tif;tests/data/tiles/tiled_label_2.tif \ No newline at end of file diff --git a/tests/dataset/test_datasets.py b/tests/dataset/test_datasets.py index f6e2e8ee..011a710d 100644 --- a/tests/dataset/test_datasets.py +++ b/tests/dataset/test_datasets.py @@ -1,18 +1,53 @@ -from typing import List from tempfile import NamedTemporaryFile +from typing import List import pytest import rasterio -from rasterio.io import DatasetReader -from rasterio.crs import CRS -from torchgeo.datasets.utils import extract_archive -from torchgeo.datasets.utils import BoundingBox -from osgeo import ogr -from _pytest.fixtures import SubRequest import torch - -from dataset.create_dataset import DRDataset, GDLVectorDataset - +from _pytest.fixtures import SubRequest +from osgeo import ogr +from rasterio.crs import CRS +from rasterio.io import DatasetReader +from torchgeo.datasets.utils import BoundingBox, extract_archive + +from dataset.create_dataset import (DRDataset, GDLVectorDataset, + SegmentationDataset) + + +class TestSegmentationDataset: + @pytest.fixture + def data(self): + dataset_list_path = 'tests/data/tiles/tiles.csv' + num_bands = 3 + dataset = SegmentationDataset(dataset_list_path, num_bands) + return dataset + + def test_len(self, data): + expected_length = 2 + assert len(data) == expected_length + + def test_getitem(self, data): + sample = data[0] + assert "image" in sample + assert 'mask' in sample + assert 'metadata' in sample + assert 'list_path' in sample + + def test_load_data(self, data): + # Test that _load_data returns the expected number of assets + assets = data._load_data() + assert len(assets) == len(data) + + def test_load_image(self, data): + # Test that _load_image returns an image and metadata + image, metadata = data._load_image(0) + assert image is not None + assert metadata is not None + + def test_load_label(self, data): + # Test that _load_label returns a label + label = data._load_label(0) + assert label is not None class TestDRDataset: @pytest.fixture(params=["tests/data/massachusetts_buildings_kaggle/22978945_15_uint8_clipped.tif", diff --git a/train_segmentation.py b/train_segmentation.py index 335bbc09..664abc82 100644 --- a/train_segmentation.py +++ b/train_segmentation.py @@ -100,7 +100,7 @@ def create_dataloader(patches_folder: Path, # TODO: should user point to the paths of these csvs directly? dataset_file, _ = Tiler.make_dataset_file_name(experiment_name, min_annot_perc, subset, attr_vals) dataset_filepath = patches_folder / dataset_file - datasets.append(dataset_constr(dataset_filepath, subset, num_bands, + datasets.append(dataset_constr(dataset_filepath, num_bands, max_sample_count=num_patches[subset], radiom_transform=aug.compose_transforms(params=cfg, dataset=subset,