Skip to content

Commit

Permalink
Merge pull request NRCan#554 from valhassan/553-feature-refactor-segm…
Browse files Browse the repository at this point in the history
…entationdataset-class
  • Loading branch information
valhassan authored Feb 22, 2024
2 parents 8b7cabc + 0ef8ed1 commit abcb9e9
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 80 deletions.
178 changes: 114 additions & 64 deletions dataset/create_dataset.py
Original file line number Diff line number Diff line change
@@ -1,119 +1,169 @@
import numpy as np
from pathlib import Path
from typing import Any, Dict, cast
import sys
from pathlib import Path
from typing import Any, Dict, List, cast

from rasterio.windows import from_bounds
import kornia as K
import numpy as np
import pandas as pd
import rasterio
import torch
from affine import Affine
from osgeo import ogr
# These two import statements prevent exception when using eval(metadata) in SegmentationDataset()'s __init__()
from rasterio.crs import CRS
from rasterio.io import DatasetReader
from rasterio.plot import reshape_as_image
from rasterio.vrt import WarpedVRT
from rasterio.windows import from_bounds
from torch.utils.data import Dataset
from torchgeo.datasets import GeoDataset
from rasterio.vrt import WarpedVRT
from torchgeo.datasets.utils import BoundingBox
import torch
from osgeo import ogr

from utils.logger import get_logger

# These two import statements prevent exception when using eval(metadata) in SegmentationDataset()'s __init__()
from rasterio.crs import CRS
from affine import Affine

# Set the logging file
logging = get_logger(__name__) # import logging


def append_to_dataset(dataset, sample):
"""
Append a new sample to a provided dataset. The dataset has to be expanded before we can add value to it.
:param dataset:
:param sample: data to append
:return: Index of the newly added sample.
"""
old_size = dataset.shape[0] # this function always appends samples on the first axis
dataset.resize(old_size + 1, axis=0)
dataset[old_size, ...] = sample
return old_size


class SegmentationDataset(Dataset):
"""Semantic segmentation dataset based on input csvs listing pairs of imagery and ground truth patches as .tif."""
"""Semantic segmentation dataset based on input csvs listing pairs of imagery and ground truth patches as .tif.
Args:
dataset_list_path (str): The path to the dataset list file.
num_bands (int): The number of bands in the imagery.
dontcare (Optional[int]): The value to be ignored in the label.
max_sample_count (Optional[int]): The maximum number of samples to load from the dataset.
radiom_transform (Optional[Callable]): The radiometric transform function to be applied to the samples.
geom_transform (Optional[Callable]): The geometric transform function to be applied to the samples.
totensor_transform (Optional[Callable]): The transform function to convert samples to tensors.
debug (bool): Whether to enable debug mode.
Attributes:
max_sample_count (int): The maximum number of samples to load from the dataset.
num_bands (int): The number of bands in the imagery.
radiom_transform (Optional[Callable]): The radiometric transform function to be applied to the samples.
geom_transform (Optional[Callable]): The geometric transform function to be applied to the samples.
totensor_transform (Optional[Callable]): The transform function to convert samples to tensors.
debug (bool): Whether debug mode is enabled.
dontcare (Optional[int]): The value to be ignored in the label.
list_path (str): The path to the dataset list file.
assets (List[Dict[str, str]]): The list of filepaths to images and labels.
"""

def __init__(self,
dataset_list_path,
dataset_type,
num_bands,
dontcare=None,
max_sample_count=None,
radiom_transform=None,
geom_transform=None,
totensor_transform=None,
debug=False):
# note: if 'max_sample_count' is None, then it will be read from the dataset at runtime
self.max_sample_count = max_sample_count
self.dataset_type = dataset_type
self.num_bands = num_bands
self.radiom_transform = radiom_transform
self.geom_transform = geom_transform
self.totensor_transform = totensor_transform
self.debug = debug
self.dontcare = dontcare
self.list_path = dataset_list_path

if not Path(self.list_path).is_file():
logging.error(f"Couldn't locate dataset list file: {self.list_path}.\n"
f"If purposely omitting test set, this error can be ignored")
self.max_sample_count = 0
else:
with open(self.list_path, 'r') as datafile:
datalist = datafile.readlines()
if self.max_sample_count is None:
self.max_sample_count = len(datalist)

self.assets = self._load_data()

def __len__(self):
return self.max_sample_count

