Skip to content

Commit

Permalink
add dice evaluation metric (#225)
Browse files Browse the repository at this point in the history
* add dice evaluation metric

* add dice evaluation metric

* add dice evaluation metric

* support 2 metrics

* support 2 metrics

* support 2 metrics

* support 2 metrics

* fix docstring

* use np.round once for all
  • Loading branch information
Junjun2016 authored Nov 24, 2020
1 parent 90e8e38 commit 993be25
Show file tree
Hide file tree
Showing 9 changed files with 420 additions and 179 deletions.
5 changes: 3 additions & 2 deletions mmseg/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .class_names import get_classes, get_palette
from .eval_hooks import DistEvalHook, EvalHook
from .mean_iou import mean_iou
from .metrics import eval_metrics, mean_dice, mean_iou

__all__ = [
'EvalHook', 'DistEvalHook', 'mean_iou', 'get_classes', 'get_palette'
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics',
'get_classes', 'get_palette'
]
74 changes: 0 additions & 74 deletions mmseg/core/evaluation/mean_iou.py

This file was deleted.

176 changes: 176 additions & 0 deletions mmseg/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import numpy as np


def intersect_and_union(pred_label, label, num_classes, ignore_index):
"""Calculate intersection and Union.
Args:
pred_label (ndarray): Prediction segmentation map
label (ndarray): Ground truth segmentation map
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.
Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes
ndarray: The union of prediction and ground truth histogram on all
classes
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""

mask = (label != ignore_index)
pred_label = pred_label[mask]
label = label[mask]

intersect = pred_label[pred_label == label]
area_intersect, _ = np.histogram(
intersect, bins=np.arange(num_classes + 1))
area_pred_label, _ = np.histogram(
pred_label, bins=np.arange(num_classes + 1))
area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1))
area_union = area_pred_label + area_label - area_intersect

return area_intersect, area_union, area_pred_label, area_label


def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index):
"""Calculate Total Intersection and Union.
Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.
Returns:
ndarray: The intersection of prediction and ground truth histogram
on all classes
ndarray: The union of prediction and ground truth histogram on all
classes
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""

num_imgs = len(results)
assert len(gt_seg_maps) == num_imgs
total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
total_area_union = np.zeros((num_classes, ), dtype=np.float)
total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
total_area_label = np.zeros((num_classes, ), dtype=np.float)
for i in range(num_imgs):
area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(results[i], gt_seg_maps[i], num_classes,
ignore_index=ignore_index)
total_area_intersect += area_intersect
total_area_union += area_union
total_area_pred_label += area_pred_label
total_area_label += area_label
return total_area_intersect, total_area_union, \
total_area_pred_label, total_area_label


def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None):
"""Calculate Mean Intersection and Union (mIoU)
Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, )
ndarray: Per category IoU, shape (num_classes, )
"""

all_acc, acc, iou = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
ignore_index=ignore_index,
metrics=['mIoU'],
nan_to_num=nan_to_num)
return all_acc, acc, iou


def mean_dice(results,
gt_seg_maps,
num_classes,
ignore_index,
nan_to_num=None):
"""Calculate Mean Dice (mDice)
Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, )
ndarray: Per category dice, shape (num_classes, )
"""

all_acc, acc, dice = eval_metrics(
results=results,
gt_seg_maps=gt_seg_maps,
num_classes=num_classes,
ignore_index=ignore_index,
metrics=['mDice'],
nan_to_num=nan_to_num)
return all_acc, acc, dice


def eval_metrics(results,
gt_seg_maps,
num_classes,
ignore_index,
metrics=['mIoU'],
nan_to_num=None):
"""Calculate evaluation metrics
Args:
results (list[ndarray]): List of prediction segmentation maps
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, )
ndarray: Per category evalution metrics, shape (num_classes, )
"""

if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice']
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError('metrics {} is not supported'.format(metrics))
total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = total_intersect_and_union(results, gt_seg_maps,
num_classes,
ignore_index=ignore_index)
all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label
ret_metrics = [all_acc, acc]
for metric in metrics:
if metric == 'mIoU':
iou = total_area_intersect / total_area_union
ret_metrics.append(iou)
elif metric == 'mDice':
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
ret_metrics.append(dice)
if nan_to_num is not None:
ret_metrics = [
np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics
]
return ret_metrics
80 changes: 43 additions & 37 deletions mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
import mmcv
import numpy as np
from mmcv.utils import print_log
from terminaltables import AsciiTable
from torch.utils.data import Dataset

