Skip to content

Commit

Permalink
Merge pull request #37 from sentinel-hub/develop
Browse files Browse the repository at this point in the history
Add resunet-a architecture used for field delineation
  • Loading branch information
devisperessutti authored Oct 27, 2020
2 parents b47a3f4 + 7d8ea35 commit d7203fb
Show file tree
Hide file tree
Showing 5 changed files with 638 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Project also contains other folders:
Segmentation models for land cover semantic segmentation:
* **Fully-Convolutional-Network (FCN, a.k.a. U-net)**, vanilla implementation of method described in this [paper](https://arxiv.org/abs/1505.04597). This network expects 2D MSI images as inputs and predicts 2D label maps as output.
* **Temporal FCN**, where the whole time-series is considered as a 3D MSI volume and convolutions are performed along the temporal dimension as well spatial dimension. The output of the network is a 2D label map as in previous cases. More details can be found in this [paper](https://www.researchgate.net/publication/333262625_Spatio-Temporal_Deep_Learning_An_Application_to_Land_Cover_Classification).
* **ResUNet-a**, architecture proposed in Diakogiannis et al. ["ResUNet-a: A deep learning framework for semantic segmetnation of remotely sensed data"](https://www.sciencedirect.com/science/article/abs/pii/S0924271620300149). Original `mxnet` implementation can be found [here](https://github.com/feevos/resuneta).

Classification models for crop classification using time-series:
* **TCN**: Implementation of the TCN network taken from the [keras-TCN implementation by Philippe Remy](https://github.com/philipperemy/keras-tcn).
Expand Down
149 changes: 149 additions & 0 deletions eoflow/models/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import warnings

from typing import Any, Callable, List

from skimage import measure
from scipy import ndimage

import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np


class InitializableMetric(tf.keras.metrics.Metric):
Expand Down Expand Up @@ -156,3 +164,144 @@ def get_config(self):
self.assert_initialized()

return self.metric.get_config()


class GeometricMetrics(InitializableMetric):
""""
Implementation of Geometric error metrics. Oversegmentation, Undersegmentation, Border, Fragmentation errors.
The error metrics are based on a paper by C. Persello, A Novel Protocol for Accuracy Assessment in Classification of
Very High Resolution Images (https://ieeexplore.ieee.org/document/5282610)
"""

@staticmethod
def _detect_edges(im: np.ndarray, thr: float = 0) -> np.ndarray:
""" Edge detection function using the sobel operator. """
sx = ndimage.sobel(im, axis=0, mode='constant')
sy = ndimage.sobel(im, axis=1, mode='constant')
sob = np.hypot(sx, sy)
return sob > thr

@staticmethod
def _segmentation_error(intersection_area: float, object_area: float) -> float:
return 1. - intersection_area / object_area

@staticmethod
def _intersection(mask1: np.ndarray, mask2: np.ndarray) -> float:
return np.sum(np.logical_and(mask1, mask2))

def _border_err(self, border_ref_edge: np.ndarray, border_meas_edge: np.ndarray) -> float:
ref_edge_size = np.sum(border_ref_edge)
intersection = self._intersection(border_ref_edge, border_meas_edge)
err = intersection / ref_edge_size if ref_edge_size != 0 else 0
be = 1. - err
return be

def _fragmentation_err(self, r: int, reference_mask: np.ndarray) -> float:
if r <= 1:
return 0
den = np.sum(reference_mask) - self.pixel_size
err = (r - 1.) / den if den > 0 else 0
return err

@staticmethod
def _validate_input(reference, measurement):
if np.ndim(reference) != np.ndim(measurement):
raise ValueError("Reference and measurement input shapes must match.")

def __init__(self, pixel_size: int = 1, edge_func: Callable = None, **edge_func_params: Any):

super().__init__(name='geometric_metrics', dtype=tf.float64)

self.oversegmentation_error = []
self.undersegmentation_error = []
self.border_error = []
self.fragmentation_error = []

self.edge_func = self._detect_edges if edge_func is None else edge_func
self.edge_func_params = edge_func_params
self.pixel_size = pixel_size

def update_state(self, reference: np.ndarray, measurement: np.ndarray, encode_reference: bool = True,
background_value: int = 0) -> None:
""" Calculate the error metrics for a measurement and reference arrays. For each .
If encode_reference is set to True, connected components will be used to label objects in the reference and
measurements.
"""

if not tf.executing_eagerly():
warnings.warn("Geometric metrics must be run with eager execution. If running as a compiled Keras model, "
"enable eager execution with model.run_eagerly = True")

reference = reference.numpy() if isinstance(reference, tf.Tensor) else reference
measurement = measurement.numpy() if isinstance(reference, tf.Tensor) else measurement

self._validate_input(reference, measurement)

for ref, meas in zip(reference, measurement):
ref = ref
meas = meas

if encode_reference:
cc_reference = measure.label(ref, background=background_value)
else:
cc_reference = ref

cc_measurement = measure.label(meas, background=background_value)
components_reference = set(np.unique(cc_reference)).difference([background_value])

ref_edges = self.edge_func(cc_reference)
meas_edges = self.edge_func(cc_measurement)
for component in components_reference:
reference_mask = cc_reference == component

uniq, count = np.unique(cc_measurement[reference_mask & (cc_measurement != background_value)],
return_counts=True)
ref_area = np.sum(reference_mask)

max_interecting_measurement = uniq[count.argmax()] if len(count) > 0 else background_value
meas_mask = cc_measurement == max_interecting_measurement
meas_area = np.count_nonzero(cc_measurement == max_interecting_measurement)
intersection_area = count.max() if len(count) > 0 else 0

self.oversegmentation_error.append(self._segmentation_error(intersection_area, ref_area))
self.undersegmentation_error.append(self._segmentation_error(intersection_area, meas_area))
border_ref_edge = ref_edges.squeeze() & reference_mask.squeeze()
border_meas_edge = meas_edges.squeeze() & meas_mask.squeeze()

self.border_error.append(self._border_err(border_ref_edge, border_meas_edge))
self.fragmentation_error.append(self._fragmentation_err(len(uniq), reference_mask))

def get_oversegmentation_error(self) -> float:
""" Return oversegmentation error. """
return np.array(self.oversegmentation_error).mean()

def get_undersegmentation_error(self) -> float:
""" Return undersegmentation error. """

return np.array(self.undersegmentation_error).mean()

def get_border_error(self) -> float:
""" Return border error. """

return np.array(self.border_error).mean()

def get_fragmentation_error(self) -> float:
""" Return fragmentation error. """

return np.array(self.fragmentation_error).mean()

def result(self) -> List[float]:
""" Return a list of values representing oversegmentation, undersegmentation, border, fragmentation errors. """

return [self.get_oversegmentation_error(),
self.get_undersegmentation_error(),
self.get_border_error(), self.get_fragmentation_error()]

def reset_states(self) -> None:
""" Empty all the error arrays. """
self.oversegmentation_error = []
self.undersegmentation_error = []
self.border_error = []
self.fragmentation_error = []
Loading

0 comments on commit d7203fb

Please sign in to comment.