Skip to content

Commit

Permalink
Add Panoptic Segmentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
wuyefeilin committed May 19, 2021
1 parent a79cb15 commit e37b7c6
Show file tree
Hide file tree
Showing 32 changed files with 3,755 additions and 4 deletions.
144 changes: 144 additions & 0 deletions contrib/PanopticDeepLab/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@

# Panoptic DeepLab

基于PaddlePaddle实现[Panoptic Deeplab](https://arxiv.org/abs/1911.10194)全景分割算法。

Panoptic DeepLab首次证实了bottem-up算法能够达到state-of-the-art的效果。Panoptic DeepLab预测三个输出:Semantic Segmentation, Center Prediction 和 Center Regression。实例类别像素根据最近距离原则聚集到实例中心点得到实例分割结果。最后按照majority-vote规则融合语义分割结果和实例分割结果,得到最终的全景分割结果。
其通过将每一个像素赋值给每一个类别或实例达到分割的效果。
![](./docs/panoptic_deeplab.jpg)

## Model Baselines

### Cityscapes
| Backbone | Batch Size |Resolution | Training Iters | PQ | SQ | RQ | AP | mIoU | Links |
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
|ResNet50_OS32| 8 | 2049x1025|90000|58.35%|80.03%|71.52%|25.80%|79.18%|[model](https://bj.bcebos.com/paddleseg/dygraph/pnoptic_segmentation/panoptic_deeplab_resnet50_os32_cityscapes_2049x1025_bs1_90k_lr00005/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/pnoptic_segmentation/panoptic_deeplab_resnet50_os32_cityscapes_2049x1025_bs1_90k_lr00005/train.log)|
|ResNet50_OS32| 64 | 1025x513|90000|60.32%|80.56%|73.56%|26.77%|79.67%|[model](https://bj.bcebos.com/paddleseg/dygraph/pnoptic_segmentation/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/pnoptic_segmentation/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005/train.log)|

## 环境准备

1. 系统环境
* PaddlePaddle >= 2.0.0
* Python >= 3.6+
推荐使用GPU版本的PaddlePaddle版本。详细安装教程请参考官方网站[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/windows-pip.html)

2. 下载PaddleSeg repo
```shell
git clone https://github.com/PaddlePaddle/PaddleSeg
```

3. 安装paddleseg
```shell
cd PaddleSeg
pip install -e .
```

4. 进入PaddleSeg/contrib/PanopticDeepLab目录
```shell
cd contrib/PanopticDeepLab
```

## 数据集准备

将数据集放置于`data`目录下。

### Cityscapes

前往[CityScapes官网](https://www.cityscapes-dataset.com/)下载数据集并整理成如下结构:

```
cityscapes/
|--gtFine/
| |--train/
| | |--aachen/
| | | |--*_color.png, *_instanceIds.png, *_labelIds.png, *_polygons.json,
| | | |--*_labelTrainIds.png
| | | |--...
| |--val/
| |--test/
| |--cityscapes_panoptic_train_trainId.json
| |--cityscapes_panoptic_train_trainId/
| | |-- *_panoptic.png
| |--cityscapes_panoptic_val_trainId.json
| |--cityscapes_panoptic_val_trainId/
| | |-- *_panoptic.png
|--leftImg8bit/
| |--train/
| |--val/
| |--test/
```

安装CityscapesScripts
```shell
pip install git+https://github.com/mcordts/cityscapesScripts.git
```

`*_panoptic.png` 生成命令(需找到`createPanopticImgs.py`文件):
```shell
python /path/to/cityscapesscripts/preparation/createPanopticImgs.py \
--dataset-folder data/cityscapes/gtFine/ \
--output-folder data/cityscapes/gtFine/ \
--use-train-id
```

## 训练
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # 根据实际情况进行显卡数量的设置
python -m paddle.distributed.launch train.py \
--config configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005.yml \
--do_eval \
--use_vdl \
--save_interval 5000 \
--save_dir output
```

**note:** 使用--do_eval会影响训练速度及增加显存消耗,根据选择进行开闭。

更多参数信息请运行如下命令进行查看:
```shell
python train.py --help
```

## 评估
```shell
python val.py \
--config configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005.yml \
--model_path output/iter_90000/model.pdparams
```
你可以直接下载我们提供的模型进行评估。

更多参数信息请运行如下命令进行查看:
```shell
python val.py --help
```

## 预测及可视化结果保存
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # 根据实际情况进行显卡数量的设置
python -m paddle.distributed.launch predict.py \
--config configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_120k.yml \
--model_path output/iter_90000/model.pdparams \
--image_path data/cityscapes/leftImg8bit/val/ \
--save_dir ./output/result
```
你可以直接下载我们提供的模型进行预测。

更多参数信息请运行如下命令进行查看:
```shell
python predict.py --help
```
全景分割结果:
<center>
<img src="docs/visualization_panoptic.png">
</center>

语义分割结果:
<center>
<img src="docs/visualization_semantic.png">
</center>

实例分割结果:
<center>
<img src="docs/visualization_instance.png">
</center>
55 changes: 55 additions & 0 deletions contrib/PanopticDeepLab/configs/_base_/cityscapes_panoptic.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
train_dataset:
type: CityscapesPanoptic
dataset_root: data/cityscapes
transforms:
- type: ResizeStepScaling
min_scale_factor: 0.5
max_scale_factor: 2.0
scale_step_size: 0.25
- type: RandomPaddingCrop
crop_size: [2049, 1025]
label_padding_value: [0, 0, 0]
- type: RandomHorizontalFlip
- type: RandomDistort
brightness_range: 0.4
contrast_range: 0.4
saturation_range: 0.4
- type: Normalize
mode: train
ignore_stuff_in_offset: True
small_instance_area: 4096
small_instance_weight: 3

val_dataset:
type: CityscapesPanoptic
dataset_root: data/cityscapes
transforms:
- type: Padding
target_size: [2049, 1025]
label_padding_value: [0, 0, 0]
- type: Normalize
mode: val
ignore_stuff_in_offset: True
small_instance_area: 4096
small_instance_weight: 3


optimizer:
type: adam

learning_rate:
value: 0.00005
decay:
type: poly
power: 0.9
end_lr: 0.0

loss:
types:
- type: CrossEntropyLoss
top_k_percent_pixels: 0.2
- type: MSELoss
reduction: "none"
- type: L1Loss
reduction: "none"
coef: [1, 200, 0.001]
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
_base_: ./panoptic_deeplab_resnet50_os32_cityscapes_2049x1025_bs1_90k_lr00005.yml

batch_size: 8

train_dataset:
transforms:
- type: ResizeStepScaling
min_scale_factor: 0.5
max_scale_factor: 2.0
scale_step_size: 0.25
- type: RandomPaddingCrop
crop_size: [1025, 513]
label_padding_value: [0, 0, 0]
- type: RandomHorizontalFlip
- type: RandomDistort
brightness_range: 0.4
contrast_range: 0.4
saturation_range: 0.4
- type: Normalize
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
_base_: ../_base_/cityscapes_panoptic.yml

batch_size: 1
iters: 90000

model:
type: PanopticDeepLab
backbone:
type: ResNet50_vd
output_stride: 32
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
backbone_indices: [2,1,0,3]
aspp_ratios: [1, 3, 6, 9]
aspp_out_channels: 256
decoder_channels: 256
low_level_channels_projects: [128, 64, 32]
align_corners: True
instance_aspp_out_channels: 256
instance_decoder_channels: 128
instance_low_level_channels_projects: [64, 32, 16]
instance_num_classes: [1, 2]
instance_head_channels: 32
instance_class_key: ["center", "offset"]
20 changes: 20 additions & 0 deletions contrib/PanopticDeepLab/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .train import train
from .val import evaluate
from .predict import predict
from . import infer

__all__ = ['train', 'evaluate', 'predict']
Loading

0 comments on commit e37b7c6

Please sign in to comment.