from mmseg.core import mean_iou
from mmseg.core import eval_metrics
from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose


@DATASETS.register_module()
class CustomDataset(Dataset):
"""Custom dataset for semantic segmentation.
An example of file structure is as followed.
"""Custom dataset for semantic segmentation. An example of file structure
is as followed.
.. code-block:: none
Expand Down Expand Up @@ -315,57 +315,63 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
metric (str | list[str]): Metrics to be evaluated. 'mIoU' and
'mDice' are supported.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict[str, float]: Default metrics.
"""

if not isinstance(metric, str):
assert len(metric) == 1
metric = metric[0]
allowed_metrics = ['mIoU']
if metric not in allowed_metrics:
if isinstance(metric, str):
metric = [metric]
allowed_metrics = ['mIoU', 'mDice']
if not set(metric).issubset(set(allowed_metrics)):
raise KeyError('metric {} is not supported'.format(metric))

eval_results = {}
gt_seg_maps = self.get_gt_seg_maps()
if self.CLASSES is None:
num_classes = len(
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
else:
num_classes = len(self.CLASSES)

all_acc, acc, iou = mean_iou(
results, gt_seg_maps, num_classes, ignore_index=self.ignore_index)
summary_str = ''
summary_str += 'per class results:\n'

line_format = '{:<15} {:>10} {:>10}\n'
summary_str += line_format.format('Class', 'IoU', 'Acc')
ret_metrics = eval_metrics(
results,
gt_seg_maps,
num_classes,
ignore_index=self.ignore_index,
metrics=metric)
class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']]
if self.CLASSES is None:
class_names = tuple(range(num_classes))
else:
class_names = self.CLASSES
ret_metrics_round = [
np.round(ret_metric * 100, 2) for ret_metric in ret_metrics
]
for i in range(num_classes):
iou_str = '{:.2f}'.format(iou[i] * 100)
acc_str = '{:.2f}'.format(acc[i] * 100)
summary_str += line_format.format(class_names[i], iou_str, acc_str)
summary_str += 'Summary:\n'
line_format = '{:<15} {:>10} {:>10} {:>10}\n'
summary_str += line_format.format('Scope', 'mIoU', 'mAcc', 'aAcc')

iou_str = '{:.2f}'.format(np.nanmean(iou) * 100)
acc_str = '{:.2f}'.format(np.nanmean(acc) * 100)
all_acc_str = '{:.2f}'.format(all_acc * 100)
summary_str += line_format.format('global', iou_str, acc_str,
all_acc_str)
print_log(summary_str, logger)

eval_results['mIoU'] = np.nanmean(iou)
eval_results['mAcc'] = np.nanmean(acc)
eval_results['aAcc'] = all_acc

class_table_data.append([class_names[i]] +
[m[i] for m in ret_metrics_round[2:]] +
[ret_metrics_round[1][i]])
summary_table_data = [['Scope'] +
['m' + head
for head in class_table_data[0][1:]] + ['aAcc']]
ret_metrics_mean = [
np.round(np.nanmean(ret_metric) * 100, 2)
for ret_metric in ret_metrics
]
summary_table_data.append(['global'] + ret_metrics_mean[2:] +
[ret_metrics_mean[1]] +
[ret_metrics_mean[0]])
print_log('per class results:', logger)
table = AsciiTable(class_table_data)
print_log('\n' + table.table, logger=logger)
print_log('Summary:', logger)
table = AsciiTable(summary_table_data)
print_log('\n' + table.table, logger=logger)

for i in range(1, len(summary_table_data[0])):
eval_results[summary_table_data[0]
[i]] = summary_table_data[1][i] / 100.0
return eval_results
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
matplotlib
numpy
terminaltables
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmseg
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,torch
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,terminaltables,torch
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
Loading

0 comments on commit 993be25

Please sign in to comment.