diff --git a/contrib/PanopticDeepLab/README.md b/contrib/PanopticDeepLab/README.md new file mode 100644 index 0000000000..d4faa744fa --- /dev/null +++ b/contrib/PanopticDeepLab/README.md @@ -0,0 +1,144 @@ + +# Panoptic DeepLab + +基于PaddlePaddle实现[Panoptic Deeplab](https://arxiv.org/abs/1911.10194)全景分割算法。 + +Panoptic DeepLab首次证实了bottem-up算法能够达到state-of-the-art的效果。Panoptic DeepLab预测三个输出:Semantic Segmentation, Center Prediction 和 Center Regression。实例类别像素根据最近距离原则聚集到实例中心点得到实例分割结果。最后按照majority-vote规则融合语义分割结果和实例分割结果,得到最终的全景分割结果。 +其通过将每一个像素赋值给每一个类别或实例达到分割的效果。 +![](./docs/panoptic_deeplab.jpg) + +## Model Baselines + +### Cityscapes +| Backbone | Batch Size |Resolution | Training Iters | PQ | SQ | RQ | AP | mIoU | Links | +|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| +|ResNet50_OS32| 8 | 2049x1025|90000|58.35%|80.03%|71.52%|25.80%|79.18%|[model](https://bj.bcebos.com/paddleseg/dygraph/pnoptic_segmentation/panoptic_deeplab_resnet50_os32_cityscapes_2049x1025_bs1_90k_lr00005/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/pnoptic_segmentation/panoptic_deeplab_resnet50_os32_cityscapes_2049x1025_bs1_90k_lr00005/train.log)| +|ResNet50_OS32| 64 | 1025x513|90000|60.32%|80.56%|73.56%|26.77%|79.67%|[model](https://bj.bcebos.com/paddleseg/dygraph/pnoptic_segmentation/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/pnoptic_segmentation/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005/train.log)| + +## 环境准备 + +1. 系统环境 +* PaddlePaddle >= 2.0.0 +* Python >= 3.6+ +推荐使用GPU版本的PaddlePaddle版本。详细安装教程请参考官方网站[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/windows-pip.html) + +2. 下载PaddleSeg repo +```shell +git clone https://github.com/PaddlePaddle/PaddleSeg +``` + +3. 安装paddleseg +```shell +cd PaddleSeg +pip install -e . +``` + +4. 进入PaddleSeg/contrib/PanopticDeepLab目录 +```shell +cd contrib/PanopticDeepLab +``` + +## 数据集准备 + +将数据集放置于`data`目录下。 + +### Cityscapes + +前往[CityScapes官网](https://www.cityscapes-dataset.com/)下载数据集并整理成如下结构: + +``` +cityscapes/ +|--gtFine/ +| |--train/ +| | |--aachen/ +| | | |--*_color.png, *_instanceIds.png, *_labelIds.png, *_polygons.json, +| | | |--*_labelTrainIds.png +| | | |--... +| |--val/ +| |--test/ +| |--cityscapes_panoptic_train_trainId.json +| |--cityscapes_panoptic_train_trainId/ +| | |-- *_panoptic.png +| |--cityscapes_panoptic_val_trainId.json +| |--cityscapes_panoptic_val_trainId/ +| | |-- *_panoptic.png +|--leftImg8bit/ +| |--train/ +| |--val/ +| |--test/ + +``` + +安装CityscapesScripts +```shell +pip install git+https://github.com/mcordts/cityscapesScripts.git +``` + +`*_panoptic.png` 生成命令(需找到`createPanopticImgs.py`文件): +```shell +python /path/to/cityscapesscripts/preparation/createPanopticImgs.py \ + --dataset-folder data/cityscapes/gtFine/ \ + --output-folder data/cityscapes/gtFine/ \ + --use-train-id +``` + +## 训练 +```shell +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # 根据实际情况进行显卡数量的设置 +python -m paddle.distributed.launch train.py \ + --config configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005.yml \ + --do_eval \ + --use_vdl \ + --save_interval 5000 \ + --save_dir output +``` + +**note:** 使用--do_eval会影响训练速度及增加显存消耗,根据选择进行开闭。 + +更多参数信息请运行如下命令进行查看: +```shell +python train.py --help +``` + +## 评估 +```shell +python val.py \ + --config configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005.yml \ + --model_path output/iter_90000/model.pdparams +``` +你可以直接下载我们提供的模型进行评估。 + +更多参数信息请运行如下命令进行查看: +```shell +python val.py --help +``` + +## 预测及可视化结果保存 +```shell +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # 根据实际情况进行显卡数量的设置 +python -m paddle.distributed.launch predict.py \ + --config configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_120k.yml \ + --model_path output/iter_90000/model.pdparams \ + --image_path data/cityscapes/leftImg8bit/val/ \ + --save_dir ./output/result +``` +你可以直接下载我们提供的模型进行预测。 + +更多参数信息请运行如下命令进行查看: +```shell +python predict.py --help +``` +全景分割结果: +
+ +
+ +语义分割结果: +
+ +
+ +实例分割结果: +
+ +
diff --git a/contrib/PanopticDeepLab/configs/_base_/cityscapes_panoptic.yml b/contrib/PanopticDeepLab/configs/_base_/cityscapes_panoptic.yml new file mode 100644 index 0000000000..aa9466ac47 --- /dev/null +++ b/contrib/PanopticDeepLab/configs/_base_/cityscapes_panoptic.yml @@ -0,0 +1,55 @@ +train_dataset: + type: CityscapesPanoptic + dataset_root: data/cityscapes + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + scale_step_size: 0.25 + - type: RandomPaddingCrop + crop_size: [2049, 1025] + label_padding_value: [0, 0, 0] + - type: RandomHorizontalFlip + - type: RandomDistort + brightness_range: 0.4 + contrast_range: 0.4 + saturation_range: 0.4 + - type: Normalize + mode: train + ignore_stuff_in_offset: True + small_instance_area: 4096 + small_instance_weight: 3 + +val_dataset: + type: CityscapesPanoptic + dataset_root: data/cityscapes + transforms: + - type: Padding + target_size: [2049, 1025] + label_padding_value: [0, 0, 0] + - type: Normalize + mode: val + ignore_stuff_in_offset: True + small_instance_area: 4096 + small_instance_weight: 3 + + +optimizer: + type: adam + +learning_rate: + value: 0.00005 + decay: + type: poly + power: 0.9 + end_lr: 0.0 + +loss: + types: + - type: CrossEntropyLoss + top_k_percent_pixels: 0.2 + - type: MSELoss + reduction: "none" + - type: L1Loss + reduction: "none" + coef: [1, 200, 0.001] diff --git a/contrib/PanopticDeepLab/configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005.yml b/contrib/PanopticDeepLab/configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005.yml new file mode 100644 index 0000000000..445b11fbdb --- /dev/null +++ b/contrib/PanopticDeepLab/configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005.yml @@ -0,0 +1,19 @@ +_base_: ./panoptic_deeplab_resnet50_os32_cityscapes_2049x1025_bs1_90k_lr00005.yml + +batch_size: 8 + +train_dataset: + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + scale_step_size: 0.25 + - type: RandomPaddingCrop + crop_size: [1025, 513] + label_padding_value: [0, 0, 0] + - type: RandomHorizontalFlip + - type: RandomDistort + brightness_range: 0.4 + contrast_range: 0.4 + saturation_range: 0.4 + - type: Normalize diff --git a/contrib/PanopticDeepLab/configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_2049x1025_bs1_90k_lr00005.yml b/contrib/PanopticDeepLab/configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_2049x1025_bs1_90k_lr00005.yml new file mode 100644 index 0000000000..d35e90d98c --- /dev/null +++ b/contrib/PanopticDeepLab/configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_2049x1025_bs1_90k_lr00005.yml @@ -0,0 +1,23 @@ +_base_: ../_base_/cityscapes_panoptic.yml + +batch_size: 1 +iters: 90000 + +model: + type: PanopticDeepLab + backbone: + type: ResNet50_vd + output_stride: 32 + pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz + backbone_indices: [2,1,0,3] + aspp_ratios: [1, 3, 6, 9] + aspp_out_channels: 256 + decoder_channels: 256 + low_level_channels_projects: [128, 64, 32] + align_corners: True + instance_aspp_out_channels: 256 + instance_decoder_channels: 128 + instance_low_level_channels_projects: [64, 32, 16] + instance_num_classes: [1, 2] + instance_head_channels: 32 + instance_class_key: ["center", "offset"] diff --git a/contrib/PanopticDeepLab/core/__init__.py b/contrib/PanopticDeepLab/core/__init__.py new file mode 100644 index 0000000000..3358db4d38 --- /dev/null +++ b/contrib/PanopticDeepLab/core/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .train import train +from .val import evaluate +from .predict import predict +from . import infer + +__all__ = ['train', 'evaluate', 'predict'] diff --git a/contrib/PanopticDeepLab/core/infer.py b/contrib/PanopticDeepLab/core/infer.py new file mode 100644 index 0000000000..8ac1d800fe --- /dev/null +++ b/contrib/PanopticDeepLab/core/infer.py @@ -0,0 +1,351 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections.abc +from itertools import combinations +from functools import partial + +import numpy as np +import paddle +import paddle.nn.functional as F + + +def get_reverse_list(ori_shape, transforms): + """ + get reverse list of transform. + + Args: + ori_shape (list): Origin shape of image. + transforms (list): List of transform. + + Returns: + list: List of tuple, there are two format: + ('resize', (h, w)) The image shape before resize, + ('padding', (h, w)) The image shape before padding. + """ + reverse_list = [] + h, w = ori_shape[0], ori_shape[1] + for op in transforms: + if op.__class__.__name__ in ['Resize']: + reverse_list.append(('resize', (h, w))) + h, w = op.target_size[0], op.target_size[1] + if op.__class__.__name__ in ['ResizeByLong']: + reverse_list.append(('resize', (h, w))) + long_edge = max(h, w) + short_edge = min(h, w) + short_edge = int(round(short_edge * op.long_size / long_edge)) + long_edge = op.long_size + if h > w: + h = long_edge + w = short_edge + else: + w = long_edge + h = short_edge + if op.__class__.__name__ in ['Padding']: + reverse_list.append(('padding', (h, w))) + w, h = op.target_size[0], op.target_size[1] + if op.__class__.__name__ in ['LimitLong']: + long_edge = max(h, w) + short_edge = min(h, w) + if ((op.max_long is not None) and (long_edge > op.max_long)): + reverse_list.append(('resize', (h, w))) + long_edge = op.max_long + short_edge = int(round(short_edge * op.max_long / long_edge)) + elif ((op.min_long is not None) and (long_edge < op.min_long)): + reverse_list.append(('resize', (h, w))) + long_edge = op.min_long + short_edge = int(round(short_edge * op.min_long / long_edge)) + if h > w: + h = long_edge + w = short_edge + else: + w = long_edge + h = short_edge + return reverse_list + + +def reverse_transform(pred, ori_shape, transforms): + """recover pred to origin shape""" + reverse_list = get_reverse_list(ori_shape, transforms) + for item in reverse_list[::-1]: + if item[0] == 'resize': + h, w = item[1][0], item[1][1] + pred = F.interpolate(pred, (h, w), mode='nearest') + elif item[0] == 'padding': + h, w = item[1][0], item[1][1] + pred = pred[:, :, 0:h, 0:w] + else: + raise Exception("Unexpected info '{}' in im_info".format(item[0])) + return pred + + +def find_instance_center(ctr_hmp, threshold=0.1, nms_kernel=3, top_k=None): + """ + Find the center points from the center heatmap. + + Args: + ctr_hmp (Tensor): A Tensor of shape [1, H, W] of raw center heatmap output. + threshold (float, optional): Threshold applied to center heatmap score. Default: 0.1. + nms_kernel (int, optional): NMS max pooling kernel size. Default: 3. + top_k (int, optional): An Integer, top k centers to keep. Default: None + + Returns: + Tensor: A Tensor of shape [K, 2] where K is the number of center points. The order of second dim is (y, x). + """ + # thresholding, setting values below threshold to 0 + ctr_hmp = F.thresholded_relu(ctr_hmp, threshold) + + #NMS + nms_padding = (nms_kernel - 1) // 2 + ctr_hmp = ctr_hmp.unsqueeze(0) + ctr_hmp_max_pooled = F.max_pool2d( + ctr_hmp, kernel_size=nms_kernel, stride=1, padding=nms_padding) + ctr_hmp = ctr_hmp * (ctr_hmp_max_pooled == ctr_hmp) + + ctr_hmp = ctr_hmp.squeeze((0, 1)) + if len(ctr_hmp.shape) != 2: + raise ValueError('Something is wrong with center heatmap dimension.') + + if top_k is None: + top_k_score = 0 + else: + top_k_score, _ = paddle.topk(paddle.flatten(ctr_hmp), top_k) + top_k_score = top_k_score[-1] + # non-zero points are candidate centers + ctr_hmp_k = (ctr_hmp > top_k_score[-1]).astype('int64') + if ctr_hmp_k.sum() == 0: + ctr_all = None + else: + ctr_all = paddle.nonzero(ctr_hmp_k) + return ctr_all + + +def group_pixels(ctr, offsets): + """ + Gives each pixel in the image an instance id. + + Args: + ctr (Tensor): A Tensor of shape [K, 2] where K is the number of center points. The order of second dim is (y, x). + offsets (Tensor): A Tensor of shape [2, H, W] of raw offset output, where N is the batch size, + for consistent, we only support N=1. The order of second dim is (offset_y, offset_x). + + Returns: + Tensor: A Tensor of shape [1, H, W], ins_id is 1, 2, ... + """ + height, width = offsets.shape[-2:] + y_coord = paddle.arange(height, dtype=offsets.dtype).reshape([1, -1, 1]) + y_coord = paddle.concat([y_coord] * width, axis=2) + x_coord = paddle.arange(width, dtype=offsets.dtype).reshape([1, 1, -1]) + x_coord = paddle.concat([x_coord] * height, axis=1) + coord = paddle.concat([y_coord, x_coord], axis=0) + + ctr_loc = coord + offsets + ctr_loc = ctr_loc.reshape((2, height * width)).transpose((1, 0)) + + # ctr: [K, 2] -> [K, 1, 2] + # ctr_loc = [H*W, 2] -> [1, H*W, 2] + ctr = ctr.unsqueeze(1) + ctr_loc = ctr_loc.unsqueeze(0) + + # distance: [K, H*W] + distance = paddle.norm((ctr - ctr_loc).astype('float32'), axis=-1) + + # finds center with minimum distance at each location, offset by 1, to reserve id=0 for stuff + instance_id = paddle.argmin( + distance, axis=0).reshape((1, height, width)) + 1 + + return instance_id + + +def get_instance_segmentation(semantic, + ctr_hmp, + offset, + thing_list, + threshold=0.1, + nms_kernel=3, + top_k=None): + """ + Post-processing for instance segmentation, gets class agnostic instance id map. + + Args: + semantic (Tensor): A Tensor of shape [1, H, W], predicted semantic label. + ctr_hmp (Tensor): A Tensor of shape [1, H, W] of raw center heatmap output, where N is the batch size, + for consistent, we only support N=1. + offsets (Tensor): A Tensor of shape [2, H, W] of raw offset output, where N is the batch size, + for consistent, we only support N=1. The order of second dim is (offset_y, offset_x). + thing_list (list): A List of thing class id. + threshold (float, optional): A Float, threshold applied to center heatmap score. Default: 0.1. + nms_kernel (int, optional): An Integer, NMS max pooling kernel size. Default: 3. + top_k (int, optional): An Integer, top k centers to keep. Default: None. + + Returns: + Tensor: Instance segmentation results which shape is [1, H, W]. + Tensor: A Tensor of shape [1, K, 2] where K is the number of center points. The order of second dim is (y, x). + """ + thing_seg = paddle.zeros_like(semantic) + for thing_class in thing_list: + thing_seg = thing_seg + (semantic == thing_class).astype('int64') + thing_seg = (thing_seg > 0).astype('int64') + center = find_instance_center( + ctr_hmp, threshold=threshold, nms_kernel=nms_kernel, top_k=top_k) + if center is None: + return paddle.zeros_like(semantic), center + ins_seg = group_pixels(center, offset) + return thing_seg * ins_seg, center.unsqueeze(0) + + +def merge_semantic_and_instance(semantic, instance, label_divisor, thing_list, + stuff_area, ignore_index): + """ + Post-processing for panoptic segmentation, by merging semantic segmentation label and class agnostic + instance segmentation label. + + Args: + semantic (Tensor): A Tensor of shape [1, H, W], predicted semantic label. + instance (Tensor): A Tensor of shape [1, H, W], predicted instance label. + label_divisor (int): An Integer, used to convert panoptic id = semantic id * label_divisor + instance_id. + thing_list (list): A List of thing class id. + stuff_area (int): An Integer, remove stuff whose area is less tan stuff_area. + ignore_index (int): Specifies a value that is ignored. + + Returns: + Tensor: A Tensor of shape [1, H, W] . The pixels whose value equaling ignore_index is ignored. + The stuff class is represented as format like class_id, while + thing class as class_id * label_divisor + ins_id and ins_id begin from 1. + """ + # In case thing mask does not align with semantic prediction + pan_seg = paddle.zeros_like(semantic) + ignore_index + thing_seg = instance > 0 + semantic_thing_seg = paddle.zeros_like(semantic) + for thing_class in thing_list: + semantic_thing_seg += semantic == thing_class + + # keep track of instance id for each class + class_id_tracker = {} + + # paste thing by majority voting + ins_ids = paddle.unique(instance) + for ins_id in ins_ids: + if ins_id == 0: + continue + # Make sure only do majority voting within semantic_thing_seg + thing_mask = paddle.logical_and(instance == ins_id, + semantic_thing_seg == 1) + if paddle.all(paddle.logical_not(thing_mask)): + continue + # get class id for instance of ins_id + sem_ins_id = paddle.gather( + semantic.reshape((-1, )), paddle.nonzero( + thing_mask.reshape((-1, )))) # equal to semantic[thing_mask] + v, c = paddle.unique(sem_ins_id, return_counts=True) + class_id = paddle.gather(v, c.argmax()) + class_id = class_id.numpy()[0] + if class_id in class_id_tracker: + new_ins_id = class_id_tracker[class_id] + else: + class_id_tracker[class_id] = 1 + new_ins_id = 1 + class_id_tracker[class_id] += 1 + + # pan_seg[thing_mask] = class_id * label_divisor + new_ins_id + pan_seg = pan_seg * (paddle.logical_not(thing_mask)) + ( + class_id * label_divisor + new_ins_id) * thing_mask.astype('int64') + + # paste stuff to unoccupied area + class_ids = paddle.unique(semantic) + for class_id in class_ids: + if class_id.numpy() in thing_list: + # thing class + continue + # calculate stuff area + stuff_mask = paddle.logical_and(semantic == class_id, + paddle.logical_not(thing_seg)) + area = paddle.sum(stuff_mask.astype('int64')) + if area >= stuff_area: + # pan_seg[stuff_mask] = class_id + pan_seg = pan_seg * (paddle.logical_not(stuff_mask) + ) + stuff_mask.astype('int64') * class_id + + return pan_seg + + +def inference( + model, + im, + transforms, + thing_list, + label_divisor, + stuff_area, + ignore_index, + threshold=0.1, + nms_kernel=3, + top_k=None, + ori_shape=None, +): + """ + Inference for image. + + Args: + model (paddle.nn.Layer): model to get logits of image. + im (Tensor): the input image. + transforms (list): Transforms for image. + thing_list (list): A List of thing class id. + label_divisor (int): An Integer, used to convert panoptic id = semantic id * label_divisor + instance_id. + stuff_area (int): An Integer, remove stuff whose area is less tan stuff_area. + ignore_index (int): Specifies a value that is ignored. + threshold (float, optional): A Float, threshold applied to center heatmap score. Default: 0.1. + nms_kernel (int, optional): An Integer, NMS max pooling kernel size. Default: 3. + top_k (int, optional): An Integer, top k centers to keep. Default: None. + ori_shape (list. optional): Origin shape of image. Default: None. + + Returns: + list: A list of [semantic, semantic_softmax, instance, panoptic, ctr_hmp]. + semantic: Semantic segmentation results with shape [1, 1, H, W], which value is 0, 1, 2... + semantic_softmax: A Tensor represent probabilities for each class, which shape is [1, num_classes, H, W]. + instance: Instance segmentation results with class agnostic, which value is 0, 1, 2, ..., and 0 is stuff. + panoptic: Panoptic segmentation results which value is ignore_index, stuff_id, thing_id * label_divisor + ins_id , ins_id >= 1. + """ + logits = model(im) + # semantic: [1, c, h, w], center: [1, 1, h, w], offset: [1, 2, h, w] + semantic, ctr_hmp, offset = logits + semantic = paddle.argmax(semantic, axis=1, keepdim=True) + semantic = semantic.squeeze(0) # shape: [1, h, w] + semantic_softmax = F.softmax(logits[0], axis=1).squeeze() + ctr_hmp = ctr_hmp.squeeze(0) # shape: [1, h, w] + offset = offset.squeeze(0) # shape: [2, h, w] + + instance, center = get_instance_segmentation( + semantic=semantic, + ctr_hmp=ctr_hmp, + offset=offset, + thing_list=thing_list, + threshold=threshold, + nms_kernel=nms_kernel, + top_k=top_k) + panoptic = merge_semantic_and_instance(semantic, instance, label_divisor, + thing_list, stuff_area, ignore_index) + + # Recover to origin shape + # semantic: 0, 1, 2, 3... + # instance: 0, 1, 2, 3, 4, 5... and the 0 is stuff. + # panoptic: ignore_index, stuff_id, thing_id * label_divisor + ins_id , ins_id >= 1. + results = [semantic, semantic_softmax, instance, panoptic, ctr_hmp] + if ori_shape is not None: + results = [i.unsqueeze(0) for i in results] + results = [ + reverse_transform(i, ori_shape=ori_shape, transforms=transforms) + for i in results + ] + + return results diff --git a/contrib/PanopticDeepLab/core/predict.py b/contrib/PanopticDeepLab/core/predict.py new file mode 100644 index 0000000000..78b9b54ec2 --- /dev/null +++ b/contrib/PanopticDeepLab/core/predict.py @@ -0,0 +1,189 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import math + +import cv2 +import numpy as np +import paddle +import paddleseg +from paddleseg.utils import logger, progbar + +from core import infer +import utils + + +def mkdir(path): + sub_dir = os.path.dirname(path) + if not os.path.exists(sub_dir): + os.makedirs(sub_dir) + + +def partition_list(arr, m): + """split the list 'arr' into m pieces""" + n = int(math.ceil(len(arr) / float(m))) + return [arr[i:i + n] for i in range(0, len(arr), n)] + + +def get_save_name(im_path, im_dir): + """get the saved name""" + if im_dir is not None: + im_file = im_path.replace(im_dir, '') + else: + im_file = os.path.basename(im_path) + if im_file[0] == '/': + im_file = im_file[1:] + return im_file + + +def add_info_to_save_path(save_path, info): + """Add more information to save path""" + fname, fextension = os.path.splitext(save_path) + fname = '_'.join([fname, info]) + save_path = ''.join([fname, fextension]) + return save_path + + +def predict(model, + model_path, + image_list, + transforms, + thing_list, + label_divisor, + stuff_area, + ignore_index, + image_dir=None, + save_dir='output', + threshold=0.1, + nms_kernel=7, + top_k=200): + """ + predict and visualize the image_list. + + Args: + model (nn.Layer): Used to predict for input image. + model_path (str): The path of pretrained model. + image_list (list): A list of image path to be predicted. + transforms (transform.Compose): Preprocess for input image. + thing_list (list): A List of thing class id. + label_divisor (int): An Integer, used to convert panoptic id = semantic id * label_divisor + instance_id. + stuff_area (int): An Integer, remove stuff whose area is less tan stuff_area. + ignore_index (int): Specifies a value that is ignored. + image_dir (str, optional): The root directory of the images predicted. Default: None. + save_dir (str, optional): The directory to save the visualized results. Default: 'output'. + threshold(float, optional): Threshold applied to center heatmap score. Defalut: 0.1. + nms_kernel(int, optional): NMS max pooling kernel size. Default: 7. + top_k(int, optional): Top k centers to keep. Default: 200. + """ + paddleseg.utils.utils.load_entire_model(model, model_path) + model.eval() + nranks = paddle.distributed.get_world_size() + local_rank = paddle.distributed.get_rank() + if nranks > 1: + img_lists = partition_list(image_list, nranks) + else: + img_lists = [image_list] + + semantic_save_dir = os.path.join(save_dir, 'semantic') + instance_save_dir = os.path.join(save_dir, 'instance') + panoptic_save_dir = os.path.join(save_dir, 'panoptic') + + colormap = utils.cityscape_colormap() + + logger.info("Start to predict...") + progbar_pred = progbar.Progbar(target=len(img_lists[0]), verbose=1) + with paddle.no_grad(): + for i, im_path in enumerate(img_lists[local_rank]): + ori_im = cv2.imread(im_path) + ori_shape = ori_im.shape[:2] + im, _ = transforms(ori_im) + im = im[np.newaxis, ...] + im = paddle.to_tensor(im) + + semantic, semantic_softmax, instance, panoptic, ctr_hmp = infer.inference( + model=model, + im=im, + transforms=transforms.transforms, + thing_list=thing_list, + label_divisor=label_divisor, + stuff_area=stuff_area, + ignore_index=ignore_index, + threshold=threshold, + nms_kernel=nms_kernel, + top_k=top_k, + ori_shape=ori_shape) + semantic = semantic.squeeze().numpy() + instance = instance.squeeze().numpy() + panoptic = panoptic.squeeze().numpy() + + im_file = get_save_name(im_path, image_dir) + + # visual semantic segmentation results + save_path = os.path.join(semantic_save_dir, im_file) + mkdir(save_path) + utils.visualize_semantic( + semantic, save_path=save_path, colormap=colormap) + # Save added image for semantic segmentation results + save_path_ = add_info_to_save_path(save_path, 'add') + utils.visualize_semantic( + semantic, save_path=save_path_, colormap=colormap, image=ori_im) + # panoptic to semantic + ins_mask = panoptic > label_divisor + pan_to_sem = panoptic.copy() + pan_to_sem[ins_mask] = pan_to_sem[ins_mask] // label_divisor + save_path_ = add_info_to_save_path(save_path, + 'panoptic_to_semantic') + utils.visualize_semantic( + pan_to_sem, save_path=save_path_, colormap=colormap) + save_path_ = add_info_to_save_path(save_path, + 'panoptic_to_semantic_added') + utils.visualize_semantic( + pan_to_sem, + save_path=save_path_, + colormap=colormap, + image=ori_im) + + # vusual instance segmentation results + pan_to_ins = panoptic.copy() + ins_mask = pan_to_ins > label_divisor + pan_to_ins[~ins_mask] = 0 + save_path = os.path.join(instance_save_dir, im_file) + mkdir(save_path) + utils.visualize_instance(pan_to_ins, save_path=save_path) + # Save added image for instance segmentation results + save_path_ = add_info_to_save_path(save_path, 'added') + utils.visualize_instance( + pan_to_ins, save_path=save_path_, image=ori_im) + + # visual panoptic segmentation results + save_path = os.path.join(panoptic_save_dir, im_file) + mkdir(save_path) + utils.visualize_panoptic( + panoptic, + save_path=save_path, + label_divisor=label_divisor, + colormap=colormap, + ignore_index=ignore_index) + # Save added image for panoptic segmentation results + save_path_ = add_info_to_save_path(save_path, 'added') + utils.visualize_panoptic( + panoptic, + save_path=save_path_, + label_divisor=label_divisor, + colormap=colormap, + image=ori_im, + ignore_index=ignore_index) + + progbar_pred.update(i + 1) diff --git a/contrib/PanopticDeepLab/core/train.py b/contrib/PanopticDeepLab/core/train.py new file mode 100644 index 0000000000..a3bdaf966c --- /dev/null +++ b/contrib/PanopticDeepLab/core/train.py @@ -0,0 +1,315 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +from collections import deque +import shutil + +import paddle +import paddle.nn.functional as F +from paddleseg.utils import TimeAverager, calculate_eta, resume, logger + +from core.val import evaluate + + +def check_logits_losses(logits_list, losses): + len_logits = len(logits_list) + len_losses = len(losses['types']) + if len_logits != len_losses: + raise RuntimeError( + 'The length of logits_list should equal to the types of loss config: {} != {}.' + .format(len_logits, len_losses)) + + +def loss_computation(logits_list, semantic, semantic_weights, center, + center_weights, offset, offset_weights, losses): + # semantic loss + semantic_loss = losses['types'][0](logits_list[0], semantic, + semantic_weights) + semantic_loss = semantic_loss * losses['coef'][0] + + # center loss + center_loss = losses['types'][1](logits_list[1], center) + center_weights = (center_weights.unsqueeze(1)).expand_as(center_loss) + center_loss = center_loss * center_weights + if center_loss.sum() > 0: + center_loss = center_loss.sum() / center_weights.sum() + else: + center_loss = center_loss.sum() * 0 + center_loss = center_loss * losses['coef'][1] + + # offset loss + offset_loss = losses['types'][2](logits_list[2], offset) + offset_weights = (offset_weights.unsqueeze(1)).expand_as(offset_loss) + offset_loss = offset_loss * offset_weights + if offset_weights.sum() > 0: + offset_loss = offset_loss.sum() / offset_weights.sum() + else: + offset_loss = offset_loss.sum() * 0 + offset_loss = offset_loss * losses['coef'][2] + + loss_list = [semantic_loss, center_loss, offset_loss] + + return loss_list + + +def train(model, + train_dataset, + val_dataset=None, + optimizer=None, + save_dir='output', + iters=10000, + batch_size=2, + resume_model=None, + save_interval=1000, + log_iters=10, + num_workers=0, + use_vdl=False, + losses=None, + keep_checkpoint_max=5, + threshold=0.1, + nms_kernel=7, + top_k=200): + """ + Launch training. + + Args: + model(nn.Layer): A sementic segmentation model. + train_dataset (paddle.io.Dataset): Used to read and process training datasets. + val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets. + optimizer (paddle.optimizer.Optimizer): The optimizer. + save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'. + iters (int, optional): How may iters to train the model. Defualt: 10000. + batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2. + resume_model (str, optional): The path of resume model. + save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000. + log_iters (int, optional): Display logging information at every log_iters. Default: 10. + num_workers (int, optional): Num workers for data loader. Default: 0. + use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False. + losses (dict): A dict including 'types' and 'coef'. The length of coef should equal to 1 or len(losses['types']). + The 'types' item is a list of object of paddleseg.models.losses while the 'coef' item is a list of the relevant coefficient. + keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5. + threshold (float, optional): A Float, threshold applied to center heatmap score. Default: 0.1. + nms_kernel (int, optional): An Integer, NMS max pooling kernel size. Default: 7. + top_k (int, optional): An Integer, top k centers to keep. Default: 200. + """ + model.train() + nranks = paddle.distributed.ParallelEnv().nranks + local_rank = paddle.distributed.ParallelEnv().local_rank + + start_iter = 0 + if resume_model is not None: + start_iter = resume(model, optimizer, resume_model) + + if not os.path.isdir(save_dir): + if os.path.exists(save_dir): + os.remove(save_dir) + os.makedirs(save_dir) + + if nranks > 1: + # Initialize parallel environment if not done. + if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized( + ): + paddle.distributed.init_parallel_env() + ddp_model = paddle.DataParallel(model) + else: + ddp_model = paddle.DataParallel(model) + + batch_sampler = paddle.io.DistributedBatchSampler( + train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) + + loader = paddle.io.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + return_list=True, + ) + + if use_vdl: + from visualdl import LogWriter + log_writer = LogWriter(save_dir) + + avg_loss = 0.0 + avg_loss_list = [] + iters_per_epoch = len(batch_sampler) + best_pq = -1.0 + best_model_iter = -1 + reader_cost_averager = TimeAverager() + batch_cost_averager = TimeAverager() + save_models = deque() + batch_start = time.time() + + iter = start_iter + while iter < iters: + for data in loader: + iter += 1 + if iter > iters: + break + reader_cost_averager.record(time.time() - batch_start) + images = data[0] + semantic = data[1] + semantic_weights = data[2] + center = data[3] + center_weights = data[4] + offset = data[5] + offset_weights = data[6] + foreground = data[7] + + if nranks > 1: + logits_list = ddp_model(images) + else: + logits_list = model(images) + + loss_list = loss_computation( + logits_list=logits_list, + losses=losses, + semantic=semantic, + semantic_weights=semantic_weights, + center=center, + center_weights=center_weights, + offset=offset, + offset_weights=offset_weights) + loss = sum(loss_list) + loss.backward() + + optimizer.step() + lr = optimizer.get_lr() + if isinstance(optimizer._learning_rate, + paddle.optimizer.lr.LRScheduler): + optimizer._learning_rate.step() + model.clear_gradients() + avg_loss += loss.numpy()[0] + if not avg_loss_list: + avg_loss_list = [l.numpy() for l in loss_list] + else: + for i in range(len(loss_list)): + avg_loss_list[i] += loss_list[i].numpy() + batch_cost_averager.record( + time.time() - batch_start, num_samples=batch_size) + + if (iter) % log_iters == 0 and local_rank == 0: + avg_loss /= log_iters + avg_loss_list = [l[0] / log_iters for l in avg_loss_list] + remain_iters = iters - iter + avg_train_batch_cost = batch_cost_averager.get_average() + avg_train_reader_cost = reader_cost_averager.get_average() + eta = calculate_eta(remain_iters, avg_train_batch_cost) + logger.info( + "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.5f}, ips={:.4f} samples/sec | ETA {}" + .format((iter - 1) // iters_per_epoch + 1, iter, iters, + avg_loss, lr, avg_train_batch_cost, + avg_train_reader_cost, + batch_cost_averager.get_ips_average(), eta)) + logger.info( + "[LOSS] loss={:.4f}, semantic_loss={:.4f}, center_loss={:.4f}, offset_loss={:.4f}" + .format(avg_loss, avg_loss_list[0], avg_loss_list[1], + avg_loss_list[2])) + if use_vdl: + log_writer.add_scalar('Train/loss', avg_loss, iter) + # Record all losses if there are more than 2 losses. + if len(avg_loss_list) > 1: + avg_loss_dict = {} + for i, value in enumerate(avg_loss_list): + avg_loss_dict['loss_' + str(i)] = value + for key, value in avg_loss_dict.items(): + log_tag = 'Train/' + key + log_writer.add_scalar(log_tag, value, iter) + + log_writer.add_scalar('Train/lr', lr, iter) + log_writer.add_scalar('Train/batch_cost', + avg_train_batch_cost, iter) + log_writer.add_scalar('Train/reader_cost', + avg_train_reader_cost, iter) + + avg_loss = 0.0 + avg_loss_list = [] + reader_cost_averager.reset() + batch_cost_averager.reset() + + # save model + if (iter % save_interval == 0 or iter == iters) and local_rank == 0: + current_save_dir = os.path.join(save_dir, + "iter_{}".format(iter)) + if not os.path.isdir(current_save_dir): + os.makedirs(current_save_dir) + paddle.save(model.state_dict(), + os.path.join(current_save_dir, 'model.pdparams')) + paddle.save(optimizer.state_dict(), + os.path.join(current_save_dir, 'model.pdopt')) + save_models.append(current_save_dir) + if len(save_models) > keep_checkpoint_max > 0: + model_to_remove = save_models.popleft() + shutil.rmtree(model_to_remove) + + # eval model + if (iter % save_interval == 0 or iter == iters) and ( + val_dataset is + not None) and local_rank == 0 and iter > iters // 2: + num_workers = 1 if num_workers > 0 else 0 + panoptic_results, semantic_results, instance_results = evaluate( + model, + val_dataset, + threshold=threshold, + nms_kernel=nms_kernel, + top_k=top_k, + num_workers=num_workers, + print_detail=False) + pq = panoptic_results['pan_seg']['All']['pq'] + miou = semantic_results['sem_seg']['mIoU'] + map = instance_results['ins_seg']['mAP'] + map50 = instance_results['ins_seg']['mAP50'] + logger.info( + "[EVAL] PQ: {:.4f}, mIoU: {:.4f}, mAP: {:.4f}, mAP50: {:.4f}" + .format(pq, miou, map, map50)) + model.train() + + # save best model and add evaluate results to vdl + if (iter % save_interval == 0 or iter == iters) and local_rank == 0: + if val_dataset is not None and iter > iters // 2: + if pq > best_pq: + best_pq = pq + best_model_iter = iter + best_model_dir = os.path.join(save_dir, "best_model") + paddle.save( + model.state_dict(), + os.path.join(best_model_dir, 'model.pdparams')) + logger.info( + '[EVAL] The model with the best validation pq ({:.4f}) was saved at iter {}.' + .format(best_pq, best_model_iter)) + + if use_vdl: + log_writer.add_scalar('Evaluate/PQ', pq, iter) + log_writer.add_scalar('Evaluate/mIoU', miou, iter) + log_writer.add_scalar('Evaluate/mAP', map, iter) + log_writer.add_scalar('Evaluate/mAP50', map50, iter) + batch_start = time.time() + + # Calculate flops. + if local_rank == 0: + + def count_syncbn(m, x, y): + x = x[0] + nelements = x.numel() + m.total_ops += int(2 * nelements) + + _, c, h, w = images.shape + flops = paddle.flops( + model, [1, c, h, w], + custom_ops={paddle.nn.SyncBatchNorm: count_syncbn}) + + # Sleep for half a second to let dataloader release resources. + time.sleep(0.5) + if use_vdl: + log_writer.close() diff --git a/contrib/PanopticDeepLab/core/val.py b/contrib/PanopticDeepLab/core/val.py new file mode 100644 index 0000000000..9e0f90b97b --- /dev/null +++ b/contrib/PanopticDeepLab/core/val.py @@ -0,0 +1,181 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import OrderedDict + +import numpy as np +import time +import paddle +import paddle.nn.functional as F +from paddleseg.utils import TimeAverager, calculate_eta, logger, progbar + +from utils.evaluation import SemanticEvaluator, InstanceEvaluator, PanopticEvaluator +from core import infer + +np.set_printoptions(suppress=True) + + +def evaluate(model, + eval_dataset, + threshold=0.1, + nms_kernel=7, + top_k=200, + num_workers=0, + print_detail=True): + """ + Launch evaluation. + + Args: + model(nn.Layer): A sementic segmentation model. + eval_dataset (paddle.io.Dataset): Used to read and process validation datasets. + threshold (float, optional): Threshold applied to center heatmap score. Defalut: 0.1. + nms_kernel (int, optional): NMS max pooling kernel size. Default: 7. + top_k (int, optional): Top k centers to keep. Default: 200. + num_workers (int, optional): Num workers for data loader. Default: 0. + print_detail (bool, optional): Whether to print detailed information about the evaluation process. Default: True. + + Returns: + dict: Panoptic evaluation results which includes PQ, RQ, SQ for all, each class, Things and stuff. + dict: Semantic evaluation results which includes mIoU, fwIoU, mACC and pACC. + dict: Instance evaluation results which includes mAP and mAP50, and also AP and AP50 for each class. + + """ + model.eval() + nranks = paddle.distributed.ParallelEnv().nranks + local_rank = paddle.distributed.ParallelEnv().local_rank + if nranks > 1: + # Initialize parallel environment if not done. + if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized( + ): + paddle.distributed.init_parallel_env() + batch_sampler = paddle.io.DistributedBatchSampler( + eval_dataset, batch_size=1, shuffle=False, drop_last=False) + loader = paddle.io.DataLoader( + eval_dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + return_list=True, + ) + + total_iters = len(loader) + semantic_metric = SemanticEvaluator( + eval_dataset.num_classes, ignore_index=eval_dataset.ignore_index) + instance_metric_AP50 = InstanceEvaluator( + eval_dataset.num_classes, + overlaps=0.5, + thing_list=eval_dataset.thing_list) + instance_metric_AP = InstanceEvaluator( + eval_dataset.num_classes, + overlaps=list(np.arange(0.5, 1.0, 0.05)), + thing_list=eval_dataset.thing_list) + panoptic_metric = PanopticEvaluator( + num_classes=eval_dataset.num_classes, + thing_list=eval_dataset.thing_list, + ignore_index=eval_dataset.ignore_index, + label_divisor=eval_dataset.label_divisor) + + if print_detail: + logger.info( + "Start evaluating (total_samples={}, total_iters={})...".format( + len(eval_dataset), total_iters)) + progbar_val = progbar.Progbar(target=total_iters, verbose=1) + reader_cost_averager = TimeAverager() + batch_cost_averager = TimeAverager() + batch_start = time.time() + with paddle.no_grad(): + for iter, data in enumerate(loader): + reader_cost_averager.record(time.time() - batch_start) + im = data[0] + raw_semantic_label = data[1] # raw semantic label. + raw_instance_label = data[2] + raw_panoptic_label = data[3] + ori_shape = raw_semantic_label.shape[-2:] + + semantic, semantic_softmax, instance, panoptic, ctr_hmp = infer.inference( + model=model, + im=im, + transforms=eval_dataset.transforms.transforms, + thing_list=eval_dataset.thing_list, + label_divisor=eval_dataset.label_divisor, + stuff_area=eval_dataset.stuff_area, + ignore_index=eval_dataset.ignore_index, + threshold=threshold, + nms_kernel=nms_kernel, + top_k=top_k, + ori_shape=ori_shape) + semantic = semantic.squeeze().numpy() + semantic_softmax = semantic_softmax.squeeze().numpy() + instance = instance.squeeze().numpy() + panoptic = panoptic.squeeze().numpy() + ctr_hmp = ctr_hmp.squeeze().numpy() + raw_semantic_label = raw_semantic_label.squeeze().numpy() + raw_instance_label = raw_instance_label.squeeze().numpy() + raw_panoptic_label = raw_panoptic_label.squeeze().numpy() + + # update metric for semantic, instance, panoptic + semantic_metric.update(semantic, raw_semantic_label) + + gts = instance_metric_AP.convert_gt_map(raw_semantic_label, + raw_instance_label) + # print([i[0] for i in gts]) + preds = instance_metric_AP.convert_pred_map(semantic_softmax, + panoptic) + # print([(i[0], i[1]) for i in preds ]) + ignore_mask = raw_semantic_label == eval_dataset.ignore_index + instance_metric_AP.update(preds, gts, ignore_mask=ignore_mask) + instance_metric_AP50.update(preds, gts, ignore_mask=ignore_mask) + + panoptic_metric.update(panoptic, raw_panoptic_label) + + batch_cost_averager.record( + time.time() - batch_start, num_samples=len(im)) + batch_cost = batch_cost_averager.get_average() + reader_cost = reader_cost_averager.get_average() + + if local_rank == 0: + progbar_val.update(iter + 1, [('batch_cost', batch_cost), + ('reader cost', reader_cost)]) + reader_cost_averager.reset() + batch_cost_averager.reset() + batch_start = time.time() + + semantic_results = semantic_metric.evaluate() + panoptic_results = panoptic_metric.evaluate() + instance_results = OrderedDict() + ins_ap = instance_metric_AP.evaluate() + ins_ap50 = instance_metric_AP50.evaluate() + instance_results['ins_seg'] = OrderedDict() + instance_results['ins_seg']['mAP'] = ins_ap['ins_seg']['mAP'] + instance_results['ins_seg']['AP'] = ins_ap['ins_seg']['AP'] + instance_results['ins_seg']['mAP50'] = ins_ap50['ins_seg']['mAP'] + instance_results['ins_seg']['AP50'] = ins_ap50['ins_seg']['AP'] + + if print_detail: + logger.info(panoptic_results) + print() + logger.info(semantic_results) + print() + logger.info(instance_results) + print() + + pq = panoptic_results['pan_seg']['All']['pq'] + miou = semantic_results['sem_seg']['mIoU'] + map = instance_results['ins_seg']['mAP'] + map50 = instance_results['ins_seg']['mAP50'] + logger.info( + "PQ: {:.4f}, mIoU: {:.4f}, mAP: {:.4f}, mAP50: {:.4f}".format( + pq, miou, map, map50)) + + return panoptic_results, semantic_results, instance_results diff --git a/contrib/PanopticDeepLab/datasets/__init__.py b/contrib/PanopticDeepLab/datasets/__init__.py new file mode 100644 index 0000000000..4f0f3a9500 --- /dev/null +++ b/contrib/PanopticDeepLab/datasets/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .cityscapes_panoptic import CityscapesPanoptic diff --git a/contrib/PanopticDeepLab/datasets/cityscapes_panoptic.py b/contrib/PanopticDeepLab/datasets/cityscapes_panoptic.py new file mode 100644 index 0000000000..59141367c0 --- /dev/null +++ b/contrib/PanopticDeepLab/datasets/cityscapes_panoptic.py @@ -0,0 +1,196 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import glob + +import numpy as np +import paddle +from paddleseg.cvlibs import manager +from paddleseg.transforms import Compose +import PIL.Image as Image + +from transforms import PanopticTargetGenerator, SemanticTargetGenerator, InstanceTargetGenerator, RawPanopticTargetGenerator + + +@manager.DATASETS.add_component +class CityscapesPanoptic(paddle.io.Dataset): + """ + Cityscapes dataset `https://www.cityscapes-dataset.com/`. + The folder structure is as follow: + + cityscapes/ + |--gtFine/ + | |--train/ + | | |--aachen/ + | | | |--*_color.png, *_instanceIds.png, *_labelIds.png, *_polygons.json, + | | | |--*_labelTrainIds.png + | | | |--... + | |--val/ + | |--test/ + | |--cityscapes_panoptic_train_trainId.json + | |--cityscapes_panoptic_train_trainId/ + | | |-- *_panoptic.png + | |--cityscapes_panoptic_val_trainId.json + | |--cityscapes_panoptic_val_trainId/ + | | |-- *_panoptic.png + |--leftImg8bit/ + | |--train/ + | |--val/ + | |--test/ + + Args: + transforms (list): Transforms for image. + dataset_root (str): Cityscapes dataset directory. + mode (str, optional): Which part of dataset to use. it is one of ('train', 'val'). Default: 'train'. + ignore_stuff_in_offset (bool, optional): Whether to ignore stuff region when training the offset branch. Default: False. + small_instance_area (int, optional): Instance which area less than given value is considered small. Default: 0. + small_instance_weight (int, optional): The loss weight for small instance. Default: 1. + stuff_area (int, optional): An Integer, remove stuff whose area is less tan stuff_area. Default: 2048. + """ + + def __init__(self, + transforms, + dataset_root, + mode='train', + ignore_stuff_in_offset=False, + small_instance_area=0, + small_instance_weight=1, + stuff_area=2048): + self.dataset_root = dataset_root + self.transforms = Compose(transforms) + self.file_list = list() + self.ins_list = [] + mode = mode.lower() + self.mode = mode + self.num_classes = 19 + self.ignore_index = 255 + self.thing_list = [11, 12, 13, 14, 15, 16, 17, 18] + self.label_divisor = 1000 + self.stuff_area = stuff_area + + if mode not in ['train', 'val']: + raise ValueError( + "mode should be 'train' or 'val' , but got {}.".format(mode)) + + if self.transforms is None: + raise ValueError("`transforms` is necessary, but it is None.") + + img_dir = os.path.join(self.dataset_root, 'leftImg8bit') + label_dir = os.path.join(self.dataset_root, 'gtFine') + if self.dataset_root is None or not os.path.isdir( + self.dataset_root) or not os.path.isdir( + img_dir) or not os.path.isdir(label_dir): + raise ValueError( + "The dataset is not Found or the folder structure is nonconfoumance." + ) + json_filename = os.path.join( + self.dataset_root, 'gtFine', + 'cityscapes_panoptic_{}_trainId.json'.format(mode)) + dataset = json.load(open(json_filename)) + img_files = [] + label_files = [] + for img in dataset['images']: + img_file_name = img['file_name'] + img_files.append( + os.path.join(self.dataset_root, 'leftImg8bit', mode, + img_file_name.split('_')[0], + img_file_name.replace('_gtFine', ''))) + for ann in dataset['annotations']: + ann_file_name = ann['file_name'] + label_files.append( + os.path.join(self.dataset_root, 'gtFine', + 'cityscapes_panoptic_{}_trainId'.format(mode), + ann_file_name)) + self.ins_list.append(ann['segments_info']) + + self.file_list = [[ + img_path, label_path + ] for img_path, label_path in zip(img_files, label_files)] + + self.target_transform = PanopticTargetGenerator( + self.ignore_index, + self.rgb2id, + self.thing_list, + sigma=8, + ignore_stuff_in_offset=ignore_stuff_in_offset, + small_instance_area=small_instance_area, + small_instance_weight=small_instance_weight) + + self.raw_semantic_generator = SemanticTargetGenerator( + ignore_index=self.ignore_index, rgb2id=self.rgb2id) + self.raw_instance_generator = InstanceTargetGenerator(self.rgb2id) + self.raw_panoptic_generator = RawPanopticTargetGenerator( + ignore_index=self.ignore_index, + rgb2id=self.rgb2id, + label_divisor=self.label_divisor) + + @staticmethod + def rgb2id(color): + """Converts the color to panoptic label. + Color is created by `color = [segmentId % 256, segmentId // 256, segmentId // 256 // 256]`. + + Args: + color: Ndarray or a tuple, color encoded image. + + Returns: + Panoptic label. + """ + if isinstance(color, np.ndarray) and len(color.shape) == 3: + if color.dtype == np.uint8: + color = color.astype(np.int32) + return color[:, :, + 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] + return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) + + def __getitem__(self, idx): + image_path, label_path = self.file_list[idx] + dataset_dict = {} + im, label = self.transforms(im=image_path, label=label_path) + label_dict = self.target_transform(label, self.ins_list[idx]) + for key in label_dict.keys(): + dataset_dict[key] = label_dict[key] + dataset_dict['image'] = im + if self.mode == 'val': + raw_label = np.asarray(Image.open(label_path)) + dataset_dict['raw_semantic_label'] = self.raw_semantic_generator( + raw_label, self.ins_list[idx])['semantic'] + dataset_dict['raw_instance_label'] = self.raw_instance_generator( + raw_label)['instance'] + dataset_dict['raw_panoptic_label'] = self.raw_panoptic_generator( + raw_label, self.ins_list[idx])['panoptic'] + + image = np.array(dataset_dict['image']) + semantic = np.array(dataset_dict['semantic']) + semantic_weights = np.array(dataset_dict['semantic_weights']) + center = np.array(dataset_dict['center']) + center_weights = np.array(dataset_dict['center_weights']) + offset = np.array(dataset_dict['offset']) + offset_weights = np.array(dataset_dict['offset_weights']) + foreground = np.array(dataset_dict['foreground']) + if self.mode == 'train': + return image, semantic, semantic_weights, center, center_weights, offset, offset_weights, foreground + elif self.mode == 'val': + raw_semantic_label = np.array(dataset_dict['raw_semantic_label']) + raw_instance_label = np.array(dataset_dict['raw_instance_label']) + raw_panoptic_label = np.array(dataset_dict['raw_panoptic_label']) + return image, raw_semantic_label, raw_instance_label, raw_panoptic_label + else: + raise ValueError( + '{} is not surpported, please set it one of ("train", "val")'. + format(self.mode)) + + def __len__(self): + return len(self.file_list) diff --git a/contrib/PanopticDeepLab/docs/panoptic_deeplab.jpg b/contrib/PanopticDeepLab/docs/panoptic_deeplab.jpg new file mode 100644 index 0000000000..ace44918e4 Binary files /dev/null and b/contrib/PanopticDeepLab/docs/panoptic_deeplab.jpg differ diff --git a/contrib/PanopticDeepLab/docs/visualization_instance.png b/contrib/PanopticDeepLab/docs/visualization_instance.png new file mode 100644 index 0000000000..ad9204f42a Binary files /dev/null and b/contrib/PanopticDeepLab/docs/visualization_instance.png differ diff --git a/contrib/PanopticDeepLab/docs/visualization_panoptic.png b/contrib/PanopticDeepLab/docs/visualization_panoptic.png new file mode 100644 index 0000000000..a4198e33bd Binary files /dev/null and b/contrib/PanopticDeepLab/docs/visualization_panoptic.png differ diff --git a/contrib/PanopticDeepLab/docs/visualization_semantic.png b/contrib/PanopticDeepLab/docs/visualization_semantic.png new file mode 100644 index 0000000000..d22f907a04 Binary files /dev/null and b/contrib/PanopticDeepLab/docs/visualization_semantic.png differ diff --git a/contrib/PanopticDeepLab/models/__init__.py b/contrib/PanopticDeepLab/models/__init__.py new file mode 100644 index 0000000000..44b46327e4 --- /dev/null +++ b/contrib/PanopticDeepLab/models/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .panoptic_deeplab import PanopticDeepLab diff --git a/contrib/PanopticDeepLab/models/panoptic_deeplab.py b/contrib/PanopticDeepLab/models/panoptic_deeplab.py new file mode 100644 index 0000000000..27f041b9c7 --- /dev/null +++ b/contrib/PanopticDeepLab/models/panoptic_deeplab.py @@ -0,0 +1,436 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleseg.cvlibs import manager +from paddleseg.models import layers +from paddleseg.utils import utils + +__all__ = ['PanopticDeepLab'] + + +@manager.MODELS.add_component +class PanopticDeepLab(nn.Layer): + """ + The PanopticDeeplab implementation based on PaddlePaddle. + + The original article refers to + Bowen Cheng, et, al. "Panoptic-DeepLab: A Simple, Strong, and Fast Baseline for Bottom-Up Panoptic Segmentation" + (https://arxiv.org/abs/1911.10194) + + Args: + num_classes (int): The unique number of target classes. + backbone (paddle.nn.Layer): Backbone network, currently support Resnet50_vd/Resnet101_vd/Xception65. + backbone_indices (tuple, optional): Two values in the tuple indicate the indices of output of backbone. + Default: (2, 1, 0, 3). + aspp_ratios (tuple, optional): The dilation rate using in ASSP module. + If output_stride=16, aspp_ratios should be set as (1, 6, 12, 18). + If output_stride=8, aspp_ratios is (1, 12, 24, 36). + Default: (1, 6, 12, 18). + aspp_out_channels (int, optional): The output channels of ASPP module. Default: 256. + decoder_channels (int, optional): The channels of Decoder. Default: 256. + low_level_channels_projects (list, opitonal). The channels of low level features to output. Defualt: None. + align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even, + e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False. + pretrained (str, optional): The path or url of pretrained model. Default: None. + """ + + def __init__(self, + num_classes, + backbone, + backbone_indices=(2, 1, 0, 3), + aspp_ratios=(1, 6, 12, 18), + aspp_out_channels=256, + decoder_channels=256, + low_level_channels_projects=None, + align_corners=False, + pretrained=None, + **kwargs): + super().__init__() + + self.backbone = backbone + backbone_channels = [ + backbone.feat_channels[i] for i in backbone_indices + ] + + self.head = PanopticDeepLabHead( + num_classes, backbone_indices, backbone_channels, aspp_ratios, + aspp_out_channels, decoder_channels, align_corners, + low_level_channels_projects, **kwargs) + + self.align_corners = align_corners + self.pretrained = pretrained + self.init_weight() + + def _upsample_predictions(self, pred, input_shape): + """Upsamples final prediction, with special handling to offset. + + Args: + pred (dict): stores all output of the segmentation model. + input_shape (tuple): spatial resolution of the desired shape. + + Returns: + result (OrderedDict): upsampled dictionary. + """ + # Override upsample method to correctly handle `offset` + result = OrderedDict() + for key in pred.keys(): + out = F.interpolate( + pred[key], + size=input_shape, + mode='bilinear', + align_corners=self.align_corners) + if 'offset' in key: + if input_shape[0] % 2 == 0: + scale = input_shape[0] // pred[key].shape[2] + else: + scale = (input_shape[0] - 1) // (pred[key].shape[2] - 1) + out *= scale + result[key] = out + return result + + def forward(self, x): + feat_list = self.backbone(x) + logit_dict = self.head(feat_list) + results = self._upsample_predictions(logit_dict, x.shape[-2:]) + + # return results + logit_list = [results['semantic'], results['center'], results['offset']] + return logit_list + # return [results['semantic']] + + def init_weight(self): + if self.pretrained is not None: + utils.load_entire_model(self, self.pretrained) + + +class PanopticDeepLabHead(nn.Layer): + """ + The DeepLabV3PHead implementation based on PaddlePaddle. + + Args: + num_classes (int): The unique number of target classes. + backbone_indices (tuple): Two values in the tuple indicate the indices of output of backbone. + the first index will be taken as a low-level feature in Decoder component; + the second one will be taken as input of ASPP component. + Usually backbone consists of four downsampling stage, and return an output of + each stage. If we set it as (0, 3), it means taking feature map of the first + stage in backbone as low-level feature used in Decoder, and feature map of the fourth + stage as input of ASPP. + backbone_channels (tuple): The same length with "backbone_indices". It indicates the channels of corresponding index. + aspp_ratios (tuple): The dilation rates using in ASSP module. + aspp_out_channels (int): The output channels of ASPP module. + decoder_channels (int, optional): The channels of Decoder. Default: 256. + align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even, + e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False. + low_level_channels_projects (list, opitonal). The channels of low level features to output. Defualt: None. + """ + + def __init__(self, num_classes, backbone_indices, backbone_channels, + aspp_ratios, aspp_out_channels, decoder_channels, + align_corners, low_level_channels_projects, **kwargs): + super().__init__() + self.semantic_decoder = SinglePanopticDeepLabDecoder( + backbone_indices=backbone_indices, + backbone_channels=backbone_channels, + aspp_ratios=aspp_ratios, + aspp_out_channels=aspp_out_channels, + decoder_channels=decoder_channels, + align_corners=align_corners, + low_level_channels_projects=low_level_channels_projects) + self.semantic_head = SinglePanopticDeepLabHead( + num_classes=[num_classes], + decoder_channels=decoder_channels, + head_channels=decoder_channels, + class_key=['semantic']) + self.instance_decoder = SinglePanopticDeepLabDecoder( + backbone_indices=backbone_indices, + backbone_channels=backbone_channels, + aspp_ratios=aspp_ratios, + aspp_out_channels=kwargs['instance_aspp_out_channels'], + decoder_channels=kwargs['instance_decoder_channels'], + align_corners=align_corners, + low_level_channels_projects=kwargs[ + 'instance_low_level_channels_projects']) + self.instance_head = SinglePanopticDeepLabHead( + num_classes=kwargs['instance_num_classes'], + decoder_channels=kwargs['instance_decoder_channels'], + head_channels=kwargs['instance_head_channels'], + class_key=kwargs['instance_class_key']) + + def forward(self, features): + # pred = OrdereDict() + pred = {} + + # Semantic branch + semantic = self.semantic_decoder(features) + semantic = self.semantic_head(semantic) + for key in semantic.keys(): + pred[key] = semantic[key] + + # Instance branch + instance = self.instance_decoder(features) + instance = self.instance_head(instance) + for key in instance.keys(): + pred[key] = instance[key] + + return pred + + +class SeparableConvBNReLU(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + padding='same', + **kwargs): + super().__init__() + self.depthwise_conv = layers.ConvBNReLU( + in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + padding=padding, + groups=in_channels, + **kwargs) + self.piontwise_conv = layers.ConvBNReLU( + in_channels, out_channels, kernel_size=1, groups=1, bias_attr=False) + + def forward(self, x): + x = self.depthwise_conv(x) + x = self.piontwise_conv(x) + return x + + +class ASPPModule(nn.Layer): + """ + Atrous Spatial Pyramid Pooling. + + Args: + aspp_ratios (tuple): The dilation rate using in ASSP module. + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature + is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. + use_sep_conv (bool, optional): If using separable conv in ASPP module. Default: False. + image_pooling (bool, optional): If augmented with image-level features. Default: False + drop_rate (float, optional): The drop rate. Default: 0.1. + """ + + def __init__(self, + aspp_ratios, + in_channels, + out_channels, + align_corners, + use_sep_conv=False, + image_pooling=False, + drop_rate=0.1): + super().__init__() + + self.align_corners = align_corners + self.aspp_blocks = nn.LayerList() + + for ratio in aspp_ratios: + if use_sep_conv and ratio > 1: + conv_func = SeparableConvBNReLU + else: + conv_func = layers.ConvBNReLU + + block = conv_func( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1 if ratio == 1 else 3, + dilation=ratio, + padding=0 if ratio == 1 else ratio, + bias_attr=False) + self.aspp_blocks.append(block) + + out_size = len(self.aspp_blocks) + + if image_pooling: + self.global_avg_pool = nn.Sequential( + nn.AdaptiveAvgPool2D(output_size=(1, 1)), + layers.ConvBNReLU( + in_channels, out_channels, kernel_size=1, bias_attr=False)) + out_size += 1 + self.image_pooling = image_pooling + + self.conv_bn_relu = layers.ConvBNReLU( + in_channels=out_channels * out_size, + out_channels=out_channels, + kernel_size=1, + bias_attr=False) + + self.dropout = nn.Dropout(p=drop_rate) # drop rate + + def forward(self, x): + outputs = [] + for block in self.aspp_blocks: + y = block(x) + interpolate_shape = x.shape[2:] + y = F.interpolate( + y, + interpolate_shape, + mode='bilinear', + align_corners=self.align_corners) + outputs.append(y) + + if self.image_pooling: + img_avg = self.global_avg_pool(x) + img_avg = F.interpolate( + img_avg, + interpolate_shape, + mode='bilinear', + align_corners=self.align_corners) + outputs.append(img_avg) + + x = paddle.concat(outputs, axis=1) + x = self.conv_bn_relu(x) + x = self.dropout(x) + + return x + + +class SinglePanopticDeepLabDecoder(nn.Layer): + """ + The DeepLabV3PHead implementation based on PaddlePaddle. + + Args: + backbone_indices (tuple): Two values in the tuple indicate the indices of output of backbone. + the first index will be taken as a low-level feature in Decoder component; + the second one will be taken as input of ASPP component. + Usually backbone consists of four downsampling stage, and return an output of + each stage. If we set it as (0, 3), it means taking feature map of the first + stage in backbone as low-level feature used in Decoder, and feature map of the fourth + stage as input of ASPP. + backbone_channels (tuple): The same length with "backbone_indices". It indicates the channels of corresponding index. + aspp_ratios (tuple): The dilation rates using in ASSP module. + aspp_out_channels (int): The output channels of ASPP module. + decoder_channels (int): The channels of decoder + align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature + is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. + low_level_channels_projects (list). The channels of low level features to output. + """ + + def __init__(self, backbone_indices, backbone_channels, aspp_ratios, + aspp_out_channels, decoder_channels, align_corners, + low_level_channels_projects): + super().__init__() + self.aspp = ASPPModule( + aspp_ratios, + backbone_channels[-1], + aspp_out_channels, + align_corners, + use_sep_conv=False, + image_pooling=True, + drop_rate=0.5) + self.backbone_indices = backbone_indices + self.decoder_stage = len(low_level_channels_projects) + if self.decoder_stage != len(self.backbone_indices) - 1: + raise ValueError( + "len(low_level_channels_projects) != len(backbone_indices) - 1, they are {} and {}" + .format(low_level_channels_projects, backbone_indices)) + self.align_corners = align_corners + + # Transform low-level feature + project = [] + # Fuse + fuse = [] + # Top-down direction, i.e. starting from largest stride + for i in range(self.decoder_stage): + project.append( + layers.ConvBNReLU( + backbone_channels[i], + low_level_channels_projects[i], + 1, + bias_attr=False)) + if i == 0: + fuse_in_channels = aspp_out_channels + low_level_channels_projects[ + i] + else: + fuse_in_channels = decoder_channels + low_level_channels_projects[ + i] + fuse.append( + SeparableConvBNReLU( + fuse_in_channels, + decoder_channels, + 5, + padding=2, + bias_attr=False)) + self.project = nn.LayerList(project) + self.fuse = nn.LayerList(fuse) + + def forward(self, feat_list): + x = feat_list[self.backbone_indices[-1]] + x = self.aspp(x) + + for i in range(self.decoder_stage): + l = feat_list[self.backbone_indices[i]] + l = self.project[i](l) + x = F.interpolate( + x, + size=l.shape[-2:], + mode='bilinear', + align_corners=self.align_corners) + x = paddle.concat([x, l], axis=1) + x = self.fuse[i](x) + + return x + + +class SinglePanopticDeepLabHead(nn.Layer): + """ + Decoder module of DeepLabV3P model + + Args: + num_classes (int): The number of classes. + decoder_channels (int): The channels of decoder. + head_channels (int): The channels of head. + class_key (list): The key name of output by classifier. + """ + + def __init__(self, num_classes, decoder_channels, head_channels, class_key): + super(SinglePanopticDeepLabHead, self).__init__() + self.num_head = len(num_classes) + if self.num_head != len(class_key): + raise ValueError( + "len(num_classes) != len(class_key), they are {} and {}".format( + num_classes, class_key)) + + classifier = [] + for i in range(self.num_head): + classifier.append( + nn.Sequential( + SeparableConvBNReLU( + decoder_channels, + head_channels, + 5, + padding=2, + bias_attr=False), + nn.Conv2D(head_channels, num_classes[i], 1))) + + self.classifier = nn.LayerList(classifier) + self.class_key = class_key + + def forward(self, x): + pred = OrderedDict() + # build classifier + for i, key in enumerate(self.class_key): + pred[key] = self.classifier[i](x) + + return pred diff --git a/contrib/PanopticDeepLab/predict.py b/contrib/PanopticDeepLab/predict.py new file mode 100644 index 0000000000..69b1c0b5f4 --- /dev/null +++ b/contrib/PanopticDeepLab/predict.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import paddle +from paddleseg.cvlibs import manager, Config +from paddleseg.utils import get_sys_env, logger, config_check + +from core import predict +from datasets import CityscapesPanoptic +from models import PanopticDeepLab + + +def parse_args(): + parser = argparse.ArgumentParser(description='Model prediction') + + # params of prediction + parser.add_argument( + "--config", dest="cfg", help="The config file.", default=None, type=str) + parser.add_argument( + '--model_path', + dest='model_path', + help='The path of model for prediction', + type=str, + default=None) + parser.add_argument( + '--image_path', + dest='image_path', + help= + 'The path of image, it can be a file or a directory including images', + type=str, + default=None) + parser.add_argument( + '--save_dir', + dest='save_dir', + help='The directory for saving the predicted results', + type=str, + default='./output/result') + parser.add_argument( + '--threshold', + dest='threshold', + help='Threshold applied to center heatmap score', + type=float, + default=0.1) + parser.add_argument( + '--nms_kernel', + dest='nms_kernel', + help='NMS max pooling kernel size', + type=int, + default=7) + parser.add_argument( + '--top_k', + dest='top_k', + help='Top k centers to keep', + type=int, + default=200) + + return parser.parse_args() + + +def get_image_list(image_path): + """Get image list""" + valid_suffix = [ + '.JPEG', '.jpeg', '.JPG', '.jpg', '.BMP', '.bmp', '.PNG', '.png' + ] + image_list = [] + image_dir = None + if os.path.isfile(image_path): + if os.path.splitext(image_path)[-1] in valid_suffix: + image_list.append(image_path) + elif os.path.isdir(image_path): + image_dir = image_path + for root, dirs, files in os.walk(image_path): + for f in files: + if '.ipynb_checkpoints' in root: + continue + if os.path.splitext(f)[-1] in valid_suffix: + image_list.append(os.path.join(root, f)) + else: + raise FileNotFoundError( + '`--image_path` is not found. it should be an image file or a directory including images' + ) + + if len(image_list) == 0: + raise RuntimeError('There are not image file in `--image_path`') + + return image_list, image_dir + + +def main(args): + env_info = get_sys_env() + place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[ + 'GPUs used'] else 'cpu' + + paddle.set_device(place) + if not args.cfg: + raise RuntimeError('No configuration file specified.') + + cfg = Config(args.cfg) + val_dataset = cfg.val_dataset + if not val_dataset: + raise RuntimeError( + 'The verification dataset is not specified in the configuration file.' + ) + + msg = '\n---------------Config Information---------------\n' + msg += str(cfg) + msg += '------------------------------------------------' + logger.info(msg) + + model = cfg.model + transforms = val_dataset.transforms + image_list, image_dir = get_image_list(args.image_path) + logger.info('Number of predict images = {}'.format(len(image_list))) + + config_check(cfg, val_dataset=val_dataset) + + predict( + model, + model_path=args.model_path, + transforms=transforms, + thing_list=val_dataset.thing_list, + label_divisor=val_dataset.label_divisor, + stuff_area=val_dataset.stuff_area, + ignore_index=val_dataset.ignore_index, + image_list=image_list, + image_dir=image_dir, + save_dir=args.save_dir, + threshold=args.threshold, + nms_kernel=args.nms_kernel, + top_k=args.top_k) + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/contrib/PanopticDeepLab/train.py b/contrib/PanopticDeepLab/train.py new file mode 100644 index 0000000000..7adf32edde --- /dev/null +++ b/contrib/PanopticDeepLab/train.py @@ -0,0 +1,178 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import paddle +from paddleseg.cvlibs import manager, Config +from paddleseg.utils import get_sys_env, logger, config_check + +from core import train +from datasets import CityscapesPanoptic +from models import PanopticDeepLab + + +def parse_args(): + parser = argparse.ArgumentParser(description='Model training') + # params of training + parser.add_argument( + "--config", dest="cfg", help="The config file.", default=None, type=str) + parser.add_argument( + '--iters', + dest='iters', + help='iters for training', + type=int, + default=None) + parser.add_argument( + '--batch_size', + dest='batch_size', + help='Mini batch size of one gpu or cpu', + type=int, + default=None) + parser.add_argument( + '--learning_rate', + dest='learning_rate', + help='Learning rate', + type=float, + default=None) + parser.add_argument( + '--save_interval', + dest='save_interval', + help='How many iters to save a model snapshot once during training.', + type=int, + default=1000) + parser.add_argument( + '--resume_model', + dest='resume_model', + help='The path of resume model', + type=str, + default=None) + parser.add_argument( + '--save_dir', + dest='save_dir', + help='The directory for saving the model snapshot', + type=str, + default='./output') + parser.add_argument( + '--keep_checkpoint_max', + dest='keep_checkpoint_max', + help='Maximum number of checkpoints to save', + type=int, + default=5) + parser.add_argument( + '--num_workers', + dest='num_workers', + help='Num workers for data loader', + type=int, + default=0) + parser.add_argument( + '--do_eval', + dest='do_eval', + help='Eval while training', + action='store_true') + parser.add_argument( + '--log_iters', + dest='log_iters', + help='Display logging information at every log_iters', + default=10, + type=int) + parser.add_argument( + '--use_vdl', + dest='use_vdl', + help='Whether to record the data to VisualDL during training', + action='store_true') + parser.add_argument( + '--threshold', + dest='threshold', + help='Threshold applied to center heatmap score', + type=float, + default=0.1) + parser.add_argument( + '--nms_kernel', + dest='nms_kernel', + help='NMS max pooling kernel size', + type=int, + default=7) + parser.add_argument( + '--top_k', + dest='top_k', + help='Top k centers to keep', + type=int, + default=200) + + return parser.parse_args() + + +def main(args): + env_info = get_sys_env() + info = ['{}: {}'.format(k, v) for k, v in env_info.items()] + info = '\n'.join(['', format('Environment Information', '-^48s')] + info + + ['-' * 48]) + logger.info(info) + + place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[ + 'GPUs used'] else 'cpu' + + paddle.set_device(place) + if not args.cfg: + raise RuntimeError('No configuration file specified.') + + cfg = Config( + args.cfg, + learning_rate=args.learning_rate, + iters=args.iters, + batch_size=args.batch_size) + + train_dataset = cfg.train_dataset + if train_dataset is None: + raise RuntimeError( + 'The training dataset is not specified in the configuration file.') + elif len(train_dataset) == 0: + raise ValueError( + 'The length of train_dataset is 0. Please check if your dataset is valid' + ) + val_dataset = cfg.val_dataset if args.do_eval else None + losses = cfg.loss + + msg = '\n---------------Config Information---------------\n' + msg += str(cfg) + msg += '------------------------------------------------' + logger.info(msg) + + config_check(cfg, train_dataset=train_dataset, val_dataset=val_dataset) + + train( + cfg.model, + train_dataset, + val_dataset=val_dataset, + optimizer=cfg.optimizer, + save_dir=args.save_dir, + iters=cfg.iters, + batch_size=cfg.batch_size, + resume_model=args.resume_model, + save_interval=args.save_interval, + log_iters=args.log_iters, + num_workers=args.num_workers, + use_vdl=args.use_vdl, + losses=losses, + keep_checkpoint_max=args.keep_checkpoint_max, + threshold=args.threshold, + nms_kernel=args.nms_kernel, + top_k=args.top_k, + ) + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/contrib/PanopticDeepLab/transforms/__init__.py b/contrib/PanopticDeepLab/transforms/__init__.py new file mode 100644 index 0000000000..67b27709ba --- /dev/null +++ b/contrib/PanopticDeepLab/transforms/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .target_transforms import PanopticTargetGenerator, SemanticTargetGenerator, InstanceTargetGenerator, RawPanopticTargetGenerator diff --git a/contrib/PanopticDeepLab/transforms/target_transforms.py b/contrib/PanopticDeepLab/transforms/target_transforms.py new file mode 100644 index 0000000000..8479093762 --- /dev/null +++ b/contrib/PanopticDeepLab/transforms/target_transforms.py @@ -0,0 +1,307 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +class PanopticTargetGenerator(object): + """ + Generates panoptic training target for Panoptic-DeepLab. + Annotation is assumed to have Cityscapes format. + + Args: + ignore_index (int): The ignore label for semantic segmentation. + rgb2id (Function): Function, panoptic label is encoded in a colored image, this function convert color to the + corresponding panoptic label. + thing_list (list): A list of thing classes + sigma (int, optional): The sigma for Gaussian kernel. Default: 8. + ignore_stuff_in_offset (bool, optional): Whether to ignore stuff region when training the offset branch. Default: False. + small_instance_area (int, optional): Indicates largest area for small instances. Default: 0. + small_instance_weight (int, optional): Indicates semantic loss weights for small instances. Default: 1. + ignore_crowd_in_semantic (bool, optional): Whether to ignore crowd region in semantic segmentation branch, + crowd region is ignored in the original TensorFlow implementation. Default: False. + """ + + def __init__(self, + ignore_index, + rgb2id, + thing_list, + sigma=8, + ignore_stuff_in_offset=False, + small_instance_area=0, + small_instance_weight=1, + ignore_crowd_in_semantic=False): + self.ignore_index = ignore_index + self.rgb2id = rgb2id + self.thing_list = thing_list + self.ignore_stuff_in_offset = ignore_stuff_in_offset + self.small_instance_area = small_instance_area + self.small_instance_weight = small_instance_weight + self.ignore_crowd_in_semantic = ignore_crowd_in_semantic + + self.sigma = sigma + size = 6 * sigma + 3 + x = np.arange(0, size, 1, float) + y = x[:, np.newaxis] + x0, y0 = 3 * sigma + 1, 3 * sigma + 1 + self.g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2)) + + def __call__(self, panoptic, segments): + """Generates the training target. + reference: https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/createPanopticImgs.py + reference: https://github.com/facebookresearch/detectron2/blob/master/datasets/prepare_panoptic_fpn.py#L18 + + Args: + panoptic (np.ndarray): Colored image encoding panoptic label. + segments (list): A list of dictionary containing information of every segment, it has fields: + - id: panoptic id, after decoding `panoptic`. + - category_id: semantic class id. + - area: segment area. + - bbox: segment bounding box. + - iscrowd: crowd region. + + Returns: + A dictionary with fields: + - semantic: Tensor, semantic label, shape=(H, W). + - foreground: Tensor, foreground mask label, shape=(H, W). + - center: Tensor, center heatmap, shape=(1, H, W). + - center_points: List, center coordinates, with tuple (y-coord, x-coord). + - offset: Tensor, offset, shape=(2, H, W), first dim is (offset_y, offset_x). + - semantic_weights: Tensor, loss weight for semantic prediction, shape=(H, W). + - center_weights: Tensor, ignore region of center prediction, shape=(H, W), used as weights for center + regression 0 is ignore, 1 is has instance. Multiply this mask to loss. + - offset_weights: Tensor, ignore region of offset prediction, shape=(H, W), used as weights for offset + regression 0 is ignore, 1 is has instance. Multiply this mask to loss. + """ + panoptic = self.rgb2id(panoptic) + height, width = panoptic.shape[0], panoptic.shape[1] + semantic = np.zeros_like(panoptic, dtype=np.uint8) + self.ignore_index + foreground = np.zeros_like(panoptic, dtype=np.uint8) + center = np.zeros((1, height, width), dtype=np.float32) + center_pts = [] + offset = np.zeros((2, height, width), dtype=np.float32) + y_coord = np.ones_like(panoptic, dtype=np.float32) + x_coord = np.ones_like(panoptic, dtype=np.float32) + y_coord = np.cumsum(y_coord, axis=0) - 1 + x_coord = np.cumsum(x_coord, axis=1) - 1 + # Generate pixel-wise loss weights + semantic_weights = np.ones_like(panoptic, dtype=np.uint8) + # 0: ignore, 1: has instance + # three conditions for a region to be ignored for instance branches: + # (1) It is labeled as `ignore_index` + # (2) It is crowd region (iscrowd=1) + # (3) (Optional) It is stuff region (for offset branch) + center_weights = np.zeros_like(panoptic, dtype=np.uint8) + offset_weights = np.zeros_like(panoptic, dtype=np.uint8) + for seg in segments: + cat_id = seg["category_id"] + if self.ignore_crowd_in_semantic: + if not seg['iscrowd']: + semantic[panoptic == seg["id"]] = cat_id + else: + semantic[panoptic == seg["id"]] = cat_id + if cat_id in self.thing_list: + foreground[panoptic == seg["id"]] = 1 + if not seg['iscrowd']: + # Ignored regions are not in `segments`. + # Handle crowd region. + center_weights[panoptic == seg["id"]] = 1 + if self.ignore_stuff_in_offset: + # Handle stuff region. + if cat_id in self.thing_list: + offset_weights[panoptic == seg["id"]] = 1 + else: + offset_weights[panoptic == seg["id"]] = 1 + if cat_id in self.thing_list: + # find instance center + mask_index = np.where(panoptic == seg["id"]) + if len(mask_index[0]) == 0: + # the instance is completely cropped + continue + + # Find instance area + ins_area = len(mask_index[0]) + if ins_area < self.small_instance_area: + semantic_weights[panoptic == + seg["id"]] = self.small_instance_weight + + center_y, center_x = np.mean(mask_index[0]), np.mean( + mask_index[1]) + center_pts.append([center_y, center_x]) + + # generate center heatmap + y, x = int(center_y), int(center_x) + # outside image boundary + if x < 0 or y < 0 or \ + x >= width or y >= height: + continue + sigma = self.sigma + # upper left + ul = int(np.round(x - 3 * sigma - 1)), int( + np.round(y - 3 * sigma - 1)) + # bottom right + br = int(np.round(x + 3 * sigma + 2)), int( + np.round(y + 3 * sigma + 2)) + + c, d = max(0, -ul[0]), min(br[0], width) - ul[0] + a, b = max(0, -ul[1]), min(br[1], height) - ul[1] + + cc, dd = max(0, ul[0]), min(br[0], width) + aa, bb = max(0, ul[1]), min(br[1], height) + center[0, aa:bb, cc:dd] = np.maximum(center[0, aa:bb, cc:dd], + self.g[a:b, c:d]) + + # generate offset (2, h, w) -> (y-dir, x-dir) + offset_y_index = (np.zeros_like(mask_index[0]), mask_index[0], + mask_index[1]) + offset_x_index = (np.ones_like(mask_index[0]), mask_index[0], + mask_index[1]) + offset[offset_y_index] = center_y - y_coord[mask_index] + offset[offset_x_index] = center_x - x_coord[mask_index] + + return dict( + semantic=semantic.astype('long'), + foreground=foreground.astype('long'), + center=center.astype(np.float32), + center_points=center_pts, + offset=offset.astype(np.float32), + semantic_weights=semantic_weights.astype(np.float32), + center_weights=center_weights.astype(np.float32), + offset_weights=offset_weights.astype(np.float32)) + + +class SemanticTargetGenerator(object): + """ + Generates semantic training target only for Panoptic-DeepLab (no instance). + Annotation is assumed to have Cityscapes format. + + Args: + ignore_index (int): The ignore label for semantic segmentation. + rgb2id (function): Function, panoptic label is encoded in a colored image, this function convert color to the + corresponding panoptic label. + """ + + def __init__(self, ignore_index, rgb2id): + self.ignore_index = ignore_index + self.rgb2id = rgb2id + + def __call__(self, panoptic, segments): + """Generates the training target. + reference: https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/createPanopticImgs.py + reference: https://github.com/facebookresearch/detectron2/blob/master/datasets/prepare_panoptic_fpn.py#L18 + + Args: + panoptic (np.ndarray): Colored image encoding panoptic label. + segments (list): A list of dictionary containing information of every segment, it has fields: + - id: panoptic id, after decoding `panoptic`. + - category_id: semantic class id. + - area: segment area. + - bbox: segment bounding box. + - iscrowd: crowd region. + + Returns: + A dictionary with fields: + - semantic: Tensor, semantic label, shape=(H, W). + """ + panoptic = self.rgb2id(panoptic) + semantic = np.zeros_like(panoptic, dtype=np.uint8) + self.ignore_index + for seg in segments: + cat_id = seg["category_id"] + semantic[panoptic == seg["id"]] = cat_id + + return dict(semantic=semantic.astype('long')) + + +class InstanceTargetGenerator(object): + """ + Generates instance target only for Panoptic-DeepLab. + Annotation is assumed to have Cityscapes format. + + Args: + rgb2id (function): Function, panoptic label is encoded in a colored image, this function convert color to the + corresponding panoptic label. + """ + + def __init__(self, rgb2id): + self.rgb2id = rgb2id + + def __call__(self, panoptic): + """Generates the instance target. + + Args: + panoptic (np.ndarray): Colored image encoding panoptic label. + + Returns: + A dictionary with fields: + - instance: Tensor, shape=(H, W). 0 is background. 1, 2, 3 ... is instance, so it is class agnostic. + """ + panoptic = self.rgb2id(panoptic) + instance = np.zeros_like(panoptic, dtype=np.int64) + ids = np.unique(panoptic) + ins_id = 1 + for i, id in enumerate(ids): + if id > 1000: + instance[panoptic == id] = ins_id + ins_id += 1 + + return dict(instance=instance) + + +class RawPanopticTargetGenerator(object): + """ + Generator the panoptc ground truth for evaluation, where values are 0,1,2,3,... + 11000, 11001, ..., 18000, 18001, ignore_index(general 255). + + Args: + ignore_index (int): The ignore label for semantic segmentation. + rgb2id (function): Function, panoptic label is encoded in a colored image, this function convert color to the + corresponding panoptic label. + label_divisor(int, optional): An Integer, used to convert panoptic id = semantic id * label_divisor + instance_id. Default: 1000. + """ + + def __init__(self, ignore_index, rgb2id, label_divisor=1000): + self.ingore_index = ignore_index + self.rgb2id = rgb2id + self.label_divisor = label_divisor + + def __call__(self, panoptic, segments): + """ + Generates the raw panoptic target + + Args: + panoptic (numpy.array): colored image encoding panoptic label. + segments (list): A list of dictionary containing information of every segment, it has fields: + - id: panoptic id, after decoding `panoptic`. + - category_id: semantic class id. + - area: segment area. + - bbox: segment bounding box. + - iscrowd: crowd region. + + Returns: + A dictionary with fields: + - panoptic: Tensor, panoptic label, shape=(H, W). + """ + panoptic = self.rgb2id(panoptic) + raw_panoptic = np.zeros_like(panoptic) + self.ingore_index + for seg in segments: + cat_id = seg['category_id'] + # if seg['iscrowd'] == 1: + # continue + if seg['id'] < 1000: + raw_panoptic[panoptic == seg['id']] = cat_id + else: + ins_id = seg['id'] % self.label_divisor + raw_panoptic[panoptic == + seg['id']] = cat_id * self.label_divisor + ins_id + return dict(panoptic=raw_panoptic.astype('long')) diff --git a/contrib/PanopticDeepLab/utils/__init__.py b/contrib/PanopticDeepLab/utils/__init__.py new file mode 100644 index 0000000000..894d8a7adf --- /dev/null +++ b/contrib/PanopticDeepLab/utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .visualize import visualize_semantic, visualize_instance, visualize_panoptic, cityscape_colormap diff --git a/contrib/PanopticDeepLab/utils/evaluation/__init__.py b/contrib/PanopticDeepLab/utils/evaluation/__init__.py new file mode 100644 index 0000000000..8cd9f71a3b --- /dev/null +++ b/contrib/PanopticDeepLab/utils/evaluation/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .semantic import SemanticEvaluator +from .instance import InstanceEvaluator +from .panoptic import PanopticEvaluator diff --git a/contrib/PanopticDeepLab/utils/evaluation/instance.py b/contrib/PanopticDeepLab/utils/evaluation/instance.py new file mode 100644 index 0000000000..97e27d6018 --- /dev/null +++ b/contrib/PanopticDeepLab/utils/evaluation/instance.py @@ -0,0 +1,353 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict, OrderedDict + +import numpy as np + + +class InstanceEvaluator(object): + """ + Refer to 'https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py' + Calculate the matching results of each image, each class, each IoU, and then get the final + matching results of each class and each IoU of dataset. Base on the matching results, the AP + and mAP can be calculated. + we need two vectors for each class and for each overlap + The first vector (y_true) is binary and is 1, where the ground truth says true, + and is 0 otherwise. + The second vector (y_score) is float [0...1] and represents the confidence of + the prediction. + We represent the following cases as: + | y_true | y_score + gt instance with matched prediction | 1 | confidence + gt instance w/o matched prediction | 1 | 0.0 + false positive prediction | 0 | confidence + The current implementation makes only sense for an overlap threshold >= 0.5, + since only then, a single prediction can either be ignored or matched, but + never both. Further, it can never match to two gt instances. + For matching, we vary the overlap and do the following steps: + 1.) remove all predictions that satisfy the overlap criterion with an ignore region (either void or *group) + 2.) remove matches that do not satisfy the overlap + 3.) mark non-matched predictions as false positive + In the processing, 0 represent the first class of 'thing'. So the label will less 1 than the dataset. + + Args: + num_classes (int): The unique number of target classes. Exclude background class, labeled 0 usually. + overlaps (float|list, optional): The threshold of IoU. Default: 0.5. + thing_list (list|None, optional): Thing class, only calculate AP for the thing class. Default: None. + """ + + def __init__(self, num_classes, overlaps=0.5, thing_list=None): + super().__init__() + self.num_classes = num_classes + if isinstance(overlaps, float): + overlaps = [overlaps] + self.overlaps = overlaps + self.y_true = [[np.empty(0) for _i in range(len(overlaps))] + for _j in range(num_classes)] + self.y_score = [[np.empty(0) for _i in range(len(overlaps))] + for _j in range(num_classes)] + self.hard_fns = [[0] * len(overlaps) for _ in range(num_classes)] + + if thing_list is None: + self.thing_list = list(range(num_classes)) + else: + self.thing_list = thing_list + + def update(self, preds, gts, ignore_mask=None): + """ + compute y_true and y_score in this image. + preds (list): tuple list [(label, confidence, mask), ...]. + gts (list): tuple list [(label, mask), ...]. + ignore_mask (np.ndarray): Mask to ignore. + """ + + pred_instances, gt_instances = self.get_instances( + preds, gts, ignore_mask=ignore_mask) + + for i in range(self.num_classes): + if i not in self.thing_list: + continue + for oi, oth in enumerate(self.overlaps): + cur_true = np.ones((len(gt_instances[i]))) + cur_score = np.ones(len(gt_instances[i])) * (-float("inf")) + cur_match = np.zeros(len(gt_instances[i]), dtype=np.bool) + for gti, gt_instance in enumerate(gt_instances[i]): + found_match = False + for pred_instance in gt_instance['matched_pred']: + overlap = float(pred_instance['intersection']) / ( + gt_instance['pixel_count'] + + pred_instance['pixel_count'] - + pred_instance['intersection']) + if overlap > oth: + confidence = pred_instance['confidence'] + + # if we already has a prediction for this groundtruth + # the prediction with the lower score is automatically a false positive + if cur_match[gti]: + max_score = max(cur_score[gti], confidence) + min_score = min(cur_score[gti], confidence) + cur_score = max_score + # append false positive + cur_true = np.append(cur_true, 0) + cur_score = np.append(cur_score, min_score) + cur_match = np.append(cur_match, True) + # otherwise set score + else: + found_match = True + cur_match[gti] = True + cur_score[gti] = confidence + + if not found_match: + self.hard_fns[i][oi] += 1 + # remove not-matched ground truth instances + cur_true = cur_true[cur_match == True] + cur_score = cur_score[cur_match == True] + + # collect not-matched predictions as false positive + for pred_instance in pred_instances[i]: + found_gt = False + for gt_instance in pred_instance['matched_gt']: + overlap = float(gt_instance['intersection']) / ( + gt_instance['pixel_count'] + + pred_instance['pixel_count'] - + gt_instance['intersection']) + if overlap > oth: + found_gt = True + break + if not found_gt: + proportion_ignore = 0 + if ignore_mask is not None: + nb_ignore_pixels = pred_instance[ + 'void_intersection'] + proportion_ignore = float( + nb_ignore_pixels) / pred_instance['pixel_count'] + if proportion_ignore <= oth: + cur_true = np.append(cur_true, 0) + cur_score = np.append(cur_score, + pred_instance['confidence']) + self.y_true[i][oi] = np.append(self.y_true[i][oi], cur_true) + self.y_score[i][oi] = np.append(self.y_score[i][oi], cur_score) + + def evaluate(self): + ap = self.cal_ap() + map = self.cal_map() + + res = {} + res["AP"] = [{i: ap[i] * 100} for i in self.thing_list] + res["mAP"] = 100 * map + + results = OrderedDict({"ins_seg": res}) + return results + + def cal_ap(self): + """ + calculate ap for every classes + """ + self.ap = [0] * self.num_classes + self.ap_overlap = [[0] * len(self.overlaps) + for _ in range(self.num_classes)] + for i in range(self.num_classes): + if i not in self.thing_list: + continue + for j in range(len(self.overlaps)): + y_true = self.y_true[i][j] + y_score = self.y_score[i][j] + if len(y_true) == 0: + self.ap_overlap[i][j] = 0 + continue + score_argsort = np.argsort(y_score) + y_score_sorted = y_score[score_argsort] + y_true_sorted = y_true[score_argsort] + y_true_sorted_cumsum = np.cumsum(y_true_sorted) + + # unique thresholds + thresholds, unique_indices = np.unique( + y_score_sorted, return_index=True) + + # since we need to add an artificial point to the precision-recall curve + # increase its length by 1 + nb_pr = len(unique_indices) + 1 + + # calculate precision and recall + nb_examples = len(y_score_sorted) + nb_true_exampels = y_true_sorted_cumsum[-1] + precision = np.zeros(nb_pr) + recall = np.zeros(nb_pr) + + # deal with the first point + # only thing we need to do, is to append a zero to the cumsum at the end. + # an index of -1 uses that zero then + y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0) + + # deal with remaining + for idx_res, idx_scores in enumerate(unique_indices): + cumsum = y_true_sorted_cumsum[idx_scores - 1] + tp = nb_true_exampels - cumsum + fp = nb_examples - idx_scores - tp + fn = cumsum + self.hard_fns[i][j] + p = float(tp) / (tp + fp) + r = float(tp) / (tp + fn) + precision[idx_res] = p + recall[idx_res] = r + + # add first point in curve + precision[-1] = 1. + # In some calculation,make precision the max after this point in curve. + #precision = [np.max(precision[:i+1]) for i in range(len(precision))] + recall[-1] = 0. + + # compute average of precision-recall curve + # integration is performed via zero order, or equivalently step-wise integration + # first compute the widths of each step: + # use a convolution with appropriate kernel, manually deal with the boundaries first + recall_for_conv = np.copy(recall) + recall_for_conv = np.append(recall_for_conv[0], recall_for_conv) + recall_for_conv = np.append(recall_for_conv, 0.) + + step_widths = np.convolve(recall_for_conv, [-0.5, 0, 0.5], + 'valid') + + # integrate is now simply a dot product + ap_current = np.dot(precision, step_widths) + self.ap_overlap[i][j] = ap_current + + ap = [np.average(i) for i in self.ap_overlap] + self.ap = ap + + return ap + + def cal_map(self): + """ + calculate map for all classes + """ + self.cal_ap() + valid_ap = [self.ap[i] for i in self.thing_list] + map = np.mean(valid_ap) + self.map = map + + return map + + def get_instances(self, preds, gts, ignore_mask=None): + """ + In this method, we create two dicts of list + - pred_instances: contains all predictions and their associated gt + - gtInstances: contains all gt instances and their associated predictions + + Args: + preds (list): Prediction of image. + gts (list): Ground truth of image. + ignore_mask (np.ndarray, optional): Ignore mask. Default: None. + + Return: + dict: pred_instances, the type is dict(list(dict))), e.g. {0: [{'pred_id':0, 'label':0', + 'pixel_count':100, 'confidence': 0.9, 'void_intersection': 0, + 'matched_gt': [gt_instance0, gt_instance1, ...]}, ], 1: } + dict: gt_instances, the type is dict(list(dict))), e.g. {0: [{'inst_id':0, 'label':0', + 'pixel_count':100, 'mask': np.ndarray, 'matched_pred': [pred_instance0, pred_instance1, ...]}, ], 1: } + """ + + pred_instances = defaultdict(list) + gt_instances = defaultdict(list) + + gt_inst_count = 0 + for gt in gts: + label, mask = gt + gt_instance = defaultdict(list) + gt_instance['inst_id'] = gt_inst_count + gt_instance['label'] = label + gt_instance['pixel_count'] = np.count_nonzero(mask) + gt_instance['mask'] = mask + gt_instances[label].append(gt_instance) + gt_inst_count += 1 + + pred_inst_count = 0 + for pred in preds: + label, conf, mask = pred + pred_instance = defaultdict(list) + pred_instance['label'] = label + pred_instance['pred_id'] = pred_inst_count + pred_instance['pixel_count'] = np.count_nonzero(mask) + pred_instance['confidence'] = conf + if ignore_mask is not None: + pred_instance['void_intersection'] = np.count_nonzero( + np.logical_and(mask, ignore_mask)) + + # Loop through all ground truth instances with matching label + matched_gt = [] + for gt_num, gt_instance in enumerate(gt_instances[label]): + # print(gt_instances) + intersection = np.count_nonzero( + np.logical_and(mask, gt_instances[label][gt_num]['mask'])) + if intersection > 0: + gt_copy = gt_instance.copy() + pred_copy = pred_instance.copy() + + gt_copy['intersection'] = intersection + pred_copy['intersection'] = intersection + + matched_gt.append(gt_copy) + gt_instances[label][gt_num]['matched_pred'].append( + pred_copy) + + pred_instance['matched_gt'] = matched_gt + pred_inst_count += 1 + pred_instances[label].append(pred_instance) + + return pred_instances, gt_instances + + @staticmethod + def convert_gt_map(seg_map, ins_map): + """ + Convet the ground truth with format (h*w) to the format that satisfies the AP calculation. + + Args: + seg_map (np.ndarray): the sementic segmentation map with shape H * W. Value is 0, 1, 2, ... + ins_map (np.ndarray): the instance segmentation map with shape H * W. Value is 0, 1, 2, ... + + Returns: + list: tuple list like: [(label, mask), ...] + """ + gts = [] + instance_cnt = np.unique(ins_map) + for i in instance_cnt: + if i == 0: + continue + mask = ins_map == i + label = seg_map[mask][0] + gts.append((label, mask.astype('int32'))) + return gts + + @staticmethod + def convert_pred_map(seg_pred, pan_pred): + """ + Convet the predictions with format (h*w) to the format that satisfies the AP calculation. + + Args: + seg_pred (np.ndarray): the sementic segmentation map with shape C * H * W. Value is probability. + pan_pred (np.ndarray): panoptic predictions, void_label, stuff_id * label_divisor, thing_id * label_divisor + ins_id , ins_id >= 1. + + Returns: + list: tuple list like: [(label, score, mask), ...] + """ + preds = [] + instance_cnt = np.unique(pan_pred) + for i in instance_cnt: + if (i < 1000) or (i % 1000 == 0): + continue + mask = pan_pred == i + label = i // 1000 + score = np.mean(seg_pred[label][mask]) + preds.append((label, score, mask.astype('int32'))) + return preds diff --git a/contrib/PanopticDeepLab/utils/evaluation/panoptic.py b/contrib/PanopticDeepLab/utils/evaluation/panoptic.py new file mode 100644 index 0000000000..01fd6f75b9 --- /dev/null +++ b/contrib/PanopticDeepLab/utils/evaluation/panoptic.py @@ -0,0 +1,210 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ------------------------------------------------------------------------------ +# Reference: https://github.com/mcordts/cityscapesScripts/blob/aeb7b82531f86185ce287705be28f452ba3ddbb8/cityscapesscripts/evaluation/evalPanopticSemanticLabeling.py +# Modified by Guowei Chen +# ------------------------------------------------------------------------------ + +from collections import defaultdict, OrderedDict + +import numpy as np + +OFFSET = 256 * 256 * 256 + + +class PQStatCat(): + def __init__(self): + self.iou = 0.0 + self.tp = 0 + self.fp = 0 + self.fn = 0 + + def __iadd__(self, pd_stat_cat): + self.iou += pd_stat_cat.iou + self.tp += pd_stat_cat.tp + self.fp += pd_stat_cat.fp + self.fn += pd_stat_cat.fn + return self + + def __repr__(self): + s = 'iou: ' + str(self.iou) + ' tp: ' + str(self.tp) + ' fp: ' + str( + self.fp) + ' fn: ' + str(self.fn) + return s + + +class PQStat(): + def __init__(self, num_classes): + self.pq_per_cat = defaultdict(PQStatCat) + self.num_classes = num_classes + + def __getitem__(self, i): + return self.pq_per_cat[i] + + def __iadd__(self, pd_stat): + for label, pq_stat_cat in pd_stat.pq_per_cat.items(): + self.pd_per_cat[label] += pq_stat_cat + return self + + def pq_average(self, isthing=None, thing_list=None): + """ + Calculate the average pq for all and every class. + + Args: + num_classes (int): number of classes. + isthing (bool|None, optional): calculate average pq for thing class if isthing is True, + for stuff class if isthing is False and for all if isthing is None. Default: None. Default: None. + thing_list (list|None, optional): A list of thing class. It should be provided when isthing is equal to True or False. Default: None. + """ + pq, sq, rq, n = 0, 0, 0, 0 + per_class_results = {} + for label in range(self.num_classes): + if isthing is not None: + if isthing: + if label not in thing_list: + continue + else: + if label in thing_list: + continue + iou = self.pq_per_cat[label].iou + tp = self.pq_per_cat[label].tp + fp = self.pq_per_cat[label].fp + fn = self.pq_per_cat[label].fn + if tp + fp + fn == 0: + per_class_results[label] = {'pq': 0.0, 'sq': 0.0, 'rq': 0.0} + continue + n += 1 + pq_class = iou / (tp + 0.5 * fp + 0.5 * fn) + sq_class = iou / tp if tp != 0 else 0 + rq_class = tp / (tp + 0.5 * fp + 0.5 * fn) + + per_class_results[label] = { + 'pq': pq_class, + 'sq': sq_class, + 'rq': rq_class + } + pq += pq_class + sq += sq_class + rq += rq_class + + return { + 'pq': pq / n, + 'sq': sq / n, + 'rq': rq / n, + 'n': n + }, per_class_results + + +class PanopticEvaluator: + """ + Evaluate semantic segmentation + """ + + def __init__(self, + num_classes, + thing_list, + ignore_index=255, + label_divisor=1000): + self.pq_stat = PQStat(num_classes) + self.num_classes = num_classes + self.thing_list = thing_list + self.ignore_index = ignore_index + self.label_divisor = label_divisor + + def update(self, pred, gt): + # get the labels and counts for the pred and gt. + gt_labels, gt_labels_counts = np.unique(gt, return_counts=True) + pred_labels, pred_labels_counts = np.unique(pred, return_counts=True) + gt_segms = defaultdict(dict) + pred_segms = defaultdict(dict) + for label, label_count in zip(gt_labels, gt_labels_counts): + category_id = label // self.label_divisor if label > self.label_divisor else label + gt_segms[label]['area'] = label_count + gt_segms[label]['category_id'] = category_id + gt_segms[label]['iscrowd'] = 1 if label in self.thing_list else 0 + for label, label_count in zip(pred_labels, pred_labels_counts): + category_id = label // self.label_divisor if label > self.label_divisor else label + pred_segms[label]['area'] = label_count + pred_segms[label]['category_id'] = category_id + + # confusion matrix calculation + pan_gt_pred = gt.astype(np.uint64) * OFFSET + pred.astype(np.uint64) + gt_pred_map = {} + labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True) + for label, intersection in zip(labels, labels_cnt): + gt_id = label // OFFSET + pred_id = label % OFFSET + gt_pred_map[(gt_id, pred_id)] = intersection + + # count all matched pairs + gt_matched = set() + pred_matched = set() + for label_tuple, intersection in gt_pred_map.items(): + gt_label, pred_label = label_tuple + if gt_label == self.ignore_index or pred_label == self.ignore_index: + continue + if gt_segms[gt_label]['iscrowd'] == 1: + continue + if gt_segms[gt_label]['category_id'] != pred_segms[pred_label][ + 'category_id']: + continue + union = pred_segms[pred_label]['area'] + gt_segms[gt_label][ + 'area'] - intersection - gt_pred_map.get( + (self.ignore_index, pred_label), 0) + iou = intersection / union + if iou > 0.5: + self.pq_stat[gt_segms[gt_label]['category_id']].tp += 1 + self.pq_stat[gt_segms[gt_label]['category_id']].iou += iou + gt_matched.add(gt_label) + pred_matched.add(pred_label) + + # count false negtive + crowd_labels_dict = {} + for gt_label, gt_info in gt_segms.items(): + if gt_label in gt_matched: + continue + if gt_label == self.ignore_index: + continue + # ignore crowd + if gt_info['iscrowd'] == 1: + crowd_labels_dict[gt_info['category_id']] = gt_label + continue + self.pq_stat[gt_info['category_id']].fn += 1 + + # count false positive + for pred_label, pred_info in pred_segms.items(): + if pred_label in pred_matched: + continue + if pred_label == self.ignore_index: + continue + # intersection of the segment with self.ignore_index + intersection = gt_pred_map.get((self.ignore_index, pred_label), 0) + if pred_info['category_id'] in crowd_labels_dict: + intersection += gt_pred_map.get( + (crowd_labels_dict[pred_info['category_id']], pred_label), + 0) + # predicted segment is ignored if more than half of the segment correspond to self.ignore_index regions + if intersection / pred_info['area'] > 0.5: + continue + self.pq_stat[pred_info['category_id']].fp += 1 + + def evaluate(self): + metrics = [("All", None), ("Things", True), ("Stuff", False)] + results = {} + for name, isthing in metrics: + results[name], per_class_results = self.pq_stat.pq_average( + isthing=isthing, thing_list=self.thing_list) + if name == 'All': + results['per_class'] = per_class_results + return OrderedDict(pan_seg=results) diff --git a/contrib/PanopticDeepLab/utils/evaluation/semantic.py b/contrib/PanopticDeepLab/utils/evaluation/semantic.py new file mode 100644 index 0000000000..79a004124d --- /dev/null +++ b/contrib/PanopticDeepLab/utils/evaluation/semantic.py @@ -0,0 +1,84 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ------------------------------------------------------------------------------ +# Reference: https://github.com/bowenc0221/panoptic-deeplab/blob/master/segmentation/evaluation/semantic.py +# Modified by Guowei Chen +# ------------------------------------------------------------------------------ + +from collections import OrderedDict + +import numpy as np + + +class SemanticEvaluator: + """ + Evaluate semantic segmentation + + Args: + num_classes (int): number of classes + ignore_index (int, optional): value in semantic segmentation ground truth. Predictions for the + corresponding pixels should be ignored. Default: 255. + """ + + def __init__(self, num_classes, ignore_index=255): + self._num_classes = num_classes + self._ignore_index = ignore_index + self._N = num_classes + 1 # store ignore label in the last class + + self._conf_matrix = np.zeros((self._N, self._N), dtype=np.int64) + + def update(self, pred, gt): + pred = pred.astype(np.int) + gt = gt.astype(np.int) + gt[gt == self._ignore_index] = self._num_classes + + # raw: pred, column: gt + self._conf_matrix += np.bincount( + self._N * pred.reshape(-1) + gt.reshape(-1), + minlength=self._N**2).reshape(self._N, self._N) + + def evaluate(self): + """ + Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval): + * Mean intersection-over-union averaged across classes (mIoU) + * Frequency Weighted IoU (fwIoU) + * Mean pixel accuracy averaged across classes (mACC) + * Pixel Accuracy (pACC) + """ + acc = np.zeros(self._num_classes, dtype=np.float) + iou = np.zeros(self._num_classes, dtype=np.float) + tp = self._conf_matrix.diagonal()[:-1].astype(np.float) + pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float) + class_weights = pos_gt / np.sum(pos_gt) + pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float) + + acc_valid = pos_pred > 0 + acc[acc_valid] = tp[acc_valid] / pos_pred[acc_valid] + iou_valid = (pos_gt + pos_pred) > 0 + union = pos_gt + pos_pred - tp + iou[acc_valid] = tp[acc_valid] / union[acc_valid] + macc = np.sum(acc) / np.sum(acc_valid) + miou = np.sum(iou) / np.sum(iou_valid) + fiou = np.sum(iou * class_weights) + pacc = np.sum(tp) / np.sum(pos_gt) + + res = {} + res["mIoU"] = 100 * miou + res["fwIoU"] = 100 * fiou + res["mACC"] = 100 * macc + res["pACC"] = 100 * pacc + + results = OrderedDict({"sem_seg": res}) + return results diff --git a/contrib/PanopticDeepLab/utils/visualize.py b/contrib/PanopticDeepLab/utils/visualize.py new file mode 100644 index 0000000000..6b14215c87 --- /dev/null +++ b/contrib/PanopticDeepLab/utils/visualize.py @@ -0,0 +1,197 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Reference: https://github.com/bowenc0221/panoptic-deeplab/blob/master/segmentation/utils/save_annotation.py + +import os + +import cv2 +import numpy as np +from PIL import Image as PILImage + +# Refence: https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/colormap.py#L14 +_COLORS = np.array([ + 0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494, 0.184, + 0.556, 0.466, 0.674, 0.188, 0.301, 0.745, 0.933, 0.635, 0.078, 0.184, 0.300, + 0.300, 0.300, 0.600, 0.600, 0.600, 1.000, 0.000, 0.000, 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 1.000, 0.667, 0.000, + 1.000, 0.333, 0.333, 0.000, 0.333, 0.667, 0.000, 0.333, 1.000, 0.000, 0.667, + 0.333, 0.000, 0.667, 0.667, 0.000, 0.667, 1.000, 0.000, 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, 1.000, 1.000, 0.000, 0.000, 0.333, 0.500, 0.000, 0.667, + 0.500, 0.000, 1.000, 0.500, 0.333, 0.000, 0.500, 0.333, 0.333, 0.500, 0.333, + 0.667, 0.500, 0.333, 1.000, 0.500, 0.667, 0.000, 0.500, 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, 0.667, 1.000, 0.500, 1.000, 0.000, 0.500, 1.000, 0.333, + 0.500, 1.000, 0.667, 0.500, 1.000, 1.000, 0.500, 0.000, 0.333, 1.000, 0.000, + 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000, 1.000, 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000, 1.000, 0.667, 0.333, + 1.000, 0.667, 0.667, 1.000, 0.667, 1.000, 1.000, 1.000, 0.000, 1.000, 1.000, + 0.333, 1.000, 1.000, 0.667, 1.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, 0.000, 0.143, 0.143, 0.143, 0.857, 0.857, 0.857, 1.000, + 1.000, 1.000 +]).astype(np.float32).reshape(-1, 3) + + +def random_color(rgb=False, maximum=255): + """ + Reference: https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/colormap.py#L111 + + Args: + rgb (bool, optional): whether to return RGB colors or BGR colors. Default: False. + maximum (int, optional): either 255 or 1. Default: 255. + + Returns: + ndarray: a vector of 3 numbers + """ + idx = np.random.randint(0, len(_COLORS)) + ret = _COLORS[idx] * maximum + if not rgb: + ret = ret[::-1] + return ret + + +def cityscape_colormap(): + """Get CityScapes colormap""" + colormap = np.zeros((256, 3), dtype=np.uint8) + colormap[0] = [128, 64, 128] + colormap[1] = [244, 35, 232] + colormap[2] = [70, 70, 70] + colormap[3] = [102, 102, 156] + colormap[4] = [190, 153, 153] + colormap[5] = [153, 153, 153] + colormap[6] = [250, 170, 30] + colormap[7] = [220, 220, 0] + colormap[8] = [107, 142, 35] + colormap[9] = [152, 251, 152] + colormap[10] = [70, 130, 180] + colormap[11] = [220, 20, 60] + colormap[12] = [255, 0, 0] + colormap[13] = [0, 0, 142] + colormap[14] = [0, 0, 70] + colormap[15] = [0, 60, 100] + colormap[16] = [0, 80, 100] + colormap[17] = [0, 0, 230] + colormap[18] = [119, 11, 32] + colormap = colormap[:, ::-1] + return colormap + + +def visualize_semantic(semantic, save_path, colormap, image=None, weight=0.5): + """ + Save semantic segmentation results. + + Args: + semantic(np.ndarray): The result semantic segmenation results, shape is (h, w). + save_path(str): The save path. + colormap(np.ndarray): A color map for visualization. + image(np.ndarray, optional): Origin image to prediction, merge semantic with + image if provided. Default: None. + weight(float, optional): The image weight when merge semantic with image. Default: 0.5. + """ + semantic = semantic.astype('uint8') + colored_semantic = colormap[semantic] + if image is not None: + colored_semantic = cv2.addWeighted(image, weight, colored_semantic, + 1 - weight, 0) + cv2.imwrite(save_path, colored_semantic) + + +def visualize_instance(instance, save_path, stuff_id=0, image=None, weight=0.5): + """ + Save instance segmentation results. + + Args: + instance(np.ndarray): The instance segmentation results, shape is (h, w). + save_path(str): The save path. + stuff_id(int, optional): Id for background that not want to plot. + image(np.ndarray, optional): Origin image to prediction, merge instance with + image if provided. Default: None. + weight(float, optional): The image weight when merge instance with image. Default: 0.5. + """ + # Add color map for instance segmentation result. + ids = np.unique(instance) + num_colors = len(ids) + colormap = np.zeros((num_colors, 3), dtype=np.uint8) + # Maps label to continuous value + for i in range(num_colors): + instance[instance == ids[i]] = i + colormap[i, :] = random_color(maximum=255) + if ids[i] == stuff_id: + colormap[i, :] = np.array([0, 0, 0]) + colored_instance = colormap[instance] + + if image is not None: + colored_instance = cv2.addWeighted(image, weight, colored_instance, + 1 - weight, 0) + cv2.imwrite(save_path, colored_instance) + + +def visualize_panoptic(panoptic, + save_path, + label_divisor, + colormap, + image=None, + weight=0.5, + ignore_index=255): + """ + Save panoptic segmentation results. + + Args: + panoptic(np.ndarray): The panoptic segmentation results, shape is (h, w). + save_path(str): The save path. + label_divisor(int): Used to convert panoptic id = semantic id * label_divisor + instance_id. + colormap(np.ndarray): A color map for visualization. + image(np.ndarray, optional): Origin image to prediction, merge panoptic with + image if provided. Default: None. + weight(float, optional): The image weight when merge panoptic with image. Default: 0.5. + ignore_index(int, optional): Specifies a target value that is ignored. Default: 255. + """ + colored_panoptic = np.zeros((panoptic.shape[0], panoptic.shape[1], 3), + dtype=np.uint8) + taken_colors = set((0, 0, 0)) + + def _random_color(base, max_dist=30): + color = base + np.random.randint( + low=-max_dist, high=max_dist + 1, size=3) + return tuple(np.maximum(0, np.minimum(255, color))) + + for lab in np.unique(panoptic): + mask = panoptic == lab + + ignore_mask = panoptic == ignore_index + ins_mask = panoptic > label_divisor + if lab > label_divisor: + base_color = colormap[lab // label_divisor] + elif lab != ignore_index: + base_color = colormap[lab] + else: + continue + if tuple(base_color) not in taken_colors: + taken_colors.add(tuple(base_color)) + color = base_color + else: + while True: + color = _random_color(base_color) + if color not in taken_colors: + taken_colors.add(color) + break + colored_panoptic[mask] = color + + if image is not None: + colored_panoptic = cv2.addWeighted(image, weight, colored_panoptic, + 1 - weight, 0) + cv2.imwrite(save_path, colored_panoptic) diff --git a/contrib/PanopticDeepLab/val.py b/contrib/PanopticDeepLab/val.py new file mode 100644 index 0000000000..5d33ecb8db --- /dev/null +++ b/contrib/PanopticDeepLab/val.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import paddle +import paddleseg +from paddleseg.cvlibs import manager, Config +from paddleseg.utils import get_sys_env, logger, config_check + +from core import evaluate +from datasets import CityscapesPanoptic +from models import PanopticDeepLab + + +def parse_args(): + parser = argparse.ArgumentParser(description='Model evaluation') + + # params of evaluate + parser.add_argument( + "--config", dest="cfg", help="The config file.", default=None, type=str) + parser.add_argument( + '--model_path', + dest='model_path', + help='The path of model for evaluation', + type=str, + default=None) + parser.add_argument( + '--num_workers', + dest='num_workers', + help='Num workers for data loader', + type=int, + default=0) + parser.add_argument( + '--threshold', + dest='threshold', + help='Threshold applied to center heatmap score', + type=float, + default=0.1) + parser.add_argument( + '--nms_kernel', + dest='nms_kernel', + help='NMS max pooling kernel size', + type=int, + default=7) + parser.add_argument( + '--top_k', + dest='top_k', + help='Top k centers to keep', + type=int, + default=200) + + return parser.parse_args() + + +def main(args): + env_info = get_sys_env() + place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[ + 'GPUs used'] else 'cpu' + + paddle.set_device(place) + if not args.cfg: + raise RuntimeError('No configuration file specified.') + + cfg = Config(args.cfg) + val_dataset = cfg.val_dataset + if val_dataset is None: + raise RuntimeError( + 'The verification dataset is not specified in the configuration file.' + ) + elif len(val_dataset) == 0: + raise ValueError( + 'The length of val_dataset is 0. Please check if your dataset is valid' + ) + + msg = '\n---------------Config Information---------------\n' + msg += str(cfg) + msg += '------------------------------------------------' + logger.info(msg) + + model = cfg.model + if args.model_path: + paddleseg.utils.utils.load_entire_model(model, args.model_path) + logger.info('Loaded trained params of model successfully') + + config_check(cfg, val_dataset=val_dataset) + + evaluate( + model, + val_dataset, + threshold=args.threshold, + nms_kernel=args.nms_kernel, + top_k=args.top_k, + num_workers=args.num_workers, + ) + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/paddleseg/models/losses/__init__.py b/paddleseg/models/losses/__init__.py index a0410448e8..b704e3dc4a 100644 --- a/paddleseg/models/losses/__init__.py +++ b/paddleseg/models/losses/__init__.py @@ -23,3 +23,5 @@ from .ohem_cross_entropy_loss import OhemCrossEntropyLoss from .decoupledsegnet_relax_boundary_loss import RelaxBoundaryLoss from .ohem_edge_attention_loss import OhemEdgeAttentionLoss +from .l1_loss import L1Loss +from .mean_square_error_loss import MSELoss diff --git a/paddleseg/models/losses/cross_entropy_loss.py b/paddleseg/models/losses/cross_entropy_loss.py index 9502e507b2..40117ba1ff 100644 --- a/paddleseg/models/losses/cross_entropy_loss.py +++ b/paddleseg/models/losses/cross_entropy_loss.py @@ -30,17 +30,20 @@ class CrossEntropyLoss(nn.Layer): Default ``None``. ignore_index (int64, optional): Specifies a target value that is ignored and does not contribute to the input gradient. Default ``255``. + top_k_percent_pixels (float, optional): the value lies in [0.0, 1.0]. When its value < 1.0, only compute the loss for + the top k percent pixels (e.g., the top 20% pixels). This is useful for hard pixel mining. """ - def __init__(self, weight=None, ignore_index=255): + def __init__(self, weight=None, ignore_index=255, top_k_percent_pixels=1.0): super(CrossEntropyLoss, self).__init__() if weight is not None: weight = paddle.to_tensor(weight, dtype='float32') self.weight = weight self.ignore_index = ignore_index + self.top_k_percent_pixels = top_k_percent_pixels self.EPS = 1e-8 - def forward(self, logit, label): + def forward(self, logit, label, semantic_weights=None): """ Forward computation. @@ -74,8 +77,17 @@ def forward(self, logit, label): mask = label != self.ignore_index mask = paddle.cast(mask, 'float32') loss = loss * mask - avg_loss = paddle.mean(loss) / (paddle.mean(mask) + self.EPS) + if semantic_weights is not None: + loss = loss * semantic_weights label.stop_gradient = True mask.stop_gradient = True - return avg_loss + if self.top_k_percent_pixels == 1.0: + avg_loss = paddle.mean(loss) / (paddle.mean(mask) + self.EPS) + return avg_loss + + loss = loss.reshape((-1, )) + top_k_pixels = int(self.top_k_percent_pixels * loss.numel()) + loss, _ = paddle.topk(loss, top_k_pixels) + + return loss.mean() diff --git a/paddleseg/models/losses/l1_loss.py b/paddleseg/models/losses/l1_loss.py new file mode 100644 index 0000000000..f0f58454b8 --- /dev/null +++ b/paddleseg/models/losses/l1_loss.py @@ -0,0 +1,76 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import nn +import paddle.nn.functional as F + +from paddleseg.cvlibs import manager + + +@manager.LOSSES.add_component +class L1Loss(nn.MSELoss): + r""" + This interface is used to construct a callable object of the ``L1Loss`` class. + The L1Loss layer calculates the L1 Loss of ``input`` and ``label`` as follows. + If `reduction` set to ``'none'``, the loss is: + .. math:: + Out = \lvert input - label\rvert + If `reduction` set to ``'mean'``, the loss is: + .. math:: + Out = MEAN(\lvert input - label\rvert) + If `reduction` set to ``'sum'``, the loss is: + .. math:: + Out = SUM(\lvert input - label\rvert) + + Args: + reduction (str, optional): Indicate the reduction to apply to the loss, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If `reduction` is ``'none'``, the unreduced loss is returned; + If `reduction` is ``'mean'``, the reduced mean loss is returned. + If `reduction` is ``'sum'``, the reduced sum loss is returned. + Default is ``'mean'``. + ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. Default: 255. + Shape: + input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64. + label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64, int32, int64. + output (Tensor): The L1 Loss of ``input`` and ``label``. + If `reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``input`` . + If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1]. + Examples: + .. code-block:: python + + import paddle + import numpy as np + input_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32") + label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32") + input = paddle.to_tensor(input_data) + label = paddle.to_tensor(label_data) + l1_loss = paddle.nn.L1Loss() + output = l1_loss(input, label) + print(output.numpy()) + # [0.35] + l1_loss = paddle.nn.L1Loss(reduction='sum') + output = l1_loss(input, label) + print(output.numpy()) + # [1.4] + l1_loss = paddle.nn.L1Loss(reduction='none') + output = l1_loss(input, label) + print(output) + # [[0.20000005 0.19999999] + # [0.2 0.79999995]] + """ + + def __init__(self, reduction='mean', ignore_index=255): + super().__init__(reduction=reduction) diff --git a/paddleseg/models/losses/mean_square_error_loss.py b/paddleseg/models/losses/mean_square_error_loss.py new file mode 100644 index 0000000000..e6fc8918c2 --- /dev/null +++ b/paddleseg/models/losses/mean_square_error_loss.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import nn +import paddle.nn.functional as F + +from paddleseg.cvlibs import manager + + +@manager.LOSSES.add_component +class MSELoss(nn.MSELoss): + r""" + **Mean Square Error Loss** + Computes the mean square error (squared L2 norm) of given input and label. + If :attr:`reduction` is set to ``'none'``, loss is calculated as: + .. math:: + Out = (input - label)^2 + If :attr:`reduction` is set to ``'mean'``, loss is calculated as: + .. math:: + Out = \operatorname{mean}((input - label)^2) + If :attr:`reduction` is set to ``'sum'``, loss is calculated as: + .. math:: + Out = \operatorname{sum}((input - label)^2) + where `input` and `label` are `float32` tensors of same shape. + + Args: + reduction (string, optional): The reduction method for the output, + could be 'none' | 'mean' | 'sum'. + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. + If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned. + Default is ``'mean'``. + ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. Default: 255. + Shape: + input (Tensor): Input tensor, the data type is float32 or float64 + label (Tensor): Label tensor, the data type is float32 or float64 + output (Tensor): output tensor storing the MSE loss of input and label, the data type is same as input. + Examples: + .. code-block:: python + import numpy as np + import paddle + input_data = np.array([1.5]).astype("float32") + label_data = np.array([1.7]).astype("float32") + mse_loss = paddle.nn.loss.MSELoss() + input = paddle.to_tensor(input_data) + label = paddle.to_tensor(label_data) + output = mse_loss(input, label) + print(output) + # [0.04000002] + """ + + def __init__(self, reduction='mean', ignore_index=255): + super().__init__(reduction=reduction)