return len(self.assets)
def __getitem__(self, index):
with open(self.list_path, 'r') as datafile:
datalist = datafile.readlines()
data_line = datalist[index]
with rasterio.open(data_line.split(';')[0], 'r') as sat_handle:
sat_img = reshape_as_image(sat_handle.read())
metadata = sat_handle.meta
with rasterio.open(data_line.split(';')[1].rstrip('\n'), 'r') as label_handle:
map_img = reshape_as_image(label_handle.read())
map_img = map_img[..., 0]

assert self.num_bands <= sat_img.shape[-1]

if isinstance(metadata, np.ndarray) and len(metadata) == 1:
metadata = metadata[0]
elif isinstance(metadata, bytes):
metadata = metadata.decode('UTF-8')
try:
metadata = eval(metadata)
except TypeError:
pass


sat_img, metadata = self._load_image(index)
map_img = self._load_label(index)

if isinstance(metadata, np.ndarray) and len(metadata) == 1:
metadata = metadata[0]
elif isinstance(metadata, bytes):
metadata = metadata.decode('UTF-8')
try:
metadata = eval(metadata)
except TypeError:
pass

sample = {"image": sat_img, "mask": map_img, "metadata": metadata, "list_path": self.list_path}

if self.radiom_transform: # radiometric transforms should always precede geometric ones
# radiometric transforms should always precede geometric ones
if self.radiom_transform:
sample = self.radiom_transform(sample)
if self.geom_transform: # rotation, geometric scaling, flip and crop. Will also put channels first and convert to torch tensor from numpy.
# rotation, geometric scaling, flip and crop.
# Will also put channels first and convert to torch tensor from numpy.
if self.geom_transform:
sample = self.geom_transform(sample)

sample = self.totensor_transform(sample)
if self.totensor_transform:
sample = self.totensor_transform(sample)

if self.debug:
# assert no new class values in map_img
initial_class_ids = set(np.unique(map_img))
final_class_ids = set(np.unique(sample["mask"].numpy()))
if self.dontcare is not None:
initial_class_ids.add(self.dontcare)
final_class_ids = set(np.unique(sample['mask'].numpy()))
if not final_class_ids.issubset(initial_class_ids):
logging.debug(f"WARNING: Class ids for label before and after augmentations don't match. "
f"Ignore if overwritting ignore_index in ToTensorTarget")
logging.warning(f"\nWARNING: Class values for label before and after augmentations don't match."
f"\nUnique values before: {initial_class_ids}"
f"\nUnique values after: {final_class_ids}"
f"\nIgnore if some augmentations have padded with dontcare value.")
sample['index'] = index

return sample

def _load_data(self) -> List[str]:
"""Load the filepaths to images and labels
Returns:
List[str]: a list of filepaths to train/test data
"""
df = pd.read_csv(self.list_path, sep=';', header=None, usecols=[i for i in range(2)])
assets = [{"image": x, "label": y} for x, y in zip(df[0], df[1])]

return assets

def _load_image(self, index: int):
""" Load image
Args:
index: poosition of image
Returns:
image array and metadata
"""
image_path = self.assets[index]["image"]
with rasterio.open(image_path, 'r') as image_handle:
image = reshape_as_image(image_handle.read())
metadata = image_handle.meta
assert self.num_bands <= image.shape[-1]

return image, metadata

def _load_label(self, index: int):
""" Load label
Args:
index: poosition of label
Returns:
label array and metadata
"""
label_path = self.assets[index]["label"]

with rasterio.open(label_path, 'r') as label_handle:
label = reshape_as_image(label_handle.read())
label = label[..., 0]

return label


class DRDataset(GeoDataset):
def __init__(self, dr_ds: DatasetReader) -> None:
Expand Down
11 changes: 6 additions & 5 deletions inference_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,19 +158,20 @@ def segmentation(param,
"""
sample = {"image": None, "mask": None, 'metadata': None}
start_seg = time.time()
print_log = True if logging.level == 20 else False # 20 is INFO
model.eval() # switch to evaluate mode

# initialize test time augmentation
transforms = tta.aliases.d4_transform()
# transforms = tta.aliases.d4_transform()
transforms = tta.Compose([])
tf_len = len(transforms)
h_padded, w_padded = input_image.height + chunk_size, input_image.width + chunk_size
patch_list = generate_patch_list(w_padded, h_padded, chunk_size, use_hanning)

fp = np.memmap(tp_mem, dtype='float16', mode='w+', shape=(tf_len, h_padded, w_padded, num_classes))
img_gen = gen_img_samples(src=input_image, patch_list=patch_list, chunk_size=chunk_size)
single_class_mode = False if num_classes > 1 else True
start_time = time.time()
for sub_image, h_idxs, w_idxs, hann_win in tqdm(
img_gen, position=0, leave=True, desc='Inferring on patches',
total=len(patch_list)
Expand Down Expand Up @@ -230,9 +231,9 @@ def segmentation(param,
arr1 = stretch_heatmap(heatmap_arr=arr1, out_max=heatmap_max)

pred_heatmap[row:row + chunk_size, col:col + chunk_size, :] = arr1.astype(heatmap_dtype)

end_seg = time.time() - start_seg
logging.info('Segmentation operation completed in {:.0f}m {:.0f}s'.format(end_seg // 60, end_seg % 60))
end_time = time.time() - start_time
# logging.info('Segmentation operation completed in {:.0f}m {:.0f}s'.format(end_seg // 60, end_seg % 60))
logging.info('Segmentation Completed in {:.0f}m {:.0f}s'.format(end_time // 60, end_time % 60))

if debug:
logging.debug(f'Bin count of final output: {np.unique(pred_heatmap, return_counts=True)}')
Expand Down
Binary file added tests/data/tiles/tiled_image_1.tif
Binary file not shown.
Binary file added tests/data/tiles/tiled_image_2.tif
Binary file not shown.
Binary file added tests/data/tiles/tiled_label_1.tif
Binary file not shown.
Binary file added tests/data/tiles/tiled_label_2.tif
Binary file not shown.
2 changes: 2 additions & 0 deletions tests/data/tiles/tiles.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tests/data/tiles/tiled_image_1.tif;tests/data/tiles/tiled_label_1.tif
tests/data/tiles/tiled_image_2.tif;tests/data/tiles/tiled_label_2.tif
55 changes: 45 additions & 10 deletions tests/dataset/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,53 @@
from typing import List
from tempfile import NamedTemporaryFile
from typing import List

import pytest
import rasterio
from rasterio.io import DatasetReader
from rasterio.crs import CRS
from torchgeo.datasets.utils import extract_archive
from torchgeo.datasets.utils import BoundingBox
from osgeo import ogr
from _pytest.fixtures import SubRequest
import torch

from dataset.create_dataset import DRDataset, GDLVectorDataset

from _pytest.fixtures import SubRequest
from osgeo import ogr
from rasterio.crs import CRS
from rasterio.io import DatasetReader
from torchgeo.datasets.utils import BoundingBox, extract_archive

from dataset.create_dataset import (DRDataset, GDLVectorDataset,
SegmentationDataset)


class TestSegmentationDataset:
@pytest.fixture
def data(self):
dataset_list_path = 'tests/data/tiles/tiles.csv'
num_bands = 3
dataset = SegmentationDataset(dataset_list_path, num_bands)
return dataset

def test_len(self, data):
expected_length = 2
assert len(data) == expected_length

def test_getitem(self, data):
sample = data[0]
assert "image" in sample
assert 'mask' in sample
assert 'metadata' in sample
assert 'list_path' in sample

def test_load_data(self, data):
# Test that _load_data returns the expected number of assets
assets = data._load_data()
assert len(assets) == len(data)

def test_load_image(self, data):
# Test that _load_image returns an image and metadata
image, metadata = data._load_image(0)
assert image is not None
assert metadata is not None

def test_load_label(self, data):
# Test that _load_label returns a label
label = data._load_label(0)
assert label is not None

class TestDRDataset:
@pytest.fixture(params=["tests/data/massachusetts_buildings_kaggle/22978945_15_uint8_clipped.tif",
Expand Down
2 changes: 1 addition & 1 deletion train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def create_dataloader(patches_folder: Path,
# TODO: should user point to the paths of these csvs directly?
dataset_file, _ = Tiler.make_dataset_file_name(experiment_name, min_annot_perc, subset, attr_vals)
dataset_filepath = patches_folder / dataset_file
datasets.append(dataset_constr(dataset_filepath, subset, num_bands,
datasets.append(dataset_constr(dataset_filepath, num_bands,
max_sample_count=num_patches[subset],
radiom_transform=aug.compose_transforms(params=cfg,
dataset=subset,
Expand Down

0 comments on commit abcb9e9

Please sign in to comment.