From 43e1eda70aa4eaa852233901d1675ac86cfd2749 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Mar 2023 23:18:02 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mmsegmentation/.dev/benchmark_inference.py | 82 +- mmsegmentation/.dev/check_urls.py | 65 +- .../gather_benchmark_evaluation_results.py | 56 +- .../.dev/gather_benchmark_train_results.py | 45 +- mmsegmentation/.dev/gather_models.py | 130 ++- .../generate_benchmark_evaluation_script.py | 105 +- .../.dev/generate_benchmark_train_script.py | 77 +- .../.dev/log_collector/example_config.py | 18 +- .../.dev/log_collector/log_collector.py | 94 +- mmsegmentation/.dev/log_collector/utils.py | 7 +- mmsegmentation/.dev/md2yml.py | 297 ++--- mmsegmentation/.dev/upload_modelzoo.py | 34 +- mmsegmentation/_git/packed-refs | 2 +- .../configs/_base_/datasets/ade20k.py | 66 +- .../configs/_base_/datasets/ade20k_640x640.py | 66 +- .../configs/_base_/datasets/chase_db1.py | 69 +- .../configs/_base_/datasets/cityscapes.py | 66 +- .../_base_/datasets/cityscapes_1024x1024.py | 45 +- .../_base_/datasets/cityscapes_768x768.py | 45 +- .../_base_/datasets/cityscapes_769x769.py | 45 +- .../_base_/datasets/cityscapes_832x832.py | 45 +- .../configs/_base_/datasets/coco-stuff10k.py | 66 +- .../configs/_base_/datasets/coco-stuff164k.py | 66 +- .../configs/_base_/datasets/drive.py | 69 +- mmsegmentation/configs/_base_/datasets/hrf.py | 69 +- .../configs/_base_/datasets/isaid.py | 66 +- .../configs/_base_/datasets/loveda.py | 66 +- .../configs/_base_/datasets/occlude_face.py | 89 +- .../configs/_base_/datasets/pascal_context.py | 72 +- .../_base_/datasets/pascal_context_59.py | 72 +- .../configs/_base_/datasets/pascal_voc12.py | 72 +- .../_base_/datasets/pascal_voc12_aug.py | 11 +- .../configs/_base_/datasets/potsdam.py | 66 +- .../configs/_base_/datasets/stare.py | 69 +- .../configs/_base_/datasets/vaihingen.py | 66 +- .../configs/_base_/default_runtime.py | 11 +- .../configs/_base_/models/ann_r50-d8.py | 30 +- .../configs/_base_/models/apcnet_r50-d8.py | 30 +- .../_base_/models/bisenetv1_r18-d32.py | 37 +- .../configs/_base_/models/bisenetv2.py | 42 +- .../configs/_base_/models/ccnet_r50-d8.py | 28 +- mmsegmentation/configs/_base_/models/cgnet.py | 43 +- .../configs/_base_/models/danet_r50-d8.py | 28 +- .../configs/_base_/models/deeplabv3_r50-d8.py | 28 +- .../_base_/models/deeplabv3_unet_s5-d16.py | 28 +- .../_base_/models/deeplabv3plus_r50-d8.py | 28 +- .../configs/_base_/models/dmnet_r50-d8.py | 30 +- .../configs/_base_/models/dnl_r50-d8.py | 30 +- .../configs/_base_/models/dpt_vit-b16.py | 24 +- .../configs/_base_/models/emanet_r50-d8.py | 28 +- .../configs/_base_/models/encnet_r50-d8.py | 31 +- .../configs/_base_/models/erfnet_fcn.py | 18 +- .../configs/_base_/models/fast_scnn.py | 30 +- .../_base_/models/fastfcn_r50-d32_jpu_psp.py | 33 +- .../configs/_base_/models/fcn_hr18.py | 43 +- .../configs/_base_/models/fcn_r50-d8.py | 28 +- .../configs/_base_/models/fcn_unet_s5-d16.py | 28 +- .../_base_/models/fpn_poolformer_s12.py | 34 +- .../configs/_base_/models/fpn_r50.py | 28 +- .../configs/_base_/models/gcnet_r50-d8.py | 34 +- .../configs/_base_/models/icnet_r50-d8.py | 39 +- .../configs/_base_/models/isanet_r50-d8.py | 28 +- .../configs/_base_/models/lraspp_m-v3-d8.py | 23 +- .../configs/_base_/models/nonlocal_r50-d8.py | 30 +- .../configs/_base_/models/ocrnet_hr18.py | 51 +- .../configs/_base_/models/ocrnet_r50-d8.py | 28 +- .../configs/_base_/models/pointrend_r50.py | 41 +- .../configs/_base_/models/psanet_r50-d8.py | 30 +- .../configs/_base_/models/pspnet_r50-d8.py | 28 +- .../_base_/models/pspnet_unet_s5-d16.py | 28 +- .../configs/_base_/models/segformer_mit-b0.py | 18 +- .../_base_/models/segmenter_vit-b16_mask.py | 17 +- .../configs/_base_/models/setr_mla.py | 49 +- .../configs/_base_/models/setr_naive.py | 39 +- .../configs/_base_/models/setr_pup.py | 39 +- mmsegmentation/configs/_base_/models/stdc.py | 61 +- .../_base_/models/twins_pcpvt-s_fpn.py | 32 +- .../_base_/models/twins_pcpvt-s_upernet.py | 32 +- .../configs/_base_/models/upernet_beit.py | 30 +- .../configs/_base_/models/upernet_convnext.py | 32 +- .../configs/_base_/models/upernet_mae.py | 30 +- .../configs/_base_/models/upernet_r50.py | 28 +- .../configs/_base_/models/upernet_swin.py | 32 +- .../_base_/models/upernet_vit-b16_ln_mln.py | 35 +- .../configs/_base_/schedules/schedule_160k.py | 8 +- .../configs/_base_/schedules/schedule_20k.py | 8 +- .../configs/_base_/schedules/schedule_320k.py | 8 +- .../configs/_base_/schedules/schedule_40k.py | 8 +- .../configs/_base_/schedules/schedule_80k.py | 8 +- .../ann_r101-d8_512x1024_40k_cityscapes.py | 4 +- .../ann_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../ann/ann_r101-d8_512x512_160k_ade20k.py | 4 +- .../ann/ann_r101-d8_512x512_20k_voc12aug.py | 4 +- .../ann/ann_r101-d8_512x512_40k_voc12aug.py | 4 +- .../ann/ann_r101-d8_512x512_80k_ade20k.py | 4 +- .../ann/ann_r101-d8_769x769_40k_cityscapes.py | 4 +- .../ann/ann_r101-d8_769x769_80k_cityscapes.py | 4 +- .../ann/ann_r50-d8_512x1024_40k_cityscapes.py | 6 +- .../ann/ann_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../ann/ann_r50-d8_512x512_160k_ade20k.py | 9 +- .../ann/ann_r50-d8_512x512_20k_voc12aug.py | 9 +- .../ann/ann_r50-d8_512x512_40k_voc12aug.py | 9 +- .../ann/ann_r50-d8_512x512_80k_ade20k.py | 9 +- .../ann/ann_r50-d8_769x769_40k_cityscapes.py | 10 +- .../ann/ann_r50-d8_769x769_80k_cityscapes.py | 10 +- .../apcnet_r101-d8_512x1024_40k_cityscapes.py | 4 +- .../apcnet_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../apcnet_r101-d8_512x512_160k_ade20k.py | 4 +- .../apcnet_r101-d8_512x512_80k_ade20k.py | 4 +- .../apcnet_r101-d8_769x769_40k_cityscapes.py | 4 +- .../apcnet_r101-d8_769x769_80k_cityscapes.py | 4 +- .../apcnet_r50-d8_512x1024_40k_cityscapes.py | 6 +- .../apcnet_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../apcnet_r50-d8_512x512_160k_ade20k.py | 9 +- .../apcnet_r50-d8_512x512_80k_ade20k.py | 9 +- .../apcnet_r50-d8_769x769_40k_cityscapes.py | 10 +- .../apcnet_r50-d8_769x769_80k_cityscapes.py | 10 +- ...pernet_beit-base_640x640_160k_ade20k_ms.py | 25 +- ...ernet_beit-base_8x2_640x640_160k_ade20k.py | 25 +- ..._beit-large_fp16_640x640_160k_ade20k_ms.py | 25 +- ...beit-large_fp16_8x1_640x640_160k_ade20k.py | 36 +- ..._lr5e-3_4x4_512x512_160k_coco-stuff164k.py | 8 +- ..._lr5e-3_4x4_512x512_160k_coco-stuff164k.py | 29 +- ...1_r18-d32_4x4_1024x1024_160k_cityscapes.py | 9 +- ..._in1k-pre_4x4_1024x1024_160k_cityscapes.py | 15 +- ..._in1k-pre_4x8_1024x1024_160k_cityscapes.py | 2 +- ..._lr5e-3_4x4_512x512_160k_coco-stuff164k.py | 8 +- ..._lr5e-3_4x4_512x512_160k_coco-stuff164k.py | 26 +- ...1_r50-d32_4x4_1024x1024_160k_cityscapes.py | 34 +- ..._in1k-pre_4x4_1024x1024_160k_cityscapes.py | 10 +- ..._lr5e-3_4x4_512x512_160k_coco-stuff164k.py | 8 +- ..._lr5e-3_4x4_512x512_160k_coco-stuff164k.py | 29 +- ...netv2_fcn_4x4_1024x1024_160k_cityscapes.py | 9 +- ...netv2_fcn_4x8_1024x1024_160k_cityscapes.py | 9 +- ..._fcn_fp16_4x4_1024x1024_160k_cityscapes.py | 4 +- ..._fcn_ohem_4x4_1024x1024_160k_cityscapes.py | 46 +- .../ccnet_r101-d8_512x1024_40k_cityscapes.py | 4 +- .../ccnet_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../ccnet_r101-d8_512x512_160k_ade20k.py | 4 +- .../ccnet_r101-d8_512x512_20k_voc12aug.py | 4 +- .../ccnet_r101-d8_512x512_40k_voc12aug.py | 4 +- .../ccnet/ccnet_r101-d8_512x512_80k_ade20k.py | 4 +- .../ccnet_r101-d8_769x769_40k_cityscapes.py | 4 +- .../ccnet_r101-d8_769x769_80k_cityscapes.py | 4 +- .../ccnet_r50-d8_512x1024_40k_cityscapes.py | 6 +- .../ccnet_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../ccnet/ccnet_r50-d8_512x512_160k_ade20k.py | 9 +- .../ccnet_r50-d8_512x512_20k_voc12aug.py | 10 +- .../ccnet_r50-d8_512x512_40k_voc12aug.py | 10 +- .../ccnet/ccnet_r50-d8_512x512_80k_ade20k.py | 9 +- .../ccnet_r50-d8_769x769_40k_cityscapes.py | 10 +- .../ccnet_r50-d8_769x769_80k_cityscapes.py | 10 +- .../cgnet/cgnet_512x1024_60k_cityscapes.py | 74 +- .../cgnet/cgnet_680x680_60k_cityscapes.py | 52 +- ..._convnext_base_fp16_512x512_160k_ade20k.py | 28 +- ..._convnext_base_fp16_640x640_160k_ade20k.py | 40 +- ...convnext_large_fp16_640x640_160k_ade20k.py | 40 +- ...convnext_small_fp16_512x512_160k_ade20k.py | 39 +- ..._convnext_tiny_fp16_512x512_160k_ade20k.py | 39 +- ...onvnext_xlarge_fp16_640x640_160k_ade20k.py | 40 +- .../danet_r101-d8_512x1024_40k_cityscapes.py | 4 +- .../danet_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../danet_r101-d8_512x512_160k_ade20k.py | 4 +- .../danet_r101-d8_512x512_20k_voc12aug.py | 4 +- .../danet_r101-d8_512x512_40k_voc12aug.py | 4 +- .../danet/danet_r101-d8_512x512_80k_ade20k.py | 4 +- .../danet_r101-d8_769x769_40k_cityscapes.py | 4 +- .../danet_r101-d8_769x769_80k_cityscapes.py | 4 +- .../danet_r50-d8_512x1024_40k_cityscapes.py | 6 +- .../danet_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../danet/danet_r50-d8_512x512_160k_ade20k.py | 9 +- .../danet_r50-d8_512x512_20k_voc12aug.py | 10 +- .../danet_r50-d8_512x512_40k_voc12aug.py | 10 +- .../danet/danet_r50-d8_512x512_80k_ade20k.py | 9 +- .../danet_r50-d8_769x769_40k_cityscapes.py | 10 +- .../danet_r50-d8_769x769_80k_cityscapes.py | 10 +- ..._r101-d16-mg124_512x1024_40k_cityscapes.py | 15 +- ..._r101-d16-mg124_512x1024_80k_cityscapes.py | 15 +- ...abv3_r101-d8_480x480_40k_pascal_context.py | 4 +- ...3_r101-d8_480x480_40k_pascal_context_59.py | 4 +- ...abv3_r101-d8_480x480_80k_pascal_context.py | 4 +- ...3_r101-d8_480x480_80k_pascal_context_59.py | 4 +- ...eplabv3_r101-d8_512x1024_40k_cityscapes.py | 4 +- ...eplabv3_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../deeplabv3_r101-d8_512x512_160k_ade20k.py | 4 +- .../deeplabv3_r101-d8_512x512_20k_voc12aug.py | 4 +- .../deeplabv3_r101-d8_512x512_40k_voc12aug.py | 4 +- ...r101-d8_512x512_4x4_160k_coco-stuff164k.py | 4 +- ...3_r101-d8_512x512_4x4_20k_coco-stuff10k.py | 4 +- ...r101-d8_512x512_4x4_320k_coco-stuff164k.py | 4 +- ...3_r101-d8_512x512_4x4_40k_coco-stuff10k.py | 4 +- ..._r101-d8_512x512_4x4_80k_coco-stuff164k.py | 4 +- .../deeplabv3_r101-d8_512x512_80k_ade20k.py | 4 +- ...eeplabv3_r101-d8_769x769_40k_cityscapes.py | 4 +- ...eeplabv3_r101-d8_769x769_80k_cityscapes.py | 4 +- ...v3_r101-d8_fp16_512x1024_80k_cityscapes.py | 4 +- ...plabv3_r101b-d8_512x1024_80k_cityscapes.py | 6 +- ...eplabv3_r101b-d8_769x769_80k_cityscapes.py | 6 +- ...eeplabv3_r18-d8_512x1024_80k_cityscapes.py | 7 +- ...deeplabv3_r18-d8_769x769_80k_cityscapes.py | 7 +- ...eplabv3_r18b-d8_512x1024_80k_cityscapes.py | 9 +- ...eeplabv3_r18b-d8_769x769_80k_cityscapes.py | 9 +- ...labv3_r50-d8_480x480_40k_pascal_context.py | 12 +- ...v3_r50-d8_480x480_40k_pascal_context_59.py | 12 +- ...labv3_r50-d8_480x480_80k_pascal_context.py | 12 +- ...v3_r50-d8_480x480_80k_pascal_context_59.py | 12 +- ...eeplabv3_r50-d8_512x1024_40k_cityscapes.py | 6 +- ...eeplabv3_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../deeplabv3_r50-d8_512x512_160k_ade20k.py | 9 +- .../deeplabv3_r50-d8_512x512_20k_voc12aug.py | 10 +- .../deeplabv3_r50-d8_512x512_40k_voc12aug.py | 10 +- ..._r50-d8_512x512_4x4_160k_coco-stuff164k.py | 10 +- ...v3_r50-d8_512x512_4x4_20k_coco-stuff10k.py | 10 +- ..._r50-d8_512x512_4x4_320k_coco-stuff164k.py | 10 +- ...v3_r50-d8_512x512_4x4_40k_coco-stuff10k.py | 10 +- ...3_r50-d8_512x512_4x4_80k_coco-stuff164k.py | 10 +- .../deeplabv3_r50-d8_512x512_80k_ade20k.py | 9 +- ...deeplabv3_r50-d8_769x769_40k_cityscapes.py | 10 +- ...deeplabv3_r50-d8_769x769_80k_cityscapes.py | 10 +- ...eplabv3_r50b-d8_512x1024_80k_cityscapes.py | 4 +- ...eeplabv3_r50b-d8_769x769_80k_cityscapes.py | 4 +- ..._r101-d16-mg124_512x1024_40k_cityscapes.py | 15 +- ..._r101-d16-mg124_512x1024_80k_cityscapes.py | 15 +- ...plus_r101-d8_480x480_40k_pascal_context.py | 4 +- ...s_r101-d8_480x480_40k_pascal_context_59.py | 4 +- ...plus_r101-d8_480x480_80k_pascal_context.py | 4 +- ...s_r101-d8_480x480_80k_pascal_context_59.py | 4 +- ...3plus_r101-d8_4x4_512x512_80k_vaihingen.py | 4 +- ...bv3plus_r101-d8_512x1024_40k_cityscapes.py | 4 +- ...bv3plus_r101-d8_512x1024_80k_cityscapes.py | 4 +- ...eplabv3plus_r101-d8_512x512_160k_ade20k.py | 4 +- ...plabv3plus_r101-d8_512x512_20k_voc12aug.py | 4 +- ...plabv3plus_r101-d8_512x512_40k_voc12aug.py | 4 +- ...eeplabv3plus_r101-d8_512x512_80k_ade20k.py | 4 +- ...eeplabv3plus_r101-d8_512x512_80k_loveda.py | 7 +- ...eplabv3plus_r101-d8_512x512_80k_potsdam.py | 4 +- ...abv3plus_r101-d8_769x769_40k_cityscapes.py | 4 +- ...abv3plus_r101-d8_769x769_80k_cityscapes.py | 4 +- ...us_r101-d8_fp16_512x1024_80k_cityscapes.py | 4 +- ...3plus_r101_512x512_C-CM+C-WO-NatOcc-SOT.py | 56 +- ...v3plus_r101b-d8_512x1024_80k_cityscapes.py | 6 +- ...bv3plus_r101b-d8_769x769_80k_cityscapes.py | 6 +- ...v3plus_r18-d8_4x4_512x512_80k_vaihingen.py | 7 +- ...plabv3plus_r18-d8_4x4_896x896_80k_isaid.py | 7 +- ...abv3plus_r18-d8_512x1024_80k_cityscapes.py | 7 +- ...deeplabv3plus_r18-d8_512x512_80k_loveda.py | 9 +- ...eeplabv3plus_r18-d8_512x512_80k_potsdam.py | 7 +- ...labv3plus_r18-d8_769x769_80k_cityscapes.py | 7 +- ...bv3plus_r18b-d8_512x1024_80k_cityscapes.py | 9 +- ...abv3plus_r18b-d8_769x769_80k_cityscapes.py | 9 +- ...3plus_r50-d8_480x480_40k_pascal_context.py | 12 +- ...us_r50-d8_480x480_40k_pascal_context_59.py | 12 +- ...3plus_r50-d8_480x480_80k_pascal_context.py | 12 +- ...us_r50-d8_480x480_80k_pascal_context_59.py | 12 +- ...v3plus_r50-d8_4x4_512x512_80k_vaihingen.py | 10 +- ...plabv3plus_r50-d8_4x4_896x896_80k_isaid.py | 9 +- ...abv3plus_r50-d8_512x1024_40k_cityscapes.py | 7 +- ...abv3plus_r50-d8_512x1024_80k_cityscapes.py | 7 +- ...eeplabv3plus_r50-d8_512x512_160k_ade20k.py | 9 +- ...eplabv3plus_r50-d8_512x512_20k_voc12aug.py | 10 +- ...eplabv3plus_r50-d8_512x512_40k_voc12aug.py | 10 +- ...deeplabv3plus_r50-d8_512x512_80k_ade20k.py | 9 +- ...deeplabv3plus_r50-d8_512x512_80k_loveda.py | 9 +- ...eeplabv3plus_r50-d8_512x512_80k_potsdam.py | 10 +- ...labv3plus_r50-d8_769x769_40k_cityscapes.py | 10 +- ...labv3plus_r50-d8_769x769_80k_cityscapes.py | 10 +- ...bv3plus_r50b-d8_512x1024_80k_cityscapes.py | 4 +- ...abv3plus_r50b-d8_769x769_80k_cityscapes.py | 4 +- .../dmnet_r101-d8_512x1024_40k_cityscapes.py | 4 +- .../dmnet_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../dmnet_r101-d8_512x512_160k_ade20k.py | 4 +- .../dmnet/dmnet_r101-d8_512x512_80k_ade20k.py | 4 +- .../dmnet_r101-d8_769x769_40k_cityscapes.py | 4 +- .../dmnet_r101-d8_769x769_80k_cityscapes.py | 4 +- .../dmnet_r50-d8_512x1024_40k_cityscapes.py | 6 +- .../dmnet_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../dmnet/dmnet_r50-d8_512x512_160k_ade20k.py | 9 +- .../dmnet/dmnet_r50-d8_512x512_80k_ade20k.py | 9 +- .../dmnet_r50-d8_769x769_40k_cityscapes.py | 10 +- .../dmnet_r50-d8_769x769_80k_cityscapes.py | 10 +- .../dnl_r101-d8_512x1024_40k_cityscapes.py | 4 +- .../dnl_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../dnlnet/dnl_r101-d8_512x512_160k_ade20k.py | 4 +- .../dnlnet/dnl_r101-d8_512x512_80k_ade20k.py | 4 +- .../dnl_r101-d8_769x769_40k_cityscapes.py | 4 +- .../dnl_r101-d8_769x769_80k_cityscapes.py | 4 +- .../dnl_r50-d8_512x1024_40k_cityscapes.py | 6 +- .../dnl_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../dnlnet/dnl_r50-d8_512x512_160k_ade20k.py | 9 +- .../dnlnet/dnl_r50-d8_512x512_80k_ade20k.py | 9 +- .../dnl_r50-d8_769x769_40k_cityscapes.py | 10 +- .../dnl_r50-d8_769x769_80k_cityscapes.py | 14 +- .../dpt/dpt_vit-b16_512x512_160k_ade20k.py | 25 +- .../emanet_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../emanet_r101-d8_769x769_80k_cityscapes.py | 4 +- .../emanet_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../emanet_r50-d8_769x769_80k_cityscapes.py | 10 +- .../encnet_r101-d8_512x1024_40k_cityscapes.py | 4 +- .../encnet_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../encnet_r101-d8_512x512_160k_ade20k.py | 4 +- .../encnet_r101-d8_512x512_20k_voc12aug.py | 4 +- .../encnet_r101-d8_512x512_40k_voc12aug.py | 4 +- .../encnet_r101-d8_512x512_80k_ade20k.py | 4 +- .../encnet_r101-d8_769x769_40k_cityscapes.py | 4 +- .../encnet_r101-d8_769x769_80k_cityscapes.py | 4 +- .../encnet_r50-d8_512x1024_40k_cityscapes.py | 6 +- .../encnet_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../encnet_r50-d8_512x512_160k_ade20k.py | 9 +- .../encnet_r50-d8_512x512_20k_voc12aug.py | 10 +- .../encnet_r50-d8_512x512_40k_voc12aug.py | 10 +- .../encnet_r50-d8_512x512_80k_ade20k.py | 9 +- .../encnet_r50-d8_769x769_40k_cityscapes.py | 10 +- .../encnet_r50-d8_769x769_80k_cityscapes.py | 10 +- .../encnet_r50s-d8_512x512_80k_ade20k.py | 9 +- ...erfnet_fcn_4x4_512x1024_160k_cityscapes.py | 6 +- ...32_jpu_aspp_4x4_512x1024_80k_cityscapes.py | 2 +- ...50-d32_jpu_aspp_512x1024_80k_cityscapes.py | 13 +- ...cn_r50-d32_jpu_aspp_512x512_160k_ade20k.py | 13 +- ...fcn_r50-d32_jpu_aspp_512x512_80k_ade20k.py | 13 +- ...d32_jpu_enc_4x4_512x1024_80k_cityscapes.py | 2 +- ...r50-d32_jpu_enc_512x1024_80k_cityscapes.py | 16 +- ...fcn_r50-d32_jpu_enc_512x512_160k_ade20k.py | 16 +- ...tfcn_r50-d32_jpu_enc_512x512_80k_ade20k.py | 16 +- ...d32_jpu_psp_4x4_512x1024_80k_cityscapes.py | 7 +- ...r50-d32_jpu_psp_512x1024_80k_cityscapes.py | 7 +- ...fcn_r50-d32_jpu_psp_512x512_160k_ade20k.py | 10 +- ...tfcn_r50-d32_jpu_psp_512x512_80k_ade20k.py | 10 +- .../fast_scnn_lr0.12_8x4_160k_cityscapes.py | 8 +- ...fcn_d6_r101-d16_512x1024_40k_cityscapes.py | 4 +- ...fcn_d6_r101-d16_512x1024_80k_cityscapes.py | 4 +- .../fcn_d6_r101-d16_769x769_40k_cityscapes.py | 4 +- .../fcn_d6_r101-d16_769x769_80k_cityscapes.py | 4 +- ...cn_d6_r101b-d16_512x1024_80k_cityscapes.py | 6 +- ...fcn_d6_r101b-d16_769x769_80k_cityscapes.py | 6 +- .../fcn_d6_r50-d16_512x1024_40k_cityscapes.py | 9 +- .../fcn_d6_r50-d16_512x1024_80k_cityscapes.py | 9 +- .../fcn_d6_r50-d16_769x769_40k_cityscapes.py | 10 +- .../fcn_d6_r50-d16_769x769_80k_cityscapes.py | 10 +- ...fcn_d6_r50b-d16_512x1024_80k_cityscapes.py | 4 +- .../fcn_d6_r50b-d16_769x769_80k_cityscapes.py | 4 +- .../fcn_r101-d8_480x480_40k_pascal_context.py | 4 +- ...n_r101-d8_480x480_40k_pascal_context_59.py | 4 +- .../fcn_r101-d8_480x480_80k_pascal_context.py | 4 +- ...n_r101-d8_480x480_80k_pascal_context_59.py | 4 +- .../fcn_r101-d8_512x1024_40k_cityscapes.py | 4 +- .../fcn_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../fcn/fcn_r101-d8_512x512_160k_ade20k.py | 4 +- .../fcn/fcn_r101-d8_512x512_20k_voc12aug.py | 4 +- .../fcn/fcn_r101-d8_512x512_40k_voc12aug.py | 4 +- .../fcn/fcn_r101-d8_512x512_80k_ade20k.py | 4 +- .../fcn/fcn_r101-d8_769x769_40k_cityscapes.py | 4 +- .../fcn/fcn_r101-d8_769x769_80k_cityscapes.py | 4 +- ...cn_r101-d8_fp16_512x1024_80k_cityscapes.py | 4 +- .../fcn_r101b-d8_512x1024_80k_cityscapes.py | 6 +- .../fcn_r101b-d8_769x769_80k_cityscapes.py | 6 +- .../fcn/fcn_r18-d8_512x1024_80k_cityscapes.py | 7 +- .../fcn/fcn_r18-d8_769x769_80k_cityscapes.py | 7 +- .../fcn_r18b-d8_512x1024_80k_cityscapes.py | 9 +- .../fcn/fcn_r18b-d8_769x769_80k_cityscapes.py | 9 +- .../fcn_r50-d8_480x480_40k_pascal_context.py | 11 +- ...cn_r50-d8_480x480_40k_pascal_context_59.py | 12 +- .../fcn_r50-d8_480x480_80k_pascal_context.py | 11 +- ...cn_r50-d8_480x480_80k_pascal_context_59.py | 12 +- .../fcn/fcn_r50-d8_512x1024_40k_cityscapes.py | 6 +- .../fcn/fcn_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../fcn/fcn_r50-d8_512x512_160k_ade20k.py | 9 +- .../fcn/fcn_r50-d8_512x512_20k_voc12aug.py | 9 +- .../fcn/fcn_r50-d8_512x512_40k_voc12aug.py | 9 +- .../fcn/fcn_r50-d8_512x512_80k_ade20k.py | 9 +- .../fcn/fcn_r50-d8_769x769_40k_cityscapes.py | 10 +- .../fcn/fcn_r50-d8_769x769_80k_cityscapes.py | 10 +- .../fcn_r50b-d8_512x1024_80k_cityscapes.py | 4 +- .../fcn/fcn_r50b-d8_769x769_80k_cityscapes.py | 4 +- .../gcnet_r101-d8_512x1024_40k_cityscapes.py | 4 +- .../gcnet_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../gcnet_r101-d8_512x512_160k_ade20k.py | 4 +- .../gcnet_r101-d8_512x512_20k_voc12aug.py | 4 +- .../gcnet_r101-d8_512x512_40k_voc12aug.py | 4 +- .../gcnet/gcnet_r101-d8_512x512_80k_ade20k.py | 4 +- .../gcnet_r101-d8_769x769_40k_cityscapes.py | 4 +- .../gcnet_r101-d8_769x769_80k_cityscapes.py | 4 +- .../gcnet_r50-d8_512x1024_40k_cityscapes.py | 6 +- .../gcnet_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../gcnet/gcnet_r50-d8_512x512_160k_ade20k.py | 9 +- .../gcnet_r50-d8_512x512_20k_voc12aug.py | 10 +- .../gcnet_r50-d8_512x512_40k_voc12aug.py | 10 +- .../gcnet/gcnet_r50-d8_512x512_80k_ade20k.py | 9 +- .../gcnet_r50-d8_769x769_40k_cityscapes.py | 10 +- .../gcnet_r50-d8_769x769_80k_cityscapes.py | 10 +- .../fcn_hr18_480x480_40k_pascal_context.py | 11 +- .../fcn_hr18_480x480_40k_pascal_context_59.py | 11 +- .../fcn_hr18_480x480_80k_pascal_context.py | 11 +- .../fcn_hr18_480x480_80k_pascal_context_59.py | 11 +- .../fcn_hr18_4x4_512x512_80k_vaihingen.py | 6 +- .../hrnet/fcn_hr18_4x4_896x896_80k_isaid.py | 6 +- .../fcn_hr18_512x1024_160k_cityscapes.py | 6 +- .../hrnet/fcn_hr18_512x1024_40k_cityscapes.py | 6 +- .../hrnet/fcn_hr18_512x1024_80k_cityscapes.py | 6 +- .../hrnet/fcn_hr18_512x512_160k_ade20k.py | 6 +- .../hrnet/fcn_hr18_512x512_20k_voc12aug.py | 6 +- .../hrnet/fcn_hr18_512x512_40k_voc12aug.py | 6 +- .../hrnet/fcn_hr18_512x512_80k_ade20k.py | 6 +- .../hrnet/fcn_hr18_512x512_80k_loveda.py | 6 +- .../hrnet/fcn_hr18_512x512_80k_potsdam.py | 6 +- .../fcn_hr18s_480x480_40k_pascal_context.py | 11 +- ...fcn_hr18s_480x480_40k_pascal_context_59.py | 11 +- .../fcn_hr18s_480x480_80k_pascal_context.py | 11 +- ...fcn_hr18s_480x480_80k_pascal_context_59.py | 11 +- .../fcn_hr18s_4x4_512x512_80k_vaihingen.py | 11 +- .../hrnet/fcn_hr18s_4x4_896x896_80k_isaid.py | 11 +- .../fcn_hr18s_512x1024_160k_cityscapes.py | 11 +- .../fcn_hr18s_512x1024_40k_cityscapes.py | 11 +- .../fcn_hr18s_512x1024_80k_cityscapes.py | 11 +- .../hrnet/fcn_hr18s_512x512_160k_ade20k.py | 11 +- .../hrnet/fcn_hr18s_512x512_20k_voc12aug.py | 11 +- .../hrnet/fcn_hr18s_512x512_40k_voc12aug.py | 11 +- .../hrnet/fcn_hr18s_512x512_80k_ade20k.py | 11 +- .../hrnet/fcn_hr18s_512x512_80k_loveda.py | 13 +- .../hrnet/fcn_hr18s_512x512_80k_potsdam.py | 11 +- .../fcn_hr48_480x480_40k_pascal_context.py | 12 +- .../fcn_hr48_480x480_40k_pascal_context_59.py | 12 +- .../fcn_hr48_480x480_80k_pascal_context.py | 12 +- .../fcn_hr48_480x480_80k_pascal_context_59.py | 12 +- .../fcn_hr48_4x4_512x512_80k_vaihingen.py | 12 +- .../hrnet/fcn_hr48_4x4_896x896_80k_isaid.py | 12 +- .../fcn_hr48_512x1024_160k_cityscapes.py | 12 +- .../hrnet/fcn_hr48_512x1024_40k_cityscapes.py | 12 +- .../hrnet/fcn_hr48_512x1024_80k_cityscapes.py | 12 +- .../hrnet/fcn_hr48_512x512_160k_ade20k.py | 12 +- .../hrnet/fcn_hr48_512x512_20k_voc12aug.py | 12 +- .../hrnet/fcn_hr48_512x512_40k_voc12aug.py | 12 +- .../hrnet/fcn_hr48_512x512_80k_ade20k.py | 12 +- .../hrnet/fcn_hr48_512x512_80k_loveda.py | 13 +- .../hrnet/fcn_hr48_512x512_80k_potsdam.py | 12 +- .../icnet_r101-d8_832x832_160k_cityscapes.py | 2 +- .../icnet_r101-d8_832x832_80k_cityscapes.py | 2 +- ...101-d8_in1k-pre_832x832_160k_cityscapes.py | 8 +- ...r101-d8_in1k-pre_832x832_80k_cityscapes.py | 8 +- .../icnet_r18-d8_832x832_160k_cityscapes.py | 5 +- .../icnet_r18-d8_832x832_80k_cityscapes.py | 5 +- ...r18-d8_in1k-pre_832x832_160k_cityscapes.py | 8 +- ..._r18-d8_in1k-pre_832x832_80k_cityscapes.py | 8 +- .../icnet_r50-d8_832x832_160k_cityscapes.py | 7 +- .../icnet_r50-d8_832x832_80k_cityscapes.py | 7 +- ...r50-d8_in1k-pre_832x832_160k_cityscapes.py | 8 +- ..._r50-d8_in1k-pre_832x832_80k_cityscapes.py | 8 +- .../isanet_r101-d8_512x1024_40k_cityscapes.py | 4 +- .../isanet_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../isanet_r101-d8_512x512_160k_ade20k.py | 4 +- .../isanet_r101-d8_512x512_20k_voc12aug.py | 4 +- .../isanet_r101-d8_512x512_40k_voc12aug.py | 4 +- .../isanet_r101-d8_512x512_80k_ade20k.py | 4 +- .../isanet_r101-d8_769x769_40k_cityscapes.py | 4 +- .../isanet_r101-d8_769x769_80k_cityscapes.py | 4 +- .../isanet_r50-d8_512x1024_40k_cityscapes.py | 6 +- .../isanet_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../isanet_r50-d8_512x512_160k_ade20k.py | 9 +- .../isanet_r50-d8_512x512_20k_voc12aug.py | 10 +- .../isanet_r50-d8_512x512_40k_voc12aug.py | 10 +- .../isanet_r50-d8_512x512_80k_ade20k.py | 9 +- .../isanet_r50-d8_769x769_40k_cityscapes.py | 10 +- .../isanet_r50-d8_769x769_80k_cityscapes.py | 10 +- ...bv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py | 61 +- ...fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py | 61 +- ...net_r50-d8_8x2_512x512_adamw_80k_ade20k.py | 61 +- ...net_r50-d8_8x2_512x512_adamw_80k_ade20k.py | 61 +- ...net_swin-l_8x2_512x512_adamw_80k_ade20k.py | 13 +- ...net_swin-l_8x2_640x640_adamw_80k_ade20k.py | 56 +- ...net_swin-t_8x2_512x512_adamw_80k_ade20k.py | 42 +- ...et_mae-base_fp16_512x512_160k_ade20k_ms.py | 25 +- ...t_mae-base_fp16_8x2_512x512_160k_ade20k.py | 35 +- ...eplabv3_m-v2-d8_512x1024_80k_cityscapes.py | 14 +- .../deeplabv3_m-v2-d8_512x512_160k_ade20k.py | 14 +- ...bv3plus_m-v2-d8_512x1024_80k_cityscapes.py | 14 +- ...eplabv3plus_m-v2-d8_512x512_160k_ade20k.py | 14 +- .../fcn_m-v2-d8_512x1024_80k_cityscapes.py | 14 +- .../fcn_m-v2-d8_512x512_160k_ade20k.py | 14 +- .../pspnet_m-v2-d8_512x1024_80k_cityscapes.py | 14 +- .../pspnet_m-v2-d8_512x512_160k_ade20k.py | 14 +- ...lraspp_m-v3-d8_512x1024_320k_cityscapes.py | 10 +- ...-v3-d8_scratch_512x1024_320k_cityscapes.py | 8 +- ...raspp_m-v3s-d8_512x1024_320k_cityscapes.py | 25 +- ...v3s-d8_scratch_512x1024_320k_cityscapes.py | 23 +- ...onlocal_r101-d8_512x1024_40k_cityscapes.py | 4 +- ...onlocal_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../nonlocal_r101-d8_512x512_160k_ade20k.py | 4 +- .../nonlocal_r101-d8_512x512_20k_voc12aug.py | 4 +- .../nonlocal_r101-d8_512x512_40k_voc12aug.py | 4 +- .../nonlocal_r101-d8_512x512_80k_ade20k.py | 4 +- ...nonlocal_r101-d8_769x769_40k_cityscapes.py | 4 +- ...nonlocal_r101-d8_769x769_80k_cityscapes.py | 4 +- ...nonlocal_r50-d8_512x1024_40k_cityscapes.py | 6 +- ...nonlocal_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../nonlocal_r50-d8_512x512_160k_ade20k.py | 9 +- .../nonlocal_r50-d8_512x512_20k_voc12aug.py | 10 +- .../nonlocal_r50-d8_512x512_40k_voc12aug.py | 10 +- .../nonlocal_r50-d8_512x512_80k_ade20k.py | 9 +- .../nonlocal_r50-d8_769x769_40k_cityscapes.py | 10 +- .../nonlocal_r50-d8_769x769_80k_cityscapes.py | 10 +- .../ocrnet_hr18_512x1024_160k_cityscapes.py | 6 +- .../ocrnet_hr18_512x1024_40k_cityscapes.py | 6 +- .../ocrnet_hr18_512x1024_80k_cityscapes.py | 6 +- .../ocrnet/ocrnet_hr18_512x512_160k_ade20k.py | 74 +- .../ocrnet_hr18_512x512_20k_voc12aug.py | 75 +- .../ocrnet_hr18_512x512_40k_voc12aug.py | 75 +- .../ocrnet/ocrnet_hr18_512x512_80k_ade20k.py | 74 +- .../ocrnet_hr18s_512x1024_160k_cityscapes.py | 11 +- .../ocrnet_hr18s_512x1024_40k_cityscapes.py | 11 +- .../ocrnet_hr18s_512x1024_80k_cityscapes.py | 11 +- .../ocrnet_hr18s_512x512_160k_ade20k.py | 11 +- .../ocrnet_hr18s_512x512_20k_voc12aug.py | 11 +- .../ocrnet_hr18s_512x512_40k_voc12aug.py | 11 +- .../ocrnet/ocrnet_hr18s_512x512_80k_ade20k.py | 11 +- .../ocrnet_hr48_512x1024_160k_cityscapes.py | 29 +- .../ocrnet_hr48_512x1024_40k_cityscapes.py | 29 +- .../ocrnet_hr48_512x1024_80k_cityscapes.py | 29 +- .../ocrnet/ocrnet_hr48_512x512_160k_ade20k.py | 29 +- .../ocrnet_hr48_512x512_20k_voc12aug.py | 29 +- .../ocrnet_hr48_512x512_40k_voc12aug.py | 29 +- .../ocrnet/ocrnet_hr48_512x512_80k_ade20k.py | 29 +- ...net_r101-d8_512x1024_40k_b16_cityscapes.py | 8 +- ...rnet_r101-d8_512x1024_40k_b8_cityscapes.py | 8 +- ...net_r101-d8_512x1024_80k_b16_cityscapes.py | 8 +- .../pointrend_r101_512x1024_80k_cityscapes.py | 4 +- .../pointrend_r101_512x512_160k_ade20k.py | 4 +- .../pointrend_r50_512x1024_80k_cityscapes.py | 8 +- .../pointrend_r50_512x512_160k_ade20k.py | 68 +- ...n_poolformer_m36_8x4_512x512_40k_ade20k.py | 14 +- ...n_poolformer_m48_8x4_512x512_40k_ade20k.py | 14 +- ...n_poolformer_s12_8x4_512x512_40k_ade20k.py | 84 +- ...n_poolformer_s24_8x4_512x512_40k_ade20k.py | 12 +- ...n_poolformer_s36_8x4_512x512_40k_ade20k.py | 12 +- .../psanet_r101-d8_512x1024_40k_cityscapes.py | 4 +- .../psanet_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../psanet_r101-d8_512x512_160k_ade20k.py | 4 +- .../psanet_r101-d8_512x512_20k_voc12aug.py | 4 +- .../psanet_r101-d8_512x512_40k_voc12aug.py | 4 +- .../psanet_r101-d8_512x512_80k_ade20k.py | 4 +- .../psanet_r101-d8_769x769_40k_cityscapes.py | 4 +- .../psanet_r101-d8_769x769_80k_cityscapes.py | 4 +- .../psanet_r50-d8_512x1024_40k_cityscapes.py | 6 +- .../psanet_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../psanet_r50-d8_512x512_160k_ade20k.py | 9 +- .../psanet_r50-d8_512x512_20k_voc12aug.py | 10 +- .../psanet_r50-d8_512x512_40k_voc12aug.py | 10 +- .../psanet_r50-d8_512x512_80k_ade20k.py | 9 +- .../psanet_r50-d8_769x769_40k_cityscapes.py | 10 +- .../psanet_r50-d8_769x769_80k_cityscapes.py | 10 +- ...pnet_r101-d8_480x480_40k_pascal_context.py | 4 +- ...t_r101-d8_480x480_40k_pascal_context_59.py | 4 +- ...pnet_r101-d8_480x480_80k_pascal_context.py | 4 +- ...t_r101-d8_480x480_80k_pascal_context_59.py | 4 +- .../pspnet_r101-d8_4x4_512x512_80k_potsdam.py | 4 +- ...spnet_r101-d8_4x4_512x512_80k_vaihingen.py | 4 +- .../pspnet_r101-d8_512x1024_40k_cityscapes.py | 4 +- .../pspnet_r101-d8_512x1024_40k_dark.py | 4 +- ...pnet_r101-d8_512x1024_40k_night_driving.py | 4 +- .../pspnet_r101-d8_512x1024_80k_cityscapes.py | 4 +- .../pspnet_r101-d8_512x512_160k_ade20k.py | 4 +- .../pspnet_r101-d8_512x512_20k_voc12aug.py | 4 +- .../pspnet_r101-d8_512x512_40k_voc12aug.py | 4 +- ...r101-d8_512x512_4x4_160k_coco-stuff164k.py | 4 +- ...t_r101-d8_512x512_4x4_20k_coco-stuff10k.py | 4 +- ...r101-d8_512x512_4x4_320k_coco-stuff164k.py | 4 +- ...t_r101-d8_512x512_4x4_40k_coco-stuff10k.py | 4 +- ..._r101-d8_512x512_4x4_80k_coco-stuff164k.py | 4 +- .../pspnet_r101-d8_512x512_80k_ade20k.py | 4 +- .../pspnet_r101-d8_512x512_80k_loveda.py | 7 +- .../pspnet_r101-d8_769x769_40k_cityscapes.py | 4 +- .../pspnet_r101-d8_769x769_80k_cityscapes.py | 4 +- ...et_r101-d8_fp16_512x1024_80k_cityscapes.py | 4 +- ...pspnet_r101b-d8_512x1024_80k_cityscapes.py | 6 +- .../pspnet_r101b-d8_512x1024_80k_dark.py | 6 +- ...net_r101b-d8_512x1024_80k_night_driving.py | 6 +- .../pspnet_r101b-d8_769x769_80k_cityscapes.py | 6 +- .../pspnet_r18-d8_4x4_512x512_80k_potsdam.py | 7 +- ...pspnet_r18-d8_4x4_512x512_80k_vaihingen.py | 7 +- .../pspnet_r18-d8_4x4_896x896_80k_isaid.py | 7 +- .../pspnet_r18-d8_512x1024_80k_cityscapes.py | 7 +- .../pspnet_r18-d8_512x512_80k_loveda.py | 9 +- .../pspnet_r18-d8_769x769_80k_cityscapes.py | 7 +- .../pspnet_r18b-d8_512x1024_80k_cityscapes.py | 9 +- .../pspnet_r18b-d8_769x769_80k_cityscapes.py | 9 +- .../pspnet_r50-d32_512x1024_80k_cityscapes.py | 6 +- ...-pretrain_512x1024_adamw_80k_cityscapes.py | 26 +- ...spnet_r50-d8_480x480_40k_pascal_context.py | 12 +- ...et_r50-d8_480x480_40k_pascal_context_59.py | 12 +- ...spnet_r50-d8_480x480_80k_pascal_context.py | 12 +- ...et_r50-d8_480x480_80k_pascal_context_59.py | 12 +- .../pspnet_r50-d8_4x4_512x512_80k_potsdam.py | 9 +- ...pspnet_r50-d8_4x4_512x512_80k_vaihingen.py | 9 +- .../pspnet_r50-d8_4x4_896x896_80k_isaid.py | 9 +- .../pspnet_r50-d8_512x1024_40k_cityscapes.py | 6 +- .../pspnet/pspnet_r50-d8_512x1024_40k_dark.py | 38 +- ...spnet_r50-d8_512x1024_40k_night_driving.py | 38 +- .../pspnet_r50-d8_512x1024_80k_cityscapes.py | 6 +- .../pspnet/pspnet_r50-d8_512x1024_80k_dark.py | 38 +- ...spnet_r50-d8_512x1024_80k_night_driving.py | 38 +- .../pspnet_r50-d8_512x512_160k_ade20k.py | 9 +- .../pspnet_r50-d8_512x512_20k_voc12aug.py | 10 +- .../pspnet_r50-d8_512x512_40k_voc12aug.py | 10 +- ..._r50-d8_512x512_4x4_160k_coco-stuff164k.py | 10 +- ...et_r50-d8_512x512_4x4_20k_coco-stuff10k.py | 9 +- ..._r50-d8_512x512_4x4_320k_coco-stuff164k.py | 10 +- ...et_r50-d8_512x512_4x4_40k_coco-stuff10k.py | 9 +- ...t_r50-d8_512x512_4x4_80k_coco-stuff164k.py | 10 +- .../pspnet_r50-d8_512x512_80k_ade20k.py | 9 +- .../pspnet_r50-d8_512x512_80k_loveda.py | 9 +- .../pspnet_r50-d8_769x769_40k_cityscapes.py | 10 +- .../pspnet_r50-d8_769x769_80k_cityscapes.py | 10 +- ...-pretrain_512x1024_adamw_80k_cityscapes.py | 24 +- ...pspnet_r50b-d32_512x1024_80k_cityscapes.py | 11 +- .../pspnet_r50b-d8_512x1024_80k_cityscapes.py | 4 +- .../pspnet_r50b-d8_769x769_80k_cityscapes.py | 4 +- ...eplabv3_s101-d8_512x1024_80k_cityscapes.py | 10 +- .../deeplabv3_s101-d8_512x512_160k_ade20k.py | 10 +- ...bv3plus_s101-d8_512x1024_80k_cityscapes.py | 10 +- ...eplabv3plus_s101-d8_512x512_160k_ade20k.py | 10 +- .../fcn_s101-d8_512x1024_80k_cityscapes.py | 10 +- .../fcn_s101-d8_512x512_160k_ade20k.py | 10 +- .../pspnet_s101-d8_512x1024_80k_cityscapes.py | 10 +- .../pspnet_s101-d8_512x512_160k_ade20k.py | 10 +- .../segformer_mit-b0_512x512_160k_ade20k.py | 27 +- ...er_mit-b0_8x1_1024x1024_160k_cityscapes.py | 33 +- .../segformer_mit-b1_512x512_160k_ade20k.py | 10 +- ...er_mit-b1_8x1_1024x1024_160k_cityscapes.py | 11 +- .../segformer_mit-b2_512x512_160k_ade20k.py | 10 +- ...er_mit-b2_8x1_1024x1024_160k_cityscapes.py | 12 +- .../segformer_mit-b3_512x512_160k_ade20k.py | 10 +- ...er_mit-b3_8x1_1024x1024_160k_cityscapes.py | 12 +- .../segformer_mit-b4_512x512_160k_ade20k.py | 10 +- ...er_mit-b4_8x1_1024x1024_160k_cityscapes.py | 12 +- .../segformer_mit-b5_512x512_160k_ade20k.py | 10 +- .../segformer_mit-b5_640x640_160k_ade20k.py | 53 +- ...er_mit-b5_8x1_1024x1024_160k_cityscapes.py | 12 +- ...nter_vit-b_mask_8x1_512x512_160k_ade20k.py | 50 +- ...nter_vit-l_mask_8x1_640x640_160k_ade20k.py | 65 +- ...er_vit-s_linear_8x1_512x512_160k_ade20k.py | 9 +- ...nter_vit-s_mask_8x1_512x512_160k_ade20k.py | 61 +- ...nter_vit-t_mask_8x1_512x512_160k_ade20k.py | 58 +- .../fpn_r101_512x1024_80k_cityscapes.py | 4 +- .../sem_fpn/fpn_r101_512x512_160k_ade20k.py | 4 +- .../fpn_r50_512x1024_80k_cityscapes.py | 6 +- .../sem_fpn/fpn_r50_512x512_160k_ade20k.py | 6 +- .../setr/setr_mla_512x512_160k_b16_ade20k.py | 2 +- .../setr/setr_mla_512x512_160k_b8_ade20k.py | 51 +- .../setr_naive_512x512_160k_b16_ade20k.py | 43 +- .../setr/setr_pup_512x512_160k_b16_ade20k.py | 43 +- ...it-large_mla_8x1_768x768_80k_cityscapes.py | 16 +- ...-large_naive_8x1_768x768_80k_cityscapes.py | 20 +- ...it-large_pup_8x1_768x768_80k_cityscapes.py | 39 +- .../stdc/stdc1_512x1024_80k_cityscapes.py | 8 +- .../stdc1_in1k-pre_512x1024_80k_cityscapes.py | 9 +- .../stdc/stdc2_512x1024_80k_cityscapes.py | 4 +- .../stdc2_in1k-pre_512x1024_80k_cityscapes.py | 9 +- ...512x512_160k_ade20k_pretrain_384x384_1K.py | 13 +- ...12x512_160k_ade20k_pretrain_384x384_22K.py | 9 +- ...512x512_160k_ade20k_pretrain_224x224_1K.py | 13 +- ...12x512_160k_ade20k_pretrain_224x224_22K.py | 9 +- ...12x512_pretrain_384x384_22K_160k_ade20k.py | 11 +- ...12x512_pretrain_224x224_22K_160k_ade20k.py | 13 +- ...512x512_160k_ade20k_pretrain_224x224_1K.py | 13 +- ...512x512_160k_ade20k_pretrain_224x224_1K.py | 35 +- ...vt-b_fpn_fpnhead_8x4_512x512_80k_ade20k.py | 9 +- ...cpvt-b_uperhead_8x2_512x512_160k_ade20k.py | 10 +- ...vt-l_fpn_fpnhead_8x4_512x512_80k_ade20k.py | 9 +- ...cpvt-l_uperhead_8x2_512x512_160k_ade20k.py | 10 +- ...vt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py | 8 +- ...cpvt-s_uperhead_8x4_512x512_160k_ade20k.py | 24 +- ...vt-b_fpn_fpnhead_8x4_512x512_80k_ade20k.py | 9 +- ..._svt-b_uperhead_8x2_512x512_160k_ade20k.py | 12 +- ...vt-l_fpn_fpnhead_8x4_512x512_80k_ade20k.py | 9 +- ..._svt-l_uperhead_8x2_512x512_160k_ade20k.py | 12 +- ...vt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py | 17 +- ..._svt-s_uperhead_8x2_512x512_160k_ade20k.py | 36 +- ...labv3_unet_s5-d16_128x128_40k_chase_db1.py | 9 +- ...deeplabv3_unet_s5-d16_128x128_40k_stare.py | 8 +- .../deeplabv3_unet_s5-d16_256x256_40k_hrf.py | 8 +- .../deeplabv3_unet_s5-d16_64x64_40k_drive.py | 8 +- ...6_ce-1.0-dice-3.0_128x128_40k_chase-db1.py | 13 +- ...5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py | 13 +- ..._s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py | 13 +- ..._s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py | 13 +- .../fcn_unet_s5-d16_128x128_40k_chase_db1.py | 8 +- .../unet/fcn_unet_s5-d16_128x128_40k_stare.py | 8 +- .../unet/fcn_unet_s5-d16_256x256_40k_hrf.py | 8 +- ...net_s5-d16_4x4_512x1024_160k_cityscapes.py | 9 +- .../unet/fcn_unet_s5-d16_64x64_40k_drive.py | 8 +- ...6_ce-1.0-dice-3.0_128x128_40k_chase-db1.py | 13 +- ...5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py | 13 +- ..._s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py | 13 +- ..._s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py | 13 +- ...spnet_unet_s5-d16_128x128_40k_chase_db1.py | 9 +- .../pspnet_unet_s5-d16_128x128_40k_stare.py | 8 +- .../pspnet_unet_s5-d16_256x256_40k_hrf.py | 8 +- .../pspnet_unet_s5-d16_64x64_40k_drive.py | 8 +- ...6_ce-1.0-dice-3.0_128x128_40k_chase-db1.py | 13 +- ...5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py | 13 +- ..._s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py | 13 +- ..._s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py | 13 +- .../upernet_r101_512x1024_40k_cityscapes.py | 4 +- .../upernet_r101_512x1024_80k_cityscapes.py | 4 +- .../upernet_r101_512x512_160k_ade20k.py | 4 +- .../upernet_r101_512x512_20k_voc12aug.py | 4 +- .../upernet_r101_512x512_40k_voc12aug.py | 4 +- .../upernet_r101_512x512_80k_ade20k.py | 4 +- .../upernet_r101_769x769_40k_cityscapes.py | 4 +- .../upernet_r101_769x769_80k_cityscapes.py | 4 +- .../upernet_r18_512x1024_40k_cityscapes.py | 7 +- .../upernet_r18_512x1024_80k_cityscapes.py | 7 +- .../upernet_r18_512x512_160k_ade20k.py | 11 +- .../upernet_r18_512x512_20k_voc12aug.py | 12 +- .../upernet_r18_512x512_40k_voc12aug.py | 12 +- .../upernet/upernet_r18_512x512_80k_ade20k.py | 11 +- .../upernet_r50_512x1024_40k_cityscapes.py | 6 +- .../upernet_r50_512x1024_80k_cityscapes.py | 6 +- .../upernet_r50_512x512_160k_ade20k.py | 9 +- .../upernet_r50_512x512_20k_voc12aug.py | 10 +- .../upernet_r50_512x512_40k_voc12aug.py | 10 +- .../upernet/upernet_r50_512x512_80k_ade20k.py | 9 +- .../upernet_r50_769x769_40k_cityscapes.py | 10 +- .../upernet_r50_769x769_80k_cityscapes.py | 10 +- .../upernet_deit-b16_512x512_160k_ade20k.py | 7 +- .../upernet_deit-b16_512x512_80k_ade20k.py | 7 +- ...net_deit-b16_ln_mln_512x512_160k_ade20k.py | 7 +- ...pernet_deit-b16_mln_512x512_160k_ade20k.py | 4 +- .../upernet_deit-s16_512x512_160k_ade20k.py | 7 +- .../upernet_deit-s16_512x512_80k_ade20k.py | 7 +- ...net_deit-s16_ln_mln_512x512_160k_ade20k.py | 10 +- ...pernet_deit-s16_mln_512x512_160k_ade20k.py | 7 +- ...rnet_vit-b16_ln_mln_512x512_160k_ade20k.py | 31 +- ...upernet_vit-b16_mln_512x512_160k_ade20k.py | 31 +- .../upernet_vit-b16_mln_512x512_80k_ade20k.py | 31 +- mmsegmentation/demo/image_demo.py | 28 +- mmsegmentation/demo/video_demo.py | 85 +- mmsegmentation/docs/en/conf.py | 88 +- mmsegmentation/docs/en/stat.py | 34 +- mmsegmentation/docs/zh_cn/conf.py | 88 +- mmsegmentation/docs/zh_cn/stat.py | 34 +- mmsegmentation/mmseg/__init__.py | 23 +- mmsegmentation/mmseg/apis/__init__.py | 15 +- mmsegmentation/mmseg/apis/inference.py | 61 +- mmsegmentation/mmseg/apis/test.py | 94 +- mmsegmentation/mmseg/apis/train.py | 118 +- mmsegmentation/mmseg/core/__init__.py | 7 +- mmsegmentation/mmseg/core/builder.py | 20 +- .../mmseg/core/evaluation/__init__.py | 23 +- .../mmseg/core/evaluation/class_names.py | 1025 ++++++++++++++--- .../mmseg/core/evaluation/eval_hooks.py | 61 +- .../mmseg/core/evaluation/metrics.py | 267 +++-- mmsegmentation/mmseg/core/hook/__init__.py | 2 +- .../mmseg/core/hook/wandblogger_hook.py | 157 +-- .../mmseg/core/optimizers/__init__.py | 8 +- .../layer_decay_optimizer_constructor.py | 136 ++- mmsegmentation/mmseg/core/seg/__init__.py | 2 +- mmsegmentation/mmseg/core/seg/builder.py | 2 +- .../mmseg/core/seg/sampler/__init__.py | 2 +- .../core/seg/sampler/ohem_pixel_sampler.py | 12 +- mmsegmentation/mmseg/core/utils/__init__.py | 2 +- mmsegmentation/mmseg/core/utils/dist_util.py | 2 +- mmsegmentation/mmseg/core/utils/misc.py | 2 +- mmsegmentation/mmseg/datasets/__init__.py | 37 +- mmsegmentation/mmseg/datasets/ade.py | 394 +++++-- mmsegmentation/mmseg/datasets/builder.py | 85 +- mmsegmentation/mmseg/datasets/chase_db1.py | 11 +- mmsegmentation/mmseg/datasets/cityscapes.py | 124 +- mmsegmentation/mmseg/datasets/coco_stuff.py | 423 +++++-- mmsegmentation/mmseg/datasets/coco_trash.py | 29 +- mmsegmentation/mmseg/datasets/custom.py | 178 +-- mmsegmentation/mmseg/datasets/dark_zurich.py | 5 +- .../mmseg/datasets/dataset_wrappers.py | 114 +- mmsegmentation/mmseg/datasets/drive.py | 11 +- mmsegmentation/mmseg/datasets/face.py | 7 +- mmsegmentation/mmseg/datasets/hrf.py | 10 +- mmsegmentation/mmseg/datasets/isaid.py | 75 +- mmsegmentation/mmseg/datasets/isprs.py | 29 +- mmsegmentation/mmseg/datasets/loveda.py | 40 +- .../mmseg/datasets/night_driving.py | 7 +- .../mmseg/datasets/pascal_context.py | 313 ++++- .../mmseg/datasets/pipelines/__init__.py | 60 +- .../mmseg/datasets/pipelines/compose.py | 12 +- .../mmseg/datasets/pipelines/formating.py | 10 +- .../mmseg/datasets/pipelines/formatting.py | 74 +- .../mmseg/datasets/pipelines/loading.py | 82 +- .../mmseg/datasets/pipelines/test_time_aug.py | 77 +- .../mmseg/datasets/pipelines/transforms.py | 607 +++++----- mmsegmentation/mmseg/datasets/potsdam.py | 29 +- .../mmseg/datasets/samplers/__init__.py | 2 +- .../datasets/samplers/distributed_sampler.py | 26 +- mmsegmentation/mmseg/datasets/stare.py | 11 +- mmsegmentation/mmseg/datasets/voc.py | 60 +- mmsegmentation/mmseg/models/__init__.py | 22 +- .../mmseg/models/backbones/__init__.py | 30 +- mmsegmentation/mmseg/models/backbones/beit.py | 263 +++-- .../mmseg/models/backbones/bisenetv1.py | 184 +-- .../mmseg/models/backbones/bisenetv2.py | 294 +++-- .../mmseg/models/backbones/cgnet.py | 149 +-- .../mmseg/models/backbones/erfnet.py | 200 ++-- .../mmseg/models/backbones/fast_scnn.py | 207 ++-- .../mmseg/models/backbones/hrnet.py | 358 +++--- .../mmseg/models/backbones/icnet.py | 83 +- mmsegmentation/mmseg/models/backbones/mae.py | 115 +- mmsegmentation/mmseg/models/backbones/mit.py | 241 ++-- .../mmseg/models/backbones/mobilenet_v2.py | 96 +- .../mmseg/models/backbones/mobilenet_v3.py | 160 +-- .../mmseg/models/backbones/resnest.py | 116 +- .../mmseg/models/backbones/resnet.py | 265 +++-- .../mmseg/models/backbones/resnext.py | 47 +- mmsegmentation/mmseg/models/backbones/stdc.py | 243 ++-- mmsegmentation/mmseg/models/backbones/swin.py | 402 ++++--- .../mmseg/models/backbones/timm_backbone.py | 20 +- .../mmseg/models/backbones/twins.py | 426 ++++--- mmsegmentation/mmseg/models/backbones/unet.py | 287 ++--- mmsegmentation/mmseg/models/backbones/vit.py | 259 +++-- mmsegmentation/mmseg/models/builder.py | 22 +- .../mmseg/models/decode_heads/__init__.py | 39 +- .../mmseg/models/decode_heads/ann_head.py | 96 +- .../mmseg/models/decode_heads/apc_head.py | 53 +- .../mmseg/models/decode_heads/aspp_head.py | 26 +- .../decode_heads/cascade_decode_head.py | 6 +- .../mmseg/models/decode_heads/cc_head.py | 7 +- .../mmseg/models/decode_heads/da_head.py | 45 +- .../mmseg/models/decode_heads/decode_head.py | 113 +- .../mmseg/models/decode_heads/dm_head.py | 38 +- .../mmseg/models/decode_heads/dnl_head.py | 44 +- .../mmseg/models/decode_heads/dpt_head.py | 200 ++-- .../mmseg/models/decode_heads/ema_head.py | 55 +- .../mmseg/models/decode_heads/enc_head.py | 73 +- .../mmseg/models/decode_heads/fcn_head.py | 18 +- .../mmseg/models/decode_heads/fpn_head.py | 23 +- .../mmseg/models/decode_heads/gc_head.py | 13 +- .../mmseg/models/decode_heads/isa_head.py | 29 +- .../mmseg/models/decode_heads/knet_head.py | 196 ++-- .../mmseg/models/decode_heads/lraspp_head.py | 43 +- .../mmseg/models/decode_heads/nl_head.py | 11 +- .../mmseg/models/decode_heads/ocr_head.py | 24 +- .../mmseg/models/decode_heads/point_head.py | 141 ++- .../mmseg/models/decode_heads/psa_head.py | 106 +- .../mmseg/models/decode_heads/psp_head.py | 33 +- .../models/decode_heads/segformer_head.py | 15 +- .../decode_heads/segmenter_mask_head.py | 54 +- .../models/decode_heads/sep_aspp_head.py | 30 +- .../mmseg/models/decode_heads/sep_fcn_head.py | 11 +- .../models/decode_heads/setr_mla_head.py | 16 +- .../mmseg/models/decode_heads/setr_up_head.py | 40 +- .../mmseg/models/decode_heads/stdc_head.py | 66 +- .../mmseg/models/decode_heads/uper_head.py | 34 +- .../mmseg/models/losses/__init__.py | 25 +- .../mmseg/models/losses/accuracy.py | 12 +- .../mmseg/models/losses/cross_entropy_loss.py | 159 +-- .../mmseg/models/losses/dice_loss.py | 54 +- .../mmseg/models/losses/focal_loss.py | 180 +-- .../mmseg/models/losses/lovasz_loss.py | 140 ++- .../mmseg/models/losses/tversky_loss.py | 57 +- mmsegmentation/mmseg/models/losses/utils.py | 15 +- mmsegmentation/mmseg/models/necks/__init__.py | 4 +- .../mmseg/models/necks/featurepyramid.py | 36 +- mmsegmentation/mmseg/models/necks/fpn.py | 70 +- mmsegmentation/mmseg/models/necks/ic_neck.py | 71 +- mmsegmentation/mmseg/models/necks/jpu.py | 57 +- mmsegmentation/mmseg/models/necks/mla_neck.py | 53 +- .../mmseg/models/necks/multilevel_neck.py | 32 +- .../mmseg/models/segmentors/__init__.py | 2 +- .../mmseg/models/segmentors/base.py | 99 +- .../segmentors/cascade_encoder_decoder.py | 51 +- .../models/segmentors/encoder_decoder.py | 126 +- mmsegmentation/mmseg/models/utils/__init__.py | 18 +- mmsegmentation/mmseg/models/utils/embed.py | 132 ++- .../mmseg/models/utils/inverted_residual.py | 125 +- .../mmseg/models/utils/res_layer.py | 65 +- mmsegmentation/mmseg/models/utils/se_layer.py | 21 +- .../models/utils/self_attention_block.py | 59 +- .../mmseg/models/utils/shape_convert.py | 4 +- .../mmseg/models/utils/up_conv_block.py | 47 +- mmsegmentation/mmseg/ops/__init__.py | 2 +- mmsegmentation/mmseg/ops/encoding.py | 35 +- mmsegmentation/mmseg/ops/wrappers.py | 42 +- mmsegmentation/mmseg/utils/__init__.py | 9 +- mmsegmentation/mmseg/utils/collect_env.py | 6 +- mmsegmentation/mmseg/utils/logger.py | 2 +- mmsegmentation/mmseg/utils/misc.py | 14 +- mmsegmentation/mmseg/utils/set_env.py | 38 +- .../mmseg/utils/util_distribution.py | 43 +- mmsegmentation/mmseg/version.py | 10 +- mmsegmentation/setup.py | 135 +-- .../tests/test_apis/test_single_gpu.py | 23 +- mmsegmentation/tests/test_config.py | 71 +- .../test_layer_decay_optimizer_constructor.py | 206 ++-- .../tests/test_core/test_optimizer.py | 22 +- .../tests/test_data/test_dataset.py | 768 ++++++------ .../tests/test_data/test_dataset_builder.py | 101 +- .../tests/test_data/test_loading.py | 149 ++- .../tests/test_data/test_transform.py | 584 +++++----- mmsegmentation/tests/test_data/test_tta.py | 173 +-- mmsegmentation/tests/test_digit_version.py | 32 +- mmsegmentation/tests/test_eval_hook.py | 125 +- mmsegmentation/tests/test_inference.py | 9 +- mmsegmentation/tests/test_metrics.py | 242 ++-- .../test_models/test_backbones/__init__.py | 2 +- .../test_models/test_backbones/test_beit.py | 25 +- .../test_backbones/test_bisenetv1.py | 36 +- .../test_backbones/test_bisenetv2.py | 3 +- .../test_models/test_backbones/test_blocks.py | 39 +- .../test_models/test_backbones/test_cgnet.py | 6 +- .../test_models/test_backbones/test_erfnet.py | 7 +- .../test_backbones/test_fast_scnn.py | 10 +- .../test_models/test_backbones/test_hrnet.py | 43 +- .../test_models/test_backbones/test_icnet.py | 17 +- .../test_models/test_backbones/test_mae.py | 29 +- .../test_models/test_backbones/test_mit.py | 28 +- .../test_backbones/test_mobilenet_v3.py | 4 +- .../test_backbones/test_resnest.py | 8 +- .../test_models/test_backbones/test_resnet.py | 157 +-- .../test_backbones/test_resnext.py | 15 +- .../test_models/test_backbones/test_stdc.py | 74 +- .../test_models/test_backbones/test_swin.py | 3 +- .../test_backbones/test_timm_backbone.py | 45 +- .../test_models/test_backbones/test_twins.py | 45 +- .../test_models/test_backbones/test_unet.py | 172 +-- .../test_models/test_backbones/test_vit.py | 27 +- .../tests/test_models/test_backbones/utils.py | 12 +- .../tests/test_models/test_forward.py | 116 +- .../test_models/test_heads/test_ann_head.py | 4 +- .../test_models/test_heads/test_apc_head.py | 21 +- .../test_models/test_heads/test_aspp_head.py | 17 +- .../test_models/test_heads/test_cc_head.py | 4 +- .../test_models/test_heads/test_da_head.py | 1 - .../test_heads/test_decode_head.py | 89 +- .../test_models/test_heads/test_dm_head.py | 21 +- .../test_models/test_heads/test_dnl_head.py | 10 +- .../test_models/test_heads/test_dpt_head.py | 20 +- .../test_models/test_heads/test_ema_head.py | 5 +- .../test_models/test_heads/test_enc_head.py | 10 +- .../test_models/test_heads/test_fcn_head.py | 28 +- .../test_models/test_heads/test_gc_head.py | 2 +- .../test_models/test_heads/test_isa_head.py | 8 +- .../test_models/test_heads/test_knet_head.py | 79 +- .../test_heads/test_lraspp_head.py | 34 +- .../test_models/test_heads/test_nl_head.py | 2 +- .../test_models/test_heads/test_ocr_head.py | 4 +- .../test_models/test_heads/test_point_head.py | 33 +- .../test_models/test_heads/test_psa_head.py | 41 +- .../test_models/test_heads/test_psp_head.py | 10 +- .../test_heads/test_segformer_head.py | 12 +- .../test_heads/test_segmenter_mask_head.py | 3 +- .../test_heads/test_setr_mla_head.py | 14 +- .../test_heads/test_setr_up_head.py | 9 +- .../test_models/test_heads/test_stdc_head.py | 15 +- .../test_models/test_heads/test_uper_head.py | 12 +- .../test_models/test_losses/test_ce_loss.py | 205 ++-- .../test_models/test_losses/test_dice_loss.py | 50 +- .../test_losses/test_focal_loss.py | 62 +- .../test_losses/test_lovasz_loss.py | 83 +- .../test_losses/test_tversky_loss.py | 40 +- .../test_models/test_losses/test_utils.py | 30 +- .../test_necks/test_feature2pyramid.py | 9 +- .../tests/test_models/test_necks/test_fpn.py | 6 +- .../test_models/test_necks/test_ic_neck.py | 17 +- .../tests/test_models/test_necks/test_jpu.py | 4 +- .../test_necks/test_multilevel_neck.py | 1 - .../test_cascade_encoder_decoder.py | 45 +- .../test_segmentors/test_encoder_decoder.py | 37 +- .../test_models/test_segmentors/utils.py | 68 +- .../test_models/test_utils/test_embed.py | 104 +- .../test_utils/test_shape_convert.py | 3 +- mmsegmentation/tests/test_sampler.py | 17 +- mmsegmentation/tests/test_utils/test_misc.py | 20 +- .../tests/test_utils/test_set_env.py | 88 +- .../test_utils/test_util_distribution.py | 34 +- mmsegmentation/tools/analyze_logs.py | 64 +- mmsegmentation/tools/benchmark.py | 70 +- mmsegmentation/tools/browse_dataset.py | 128 +- mmsegmentation/tools/confusion_matrix.py | 102 +- .../tools/convert_datasets/chase_db1.py | 80 +- .../tools/convert_datasets/cityscapes.py | 34 +- .../tools/convert_datasets/coco_stuff10k.py | 116 +- .../tools/convert_datasets/coco_stuff164k.py | 88 +- .../tools/convert_datasets/drive.py | 105 +- mmsegmentation/tools/convert_datasets/hrf.py | 109 +- .../tools/convert_datasets/isaid.py | 195 ++-- .../tools/convert_datasets/loveda.py | 64 +- .../tools/convert_datasets/pascal_context.py | 126 +- .../tools/convert_datasets/potsdam.py | 187 +-- .../tools/convert_datasets/stare.py | 158 +-- .../tools/convert_datasets/vaihingen.py | 179 ++- .../tools/convert_datasets/voc_aug.py | 67 +- mmsegmentation/tools/deploy_test.py | 255 ++-- mmsegmentation/tools/get_flops.py | 53 +- mmsegmentation/tools/inference.py | 8 +- .../tools/model_converters/beit2mmseg.py | 41 +- .../tools/model_converters/mit2mmseg.py | 77 +- .../tools/model_converters/stdc2mmseg.py | 72 +- .../tools/model_converters/swin2mmseg.py | 58 +- .../tools/model_converters/twins2mmseg.py | 69 +- .../tools/model_converters/vit2mmseg.py | 58 +- .../tools/model_converters/vitjax2mmseg.py | 114 +- mmsegmentation/tools/model_ensemble.py | 53 +- mmsegmentation/tools/onnx2tensorrt.py | 258 +++-- mmsegmentation/tools/print_config.py | 58 +- mmsegmentation/tools/publish_model.py | 21 +- mmsegmentation/tools/pytorch2onnx.py | 284 ++--- mmsegmentation/tools/pytorch2torchscript.py | 99 +- mmsegmentation/tools/test.py | 293 ++--- mmsegmentation/tools/test_jwseo.py | 293 ++--- .../tools/torchserve/mmseg2torchserve.py | 83 +- .../tools/torchserve/mmseg_handler.py | 21 +- .../tools/torchserve/test_torchserve.py | 37 +- mmsegmentation/tools/train.py | 196 ++-- mmsegmentation/tools/train_jwseo.py | 5 +- src/inference.py | 4 +- src/train.py | 4 +- src/utils/stratified_kfold.py | 1 - 1012 files changed, 19244 insertions(+), 14743 deletions(-) diff --git a/mmsegmentation/.dev/benchmark_inference.py b/mmsegmentation/.dev/benchmark_inference.py index 3ab681b..1a6e5dd 100644 --- a/mmsegmentation/.dev/benchmark_inference.py +++ b/mmsegmentation/.dev/benchmark_inference.py @@ -13,29 +13,29 @@ from mmseg.utils import get_root_logger # ignore warnings when segmentors inference -warnings.filterwarnings('ignore') +warnings.filterwarnings("ignore") def download_checkpoint(checkpoint_name, model_name, config_name, collect_dir): """Download checkpoint and check if hash code is true.""" - url = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}' # noqa + url = f"https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}" # noqa r = requests.get(url) - assert r.status_code != 403, f'{url} Access denied.' + assert r.status_code != 403, f"{url} Access denied." - with open(osp.join(collect_dir, checkpoint_name), 'wb') as code: + with open(osp.join(collect_dir, checkpoint_name), "wb") as code: code.write(r.content) - true_hash_code = osp.splitext(checkpoint_name)[0].split('-')[1] + true_hash_code = osp.splitext(checkpoint_name)[0].split("-")[1] # check hash code - with open(osp.join(collect_dir, checkpoint_name), 'rb') as fp: + with open(osp.join(collect_dir, checkpoint_name), "rb") as fp: sha256_cal = hashlib.sha256() sha256_cal.update(fp.read()) cur_hash_code = sha256_cal.hexdigest()[:8] - assert true_hash_code == cur_hash_code, f'{url} download failed, ' - 'incomplete downloaded file or url invalid.' + assert true_hash_code == cur_hash_code, f"{url} download failed, " + "incomplete downloaded file or url invalid." if cur_hash_code != true_hash_code: os.remove(osp.join(collect_dir, checkpoint_name)) @@ -43,32 +43,31 @@ def download_checkpoint(checkpoint_name, model_name, config_name, collect_dir): def parse_args(): parser = ArgumentParser() - parser.add_argument('config', help='test config file path') - parser.add_argument('checkpoint_root', help='Checkpoint file root path') + parser.add_argument("config", help="test config file path") + parser.add_argument("checkpoint_root", help="Checkpoint file root path") + parser.add_argument("-i", "--img", default="demo/demo.png", help="Image file") + parser.add_argument("-a", "--aug", action="store_true", help="aug test") + parser.add_argument("-m", "--model-name", help="model name to inference") + parser.add_argument("-s", "--show", action="store_true", help="show results") parser.add_argument( - '-i', '--img', default='demo/demo.png', help='Image file') - parser.add_argument('-a', '--aug', action='store_true', help='aug test') - parser.add_argument('-m', '--model-name', help='model name to inference') - parser.add_argument( - '-s', '--show', action='store_true', help='show results') - parser.add_argument( - '-d', '--device', default='cuda:0', help='Device used for inference') + "-d", "--device", default="cuda:0", help="Device used for inference" + ) return parser.parse_args() def inference_model(config_name, checkpoint, args, logger=None): cfg = Config.fromfile(config_name) if args.aug: - if 'flip' in cfg.data.test.pipeline[ - 1] and 'img_scale' in cfg.data.test.pipeline[1]: - cfg.data.test.pipeline[1].img_ratios = [ - 0.5, 0.75, 1.0, 1.25, 1.5, 1.75 - ] + if ( + "flip" in cfg.data.test.pipeline[1] + and "img_scale" in cfg.data.test.pipeline[1] + ): + cfg.data.test.pipeline[1].img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] cfg.data.test.pipeline[1].flip = True elif logger is None: - print(f'{config_name}: unable to start aug test', flush=True) + print(f"{config_name}: unable to start aug test", flush=True) else: - logger.error(f'{config_name}: unable to start aug test') + logger.error(f"{config_name}: unable to start aug test") model = init_segmentor(cfg, checkpoint, device=args.device) # test a single image @@ -94,23 +93,25 @@ def main(args): if not isinstance(model_infos, list): model_infos = [model_infos] for model_info in model_infos: - config_name = model_info['config'].strip() - print(f'processing: {config_name}', flush=True) - checkpoint = osp.join(args.checkpoint_root, - model_info['checkpoint'].strip()) + config_name = model_info["config"].strip() + print(f"processing: {config_name}", flush=True) + checkpoint = osp.join( + args.checkpoint_root, model_info["checkpoint"].strip() + ) try: # build the model from a config file and a checkpoint file inference_model(config_name, checkpoint, args) except Exception: - print(f'{config_name} test failed!') + print(f"{config_name} test failed!") continue return else: - raise RuntimeError('model name input error.') + raise RuntimeError("model name input error.") # test all model logger = get_root_logger( - log_file='benchmark_inference_image.log', log_level=logging.ERROR) + log_file="benchmark_inference_image.log", log_level=logging.ERROR + ) for model_name in config: model_infos = config[model_name] @@ -118,20 +119,23 @@ def main(args): if not isinstance(model_infos, list): model_infos = [model_infos] for model_info in model_infos: - print('processing: ', model_info['config'], flush=True) - config_path = model_info['config'].strip() + print("processing: ", model_info["config"], flush=True) + config_path = model_info["config"].strip() config_name = osp.splitext(osp.basename(config_path))[0] - checkpoint_name = model_info['checkpoint'].strip() + checkpoint_name = model_info["checkpoint"].strip() checkpoint = osp.join(args.checkpoint_root, checkpoint_name) # ensure checkpoint exists try: if not osp.exists(checkpoint): - download_checkpoint(checkpoint_name, model_name, - config_name.rstrip('.py'), - args.checkpoint_root) + download_checkpoint( + checkpoint_name, + model_name, + config_name.rstrip(".py"), + args.checkpoint_root, + ) except Exception: - logger.error(f'{checkpoint_name} download error') + logger.error(f"{checkpoint_name} download error") continue # test model inference with checkpoint @@ -142,6 +146,6 @@ def main(args): logger.error(f'{config_path} " : {repr(e)}') -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() main(args) diff --git a/mmsegmentation/.dev/check_urls.py b/mmsegmentation/.dev/check_urls.py index c98d0a1..a8f12cb 100644 --- a/mmsegmentation/.dev/check_urls.py +++ b/mmsegmentation/.dev/check_urls.py @@ -25,12 +25,10 @@ def check_url(url): def parse_args(): - parser = ArgumentParser('url valid check.') + parser = ArgumentParser("url valid check.") parser.add_argument( - '-m', - '--model-name', - type=str, - help='Select the model needed to check') + "-m", "--model-name", type=str, help="Select the model needed to check" + ) return parser.parse_args() @@ -42,56 +40,63 @@ def main(): # yml path generate. # If model_name is not set, script will check all of the models. if model_name is not None: - yml_list = [(model_name, f'configs/{model_name}/{model_name}.yml')] + yml_list = [(model_name, f"configs/{model_name}/{model_name}.yml")] else: # check all - yml_list = [(x, f'configs/{x}/{x}.yml') for x in os.listdir('configs/') - if x != '_base_'] + yml_list = [ + (x, f"configs/{x}/{x}.yml") for x in os.listdir("configs/") if x != "_base_" + ] - logger = get_root_logger(log_file='url_check.log', log_level=logging.ERROR) + logger = get_root_logger(log_file="url_check.log", log_level=logging.ERROR) for model_name, yml_path in yml_list: # Default yaml loader unsafe. - model_infos = yml.load( - open(yml_path, 'r'), Loader=yml.CLoader)['Models'] + model_infos = yml.load(open(yml_path), Loader=yml.CLoader)["Models"] for model_info in model_infos: - config_name = model_info['Name'] - checkpoint_url = model_info['Weights'] + config_name = model_info["Name"] + checkpoint_url = model_info["Weights"] # checkpoint url check status_code, flag = check_url(checkpoint_url) if flag: - logger.info(f'checkpoint | {config_name} | {checkpoint_url} | ' - f'{status_code} valid') + logger.info( + f"checkpoint | {config_name} | {checkpoint_url} | " + f"{status_code} valid" + ) else: logger.error( - f'checkpoint | {config_name} | {checkpoint_url} | ' - f'{status_code} | error') + f"checkpoint | {config_name} | {checkpoint_url} | " + f"{status_code} | error" + ) # log_json check - checkpoint_name = checkpoint_url.split('/')[-1] - model_time = '-'.join(checkpoint_name.split('-')[:-1]).replace( - f'{config_name}_', '') + checkpoint_name = checkpoint_url.split("/")[-1] + model_time = "-".join(checkpoint_name.split("-")[:-1]).replace( + f"{config_name}_", "" + ) # two style of log_json name # use '_' to link model_time (will be deprecated) - log_json_url_1 = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{config_name}_{model_time}.log.json' # noqa + log_json_url_1 = f"https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{config_name}_{model_time}.log.json" # noqa status_code_1, flag_1 = check_url(log_json_url_1) # use '-' to link model_time - log_json_url_2 = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{config_name}-{model_time}.log.json' # noqa + log_json_url_2 = f"https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{config_name}-{model_time}.log.json" # noqa status_code_2, flag_2 = check_url(log_json_url_2) if flag_1 or flag_2: if flag_1: logger.info( - f'log.json | {config_name} | {log_json_url_1} | ' - f'{status_code_1} | valid') + f"log.json | {config_name} | {log_json_url_1} | " + f"{status_code_1} | valid" + ) else: logger.info( - f'log.json | {config_name} | {log_json_url_2} | ' - f'{status_code_2} | valid') + f"log.json | {config_name} | {log_json_url_2} | " + f"{status_code_2} | valid" + ) else: logger.error( - f'log.json | {config_name} | {log_json_url_1} & ' - f'{log_json_url_2} | {status_code_1} & {status_code_2} | ' - 'error') + f"log.json | {config_name} | {log_json_url_1} & " + f"{log_json_url_2} | {status_code_1} & {status_code_2} | " + "error" + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/.dev/gather_benchmark_evaluation_results.py b/mmsegmentation/.dev/gather_benchmark_evaluation_results.py index a8bfb4c..0a11b44 100644 --- a/mmsegmentation/.dev/gather_benchmark_evaluation_results.py +++ b/mmsegmentation/.dev/gather_benchmark_evaluation_results.py @@ -9,24 +9,24 @@ def parse_args(): parser = argparse.ArgumentParser( - description='Gather benchmarked model evaluation results') - parser.add_argument('config', help='test config file path') + description="Gather benchmarked model evaluation results" + ) + parser.add_argument("config", help="test config file path") parser.add_argument( - 'root', - type=str, - help='root path of benchmarked models to be gathered') + "root", type=str, help="root path of benchmarked models to be gathered" + ) parser.add_argument( - '--out', + "--out", type=str, - default='benchmark_evaluation_info.json', - help='output path of gathered metrics and compared ' - 'results to be stored') + default="benchmark_evaluation_info.json", + help="output path of gathered metrics and compared " "results to be stored", + ) args = parser.parse_args() return args -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() root_path = args.root @@ -40,52 +40,54 @@ def parse_args(): if not isinstance(model_infos, list): model_infos = [model_infos] for model_info in model_infos: - previous_metrics = model_info['metric'] - config = model_info['config'].strip() + previous_metrics = model_info["metric"] + config = model_info["config"].strip() fname, _ = osp.splitext(osp.basename(config)) # Load benchmark evaluation json metric_json_dir = osp.join(root_path, fname) if not osp.exists(metric_json_dir): - print(f'{metric_json_dir} not existed.') + print(f"{metric_json_dir} not existed.") continue - json_list = glob.glob(osp.join(metric_json_dir, '*.json')) + json_list = glob.glob(osp.join(metric_json_dir, "*.json")) if len(json_list) == 0: - print(f'There is no eval json in {metric_json_dir}.') + print(f"There is no eval json in {metric_json_dir}.") continue log_json_path = list(sorted(json_list))[-1] metric = mmcv.load(log_json_path) - if config not in metric.get('config', {}): - print(f'{config} not included in {log_json_path}') + if config not in metric.get("config", {}): + print(f"{config} not included in {log_json_path}") continue # Compare between new benchmark results and previous metrics differential_results = {} new_metrics = {} for record_metric_key in previous_metrics: - if record_metric_key not in metric['metric']: - raise KeyError('record_metric_key not exist, please ' - 'check your config') + if record_metric_key not in metric["metric"]: + raise KeyError( + "record_metric_key not exist, please " "check your config" + ) old_metric = previous_metrics[record_metric_key] - new_metric = round(metric['metric'][record_metric_key] * 100, - 2) + new_metric = round(metric["metric"][record_metric_key] * 100, 2) differential = new_metric - old_metric - flag = '+' if differential > 0 else '-' + flag = "+" if differential > 0 else "-" differential_results[ - record_metric_key] = f'{flag}{abs(differential):.2f}' + record_metric_key + ] = f"{flag}{abs(differential):.2f}" new_metrics[record_metric_key] = new_metric result_dict[config] = dict( differential=differential_results, previous=previous_metrics, - new=new_metrics) + new=new_metrics, + ) if metrics_out: mmcv.dump(result_dict, metrics_out, indent=4) - print('===================================') + print("===================================") for config_name, metrics in result_dict.items(): print(config_name, metrics) - print('===================================') + print("===================================") diff --git a/mmsegmentation/.dev/gather_benchmark_train_results.py b/mmsegmentation/.dev/gather_benchmark_train_results.py index e729ca2..e47cf90 100644 --- a/mmsegmentation/.dev/gather_benchmark_train_results.py +++ b/mmsegmentation/.dev/gather_benchmark_train_results.py @@ -9,23 +9,24 @@ def parse_args(): parser = argparse.ArgumentParser( - description='Gather benchmarked models train results') - parser.add_argument('config', help='test config file path') + description="Gather benchmarked models train results" + ) + parser.add_argument("config", help="test config file path") parser.add_argument( - 'root', - type=str, - help='root path of benchmarked models to be gathered') + "root", type=str, help="root path of benchmarked models to be gathered" + ) parser.add_argument( - '--out', + "--out", type=str, - default='benchmark_train_info.json', - help='output path of gathered metrics to be stored') + default="benchmark_train_info.json", + help="output path of gathered metrics to be stored", + ) args = parser.parse_args() return args -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() root_path = args.root @@ -39,14 +40,14 @@ def parse_args(): if not isinstance(model_infos, list): model_infos = [model_infos] for model_info in model_infos: - config = model_info['config'] + config = model_info["config"] # benchmark train dir model_name = osp.split(osp.dirname(config))[1] config_name = osp.splitext(osp.basename(config))[0] exp_dir = osp.join(root_path, model_name, config_name) if not osp.exists(exp_dir): - print(f'{config} hasn\'t {exp_dir}') + print(f"{config} hasn't {exp_dir}") continue # parse config @@ -57,34 +58,32 @@ def parse_args(): exp_metrics = [exp_metric] # determine whether total_iters ckpt exists - ckpt_path = f'iter_{total_iters}.pth' + ckpt_path = f"iter_{total_iters}.pth" if not osp.exists(osp.join(exp_dir, ckpt_path)): - print(f'{config} hasn\'t {ckpt_path}') + print(f"{config} hasn't {ckpt_path}") continue # only the last log json counts - log_json_path = list( - sorted(glob.glob(osp.join(exp_dir, '*.log.json'))))[-1] + log_json_path = list(sorted(glob.glob(osp.join(exp_dir, "*.log.json"))))[-1] # extract metric value model_performance = get_final_results(log_json_path, total_iters) if model_performance is None: - print(f'log file error: {log_json_path}') + print(f"log file error: {log_json_path}") continue differential_results = {} old_results = {} new_results = {} for metric_key in model_performance: - if metric_key in ['mIoU']: + if metric_key in ["mIoU"]: metric = round(model_performance[metric_key] * 100, 2) - old_metric = model_info['metric'][metric_key] + old_metric = model_info["metric"][metric_key] old_results[metric_key] = old_metric new_results[metric_key] = metric differential = metric - old_metric - flag = '+' if differential > 0 else '-' - differential_results[ - metric_key] = f'{flag}{abs(differential):.2f}' + flag = "+" if differential > 0 else "-" + differential_results[metric_key] = f"{flag}{abs(differential):.2f}" result_dict[config] = dict( differential_results=differential_results, old_results=old_results, @@ -94,7 +93,7 @@ def parse_args(): # 4 save or print results if metrics_out: mmcv.dump(result_dict, metrics_out, indent=4) - print('===================================') + print("===================================") for config_name, metrics in result_dict.items(): print(config_name, metrics) - print('===================================') + print("===================================") diff --git a/mmsegmentation/.dev/gather_models.py b/mmsegmentation/.dev/gather_models.py index 6158623..c1c18a2 100644 --- a/mmsegmentation/.dev/gather_models.py +++ b/mmsegmentation/.dev/gather_models.py @@ -11,29 +11,29 @@ import torch # build schedule look-up table to automatically find the final model -RESULTS_LUT = ['mIoU', 'mAcc', 'aAcc'] +RESULTS_LUT = ["mIoU", "mAcc", "aAcc"] def calculate_file_sha256(file_path): """calculate file sha256 hash code.""" - with open(file_path, 'rb') as fp: + with open(file_path, "rb") as fp: sha256_cal = hashlib.sha256() sha256_cal.update(fp.read()) return sha256_cal.hexdigest() def process_checkpoint(in_file, out_file): - checkpoint = torch.load(in_file, map_location='cpu') + checkpoint = torch.load(in_file, map_location="cpu") # remove optimizer for smaller file size - if 'optimizer' in checkpoint: - del checkpoint['optimizer'] + if "optimizer" in checkpoint: + del checkpoint["optimizer"] # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. torch.save(checkpoint, out_file) # The hash code calculation and rename command differ on different system # platform. sha = calculate_file_sha256(out_file) - final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth' + final_file = out_file.rstrip(".pth") + f"-{sha[:8]}.pth" os.rename(out_file, final_file) # Remove prefix and suffix @@ -44,50 +44,53 @@ def process_checkpoint(in_file, out_file): def get_final_iter(config): - iter_num = config.split('_')[-2] - assert iter_num.endswith('k') + iter_num = config.split("_")[-2] + assert iter_num.endswith("k") return int(iter_num[:-1]) * 1000 def get_final_results(log_json_path, iter_num): result_dict = {} last_iter = 0 - with open(log_json_path, 'r') as f: + with open(log_json_path) as f: for line in f: log_line = json.loads(line) - if 'mode' not in log_line.keys(): + if "mode" not in log_line.keys(): continue # When evaluation, the 'iter' of new log json is the evaluation # steps on single gpu. - flag1 = 'aAcc' in log_line or log_line['mode'] == 'val' + flag1 = "aAcc" in log_line or log_line["mode"] == "val" flag2 = last_iter in [iter_num - 50, iter_num] if flag1 and flag2: - result_dict.update({ - key: log_line[key] - for key in RESULTS_LUT if key in log_line - }) + result_dict.update( + {key: log_line[key] for key in RESULTS_LUT if key in log_line} + ) return result_dict - last_iter = log_line['iter'] + last_iter = log_line["iter"] def parse_args(): - parser = argparse.ArgumentParser(description='Gather benchmarked models') + parser = argparse.ArgumentParser(description="Gather benchmarked models") parser.add_argument( - '-f', '--config-name', type=str, help='Process the selected config.') + "-f", "--config-name", type=str, help="Process the selected config." + ) parser.add_argument( - '-w', - '--work-dir', - default='work_dirs/', + "-w", + "--work-dir", + default="work_dirs/", type=str, - help='Ckpt storage root folder of benchmarked models to be gathered.') + help="Ckpt storage root folder of benchmarked models to be gathered.", + ) parser.add_argument( - '-c', - '--collect-dir', - default='work_dirs/gather', + "-c", + "--collect-dir", + default="work_dirs/gather", type=str, - help='Ckpt collect root folder of gathered models.') + help="Ckpt collect root folder of gathered models.", + ) parser.add_argument( - '--all', action='store_true', help='whether include .py and .log') + "--all", action="store_true", help="whether include .py and .log" + ) args = parser.parse_args() return args @@ -101,17 +104,16 @@ def main(): mmcv.mkdir_or_exist(collect_dir) # find all models in the root directory to be gathered - raw_configs = list(mmcv.scandir('./configs', '.py', recursive=True)) + raw_configs = list(mmcv.scandir("./configs", ".py", recursive=True)) # filter configs that is not trained in the experiments dir used_configs = [] for raw_config in raw_configs: config_name = osp.splitext(osp.basename(raw_config))[0] if osp.exists(osp.join(work_dir, config_name)): - if (selected_config_name is None - or selected_config_name == config_name): + if selected_config_name is None or selected_config_name == config_name: used_configs.append(raw_config) - print(f'Find {len(used_configs)} models to be gathered') + print(f"Find {len(used_configs)} models to be gathered") # find final_ckpt and log file for trained each config # and parse the best performance @@ -121,16 +123,16 @@ def main(): exp_dir = osp.join(work_dir, config_name) # check whether the exps is finished final_iter = get_final_iter(used_config) - final_model = f'iter_{final_iter}.pth' + final_model = f"iter_{final_iter}.pth" model_path = osp.join(exp_dir, final_model) # skip if the model is still training if not osp.exists(model_path): - print(f'{used_config} train not finished yet') + print(f"{used_config} train not finished yet") continue # get logs - log_json_paths = glob.glob(osp.join(exp_dir, '*.log.json')) + log_json_paths = glob.glob(osp.join(exp_dir, "*.log.json")) log_json_path = log_json_paths[0] model_performance = None for _log_json_path in log_json_paths: @@ -140,71 +142,77 @@ def main(): break if model_performance is None: - print(f'{used_config} model_performance is None') + print(f"{used_config} model_performance is None") continue - model_time = osp.split(log_json_path)[-1].split('.')[0] + model_time = osp.split(log_json_path)[-1].split(".")[0] model_infos.append( dict( config_name=config_name, results=model_performance, iters=final_iter, model_time=model_time, - log_json_path=osp.split(log_json_path)[-1])) + log_json_path=osp.split(log_json_path)[-1], + ) + ) # publish model for each checkpoint publish_model_infos = [] for model in model_infos: - config_name = model['config_name'] + config_name = model["config_name"] model_publish_dir = osp.join(collect_dir, config_name) - publish_model_path = osp.join(model_publish_dir, - f'{config_name}_' + model['model_time']) + publish_model_path = osp.join( + model_publish_dir, f"{config_name}_" + model["model_time"] + ) - trained_model_path = osp.join(work_dir, config_name, - f'iter_{model["iters"]}.pth') + trained_model_path = osp.join( + work_dir, config_name, f'iter_{model["iters"]}.pth' + ) if osp.exists(model_publish_dir): for file in os.listdir(model_publish_dir): - if file.endswith('.pth'): - print(f'model {file} found') - model['model_path'] = osp.abspath( - osp.join(model_publish_dir, file)) + if file.endswith(".pth"): + print(f"model {file} found") + model["model_path"] = osp.abspath(osp.join(model_publish_dir, file)) break - if 'model_path' not in model: - print(f'dir {model_publish_dir} exists, no model found') + if "model_path" not in model: + print(f"dir {model_publish_dir} exists, no model found") else: mmcv.mkdir_or_exist(model_publish_dir) # convert model - final_model_path = process_checkpoint(trained_model_path, - publish_model_path) - model['model_path'] = final_model_path + final_model_path = process_checkpoint( + trained_model_path, publish_model_path + ) + model["model_path"] = final_model_path new_json_path = f'{config_name}_{model["log_json_path"]}' # copy log shutil.copy( - osp.join(work_dir, config_name, model['log_json_path']), - osp.join(model_publish_dir, new_json_path)) + osp.join(work_dir, config_name, model["log_json_path"]), + osp.join(model_publish_dir, new_json_path), + ) if args.all: - new_txt_path = new_json_path.rstrip('.json') + new_txt_path = new_json_path.rstrip(".json") shutil.copy( - osp.join(work_dir, config_name, - model['log_json_path'].rstrip('.json')), - osp.join(model_publish_dir, new_txt_path)) + osp.join(work_dir, config_name, model["log_json_path"].rstrip(".json")), + osp.join(model_publish_dir, new_txt_path), + ) if args.all: # copy config to guarantee reproducibility - raw_config = osp.join('./configs', f'{config_name}.py') + raw_config = osp.join("./configs", f"{config_name}.py") mmcv.Config.fromfile(raw_config).dump( - osp.join(model_publish_dir, osp.basename(raw_config))) + osp.join(model_publish_dir, osp.basename(raw_config)) + ) publish_model_infos.append(model) models = dict(models=publish_model_infos) - mmcv.dump(models, osp.join(collect_dir, 'model_infos.json'), indent=4) + mmcv.dump(models, osp.join(collect_dir, "model_infos.json"), indent=4) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/.dev/generate_benchmark_evaluation_script.py b/mmsegmentation/.dev/generate_benchmark_evaluation_script.py index fd49f2b..d859710 100644 --- a/mmsegmentation/.dev/generate_benchmark_evaluation_script.py +++ b/mmsegmentation/.dev/generate_benchmark_evaluation_script.py @@ -7,64 +7,69 @@ def parse_args(): parser = argparse.ArgumentParser( - description='Convert benchmark test model list to script') - parser.add_argument('config', help='test config file path') - parser.add_argument('--port', type=int, default=28171, help='dist port') + description="Convert benchmark test model list to script" + ) + parser.add_argument("config", help="test config file path") + parser.add_argument("--port", type=int, default=28171, help="dist port") parser.add_argument( - '--work-dir', - default='work_dirs/benchmark_evaluation', - help='the dir to save metric') + "--work-dir", + default="work_dirs/benchmark_evaluation", + help="the dir to save metric", + ) parser.add_argument( - '--out', + "--out", type=str, - default='.dev/benchmark_evaluation.sh', - help='path to save model benchmark script') + default=".dev/benchmark_evaluation.sh", + help="path to save model benchmark script", + ) return parser.parse_args() def process_model_info(model_info, work_dir): - config = model_info['config'].strip() + config = model_info["config"].strip() fname, _ = osp.splitext(osp.basename(config)) job_name = fname - checkpoint = model_info['checkpoint'].strip() + checkpoint = model_info["checkpoint"].strip() work_dir = osp.join(work_dir, fname) - evals = model_info['eval'] if isinstance(model_info['eval'], - list) else [model_info['eval']] + evals = ( + model_info["eval"] + if isinstance(model_info["eval"], list) + else [model_info["eval"]] + ) - eval = ' '.join(evals) + eval = " ".join(evals) return dict( config=config, job_name=job_name, checkpoint=checkpoint, work_dir=work_dir, - eval=eval) + eval=eval, + ) -def create_test_bash_info(commands, model_test_dict, port, script_name, - partition): - config = model_test_dict['config'] - job_name = model_test_dict['job_name'] - checkpoint = model_test_dict['checkpoint'] - work_dir = model_test_dict['work_dir'] - eval = model_test_dict['eval'] +def create_test_bash_info(commands, model_test_dict, port, script_name, partition): + config = model_test_dict["config"] + job_name = model_test_dict["job_name"] + checkpoint = model_test_dict["checkpoint"] + work_dir = model_test_dict["work_dir"] + eval = model_test_dict["eval"] - echo_info = f'\necho \'{config}\' &' + echo_info = f"\necho '{config}' &" commands.append(echo_info) - commands.append('\n') + commands.append("\n") - command_info = f'GPUS=4 GPUS_PER_NODE=4 ' \ - f'CPUS_PER_TASK=2 {script_name} ' + command_info = f"GPUS=4 GPUS_PER_NODE=4 " f"CPUS_PER_TASK=2 {script_name} " - command_info += f'{partition} ' - command_info += f'{job_name} ' - command_info += f'{config} ' - command_info += f'$CHECKPOINT_DIR/{checkpoint} ' + command_info += f"{partition} " + command_info += f"{job_name} " + command_info += f"{config} " + command_info += f"$CHECKPOINT_DIR/{checkpoint} " - command_info += f'--eval {eval} ' - command_info += f'--work-dir {work_dir} ' - command_info += f'--cfg-options dist_params.port={port} ' - command_info += '&' + command_info += f"--eval {eval} " + command_info += f"--work-dir {work_dir} " + command_info += f"--cfg-options dist_params.port={port} " + command_info += "&" commands.append(command_info) @@ -72,20 +77,21 @@ def create_test_bash_info(commands, model_test_dict, port, script_name, def main(): args = parse_args() if args.out: - out_suffix = args.out.split('.')[-1] - assert args.out.endswith('.sh'), \ - f'Expected out file path suffix is .sh, but get .{out_suffix}' + out_suffix = args.out.split(".")[-1] + assert args.out.endswith( + ".sh" + ), f"Expected out file path suffix is .sh, but get .{out_suffix}" commands = [] - partition_name = 'PARTITION=$1' + partition_name = "PARTITION=$1" commands.append(partition_name) - commands.append('\n') + commands.append("\n") - checkpoint_root = 'CHECKPOINT_DIR=$2' + checkpoint_root = "CHECKPOINT_DIR=$2" commands.append(checkpoint_root) - commands.append('\n') + commands.append("\n") - script_name = osp.join('tools', 'slurm_test.sh') + script_name = osp.join("tools", "slurm_test.sh") port = args.port work_dir = args.work_dir @@ -96,17 +102,18 @@ def main(): if not isinstance(model_infos, list): model_infos = [model_infos] for model_info in model_infos: - print('processing: ', model_info['config']) + print("processing: ", model_info["config"]) model_test_dict = process_model_info(model_info, work_dir) - create_test_bash_info(commands, model_test_dict, port, script_name, - '$PARTITION') + create_test_bash_info( + commands, model_test_dict, port, script_name, "$PARTITION" + ) port += 1 - command_str = ''.join(commands) + command_str = "".join(commands) if args.out: - with open(args.out, 'w') as f: - f.write(command_str + '\n') + with open(args.out, "w") as f: + f.write(command_str + "\n") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/.dev/generate_benchmark_train_script.py b/mmsegmentation/.dev/generate_benchmark_train_script.py index 32d0a71..2276686 100644 --- a/mmsegmentation/.dev/generate_benchmark_train_script.py +++ b/mmsegmentation/.dev/generate_benchmark_train_script.py @@ -4,23 +4,26 @@ # Default using 4 gpu when training config_8gpu_list = [ - 'configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py', # noqa - 'configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py', - 'configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py', + "configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py", # noqa + "configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py", + "configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py", ] def parse_args(): parser = argparse.ArgumentParser( - description='Convert benchmark model json to script') + description="Convert benchmark model json to script" + ) parser.add_argument( - 'txt_path', type=str, help='txt path output by benchmark_filter') - parser.add_argument('--port', type=int, default=24727, help='dist port') + "txt_path", type=str, help="txt path output by benchmark_filter" + ) + parser.add_argument("--port", type=int, default=24727, help="dist port") parser.add_argument( - '--out', + "--out", type=str, - default='.dev/benchmark_train.sh', - help='path to save model benchmark script') + default=".dev/benchmark_train.sh", + help="path to save model benchmark script", + ) args = parser.parse_args() return args @@ -30,59 +33,59 @@ def create_train_bash_info(commands, config, script_name, partition, port): cfg = config.strip() # print cfg name - echo_info = f'echo \'{cfg}\' &' + echo_info = f"echo '{cfg}' &" commands.append(echo_info) - commands.append('\n') + commands.append("\n") _, model_name = osp.split(osp.dirname(cfg)) config_name, _ = osp.splitext(osp.basename(cfg)) # default setting if cfg in config_8gpu_list: - command_info = f'GPUS=8 GPUS_PER_NODE=8 ' \ - f'CPUS_PER_TASK=2 {script_name} ' + command_info = f"GPUS=8 GPUS_PER_NODE=8 " f"CPUS_PER_TASK=2 {script_name} " else: - command_info = f'GPUS=4 GPUS_PER_NODE=4 ' \ - f'CPUS_PER_TASK=2 {script_name} ' - command_info += f'{partition} ' - command_info += f'{config_name} ' - command_info += f'{cfg} ' - command_info += f'--cfg-options ' \ - f'checkpoint_config.max_keep_ckpts=1 ' \ - f'dist_params.port={port} ' - command_info += f'--work-dir work_dirs/{model_name}/{config_name} ' + command_info = f"GPUS=4 GPUS_PER_NODE=4 " f"CPUS_PER_TASK=2 {script_name} " + command_info += f"{partition} " + command_info += f"{config_name} " + command_info += f"{cfg} " + command_info += ( + f"--cfg-options " + f"checkpoint_config.max_keep_ckpts=1 " + f"dist_params.port={port} " + ) + command_info += f"--work-dir work_dirs/{model_name}/{config_name} " # Let the script shut up - command_info += '>/dev/null &' + command_info += ">/dev/null &" commands.append(command_info) - commands.append('\n') + commands.append("\n") def main(): args = parse_args() if args.out: - out_suffix = args.out.split('.')[-1] - assert args.out.endswith('.sh'), \ - f'Expected out file path suffix is .sh, but get .{out_suffix}' + out_suffix = args.out.split(".")[-1] + assert args.out.endswith( + ".sh" + ), f"Expected out file path suffix is .sh, but get .{out_suffix}" - root_name = './tools' - script_name = osp.join(root_name, 'slurm_train.sh') + root_name = "./tools" + script_name = osp.join(root_name, "slurm_train.sh") port = args.port - partition_name = 'PARTITION=$1' + partition_name = "PARTITION=$1" - commands = [partition_name, '\n', '\n'] + commands = [partition_name, "\n", "\n"] - with open(args.txt_path, 'r') as f: + with open(args.txt_path) as f: model_cfgs = f.readlines() for cfg in model_cfgs: - create_train_bash_info(commands, cfg, script_name, '$PARTITION', - port) + create_train_bash_info(commands, cfg, script_name, "$PARTITION", port) port += 1 - command_str = ''.join(commands) + command_str = "".join(commands) if args.out: - with open(args.out, 'w') as f: + with open(args.out, "w") as f: f.write(command_str) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/.dev/log_collector/example_config.py b/mmsegmentation/.dev/log_collector/example_config.py index bc2b4d6..1aea861 100644 --- a/mmsegmentation/.dev/log_collector/example_config.py +++ b/mmsegmentation/.dev/log_collector/example_config.py @@ -1,18 +1,18 @@ -work_dir = '../../work_dirs' -metric = 'mIoU' +work_dir = "../../work_dirs" +metric = "mIoU" # specify the log files we would like to collect in `log_items` log_items = [ - 'segformer_mit-b5_512x512_160k_ade20k_cnn_lr_with_warmup', - 'segformer_mit-b5_512x512_160k_ade20k_cnn_no_warmup_lr', - 'segformer_mit-b5_512x512_160k_ade20k_mit_trans_lr', - 'segformer_mit-b5_512x512_160k_ade20k_swin_trans_lr' + "segformer_mit-b5_512x512_160k_ade20k_cnn_lr_with_warmup", + "segformer_mit-b5_512x512_160k_ade20k_cnn_no_warmup_lr", + "segformer_mit-b5_512x512_160k_ade20k_mit_trans_lr", + "segformer_mit-b5_512x512_160k_ade20k_swin_trans_lr", ] # or specify ignore_keywords, then the folders whose name contain # `'segformer'` won't be collected # ignore_keywords = ['segformer'] # should not include metric -other_info_keys = ['mAcc'] -markdown_file = 'markdowns/lr_in_trans.json.md' -json_file = 'jsons/trans_in_cnn.json' +other_info_keys = ["mAcc"] +markdown_file = "markdowns/lr_in_trans.json.md" +json_file = "jsons/trans_in_cnn.json" diff --git a/mmsegmentation/.dev/log_collector/log_collector.py b/mmsegmentation/.dev/log_collector/log_collector.py index cc7b413..81e80ed 100644 --- a/mmsegmentation/.dev/log_collector/log_collector.py +++ b/mmsegmentation/.dev/log_collector/log_collector.py @@ -25,8 +25,8 @@ def parse_args(): - parser = argparse.ArgumentParser(description='extract info from log.json') - parser.add_argument('config_dir') + parser = argparse.ArgumentParser(description="extract info from log.json") + parser.add_argument("config_dir") return parser.parse_args() @@ -37,23 +37,23 @@ def has_keyword(name: str, keywords: list): def main(): args = parse_args() cfg = load_config(args.config_dir) - work_dir = cfg['work_dir'] - metric = cfg['metric'] - log_items = cfg.get('log_items', []) - ignore_keywords = cfg.get('ignore_keywords', []) - other_info_keys = cfg.get('other_info_keys', []) - markdown_file = cfg.get('markdown_file', None) - json_file = cfg.get('json_file', None) - - if json_file and osp.split(json_file)[0] != '': + work_dir = cfg["work_dir"] + metric = cfg["metric"] + log_items = cfg.get("log_items", []) + ignore_keywords = cfg.get("ignore_keywords", []) + other_info_keys = cfg.get("other_info_keys", []) + markdown_file = cfg.get("markdown_file", None) + json_file = cfg.get("json_file", None) + + if json_file and osp.split(json_file)[0] != "": os.makedirs(osp.split(json_file)[0], exist_ok=True) - if markdown_file and osp.split(markdown_file)[0] != '': + if markdown_file and osp.split(markdown_file)[0] != "": os.makedirs(osp.split(markdown_file)[0], exist_ok=True) - assert not (log_items and ignore_keywords), \ - 'log_items and ignore_keywords cannot be specified at the same time' - assert metric not in other_info_keys, \ - 'other_info_keys should not contain metric' + assert not ( + log_items and ignore_keywords + ), "log_items and ignore_keywords cannot be specified at the same time" + assert metric not in other_info_keys, "other_info_keys should not contain metric" if ignore_keywords and isinstance(ignore_keywords, str): ignore_keywords = [ignore_keywords] @@ -64,7 +64,8 @@ def main(): if not log_items: log_items = [ - item for item in sorted(os.listdir(work_dir)) + item + for item in sorted(os.listdir(work_dir)) if not has_keyword(item, ignore_keywords) ] @@ -72,50 +73,54 @@ def main(): for config_dir in log_items: preceding_path = os.path.join(work_dir, config_dir) log_list = [ - item for item in os.listdir(preceding_path) - if item.endswith('.log.json') + item for item in os.listdir(preceding_path) if item.endswith(".log.json") ] log_list = sorted( log_list, key=lambda time_str: datetime.datetime.strptime( - time_str, '%Y%m%d_%H%M%S.log.json')) + time_str, "%Y%m%d_%H%M%S.log.json" + ), + ) val_list = [] last_iter = 0 for log_name in log_list: - with open(os.path.join(preceding_path, log_name), 'r') as f: + with open(os.path.join(preceding_path, log_name)) as f: # ignore the info line f.readline() all_lines = f.readlines() - val_list.extend([ - json.loads(line) for line in all_lines - if json.loads(line)['mode'] == 'val' - ]) + val_list.extend( + [ + json.loads(line) + for line in all_lines + if json.loads(line)["mode"] == "val" + ] + ) for index in range(len(all_lines) - 1, -1, -1): line_dict = json.loads(all_lines[index]) - if line_dict['mode'] == 'train': - last_iter = max(last_iter, line_dict['iter']) + if line_dict["mode"] == "train": + last_iter = max(last_iter, line_dict["iter"]) break - new_log_dict = dict( - method=config_dir, metric_used=metric, last_iter=last_iter) + new_log_dict = dict(method=config_dir, metric_used=metric, last_iter=last_iter) for index, log in enumerate(val_list, 1): new_ordered_dict = OrderedDict() - new_ordered_dict['eval_index'] = index + new_ordered_dict["eval_index"] = index new_ordered_dict[metric] = log[metric] for key in other_info_keys: if key in log: new_ordered_dict[key] = log[key] val_list[index - 1] = new_ordered_dict - assert len(val_list) >= 1, \ - f"work dir {config_dir} doesn't contain any evaluation." - new_log_dict['last eval'] = val_list[-1] - new_log_dict['best eval'] = max(val_list, key=lambda x: x[metric]) + assert ( + len(val_list) >= 1 + ), f"work dir {config_dir} doesn't contain any evaluation." + new_log_dict["last eval"] = val_list[-1] + new_log_dict["best eval"] = max(val_list, key=lambda x: x[metric]) experiment_info_list.append(new_log_dict) - print(f'{config_dir} is processed') + print(f"{config_dir} is processed") if json_file: - with open(json_file, 'w') as f: + with open(json_file, "w") as f: json.dump(experiment_info_list, f, indent=4) if markdown_file: @@ -125,15 +130,18 @@ def main(): f"|{index}|{log['method']}|{log['best eval'][metric]}" f"|{log['best eval']['eval_index']}|" f"{log['last eval'][metric]}|" - f"{log['last eval']['eval_index']}|{log['last_iter']}|\n") - with open(markdown_file, 'w') as f: - f.write(f'|exp_num|method|{metric} best|best index|' - f'{metric} last|last index|last iter num|\n') - f.write('|:---:|:---:|:---:|:---:|:---:|:---:|:---:|\n') + f"{log['last eval']['eval_index']}|{log['last_iter']}|\n" + ) + with open(markdown_file, "w") as f: + f.write( + f"|exp_num|method|{metric} best|best index|" + f"{metric} last|last index|last iter num|\n" + ) + f.write("|:---:|:---:|:---:|:---:|:---:|:---:|:---:|\n") f.writelines(lines_to_write) - print('processed successfully') + print("processed successfully") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/.dev/log_collector/utils.py b/mmsegmentation/.dev/log_collector/utils.py index 848516a..356eee6 100644 --- a/mmsegmentation/.dev/log_collector/utils.py +++ b/mmsegmentation/.dev/log_collector/utils.py @@ -6,15 +6,12 @@ def load_config(cfg_dir: str) -> dict: - assert cfg_dir.endswith('.py') + assert cfg_dir.endswith(".py") root_path, file_name = osp.split(cfg_dir) temp_module = osp.splitext(file_name)[0] sys.path.insert(0, root_path) mod = import_module(temp_module) sys.path.pop(0) - cfg_dict = { - k: v - for k, v in mod.__dict__.items() if not k.startswith('__') - } + cfg_dict = {k: v for k, v in mod.__dict__.items() if not k.startswith("__")} del sys.modules[temp_module] return cfg_dict diff --git a/mmsegmentation/.dev/md2yml.py b/mmsegmentation/.dev/md2yml.py index 1d68498..ef50c86 100755 --- a/mmsegmentation/.dev/md2yml.py +++ b/mmsegmentation/.dev/md2yml.py @@ -15,14 +15,43 @@ from lxml import etree from mmcv.fileio import dump -MMSEG_ROOT = osp.dirname(osp.dirname((osp.dirname(__file__)))) +MMSEG_ROOT = osp.dirname(osp.dirname(osp.dirname(__file__))) COLLECTIONS = [ - 'ANN', 'APCNet', 'BiSeNetV1', 'BiSeNetV2', 'CCNet', 'CGNet', 'DANet', - 'DeepLabV3', 'DeepLabV3+', 'DMNet', 'DNLNet', 'DPT', 'EMANet', 'EncNet', - 'ERFNet', 'FastFCN', 'FastSCNN', 'FCN', 'GCNet', 'ICNet', 'ISANet', 'KNet', - 'NonLocalNet', 'OCRNet', 'PointRend', 'PSANet', 'PSPNet', 'Segformer', - 'Segmenter', 'FPN', 'SETR', 'STDC', 'UNet', 'UPerNet' + "ANN", + "APCNet", + "BiSeNetV1", + "BiSeNetV2", + "CCNet", + "CGNet", + "DANet", + "DeepLabV3", + "DeepLabV3+", + "DMNet", + "DNLNet", + "DPT", + "EMANet", + "EncNet", + "ERFNet", + "FastFCN", + "FastSCNN", + "FCN", + "GCNet", + "ICNet", + "ISANet", + "KNet", + "NonLocalNet", + "OCRNet", + "PointRend", + "PSANet", + "PSPNet", + "Segformer", + "Segmenter", + "FPN", + "SETR", + "STDC", + "UNet", + "UPerNet", ] COLLECTIONS_TEMP = [] @@ -39,10 +68,10 @@ def dump_yaml_and_check_difference(obj, filename, sort_keys=False): Bool: If the target YAML file is different from the original. """ - str_dump = dump(obj, None, file_format='yaml', sort_keys=sort_keys) + str_dump = dump(obj, None, file_format="yaml", sort_keys=sort_keys) if osp.isfile(filename): file_exists = True - with open(filename, 'r', encoding='utf-8') as f: + with open(filename, encoding="utf-8") as f: str_orig = f.read() else: file_exists = False @@ -52,7 +81,7 @@ def dump_yaml_and_check_difference(obj, filename, sort_keys=False): is_different = False else: is_different = True - with open(filename, 'w', encoding='utf-8') as f: + with open(filename, "w", encoding="utf-8") as f: f.write(str_dump) return is_different @@ -71,17 +100,12 @@ def parse_md(md_file): collection = dict( Name=collection_name, - Metadata={'Training Data': []}, - Paper={ - 'URL': '', - 'Title': '' - }, + Metadata={"Training Data": []}, + Paper={"URL": "", "Title": ""}, README=md_file, - Code={ - 'URL': '', - 'Version': '' - }) - collection.update({'Converted From': {'Weights': '', 'Code': ''}}) + Code={"URL": "", "Version": ""}, + ) + collection.update({"Converted From": {"Weights": "", "Code": ""}}) models = [] datasets = [] paper_url = None @@ -97,110 +121,118 @@ def parse_md(md_file): # should be set with head or neck of this config file. is_backbone = None - with open(md_file, 'r', encoding='UTF-8') as md: + with open(md_file, encoding="UTF-8") as md: lines = md.readlines() i = 0 - current_dataset = '' + current_dataset = "" while i < len(lines): line = lines[i].strip() # In latest README.md the title and url are in the third line. if i == 2: - paper_url = lines[i].split('](')[1].split(')')[0] - paper_title = lines[i].split('](')[0].split('[')[1] + paper_url = lines[i].split("](")[1].split(")")[0] + paper_title = lines[i].split("](")[0].split("[")[1] if len(line) == 0: i += 1 continue - elif line[:3] == ' batch_size: 8 samples_per_gpu=1, train=dict(pipeline=train_pipeline), val=dict(pipeline=test_pipeline), - test=dict(pipeline=test_pipeline)) + test=dict(pipeline=test_pipeline), +) diff --git a/mmsegmentation/configs/segmenter/segmenter_vit-l_mask_8x1_640x640_160k_ade20k.py b/mmsegmentation/configs/segmenter/segmenter_vit-l_mask_8x1_640x640_160k_ade20k.py index 4e6a0b1..bcaf068 100644 --- a/mmsegmentation/configs/segmenter/segmenter_vit-l_mask_8x1_640x640_160k_ade20k.py +++ b/mmsegmentation/configs/segmenter/segmenter_vit-l_mask_8x1_640x640_160k_ade20k.py @@ -1,61 +1,66 @@ _base_ = [ - '../_base_/models/segmenter_vit-b16_mask.py', - '../_base_/datasets/ade20k_640x640.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_160k.py' + "../_base_/models/segmenter_vit-b16_mask.py", + "../_base_/datasets/ade20k_640x640.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_large_p16_384_20220308-d4efb41d.pth' # noqa +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_large_p16_384_20220308-d4efb41d.pth" # noqa model = dict( pretrained=checkpoint, backbone=dict( - type='VisionTransformer', + type="VisionTransformer", img_size=(640, 640), embed_dims=1024, num_layers=24, - num_heads=16), + num_heads=16, + ), decode_head=dict( - type='SegmenterMaskTransformerHead', + type="SegmenterMaskTransformerHead", in_channels=1024, channels=1024, num_heads=16, - embed_dims=1024), - test_cfg=dict(mode='slide', crop_size=(640, 640), stride=(608, 608))) + embed_dims=1024, + ), + test_cfg=dict(mode="slide", crop_size=(640, 640), stride=(608, 608)), +) optimizer = dict(lr=0.001, weight_decay=0.0) -img_norm_cfg = dict( - mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) +img_norm_cfg = dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) crop_size = (640, 640) train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', reduce_zero_label=True), - dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), - dict(type='RandomFlip', prob=0.5), - dict(type='PhotoMetricDistortion'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), - dict(type='DefaultFormatBundle'), - dict(type='Collect', keys=['img', 'gt_semantic_seg']) + dict(type="LoadImageFromFile"), + dict(type="LoadAnnotations", reduce_zero_label=True), + dict(type="Resize", img_scale=(2560, 640), ratio_range=(0.5, 2.0)), + dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75), + dict(type="RandomFlip", prob=0.5), + dict(type="PhotoMetricDistortion"), + dict(type="Normalize", **img_norm_cfg), + dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255), + dict(type="DefaultFormatBundle"), + dict(type="Collect", keys=["img", "gt_semantic_seg"]), ] test_pipeline = [ - dict(type='LoadImageFromFile'), + dict(type="LoadImageFromFile"), dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=(2560, 640), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ - dict(type='Resize', keep_ratio=True), - dict(type='RandomFlip'), - dict(type='Normalize', **img_norm_cfg), - dict(type='ImageToTensor', keys=['img']), - dict(type='Collect', keys=['img']) - ]) + dict(type="Resize", keep_ratio=True), + dict(type="RandomFlip"), + dict(type="Normalize", **img_norm_cfg), + dict(type="ImageToTensor", keys=["img"]), + dict(type="Collect", keys=["img"]), + ], + ), ] data = dict( # num_gpus: 8 -> batch_size: 8 samples_per_gpu=1, train=dict(pipeline=train_pipeline), val=dict(pipeline=test_pipeline), - test=dict(pipeline=test_pipeline)) + test=dict(pipeline=test_pipeline), +) diff --git a/mmsegmentation/configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py b/mmsegmentation/configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py index adc8c1b..95e01a1 100644 --- a/mmsegmentation/configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py @@ -1,14 +1,15 @@ -_base_ = './segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py' +_base_ = "./segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py" model = dict( decode_head=dict( _delete_=True, - type='FCNHead', + type="FCNHead", in_channels=384, channels=384, num_convs=0, dropout_ratio=0.0, concat_input=False, num_classes=150, - loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))) + loss_decode=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), + ) +) diff --git a/mmsegmentation/configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py b/mmsegmentation/configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py index 7e0eeb1..46d868a 100644 --- a/mmsegmentation/configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py @@ -1,12 +1,13 @@ _base_ = [ - '../_base_/models/segmenter_vit-b16_mask.py', - '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_160k.py' + "../_base_/models/segmenter_vit-b16_mask.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_small_p16_384_20220308-410f6037.pth' # noqa +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_small_p16_384_20220308-410f6037.pth" # noqa -backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True) +backbone_norm_cfg = dict(type="LN", eps=1e-6, requires_grad=True) model = dict( pretrained=checkpoint, backbone=dict( @@ -15,7 +16,7 @@ num_heads=6, ), decode_head=dict( - type='SegmenterMaskTransformerHead', + type="SegmenterMaskTransformerHead", in_channels=384, channels=384, num_classes=150, @@ -23,44 +24,46 @@ num_heads=6, embed_dims=384, dropout_ratio=0.0, - loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))) + loss_decode=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), + ), +) optimizer = dict(lr=0.001, weight_decay=0.0) -img_norm_cfg = dict( - mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) +img_norm_cfg = dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) crop_size = (512, 512) train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', reduce_zero_label=True), - dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), - dict(type='RandomFlip', prob=0.5), - dict(type='PhotoMetricDistortion'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), - dict(type='DefaultFormatBundle'), - dict(type='Collect', keys=['img', 'gt_semantic_seg']) + dict(type="LoadImageFromFile"), + dict(type="LoadAnnotations", reduce_zero_label=True), + dict(type="Resize", img_scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75), + dict(type="RandomFlip", prob=0.5), + dict(type="PhotoMetricDistortion"), + dict(type="Normalize", **img_norm_cfg), + dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255), + dict(type="DefaultFormatBundle"), + dict(type="Collect", keys=["img", "gt_semantic_seg"]), ] test_pipeline = [ - dict(type='LoadImageFromFile'), + dict(type="LoadImageFromFile"), dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=(2048, 512), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ - dict(type='Resize', keep_ratio=True), - dict(type='RandomFlip'), - dict(type='Normalize', **img_norm_cfg), - dict(type='ImageToTensor', keys=['img']), - dict(type='Collect', keys=['img']) - ]) + dict(type="Resize", keep_ratio=True), + dict(type="RandomFlip"), + dict(type="Normalize", **img_norm_cfg), + dict(type="ImageToTensor", keys=["img"]), + dict(type="Collect", keys=["img"]), + ], + ), ] data = dict( # num_gpus: 8 -> batch_size: 8 samples_per_gpu=1, train=dict(pipeline=train_pipeline), val=dict(pipeline=test_pipeline), - test=dict(pipeline=test_pipeline)) + test=dict(pipeline=test_pipeline), +) diff --git a/mmsegmentation/configs/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k.py b/mmsegmentation/configs/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k.py index ec0107d..4c70698 100644 --- a/mmsegmentation/configs/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k.py @@ -1,56 +1,60 @@ _base_ = [ - '../_base_/models/segmenter_vit-b16_mask.py', - '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_160k.py' + "../_base_/models/segmenter_vit-b16_mask.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_tiny_p16_384_20220308-cce8c795.pth' # noqa +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_tiny_p16_384_20220308-cce8c795.pth" # noqa model = dict( pretrained=checkpoint, backbone=dict(embed_dims=192, num_heads=3), decode_head=dict( - type='SegmenterMaskTransformerHead', + type="SegmenterMaskTransformerHead", in_channels=192, channels=192, num_heads=3, - embed_dims=192)) + embed_dims=192, + ), +) optimizer = dict(lr=0.001, weight_decay=0.0) -img_norm_cfg = dict( - mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) +img_norm_cfg = dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) crop_size = (512, 512) train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', reduce_zero_label=True), - dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), - dict(type='RandomFlip', prob=0.5), - dict(type='PhotoMetricDistortion'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), - dict(type='DefaultFormatBundle'), - dict(type='Collect', keys=['img', 'gt_semantic_seg']) + dict(type="LoadImageFromFile"), + dict(type="LoadAnnotations", reduce_zero_label=True), + dict(type="Resize", img_scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75), + dict(type="RandomFlip", prob=0.5), + dict(type="PhotoMetricDistortion"), + dict(type="Normalize", **img_norm_cfg), + dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255), + dict(type="DefaultFormatBundle"), + dict(type="Collect", keys=["img", "gt_semantic_seg"]), ] test_pipeline = [ - dict(type='LoadImageFromFile'), + dict(type="LoadImageFromFile"), dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=(2048, 512), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ - dict(type='Resize', keep_ratio=True), - dict(type='RandomFlip'), - dict(type='Normalize', **img_norm_cfg), - dict(type='ImageToTensor', keys=['img']), - dict(type='Collect', keys=['img']) - ]) + dict(type="Resize", keep_ratio=True), + dict(type="RandomFlip"), + dict(type="Normalize", **img_norm_cfg), + dict(type="ImageToTensor", keys=["img"]), + dict(type="Collect", keys=["img"]), + ], + ), ] data = dict( # num_gpus: 8 -> batch_size: 8 samples_per_gpu=1, train=dict(pipeline=train_pipeline), val=dict(pipeline=test_pipeline), - test=dict(pipeline=test_pipeline)) + test=dict(pipeline=test_pipeline), +) diff --git a/mmsegmentation/configs/sem_fpn/fpn_r101_512x1024_80k_cityscapes.py b/mmsegmentation/configs/sem_fpn/fpn_r101_512x1024_80k_cityscapes.py index 7f8710d..1d0c7e0 100644 --- a/mmsegmentation/configs/sem_fpn/fpn_r101_512x1024_80k_cityscapes.py +++ b/mmsegmentation/configs/sem_fpn/fpn_r101_512x1024_80k_cityscapes.py @@ -1,2 +1,2 @@ -_base_ = './fpn_r50_512x1024_80k_cityscapes.py' -model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) +_base_ = "./fpn_r50_512x1024_80k_cityscapes.py" +model = dict(pretrained="open-mmlab://resnet101_v1c", backbone=dict(depth=101)) diff --git a/mmsegmentation/configs/sem_fpn/fpn_r101_512x512_160k_ade20k.py b/mmsegmentation/configs/sem_fpn/fpn_r101_512x512_160k_ade20k.py index 2654096..3597dc1 100644 --- a/mmsegmentation/configs/sem_fpn/fpn_r101_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/sem_fpn/fpn_r101_512x512_160k_ade20k.py @@ -1,2 +1,2 @@ -_base_ = './fpn_r50_512x512_160k_ade20k.py' -model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) +_base_ = "./fpn_r50_512x512_160k_ade20k.py" +model = dict(pretrained="open-mmlab://resnet101_v1c", backbone=dict(depth=101)) diff --git a/mmsegmentation/configs/sem_fpn/fpn_r50_512x1024_80k_cityscapes.py b/mmsegmentation/configs/sem_fpn/fpn_r50_512x1024_80k_cityscapes.py index 4bf3edd..de5af01 100644 --- a/mmsegmentation/configs/sem_fpn/fpn_r50_512x1024_80k_cityscapes.py +++ b/mmsegmentation/configs/sem_fpn/fpn_r50_512x1024_80k_cityscapes.py @@ -1,4 +1,6 @@ _base_ = [ - '../_base_/models/fpn_r50.py', '../_base_/datasets/cityscapes.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' + "../_base_/models/fpn_r50.py", + "../_base_/datasets/cityscapes.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_80k.py", ] diff --git a/mmsegmentation/configs/sem_fpn/fpn_r50_512x512_160k_ade20k.py b/mmsegmentation/configs/sem_fpn/fpn_r50_512x512_160k_ade20k.py index 5cdfc8c..65119d5 100644 --- a/mmsegmentation/configs/sem_fpn/fpn_r50_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/sem_fpn/fpn_r50_512x512_160k_ade20k.py @@ -1,5 +1,7 @@ _base_ = [ - '../_base_/models/fpn_r50.py', '../_base_/datasets/ade20k.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' + "../_base_/models/fpn_r50.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] model = dict(decode_head=dict(num_classes=150)) diff --git a/mmsegmentation/configs/setr/setr_mla_512x512_160k_b16_ade20k.py b/mmsegmentation/configs/setr/setr_mla_512x512_160k_b16_ade20k.py index c8418c6..9c6ae88 100644 --- a/mmsegmentation/configs/setr/setr_mla_512x512_160k_b16_ade20k.py +++ b/mmsegmentation/configs/setr/setr_mla_512x512_160k_b16_ade20k.py @@ -1,4 +1,4 @@ -_base_ = ['./setr_mla_512x512_160k_b8_ade20k.py'] +_base_ = ["./setr_mla_512x512_160k_b8_ade20k.py"] # num_gpus: 8 -> batch_size: 16 data = dict(samples_per_gpu=2) diff --git a/mmsegmentation/configs/setr/setr_mla_512x512_160k_b8_ade20k.py b/mmsegmentation/configs/setr/setr_mla_512x512_160k_b8_ade20k.py index e1a07ce..e35c57d 100644 --- a/mmsegmentation/configs/setr/setr_mla_512x512_160k_b8_ade20k.py +++ b/mmsegmentation/configs/setr/setr_mla_512x512_160k_b8_ade20k.py @@ -1,85 +1,96 @@ _base_ = [ - '../_base_/models/setr_mla.py', '../_base_/datasets/ade20k.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' + "../_base_/models/setr_mla.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] -norm_cfg = dict(type='SyncBN', requires_grad=True) +norm_cfg = dict(type="SyncBN", requires_grad=True) model = dict( pretrained=None, backbone=dict( img_size=(512, 512), - drop_rate=0., - init_cfg=dict( - type='Pretrained', checkpoint='pretrain/vit_large_p16.pth')), + drop_rate=0.0, + init_cfg=dict(type="Pretrained", checkpoint="pretrain/vit_large_p16.pth"), + ), decode_head=dict(num_classes=150), auxiliary_head=[ dict( - type='FCNHead', + type="FCNHead", in_channels=256, channels=256, in_index=0, dropout_ratio=0, norm_cfg=norm_cfg, - act_cfg=dict(type='ReLU'), + act_cfg=dict(type="ReLU"), num_convs=0, kernel_size=1, concat_input=False, num_classes=150, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4 + ), + ), dict( - type='FCNHead', + type="FCNHead", in_channels=256, channels=256, in_index=1, dropout_ratio=0, norm_cfg=norm_cfg, - act_cfg=dict(type='ReLU'), + act_cfg=dict(type="ReLU"), num_convs=0, kernel_size=1, concat_input=False, num_classes=150, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4 + ), + ), dict( - type='FCNHead', + type="FCNHead", in_channels=256, channels=256, in_index=2, dropout_ratio=0, norm_cfg=norm_cfg, - act_cfg=dict(type='ReLU'), + act_cfg=dict(type="ReLU"), num_convs=0, kernel_size=1, concat_input=False, num_classes=150, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4 + ), + ), dict( - type='FCNHead', + type="FCNHead", in_channels=256, channels=256, in_index=3, dropout_ratio=0, norm_cfg=norm_cfg, - act_cfg=dict(type='ReLU'), + act_cfg=dict(type="ReLU"), num_convs=0, kernel_size=1, concat_input=False, num_classes=150, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4 + ), + ), ], - test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(341, 341)), + test_cfg=dict(mode="slide", crop_size=(512, 512), stride=(341, 341)), ) optimizer = dict( lr=0.001, weight_decay=0.0, - paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10.)})) + paramwise_cfg=dict(custom_keys={"head": dict(lr_mult=10.0)}), +) # num_gpus: 8 -> batch_size: 8 data = dict(samples_per_gpu=1) diff --git a/mmsegmentation/configs/setr/setr_naive_512x512_160k_b16_ade20k.py b/mmsegmentation/configs/setr/setr_naive_512x512_160k_b16_ade20k.py index 8ad8c9f..dadf93a 100644 --- a/mmsegmentation/configs/setr/setr_naive_512x512_160k_b16_ade20k.py +++ b/mmsegmentation/configs/setr/setr_naive_512x512_160k_b16_ade20k.py @@ -1,67 +1,76 @@ _base_ = [ - '../_base_/models/setr_naive.py', '../_base_/datasets/ade20k.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' + "../_base_/models/setr_naive.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] -norm_cfg = dict(type='SyncBN', requires_grad=True) +norm_cfg = dict(type="SyncBN", requires_grad=True) model = dict( pretrained=None, backbone=dict( img_size=(512, 512), - drop_rate=0., - init_cfg=dict( - type='Pretrained', checkpoint='pretrain/vit_large_p16.pth')), + drop_rate=0.0, + init_cfg=dict(type="Pretrained", checkpoint="pretrain/vit_large_p16.pth"), + ), decode_head=dict(num_classes=150), auxiliary_head=[ dict( - type='SETRUPHead', + type="SETRUPHead", in_channels=1024, channels=256, in_index=0, num_classes=150, dropout_ratio=0, norm_cfg=norm_cfg, - act_cfg=dict(type='ReLU'), + act_cfg=dict(type="ReLU"), num_convs=2, kernel_size=1, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4 + ), + ), dict( - type='SETRUPHead', + type="SETRUPHead", in_channels=1024, channels=256, in_index=1, num_classes=150, dropout_ratio=0, norm_cfg=norm_cfg, - act_cfg=dict(type='ReLU'), + act_cfg=dict(type="ReLU"), num_convs=2, kernel_size=1, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4 + ), + ), dict( - type='SETRUPHead', + type="SETRUPHead", in_channels=1024, channels=256, in_index=2, num_classes=150, dropout_ratio=0, norm_cfg=norm_cfg, - act_cfg=dict(type='ReLU'), + act_cfg=dict(type="ReLU"), num_convs=2, kernel_size=1, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)) + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4 + ), + ), ], - test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(341, 341)), + test_cfg=dict(mode="slide", crop_size=(512, 512), stride=(341, 341)), ) optimizer = dict( lr=0.01, weight_decay=0.0, - paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10.)})) + paramwise_cfg=dict(custom_keys={"head": dict(lr_mult=10.0)}), +) # num_gpus: 8 -> batch_size: 16 data = dict(samples_per_gpu=2) diff --git a/mmsegmentation/configs/setr/setr_pup_512x512_160k_b16_ade20k.py b/mmsegmentation/configs/setr/setr_pup_512x512_160k_b16_ade20k.py index 83997a2..f70679c 100644 --- a/mmsegmentation/configs/setr/setr_pup_512x512_160k_b16_ade20k.py +++ b/mmsegmentation/configs/setr/setr_pup_512x512_160k_b16_ade20k.py @@ -1,67 +1,76 @@ _base_ = [ - '../_base_/models/setr_pup.py', '../_base_/datasets/ade20k.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' + "../_base_/models/setr_pup.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] -norm_cfg = dict(type='SyncBN', requires_grad=True) +norm_cfg = dict(type="SyncBN", requires_grad=True) model = dict( pretrained=None, backbone=dict( img_size=(512, 512), - drop_rate=0., - init_cfg=dict( - type='Pretrained', checkpoint='pretrain/vit_large_p16.pth')), + drop_rate=0.0, + init_cfg=dict(type="Pretrained", checkpoint="pretrain/vit_large_p16.pth"), + ), decode_head=dict(num_classes=150), auxiliary_head=[ dict( - type='SETRUPHead', + type="SETRUPHead", in_channels=1024, channels=256, in_index=0, num_classes=150, dropout_ratio=0, norm_cfg=norm_cfg, - act_cfg=dict(type='ReLU'), + act_cfg=dict(type="ReLU"), num_convs=2, kernel_size=3, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4 + ), + ), dict( - type='SETRUPHead', + type="SETRUPHead", in_channels=1024, channels=256, in_index=1, num_classes=150, dropout_ratio=0, norm_cfg=norm_cfg, - act_cfg=dict(type='ReLU'), + act_cfg=dict(type="ReLU"), num_convs=2, kernel_size=3, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4 + ), + ), dict( - type='SETRUPHead', + type="SETRUPHead", in_channels=1024, channels=256, in_index=2, num_classes=150, dropout_ratio=0, norm_cfg=norm_cfg, - act_cfg=dict(type='ReLU'), + act_cfg=dict(type="ReLU"), num_convs=2, kernel_size=3, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4 + ), + ), ], - test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(341, 341)), + test_cfg=dict(mode="slide", crop_size=(512, 512), stride=(341, 341)), ) optimizer = dict( lr=0.001, weight_decay=0.0, - paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10.)})) + paramwise_cfg=dict(custom_keys={"head": dict(lr_mult=10.0)}), +) # num_gpus: 8 -> batch_size: 16 data = dict(samples_per_gpu=2) diff --git a/mmsegmentation/configs/setr/setr_vit-large_mla_8x1_768x768_80k_cityscapes.py b/mmsegmentation/configs/setr/setr_vit-large_mla_8x1_768x768_80k_cityscapes.py index 4237cd5..0565f20 100644 --- a/mmsegmentation/configs/setr/setr_vit-large_mla_8x1_768x768_80k_cityscapes.py +++ b/mmsegmentation/configs/setr/setr_vit-large_mla_8x1_768x768_80k_cityscapes.py @@ -1,17 +1,21 @@ _base_ = [ - '../_base_/models/setr_mla.py', '../_base_/datasets/cityscapes_768x768.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' + "../_base_/models/setr_mla.py", + "../_base_/datasets/cityscapes_768x768.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_80k.py", ] model = dict( pretrained=None, backbone=dict( drop_rate=0, - init_cfg=dict( - type='Pretrained', checkpoint='pretrain/vit_large_p16.pth')), - test_cfg=dict(mode='slide', crop_size=(768, 768), stride=(512, 512))) + init_cfg=dict(type="Pretrained", checkpoint="pretrain/vit_large_p16.pth"), + ), + test_cfg=dict(mode="slide", crop_size=(768, 768), stride=(512, 512)), +) optimizer = dict( lr=0.002, weight_decay=0.0, - paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10.)})) + paramwise_cfg=dict(custom_keys={"head": dict(lr_mult=10.0)}), +) data = dict(samples_per_gpu=1) diff --git a/mmsegmentation/configs/setr/setr_vit-large_naive_8x1_768x768_80k_cityscapes.py b/mmsegmentation/configs/setr/setr_vit-large_naive_8x1_768x768_80k_cityscapes.py index 0c6621e..2c53ec9 100644 --- a/mmsegmentation/configs/setr/setr_vit-large_naive_8x1_768x768_80k_cityscapes.py +++ b/mmsegmentation/configs/setr/setr_vit-large_naive_8x1_768x768_80k_cityscapes.py @@ -1,18 +1,20 @@ _base_ = [ - '../_base_/models/setr_naive.py', - '../_base_/datasets/cityscapes_768x768.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_80k.py' + "../_base_/models/setr_naive.py", + "../_base_/datasets/cityscapes_768x768.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_80k.py", ] model = dict( pretrained=None, backbone=dict( - drop_rate=0., - init_cfg=dict( - type='Pretrained', checkpoint='pretrain/vit_large_p16.pth')), - test_cfg=dict(mode='slide', crop_size=(768, 768), stride=(512, 512))) + drop_rate=0.0, + init_cfg=dict(type="Pretrained", checkpoint="pretrain/vit_large_p16.pth"), + ), + test_cfg=dict(mode="slide", crop_size=(768, 768), stride=(512, 512)), +) optimizer = dict( - weight_decay=0.0, - paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10.)})) + weight_decay=0.0, paramwise_cfg=dict(custom_keys={"head": dict(lr_mult=10.0)}) +) data = dict(samples_per_gpu=1) diff --git a/mmsegmentation/configs/setr/setr_vit-large_pup_8x1_768x768_80k_cityscapes.py b/mmsegmentation/configs/setr/setr_vit-large_pup_8x1_768x768_80k_cityscapes.py index e108988..ac15a6e 100644 --- a/mmsegmentation/configs/setr/setr_vit-large_pup_8x1_768x768_80k_cityscapes.py +++ b/mmsegmentation/configs/setr/setr_vit-large_pup_8x1_768x768_80k_cityscapes.py @@ -1,19 +1,21 @@ _base_ = [ - '../_base_/models/setr_pup.py', '../_base_/datasets/cityscapes_768x768.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' + "../_base_/models/setr_pup.py", + "../_base_/datasets/cityscapes_768x768.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_80k.py", ] -norm_cfg = dict(type='SyncBN', requires_grad=True) +norm_cfg = dict(type="SyncBN", requires_grad=True) crop_size = (768, 768) model = dict( pretrained=None, backbone=dict( - drop_rate=0., - init_cfg=dict( - type='Pretrained', checkpoint='pretrain/vit_large_p16.pth')), + drop_rate=0.0, + init_cfg=dict(type="Pretrained", checkpoint="pretrain/vit_large_p16.pth"), + ), auxiliary_head=[ dict( - type='SETRUPHead', + type="SETRUPHead", in_channels=1024, channels=256, in_index=0, @@ -25,9 +27,11 @@ kernel_size=3, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4 + ), + ), dict( - type='SETRUPHead', + type="SETRUPHead", in_channels=1024, channels=256, in_index=1, @@ -39,9 +43,11 @@ kernel_size=3, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4 + ), + ), dict( - type='SETRUPHead', + type="SETRUPHead", in_channels=1024, channels=256, in_index=2, @@ -53,12 +59,15 @@ kernel_size=3, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)) + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=0.4 + ), + ), ], - test_cfg=dict(mode='slide', crop_size=crop_size, stride=(512, 512))) + test_cfg=dict(mode="slide", crop_size=crop_size, stride=(512, 512)), +) optimizer = dict( - weight_decay=0.0, - paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10.)})) + weight_decay=0.0, paramwise_cfg=dict(custom_keys={"head": dict(lr_mult=10.0)}) +) data = dict(samples_per_gpu=1) diff --git a/mmsegmentation/configs/stdc/stdc1_512x1024_80k_cityscapes.py b/mmsegmentation/configs/stdc/stdc1_512x1024_80k_cityscapes.py index 849e771..babec13 100644 --- a/mmsegmentation/configs/stdc/stdc1_512x1024_80k_cityscapes.py +++ b/mmsegmentation/configs/stdc/stdc1_512x1024_80k_cityscapes.py @@ -1,8 +1,10 @@ _base_ = [ - '../_base_/models/stdc.py', '../_base_/datasets/cityscapes.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' + "../_base_/models/stdc.py", + "../_base_/datasets/cityscapes.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_80k.py", ] -lr_config = dict(warmup='linear', warmup_iters=1000) +lr_config = dict(warmup="linear", warmup_iters=1000) data = dict( samples_per_gpu=12, workers_per_gpu=4, diff --git a/mmsegmentation/configs/stdc/stdc1_in1k-pre_512x1024_80k_cityscapes.py b/mmsegmentation/configs/stdc/stdc1_in1k-pre_512x1024_80k_cityscapes.py index f295bf4..78d792b 100644 --- a/mmsegmentation/configs/stdc/stdc1_in1k-pre_512x1024_80k_cityscapes.py +++ b/mmsegmentation/configs/stdc/stdc1_in1k-pre_512x1024_80k_cityscapes.py @@ -1,6 +1,7 @@ -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/stdc/stdc1_20220308-5368626c.pth' # noqa -_base_ = './stdc1_512x1024_80k_cityscapes.py' +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/stdc/stdc1_20220308-5368626c.pth" # noqa +_base_ = "./stdc1_512x1024_80k_cityscapes.py" model = dict( backbone=dict( - backbone_cfg=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint)))) + backbone_cfg=dict(init_cfg=dict(type="Pretrained", checkpoint=checkpoint)) + ) +) diff --git a/mmsegmentation/configs/stdc/stdc2_512x1024_80k_cityscapes.py b/mmsegmentation/configs/stdc/stdc2_512x1024_80k_cityscapes.py index f7afb50..cc190d2 100644 --- a/mmsegmentation/configs/stdc/stdc2_512x1024_80k_cityscapes.py +++ b/mmsegmentation/configs/stdc/stdc2_512x1024_80k_cityscapes.py @@ -1,2 +1,2 @@ -_base_ = './stdc1_512x1024_80k_cityscapes.py' -model = dict(backbone=dict(backbone_cfg=dict(stdc_type='STDCNet2'))) +_base_ = "./stdc1_512x1024_80k_cityscapes.py" +model = dict(backbone=dict(backbone_cfg=dict(stdc_type="STDCNet2"))) diff --git a/mmsegmentation/configs/stdc/stdc2_in1k-pre_512x1024_80k_cityscapes.py b/mmsegmentation/configs/stdc/stdc2_in1k-pre_512x1024_80k_cityscapes.py index 4148ac4..d6e20f3 100644 --- a/mmsegmentation/configs/stdc/stdc2_in1k-pre_512x1024_80k_cityscapes.py +++ b/mmsegmentation/configs/stdc/stdc2_in1k-pre_512x1024_80k_cityscapes.py @@ -1,6 +1,7 @@ -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/stdc/stdc2_20220308-7dbd9127.pth' # noqa -_base_ = './stdc2_512x1024_80k_cityscapes.py' +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/stdc/stdc2_20220308-7dbd9127.pth" # noqa +_base_ = "./stdc2_512x1024_80k_cityscapes.py" model = dict( backbone=dict( - backbone_cfg=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint)))) + backbone_cfg=dict(init_cfg=dict(type="Pretrained", checkpoint=checkpoint)) + ) +) diff --git a/mmsegmentation/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K.py b/mmsegmentation/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K.py index 027bd6f..dcfc509 100644 --- a/mmsegmentation/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K.py +++ b/mmsegmentation/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_1K.py @@ -1,15 +1,16 @@ _base_ = [ - 'upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_' - 'pretrain_224x224_1K.py' + "upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_" "pretrain_224x224_1K.py" ] -checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pth' # noqa +checkpoint_file = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pth" # noqa model = dict( backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file), + init_cfg=dict(type="Pretrained", checkpoint=checkpoint_file), pretrain_img_size=384, embed_dims=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], - window_size=12), + window_size=12, + ), decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150), - auxiliary_head=dict(in_channels=512, num_classes=150)) + auxiliary_head=dict(in_channels=512, num_classes=150), +) diff --git a/mmsegmentation/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K.py b/mmsegmentation/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K.py index e662d4f..7629672 100644 --- a/mmsegmentation/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K.py +++ b/mmsegmentation/configs/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K.py @@ -1,8 +1,7 @@ _base_ = [ - './upernet_swin_base_patch4_window12_512x512_160k_ade20k_' - 'pretrain_384x384_1K.py' + "./upernet_swin_base_patch4_window12_512x512_160k_ade20k_" "pretrain_384x384_1K.py" ] -checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_22k_20220317-e5c09f74.pth' # noqa +checkpoint_file = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_22k_20220317-e5c09f74.pth" # noqa model = dict( - backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file))) + backbone=dict(init_cfg=dict(type="Pretrained", checkpoint=checkpoint_file)) +) diff --git a/mmsegmentation/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py b/mmsegmentation/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py index 6e05677..8b557d6 100644 --- a/mmsegmentation/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py +++ b/mmsegmentation/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py @@ -1,13 +1,14 @@ _base_ = [ - './upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_' - 'pretrain_224x224_1K.py' + "./upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_" "pretrain_224x224_1K.py" ] -checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_20220317-e9b98025.pth' # noqa +checkpoint_file = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_20220317-e9b98025.pth" # noqa model = dict( backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file), + init_cfg=dict(type="Pretrained", checkpoint=checkpoint_file), embed_dims=128, depths=[2, 2, 18, 2], - num_heads=[4, 8, 16, 32]), + num_heads=[4, 8, 16, 32], + ), decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150), - auxiliary_head=dict(in_channels=512, num_classes=150)) + auxiliary_head=dict(in_channels=512, num_classes=150), +) diff --git a/mmsegmentation/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py b/mmsegmentation/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py index 7a9c506..118845f 100644 --- a/mmsegmentation/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py +++ b/mmsegmentation/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py @@ -1,8 +1,7 @@ _base_ = [ - './upernet_swin_base_patch4_window7_512x512_160k_ade20k_' - 'pretrain_224x224_1K.py' + "./upernet_swin_base_patch4_window7_512x512_160k_ade20k_" "pretrain_224x224_1K.py" ] -checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth' # noqa +checkpoint_file = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth" # noqa model = dict( - backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file))) + backbone=dict(init_cfg=dict(type="Pretrained", checkpoint=checkpoint_file)) +) diff --git a/mmsegmentation/configs/swin/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k.py b/mmsegmentation/configs/swin/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k.py index a43e5be..92b748e 100644 --- a/mmsegmentation/configs/swin/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k.py +++ b/mmsegmentation/configs/swin/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k.py @@ -1,10 +1,11 @@ _base_ = [ - 'upernet_swin_large_patch4_window7_512x512_' - 'pretrain_224x224_22K_160k_ade20k.py' + "upernet_swin_large_patch4_window7_512x512_" "pretrain_224x224_22K_160k_ade20k.py" ] -checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth' # noqa +checkpoint_file = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth" # noqa model = dict( backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file), + init_cfg=dict(type="Pretrained", checkpoint=checkpoint_file), pretrain_img_size=384, - window_size=12)) + window_size=12, + ) +) diff --git a/mmsegmentation/configs/swin/upernet_swin_large_patch4_window7_512x512_pretrain_224x224_22K_160k_ade20k.py b/mmsegmentation/configs/swin/upernet_swin_large_patch4_window7_512x512_pretrain_224x224_22K_160k_ade20k.py index 8a78f32..3eb190c 100644 --- a/mmsegmentation/configs/swin/upernet_swin_large_patch4_window7_512x512_pretrain_224x224_22K_160k_ade20k.py +++ b/mmsegmentation/configs/swin/upernet_swin_large_patch4_window7_512x512_pretrain_224x224_22K_160k_ade20k.py @@ -1,15 +1,16 @@ _base_ = [ - 'upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_' - 'pretrain_224x224_1K.py' + "upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_" "pretrain_224x224_1K.py" ] -checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220412-aeecf2aa.pth' # noqa +checkpoint_file = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220412-aeecf2aa.pth" # noqa model = dict( backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file), + init_cfg=dict(type="Pretrained", checkpoint=checkpoint_file), pretrain_img_size=224, embed_dims=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], - window_size=7), + window_size=7, + ), decode_head=dict(in_channels=[192, 384, 768, 1536], num_classes=150), - auxiliary_head=dict(in_channels=768, num_classes=150)) + auxiliary_head=dict(in_channels=768, num_classes=150), +) diff --git a/mmsegmentation/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py b/mmsegmentation/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py index 1958e0e..5eccd4c 100644 --- a/mmsegmentation/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py +++ b/mmsegmentation/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py @@ -1,11 +1,12 @@ _base_ = [ - './upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_' - 'pretrain_224x224_1K.py' + "./upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_" "pretrain_224x224_1K.py" ] -checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_small_patch4_window7_224_20220317-7ba6d6dd.pth' # noqa +checkpoint_file = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_small_patch4_window7_224_20220317-7ba6d6dd.pth" # noqa model = dict( backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file), - depths=[2, 2, 18, 2]), + init_cfg=dict(type="Pretrained", checkpoint=checkpoint_file), + depths=[2, 2, 18, 2], + ), decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150), - auxiliary_head=dict(in_channels=384, num_classes=150)) + auxiliary_head=dict(in_channels=384, num_classes=150), +) diff --git a/mmsegmentation/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py b/mmsegmentation/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py index 6d8c413..b8d5386 100644 --- a/mmsegmentation/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py +++ b/mmsegmentation/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py @@ -1,45 +1,52 @@ _base_ = [ - '../_base_/models/upernet_swin.py', '../_base_/datasets/ade20k.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' + "../_base_/models/upernet_swin.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] -checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth' # noqa +checkpoint_file = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth" # noqa model = dict( backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file), + init_cfg=dict(type="Pretrained", checkpoint=checkpoint_file), embed_dims=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, use_abs_pos_embed=False, drop_path_rate=0.3, - patch_norm=True), + patch_norm=True, + ), decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150), - auxiliary_head=dict(in_channels=384, num_classes=150)) + auxiliary_head=dict(in_channels=384, num_classes=150), +) # AdamW optimizer, no weight decay for position embedding & layer norm # in backbone optimizer = dict( _delete_=True, - type='AdamW', + type="AdamW", lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, paramwise_cfg=dict( custom_keys={ - 'absolute_pos_embed': dict(decay_mult=0.), - 'relative_position_bias_table': dict(decay_mult=0.), - 'norm': dict(decay_mult=0.) - })) + "absolute_pos_embed": dict(decay_mult=0.0), + "relative_position_bias_table": dict(decay_mult=0.0), + "norm": dict(decay_mult=0.0), + } + ), +) lr_config = dict( _delete_=True, - policy='poly', - warmup='linear', + policy="poly", + warmup="linear", warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, - by_epoch=False) + by_epoch=False, +) # By default, models are trained on 8 GPUs with 2 images per GPU data = dict(samples_per_gpu=2) diff --git a/mmsegmentation/configs/twins/twins_pcpvt-b_fpn_fpnhead_8x4_512x512_80k_ade20k.py b/mmsegmentation/configs/twins/twins_pcpvt-b_fpn_fpnhead_8x4_512x512_80k_ade20k.py index b79fefd..d517d04 100644 --- a/mmsegmentation/configs/twins/twins_pcpvt-b_fpn_fpnhead_8x4_512x512_80k_ade20k.py +++ b/mmsegmentation/configs/twins/twins_pcpvt-b_fpn_fpnhead_8x4_512x512_80k_ade20k.py @@ -1,8 +1,9 @@ -_base_ = ['./twins_pcpvt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py'] +_base_ = ["./twins_pcpvt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py"] -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/pcpvt_base_20220308-0621964c.pth' # noqa +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/pcpvt_base_20220308-0621964c.pth" # noqa model = dict( backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint), - depths=[3, 4, 18, 3]), ) + init_cfg=dict(type="Pretrained", checkpoint=checkpoint), depths=[3, 4, 18, 3] + ), +) diff --git a/mmsegmentation/configs/twins/twins_pcpvt-b_uperhead_8x2_512x512_160k_ade20k.py b/mmsegmentation/configs/twins/twins_pcpvt-b_uperhead_8x2_512x512_160k_ade20k.py index 8c299d3..f233291 100644 --- a/mmsegmentation/configs/twins/twins_pcpvt-b_uperhead_8x2_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/twins/twins_pcpvt-b_uperhead_8x2_512x512_160k_ade20k.py @@ -1,11 +1,13 @@ -_base_ = ['./twins_pcpvt-s_uperhead_8x4_512x512_160k_ade20k.py'] +_base_ = ["./twins_pcpvt-s_uperhead_8x4_512x512_160k_ade20k.py"] -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/pcpvt_base_20220308-0621964c.pth' # noqa +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/pcpvt_base_20220308-0621964c.pth" # noqa model = dict( backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint), + init_cfg=dict(type="Pretrained", checkpoint=checkpoint), depths=[3, 4, 18, 3], - drop_path_rate=0.3)) + drop_path_rate=0.3, + ) +) data = dict(samples_per_gpu=2, workers_per_gpu=2) diff --git a/mmsegmentation/configs/twins/twins_pcpvt-l_fpn_fpnhead_8x4_512x512_80k_ade20k.py b/mmsegmentation/configs/twins/twins_pcpvt-l_fpn_fpnhead_8x4_512x512_80k_ade20k.py index abb652e..28d2019 100644 --- a/mmsegmentation/configs/twins/twins_pcpvt-l_fpn_fpnhead_8x4_512x512_80k_ade20k.py +++ b/mmsegmentation/configs/twins/twins_pcpvt-l_fpn_fpnhead_8x4_512x512_80k_ade20k.py @@ -1,8 +1,9 @@ -_base_ = ['./twins_pcpvt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py'] +_base_ = ["./twins_pcpvt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py"] -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/pcpvt_large_20220308-37579dc6.pth' # noqa +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/pcpvt_large_20220308-37579dc6.pth" # noqa model = dict( backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint), - depths=[3, 8, 27, 3])) + init_cfg=dict(type="Pretrained", checkpoint=checkpoint), depths=[3, 8, 27, 3] + ) +) diff --git a/mmsegmentation/configs/twins/twins_pcpvt-l_uperhead_8x2_512x512_160k_ade20k.py b/mmsegmentation/configs/twins/twins_pcpvt-l_uperhead_8x2_512x512_160k_ade20k.py index f6f7d27..f1e06e3 100644 --- a/mmsegmentation/configs/twins/twins_pcpvt-l_uperhead_8x2_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/twins/twins_pcpvt-l_uperhead_8x2_512x512_160k_ade20k.py @@ -1,11 +1,13 @@ -_base_ = ['./twins_pcpvt-s_uperhead_8x4_512x512_160k_ade20k.py'] +_base_ = ["./twins_pcpvt-s_uperhead_8x4_512x512_160k_ade20k.py"] -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/pcpvt_large_20220308-37579dc6.pth' # noqa +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/pcpvt_large_20220308-37579dc6.pth" # noqa model = dict( backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint), + init_cfg=dict(type="Pretrained", checkpoint=checkpoint), depths=[3, 8, 27, 3], - drop_path_rate=0.3)) + drop_path_rate=0.3, + ) +) data = dict(samples_per_gpu=2, workers_per_gpu=2) diff --git a/mmsegmentation/configs/twins/twins_pcpvt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py b/mmsegmentation/configs/twins/twins_pcpvt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py index 3d7be96..c62d10c 100644 --- a/mmsegmentation/configs/twins/twins_pcpvt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py +++ b/mmsegmentation/configs/twins/twins_pcpvt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py @@ -1,6 +1,8 @@ _base_ = [ - '../_base_/models/twins_pcpvt-s_fpn.py', '../_base_/datasets/ade20k.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' + "../_base_/models/twins_pcpvt-s_fpn.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_80k.py", ] -optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0001) +optimizer = dict(_delete_=True, type="AdamW", lr=0.0001, weight_decay=0.0001) diff --git a/mmsegmentation/configs/twins/twins_pcpvt-s_uperhead_8x4_512x512_160k_ade20k.py b/mmsegmentation/configs/twins/twins_pcpvt-s_uperhead_8x4_512x512_160k_ade20k.py index c888b92..9a9dc89 100644 --- a/mmsegmentation/configs/twins/twins_pcpvt-s_uperhead_8x4_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/twins/twins_pcpvt-s_uperhead_8x4_512x512_160k_ade20k.py @@ -1,26 +1,28 @@ _base_ = [ - '../_base_/models/twins_pcpvt-s_upernet.py', - '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_160k.py' + "../_base_/models/twins_pcpvt-s_upernet.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] optimizer = dict( _delete_=True, - type='AdamW', + type="AdamW", lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, - paramwise_cfg=dict(custom_keys={ - 'pos_block': dict(decay_mult=0.), - 'norm': dict(decay_mult=0.) - })) + paramwise_cfg=dict( + custom_keys={"pos_block": dict(decay_mult=0.0), "norm": dict(decay_mult=0.0)} + ), +) lr_config = dict( _delete_=True, - policy='poly', - warmup='linear', + policy="poly", + warmup="linear", warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, - by_epoch=False) + by_epoch=False, +) diff --git a/mmsegmentation/configs/twins/twins_svt-b_fpn_fpnhead_8x4_512x512_80k_ade20k.py b/mmsegmentation/configs/twins/twins_svt-b_fpn_fpnhead_8x4_512x512_80k_ade20k.py index 00d8957..46e876c 100644 --- a/mmsegmentation/configs/twins/twins_svt-b_fpn_fpnhead_8x4_512x512_80k_ade20k.py +++ b/mmsegmentation/configs/twins/twins_svt-b_fpn_fpnhead_8x4_512x512_80k_ade20k.py @@ -1,12 +1,13 @@ -_base_ = ['./twins_svt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py'] +_base_ = ["./twins_svt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py"] -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/alt_gvt_base_20220308-1b7eb711.pth' # noqa +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/alt_gvt_base_20220308-1b7eb711.pth" # noqa model = dict( backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint), + init_cfg=dict(type="Pretrained", checkpoint=checkpoint), embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], - depths=[2, 2, 18, 2]), + depths=[2, 2, 18, 2], + ), neck=dict(in_channels=[96, 192, 384, 768]), ) diff --git a/mmsegmentation/configs/twins/twins_svt-b_uperhead_8x2_512x512_160k_ade20k.py b/mmsegmentation/configs/twins/twins_svt-b_uperhead_8x2_512x512_160k_ade20k.py index a969fed..7b791c5 100644 --- a/mmsegmentation/configs/twins/twins_svt-b_uperhead_8x2_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/twins/twins_svt-b_uperhead_8x2_512x512_160k_ade20k.py @@ -1,12 +1,14 @@ -_base_ = ['./twins_svt-s_uperhead_8x2_512x512_160k_ade20k.py'] +_base_ = ["./twins_svt-s_uperhead_8x2_512x512_160k_ade20k.py"] -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/alt_gvt_base_20220308-1b7eb711.pth' # noqa +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/alt_gvt_base_20220308-1b7eb711.pth" # noqa model = dict( backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint), + init_cfg=dict(type="Pretrained", checkpoint=checkpoint), embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], - depths=[2, 2, 18, 2]), + depths=[2, 2, 18, 2], + ), decode_head=dict(in_channels=[96, 192, 384, 768]), - auxiliary_head=dict(in_channels=384)) + auxiliary_head=dict(in_channels=384), +) diff --git a/mmsegmentation/configs/twins/twins_svt-l_fpn_fpnhead_8x4_512x512_80k_ade20k.py b/mmsegmentation/configs/twins/twins_svt-l_fpn_fpnhead_8x4_512x512_80k_ade20k.py index c68bfd4..e68f2b4 100644 --- a/mmsegmentation/configs/twins/twins_svt-l_fpn_fpnhead_8x4_512x512_80k_ade20k.py +++ b/mmsegmentation/configs/twins/twins_svt-l_fpn_fpnhead_8x4_512x512_80k_ade20k.py @@ -1,13 +1,14 @@ -_base_ = ['./twins_svt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py'] +_base_ = ["./twins_svt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py"] -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/alt_gvt_large_20220308-fb5936f3.pth' # noqa +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/alt_gvt_large_20220308-fb5936f3.pth" # noqa model = dict( backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint), + init_cfg=dict(type="Pretrained", checkpoint=checkpoint), embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], depths=[2, 2, 18, 2], - drop_path_rate=0.3), + drop_path_rate=0.3, + ), neck=dict(in_channels=[128, 256, 512, 1024]), ) diff --git a/mmsegmentation/configs/twins/twins_svt-l_uperhead_8x2_512x512_160k_ade20k.py b/mmsegmentation/configs/twins/twins_svt-l_uperhead_8x2_512x512_160k_ade20k.py index f98c070..76d5cf8 100644 --- a/mmsegmentation/configs/twins/twins_svt-l_uperhead_8x2_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/twins/twins_svt-l_uperhead_8x2_512x512_160k_ade20k.py @@ -1,13 +1,15 @@ -_base_ = ['./twins_svt-s_uperhead_8x2_512x512_160k_ade20k.py'] +_base_ = ["./twins_svt-s_uperhead_8x2_512x512_160k_ade20k.py"] -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/alt_gvt_large_20220308-fb5936f3.pth' # noqa +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/alt_gvt_large_20220308-fb5936f3.pth" # noqa model = dict( backbone=dict( - init_cfg=dict(type='Pretrained', checkpoint=checkpoint), + init_cfg=dict(type="Pretrained", checkpoint=checkpoint), embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], depths=[2, 2, 18, 2], - drop_path_rate=0.3), + drop_path_rate=0.3, + ), decode_head=dict(in_channels=[128, 256, 512, 1024]), - auxiliary_head=dict(in_channels=512)) + auxiliary_head=dict(in_channels=512), +) diff --git a/mmsegmentation/configs/twins/twins_svt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py b/mmsegmentation/configs/twins/twins_svt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py index dbb944c..e3f0add 100644 --- a/mmsegmentation/configs/twins/twins_svt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py +++ b/mmsegmentation/configs/twins/twins_svt-s_fpn_fpnhead_8x4_512x512_80k_ade20k.py @@ -1,22 +1,25 @@ _base_ = [ - '../_base_/models/twins_pcpvt-s_fpn.py', '../_base_/datasets/ade20k.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' + "../_base_/models/twins_pcpvt-s_fpn.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_80k.py", ] -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/alt_gvt_small_20220308-7e1c3695.pth' # noqa +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/alt_gvt_small_20220308-7e1c3695.pth" # noqa model = dict( backbone=dict( - type='SVT', - init_cfg=dict(type='Pretrained', checkpoint=checkpoint), + type="SVT", + init_cfg=dict(type="Pretrained", checkpoint=checkpoint), embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4], depths=[2, 2, 10, 4], windiow_sizes=[7, 7, 7, 7], - norm_after_stage=True), + norm_after_stage=True, + ), neck=dict(in_channels=[64, 128, 256, 512], out_channels=256, num_outs=4), decode_head=dict(num_classes=150), ) -optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0001) +optimizer = dict(_delete_=True, type="AdamW", lr=0.0001, weight_decay=0.0001) diff --git a/mmsegmentation/configs/twins/twins_svt-s_uperhead_8x2_512x512_160k_ade20k.py b/mmsegmentation/configs/twins/twins_svt-s_uperhead_8x2_512x512_160k_ade20k.py index 44bf60b..9643b77 100644 --- a/mmsegmentation/configs/twins/twins_svt-s_uperhead_8x2_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/twins/twins_svt-s_uperhead_8x2_512x512_160k_ade20k.py @@ -1,43 +1,47 @@ _base_ = [ - '../_base_/models/twins_pcpvt-s_upernet.py', - '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_160k.py' + "../_base_/models/twins_pcpvt-s_upernet.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] -checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/alt_gvt_small_20220308-7e1c3695.pth' # noqa +checkpoint = "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/twins/alt_gvt_small_20220308-7e1c3695.pth" # noqa model = dict( backbone=dict( - type='SVT', - init_cfg=dict(type='Pretrained', checkpoint=checkpoint), + type="SVT", + init_cfg=dict(type="Pretrained", checkpoint=checkpoint), embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4], depths=[2, 2, 10, 4], windiow_sizes=[7, 7, 7, 7], - norm_after_stage=True), + norm_after_stage=True, + ), decode_head=dict(in_channels=[64, 128, 256, 512]), - auxiliary_head=dict(in_channels=256)) + auxiliary_head=dict(in_channels=256), +) optimizer = dict( _delete_=True, - type='AdamW', + type="AdamW", lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, - paramwise_cfg=dict(custom_keys={ - 'pos_block': dict(decay_mult=0.), - 'norm': dict(decay_mult=0.) - })) + paramwise_cfg=dict( + custom_keys={"pos_block": dict(decay_mult=0.0), "norm": dict(decay_mult=0.0)} + ), +) lr_config = dict( _delete_=True, - policy='poly', - warmup='linear', + policy="poly", + warmup="linear", warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, - by_epoch=False) + by_epoch=False, +) data = dict(samples_per_gpu=2, workers_per_gpu=2) diff --git a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_128x128_40k_chase_db1.py b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_128x128_40k_chase_db1.py index c706cf3..b94014f 100644 --- a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_128x128_40k_chase_db1.py +++ b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_128x128_40k_chase_db1.py @@ -1,7 +1,8 @@ _base_ = [ - '../_base_/models/deeplabv3_unet_s5-d16.py', - '../_base_/datasets/chase_db1.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_40k.py' + "../_base_/models/deeplabv3_unet_s5-d16.py", + "../_base_/datasets/chase_db1.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] model = dict(test_cfg=dict(crop_size=(128, 128), stride=(85, 85))) -evaluation = dict(metric='mDice') +evaluation = dict(metric="mDice") diff --git a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_128x128_40k_stare.py b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_128x128_40k_stare.py index 0ef02dc..1c38b46 100644 --- a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_128x128_40k_stare.py +++ b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_128x128_40k_stare.py @@ -1,6 +1,8 @@ _base_ = [ - '../_base_/models/deeplabv3_unet_s5-d16.py', '../_base_/datasets/stare.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' + "../_base_/models/deeplabv3_unet_s5-d16.py", + "../_base_/datasets/stare.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] model = dict(test_cfg=dict(crop_size=(128, 128), stride=(85, 85))) -evaluation = dict(metric='mDice') +evaluation = dict(metric="mDice") diff --git a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_256x256_40k_hrf.py b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_256x256_40k_hrf.py index 118428b..1231a50 100644 --- a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_256x256_40k_hrf.py +++ b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_256x256_40k_hrf.py @@ -1,6 +1,8 @@ _base_ = [ - '../_base_/models/deeplabv3_unet_s5-d16.py', '../_base_/datasets/hrf.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' + "../_base_/models/deeplabv3_unet_s5-d16.py", + "../_base_/datasets/hrf.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] model = dict(test_cfg=dict(crop_size=(256, 256), stride=(170, 170))) -evaluation = dict(metric='mDice') +evaluation = dict(metric="mDice") diff --git a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_64x64_40k_drive.py b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_64x64_40k_drive.py index 1f8862a..0086570 100644 --- a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_64x64_40k_drive.py +++ b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_64x64_40k_drive.py @@ -1,6 +1,8 @@ _base_ = [ - '../_base_/models/deeplabv3_unet_s5-d16.py', '../_base_/datasets/drive.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' + "../_base_/models/deeplabv3_unet_s5-d16.py", + "../_base_/datasets/drive.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] model = dict(test_cfg=dict(crop_size=(64, 64), stride=(42, 42))) -evaluation = dict(metric='mDice') +evaluation = dict(metric="mDice") diff --git a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1.py b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1.py index 1c48cbc..a664d10 100644 --- a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1.py +++ b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1.py @@ -1,6 +1,9 @@ -_base_ = './deeplabv3_unet_s5-d16_128x128_40k_chase_db1.py' +_base_ = "./deeplabv3_unet_s5-d16_128x128_40k_chase_db1.py" model = dict( - decode_head=dict(loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), - dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0) - ])) + decode_head=dict( + loss_decode=[ + dict(type="CrossEntropyLoss", loss_name="loss_ce", loss_weight=1.0), + dict(type="DiceLoss", loss_name="loss_dice", loss_weight=3.0), + ] + ) +) diff --git a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py index 1022ede..47c8abd 100644 --- a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py +++ b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py @@ -1,6 +1,9 @@ -_base_ = './deeplabv3_unet_s5-d16_128x128_40k_stare.py' +_base_ = "./deeplabv3_unet_s5-d16_128x128_40k_stare.py" model = dict( - decode_head=dict(loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), - dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0) - ])) + decode_head=dict( + loss_decode=[ + dict(type="CrossEntropyLoss", loss_name="loss_ce", loss_weight=1.0), + dict(type="DiceLoss", loss_name="loss_dice", loss_weight=3.0), + ] + ) +) diff --git a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py index fc17da7..85eacd6 100644 --- a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py +++ b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py @@ -1,6 +1,9 @@ -_base_ = './deeplabv3_unet_s5-d16_256x256_40k_hrf.py' +_base_ = "./deeplabv3_unet_s5-d16_256x256_40k_hrf.py" model = dict( - decode_head=dict(loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), - dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0) - ])) + decode_head=dict( + loss_decode=[ + dict(type="CrossEntropyLoss", loss_name="loss_ce", loss_weight=1.0), + dict(type="DiceLoss", loss_name="loss_dice", loss_weight=3.0), + ] + ) +) diff --git a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py index 3f1f12e..6ba0e03 100644 --- a/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py +++ b/mmsegmentation/configs/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py @@ -1,6 +1,9 @@ -_base_ = './deeplabv3_unet_s5-d16_64x64_40k_drive.py' +_base_ = "./deeplabv3_unet_s5-d16_64x64_40k_drive.py" model = dict( - decode_head=dict(loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), - dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0) - ])) + decode_head=dict( + loss_decode=[ + dict(type="CrossEntropyLoss", loss_name="loss_ce", loss_weight=1.0), + dict(type="DiceLoss", loss_name="loss_dice", loss_weight=3.0), + ] + ) +) diff --git a/mmsegmentation/configs/unet/fcn_unet_s5-d16_128x128_40k_chase_db1.py b/mmsegmentation/configs/unet/fcn_unet_s5-d16_128x128_40k_chase_db1.py index 2bc52d9..48ad438 100644 --- a/mmsegmentation/configs/unet/fcn_unet_s5-d16_128x128_40k_chase_db1.py +++ b/mmsegmentation/configs/unet/fcn_unet_s5-d16_128x128_40k_chase_db1.py @@ -1,6 +1,8 @@ _base_ = [ - '../_base_/models/fcn_unet_s5-d16.py', '../_base_/datasets/chase_db1.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' + "../_base_/models/fcn_unet_s5-d16.py", + "../_base_/datasets/chase_db1.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] model = dict(test_cfg=dict(crop_size=(128, 128), stride=(85, 85))) -evaluation = dict(metric='mDice') +evaluation = dict(metric="mDice") diff --git a/mmsegmentation/configs/unet/fcn_unet_s5-d16_128x128_40k_stare.py b/mmsegmentation/configs/unet/fcn_unet_s5-d16_128x128_40k_stare.py index 5d836c6..b0ccfab 100644 --- a/mmsegmentation/configs/unet/fcn_unet_s5-d16_128x128_40k_stare.py +++ b/mmsegmentation/configs/unet/fcn_unet_s5-d16_128x128_40k_stare.py @@ -1,6 +1,8 @@ _base_ = [ - '../_base_/models/fcn_unet_s5-d16.py', '../_base_/datasets/stare.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' + "../_base_/models/fcn_unet_s5-d16.py", + "../_base_/datasets/stare.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] model = dict(test_cfg=dict(crop_size=(128, 128), stride=(85, 85))) -evaluation = dict(metric='mDice') +evaluation = dict(metric="mDice") diff --git a/mmsegmentation/configs/unet/fcn_unet_s5-d16_256x256_40k_hrf.py b/mmsegmentation/configs/unet/fcn_unet_s5-d16_256x256_40k_hrf.py index be8eec7..2f0328e 100644 --- a/mmsegmentation/configs/unet/fcn_unet_s5-d16_256x256_40k_hrf.py +++ b/mmsegmentation/configs/unet/fcn_unet_s5-d16_256x256_40k_hrf.py @@ -1,6 +1,8 @@ _base_ = [ - '../_base_/models/fcn_unet_s5-d16.py', '../_base_/datasets/hrf.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' + "../_base_/models/fcn_unet_s5-d16.py", + "../_base_/datasets/hrf.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] model = dict(test_cfg=dict(crop_size=(256, 256), stride=(170, 170))) -evaluation = dict(metric='mDice') +evaluation = dict(metric="mDice") diff --git a/mmsegmentation/configs/unet/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py b/mmsegmentation/configs/unet/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py index a2f7dbe..0538508 100644 --- a/mmsegmentation/configs/unet/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py +++ b/mmsegmentation/configs/unet/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py @@ -1,6 +1,8 @@ _base_ = [ - '../_base_/models/fcn_unet_s5-d16.py', '../_base_/datasets/cityscapes.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' + "../_base_/models/fcn_unet_s5-d16.py", + "../_base_/datasets/cityscapes.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] model = dict( @@ -8,7 +10,8 @@ auxiliary_head=dict(num_classes=19), # model training and testing settings train_cfg=dict(), - test_cfg=dict(mode='whole')) + test_cfg=dict(mode="whole"), +) data = dict( samples_per_gpu=4, diff --git a/mmsegmentation/configs/unet/fcn_unet_s5-d16_64x64_40k_drive.py b/mmsegmentation/configs/unet/fcn_unet_s5-d16_64x64_40k_drive.py index 80483ad..c0fd076 100644 --- a/mmsegmentation/configs/unet/fcn_unet_s5-d16_64x64_40k_drive.py +++ b/mmsegmentation/configs/unet/fcn_unet_s5-d16_64x64_40k_drive.py @@ -1,6 +1,8 @@ _base_ = [ - '../_base_/models/fcn_unet_s5-d16.py', '../_base_/datasets/drive.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' + "../_base_/models/fcn_unet_s5-d16.py", + "../_base_/datasets/drive.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] model = dict(test_cfg=dict(crop_size=(64, 64), stride=(42, 42))) -evaluation = dict(metric='mDice') +evaluation = dict(metric="mDice") diff --git a/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1.py b/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1.py index 5264866..d7765b4 100644 --- a/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1.py +++ b/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1.py @@ -1,6 +1,9 @@ -_base_ = './fcn_unet_s5-d16_128x128_40k_chase_db1.py' +_base_ = "./fcn_unet_s5-d16_128x128_40k_chase_db1.py" model = dict( - decode_head=dict(loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), - dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0) - ])) + decode_head=dict( + loss_decode=[ + dict(type="CrossEntropyLoss", loss_name="loss_ce", loss_weight=1.0), + dict(type="DiceLoss", loss_name="loss_dice", loss_weight=3.0), + ] + ) +) diff --git a/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py b/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py index cf5fa1f..9531c73 100644 --- a/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py +++ b/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py @@ -1,6 +1,9 @@ -_base_ = './fcn_unet_s5-d16_128x128_40k_stare.py' +_base_ = "./fcn_unet_s5-d16_128x128_40k_stare.py" model = dict( - decode_head=dict(loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), - dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0) - ])) + decode_head=dict( + loss_decode=[ + dict(type="CrossEntropyLoss", loss_name="loss_ce", loss_weight=1.0), + dict(type="DiceLoss", loss_name="loss_dice", loss_weight=3.0), + ] + ) +) diff --git a/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py b/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py index a154d7e..54cbe1f 100644 --- a/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py +++ b/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py @@ -1,6 +1,9 @@ -_base_ = './fcn_unet_s5-d16_256x256_40k_hrf.py' +_base_ = "./fcn_unet_s5-d16_256x256_40k_hrf.py" model = dict( - decode_head=dict(loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), - dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0) - ])) + decode_head=dict( + loss_decode=[ + dict(type="CrossEntropyLoss", loss_name="loss_ce", loss_weight=1.0), + dict(type="DiceLoss", loss_name="loss_dice", loss_weight=3.0), + ] + ) +) diff --git a/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py b/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py index 1b8f860..4403a53 100644 --- a/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py +++ b/mmsegmentation/configs/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py @@ -1,6 +1,9 @@ -_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py' +_base_ = "./fcn_unet_s5-d16_64x64_40k_drive.py" model = dict( - decode_head=dict(loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), - dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0) - ])) + decode_head=dict( + loss_decode=[ + dict(type="CrossEntropyLoss", loss_name="loss_ce", loss_weight=1.0), + dict(type="DiceLoss", loss_name="loss_dice", loss_weight=3.0), + ] + ) +) diff --git a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_128x128_40k_chase_db1.py b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_128x128_40k_chase_db1.py index b085a17..4aad1e2 100644 --- a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_128x128_40k_chase_db1.py +++ b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_128x128_40k_chase_db1.py @@ -1,7 +1,8 @@ _base_ = [ - '../_base_/models/pspnet_unet_s5-d16.py', - '../_base_/datasets/chase_db1.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_40k.py' + "../_base_/models/pspnet_unet_s5-d16.py", + "../_base_/datasets/chase_db1.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] model = dict(test_cfg=dict(crop_size=(128, 128), stride=(85, 85))) -evaluation = dict(metric='mDice') +evaluation = dict(metric="mDice") diff --git a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_128x128_40k_stare.py b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_128x128_40k_stare.py index 9d729ce..c8c596a 100644 --- a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_128x128_40k_stare.py +++ b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_128x128_40k_stare.py @@ -1,6 +1,8 @@ _base_ = [ - '../_base_/models/pspnet_unet_s5-d16.py', '../_base_/datasets/stare.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' + "../_base_/models/pspnet_unet_s5-d16.py", + "../_base_/datasets/stare.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] model = dict(test_cfg=dict(crop_size=(128, 128), stride=(85, 85))) -evaluation = dict(metric='mDice') +evaluation = dict(metric="mDice") diff --git a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_256x256_40k_hrf.py b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_256x256_40k_hrf.py index f57c916..af8fb79 100644 --- a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_256x256_40k_hrf.py +++ b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_256x256_40k_hrf.py @@ -1,6 +1,8 @@ _base_ = [ - '../_base_/models/pspnet_unet_s5-d16.py', '../_base_/datasets/hrf.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' + "../_base_/models/pspnet_unet_s5-d16.py", + "../_base_/datasets/hrf.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] model = dict(test_cfg=dict(crop_size=(256, 256), stride=(170, 170))) -evaluation = dict(metric='mDice') +evaluation = dict(metric="mDice") diff --git a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_64x64_40k_drive.py b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_64x64_40k_drive.py index 7b5421a..4f1d683 100644 --- a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_64x64_40k_drive.py +++ b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_64x64_40k_drive.py @@ -1,6 +1,8 @@ _base_ = [ - '../_base_/models/pspnet_unet_s5-d16.py', '../_base_/datasets/drive.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' + "../_base_/models/pspnet_unet_s5-d16.py", + "../_base_/datasets/drive.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] model = dict(test_cfg=dict(crop_size=(64, 64), stride=(42, 42))) -evaluation = dict(metric='mDice') +evaluation = dict(metric="mDice") diff --git a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1.py b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1.py index a63dc11..f4d2694 100644 --- a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1.py +++ b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1.py @@ -1,6 +1,9 @@ -_base_ = './pspnet_unet_s5-d16_128x128_40k_chase_db1.py' +_base_ = "./pspnet_unet_s5-d16_128x128_40k_chase_db1.py" model = dict( - decode_head=dict(loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), - dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0) - ])) + decode_head=dict( + loss_decode=[ + dict(type="CrossEntropyLoss", loss_name="loss_ce", loss_weight=1.0), + dict(type="DiceLoss", loss_name="loss_dice", loss_weight=3.0), + ] + ) +) diff --git a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py index 1a3b665..7358571 100644 --- a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py +++ b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare.py @@ -1,6 +1,9 @@ -_base_ = './pspnet_unet_s5-d16_128x128_40k_stare.py' +_base_ = "./pspnet_unet_s5-d16_128x128_40k_stare.py" model = dict( - decode_head=dict(loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), - dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0) - ])) + decode_head=dict( + loss_decode=[ + dict(type="CrossEntropyLoss", loss_name="loss_ce", loss_weight=1.0), + dict(type="DiceLoss", loss_name="loss_dice", loss_weight=3.0), + ] + ) +) diff --git a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py index e19d6cf..2cd7a05 100644 --- a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py +++ b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf.py @@ -1,6 +1,9 @@ -_base_ = './pspnet_unet_s5-d16_256x256_40k_hrf.py' +_base_ = "./pspnet_unet_s5-d16_256x256_40k_hrf.py" model = dict( - decode_head=dict(loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), - dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0) - ])) + decode_head=dict( + loss_decode=[ + dict(type="CrossEntropyLoss", loss_name="loss_ce", loss_weight=1.0), + dict(type="DiceLoss", loss_name="loss_dice", loss_weight=3.0), + ] + ) +) diff --git a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py index 7934923..0dcabb6 100644 --- a/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py +++ b/mmsegmentation/configs/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive.py @@ -1,6 +1,9 @@ -_base_ = './pspnet_unet_s5-d16_64x64_40k_drive.py' +_base_ = "./pspnet_unet_s5-d16_64x64_40k_drive.py" model = dict( - decode_head=dict(loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), - dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0) - ])) + decode_head=dict( + loss_decode=[ + dict(type="CrossEntropyLoss", loss_name="loss_ce", loss_weight=1.0), + dict(type="DiceLoss", loss_name="loss_dice", loss_weight=3.0), + ] + ) +) diff --git a/mmsegmentation/configs/upernet/upernet_r101_512x1024_40k_cityscapes.py b/mmsegmentation/configs/upernet/upernet_r101_512x1024_40k_cityscapes.py index b90b597..7f84f0a 100644 --- a/mmsegmentation/configs/upernet/upernet_r101_512x1024_40k_cityscapes.py +++ b/mmsegmentation/configs/upernet/upernet_r101_512x1024_40k_cityscapes.py @@ -1,2 +1,2 @@ -_base_ = './upernet_r50_512x1024_40k_cityscapes.py' -model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) +_base_ = "./upernet_r50_512x1024_40k_cityscapes.py" +model = dict(pretrained="open-mmlab://resnet101_v1c", backbone=dict(depth=101)) diff --git a/mmsegmentation/configs/upernet/upernet_r101_512x1024_80k_cityscapes.py b/mmsegmentation/configs/upernet/upernet_r101_512x1024_80k_cityscapes.py index 420ca2e..7c6959b 100644 --- a/mmsegmentation/configs/upernet/upernet_r101_512x1024_80k_cityscapes.py +++ b/mmsegmentation/configs/upernet/upernet_r101_512x1024_80k_cityscapes.py @@ -1,2 +1,2 @@ -_base_ = './upernet_r50_512x1024_80k_cityscapes.py' -model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) +_base_ = "./upernet_r50_512x1024_80k_cityscapes.py" +model = dict(pretrained="open-mmlab://resnet101_v1c", backbone=dict(depth=101)) diff --git a/mmsegmentation/configs/upernet/upernet_r101_512x512_160k_ade20k.py b/mmsegmentation/configs/upernet/upernet_r101_512x512_160k_ade20k.py index 146f13e..d9750b6 100644 --- a/mmsegmentation/configs/upernet/upernet_r101_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/upernet/upernet_r101_512x512_160k_ade20k.py @@ -1,2 +1,2 @@ -_base_ = './upernet_r50_512x512_160k_ade20k.py' -model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) +_base_ = "./upernet_r50_512x512_160k_ade20k.py" +model = dict(pretrained="open-mmlab://resnet101_v1c", backbone=dict(depth=101)) diff --git a/mmsegmentation/configs/upernet/upernet_r101_512x512_20k_voc12aug.py b/mmsegmentation/configs/upernet/upernet_r101_512x512_20k_voc12aug.py index 56345d1..409b141 100644 --- a/mmsegmentation/configs/upernet/upernet_r101_512x512_20k_voc12aug.py +++ b/mmsegmentation/configs/upernet/upernet_r101_512x512_20k_voc12aug.py @@ -1,2 +1,2 @@ -_base_ = './upernet_r50_512x512_20k_voc12aug.py' -model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) +_base_ = "./upernet_r50_512x512_20k_voc12aug.py" +model = dict(pretrained="open-mmlab://resnet101_v1c", backbone=dict(depth=101)) diff --git a/mmsegmentation/configs/upernet/upernet_r101_512x512_40k_voc12aug.py b/mmsegmentation/configs/upernet/upernet_r101_512x512_40k_voc12aug.py index 0669b74..0b98927 100644 --- a/mmsegmentation/configs/upernet/upernet_r101_512x512_40k_voc12aug.py +++ b/mmsegmentation/configs/upernet/upernet_r101_512x512_40k_voc12aug.py @@ -1,2 +1,2 @@ -_base_ = './upernet_r50_512x512_40k_voc12aug.py' -model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) +_base_ = "./upernet_r50_512x512_40k_voc12aug.py" +model = dict(pretrained="open-mmlab://resnet101_v1c", backbone=dict(depth=101)) diff --git a/mmsegmentation/configs/upernet/upernet_r101_512x512_80k_ade20k.py b/mmsegmentation/configs/upernet/upernet_r101_512x512_80k_ade20k.py index abfb9c5..21af2dd 100644 --- a/mmsegmentation/configs/upernet/upernet_r101_512x512_80k_ade20k.py +++ b/mmsegmentation/configs/upernet/upernet_r101_512x512_80k_ade20k.py @@ -1,2 +1,2 @@ -_base_ = './upernet_r50_512x512_80k_ade20k.py' -model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) +_base_ = "./upernet_r50_512x512_80k_ade20k.py" +model = dict(pretrained="open-mmlab://resnet101_v1c", backbone=dict(depth=101)) diff --git a/mmsegmentation/configs/upernet/upernet_r101_769x769_40k_cityscapes.py b/mmsegmentation/configs/upernet/upernet_r101_769x769_40k_cityscapes.py index e5f3a3f..74e0ffc 100644 --- a/mmsegmentation/configs/upernet/upernet_r101_769x769_40k_cityscapes.py +++ b/mmsegmentation/configs/upernet/upernet_r101_769x769_40k_cityscapes.py @@ -1,2 +1,2 @@ -_base_ = './upernet_r50_769x769_40k_cityscapes.py' -model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) +_base_ = "./upernet_r50_769x769_40k_cityscapes.py" +model = dict(pretrained="open-mmlab://resnet101_v1c", backbone=dict(depth=101)) diff --git a/mmsegmentation/configs/upernet/upernet_r101_769x769_80k_cityscapes.py b/mmsegmentation/configs/upernet/upernet_r101_769x769_80k_cityscapes.py index a709165..dc6df2b 100644 --- a/mmsegmentation/configs/upernet/upernet_r101_769x769_80k_cityscapes.py +++ b/mmsegmentation/configs/upernet/upernet_r101_769x769_80k_cityscapes.py @@ -1,2 +1,2 @@ -_base_ = './upernet_r50_769x769_80k_cityscapes.py' -model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) +_base_ = "./upernet_r50_769x769_80k_cityscapes.py" +model = dict(pretrained="open-mmlab://resnet101_v1c", backbone=dict(depth=101)) diff --git a/mmsegmentation/configs/upernet/upernet_r18_512x1024_40k_cityscapes.py b/mmsegmentation/configs/upernet/upernet_r18_512x1024_40k_cityscapes.py index f5aec1f..50aec22 100644 --- a/mmsegmentation/configs/upernet/upernet_r18_512x1024_40k_cityscapes.py +++ b/mmsegmentation/configs/upernet/upernet_r18_512x1024_40k_cityscapes.py @@ -1,6 +1,7 @@ -_base_ = './upernet_r50_512x1024_40k_cityscapes.py' +_base_ = "./upernet_r50_512x1024_40k_cityscapes.py" model = dict( - pretrained='open-mmlab://resnet18_v1c', + pretrained="open-mmlab://resnet18_v1c", backbone=dict(depth=18), decode_head=dict(in_channels=[64, 128, 256, 512]), - auxiliary_head=dict(in_channels=256)) + auxiliary_head=dict(in_channels=256), +) diff --git a/mmsegmentation/configs/upernet/upernet_r18_512x1024_80k_cityscapes.py b/mmsegmentation/configs/upernet/upernet_r18_512x1024_80k_cityscapes.py index 444f362..49cd177 100644 --- a/mmsegmentation/configs/upernet/upernet_r18_512x1024_80k_cityscapes.py +++ b/mmsegmentation/configs/upernet/upernet_r18_512x1024_80k_cityscapes.py @@ -1,6 +1,7 @@ -_base_ = './upernet_r50_512x1024_80k_cityscapes.py' +_base_ = "./upernet_r50_512x1024_80k_cityscapes.py" model = dict( - pretrained='open-mmlab://resnet18_v1c', + pretrained="open-mmlab://resnet18_v1c", backbone=dict(depth=18), decode_head=dict(in_channels=[64, 128, 256, 512]), - auxiliary_head=dict(in_channels=256)) + auxiliary_head=dict(in_channels=256), +) diff --git a/mmsegmentation/configs/upernet/upernet_r18_512x512_160k_ade20k.py b/mmsegmentation/configs/upernet/upernet_r18_512x512_160k_ade20k.py index 9ac6c35..9edb405 100644 --- a/mmsegmentation/configs/upernet/upernet_r18_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/upernet/upernet_r18_512x512_160k_ade20k.py @@ -1,9 +1,12 @@ _base_ = [ - '../_base_/models/upernet_r50.py', '../_base_/datasets/ade20k.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' + "../_base_/models/upernet_r50.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] model = dict( - pretrained='open-mmlab://resnet18_v1c', + pretrained="open-mmlab://resnet18_v1c", backbone=dict(depth=18), decode_head=dict(in_channels=[64, 128, 256, 512], num_classes=150), - auxiliary_head=dict(in_channels=256, num_classes=150)) + auxiliary_head=dict(in_channels=256, num_classes=150), +) diff --git a/mmsegmentation/configs/upernet/upernet_r18_512x512_20k_voc12aug.py b/mmsegmentation/configs/upernet/upernet_r18_512x512_20k_voc12aug.py index 5cae4f5..2b8192e 100644 --- a/mmsegmentation/configs/upernet/upernet_r18_512x512_20k_voc12aug.py +++ b/mmsegmentation/configs/upernet/upernet_r18_512x512_20k_voc12aug.py @@ -1,10 +1,12 @@ _base_ = [ - '../_base_/models/upernet_r50.py', - '../_base_/datasets/pascal_voc12_aug.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_20k.py' + "../_base_/models/upernet_r50.py", + "../_base_/datasets/pascal_voc12_aug.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_20k.py", ] model = dict( - pretrained='open-mmlab://resnet18_v1c', + pretrained="open-mmlab://resnet18_v1c", backbone=dict(depth=18), decode_head=dict(in_channels=[64, 128, 256, 512], num_classes=21), - auxiliary_head=dict(in_channels=256, num_classes=21)) + auxiliary_head=dict(in_channels=256, num_classes=21), +) diff --git a/mmsegmentation/configs/upernet/upernet_r18_512x512_40k_voc12aug.py b/mmsegmentation/configs/upernet/upernet_r18_512x512_40k_voc12aug.py index 652ded7..1223a80 100644 --- a/mmsegmentation/configs/upernet/upernet_r18_512x512_40k_voc12aug.py +++ b/mmsegmentation/configs/upernet/upernet_r18_512x512_40k_voc12aug.py @@ -1,10 +1,12 @@ _base_ = [ - '../_base_/models/upernet_r50.py', - '../_base_/datasets/pascal_voc12_aug.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_40k.py' + "../_base_/models/upernet_r50.py", + "../_base_/datasets/pascal_voc12_aug.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] model = dict( - pretrained='open-mmlab://resnet18_v1c', + pretrained="open-mmlab://resnet18_v1c", backbone=dict(depth=18), decode_head=dict(in_channels=[64, 128, 256, 512], num_classes=21), - auxiliary_head=dict(in_channels=256, num_classes=21)) + auxiliary_head=dict(in_channels=256, num_classes=21), +) diff --git a/mmsegmentation/configs/upernet/upernet_r18_512x512_80k_ade20k.py b/mmsegmentation/configs/upernet/upernet_r18_512x512_80k_ade20k.py index 1a7956d..e0a8eea 100644 --- a/mmsegmentation/configs/upernet/upernet_r18_512x512_80k_ade20k.py +++ b/mmsegmentation/configs/upernet/upernet_r18_512x512_80k_ade20k.py @@ -1,9 +1,12 @@ _base_ = [ - '../_base_/models/upernet_r50.py', '../_base_/datasets/ade20k.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' + "../_base_/models/upernet_r50.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_80k.py", ] model = dict( - pretrained='open-mmlab://resnet18_v1c', + pretrained="open-mmlab://resnet18_v1c", backbone=dict(depth=18), decode_head=dict(in_channels=[64, 128, 256, 512], num_classes=150), - auxiliary_head=dict(in_channels=256, num_classes=150)) + auxiliary_head=dict(in_channels=256, num_classes=150), +) diff --git a/mmsegmentation/configs/upernet/upernet_r50_512x1024_40k_cityscapes.py b/mmsegmentation/configs/upernet/upernet_r50_512x1024_40k_cityscapes.py index d621e89..e28a9bc 100644 --- a/mmsegmentation/configs/upernet/upernet_r50_512x1024_40k_cityscapes.py +++ b/mmsegmentation/configs/upernet/upernet_r50_512x1024_40k_cityscapes.py @@ -1,4 +1,6 @@ _base_ = [ - '../_base_/models/upernet_r50.py', '../_base_/datasets/cityscapes.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' + "../_base_/models/upernet_r50.py", + "../_base_/datasets/cityscapes.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] diff --git a/mmsegmentation/configs/upernet/upernet_r50_512x1024_80k_cityscapes.py b/mmsegmentation/configs/upernet/upernet_r50_512x1024_80k_cityscapes.py index 95fffcc..e442f24 100644 --- a/mmsegmentation/configs/upernet/upernet_r50_512x1024_80k_cityscapes.py +++ b/mmsegmentation/configs/upernet/upernet_r50_512x1024_80k_cityscapes.py @@ -1,4 +1,6 @@ _base_ = [ - '../_base_/models/upernet_r50.py', '../_base_/datasets/cityscapes.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' + "../_base_/models/upernet_r50.py", + "../_base_/datasets/cityscapes.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_80k.py", ] diff --git a/mmsegmentation/configs/upernet/upernet_r50_512x512_160k_ade20k.py b/mmsegmentation/configs/upernet/upernet_r50_512x512_160k_ade20k.py index f5dd9aa..e9b0062 100644 --- a/mmsegmentation/configs/upernet/upernet_r50_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/upernet/upernet_r50_512x512_160k_ade20k.py @@ -1,6 +1,7 @@ _base_ = [ - '../_base_/models/upernet_r50.py', '../_base_/datasets/ade20k.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' + "../_base_/models/upernet_r50.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] -model = dict( - decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150)) +model = dict(decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150)) diff --git a/mmsegmentation/configs/upernet/upernet_r50_512x512_20k_voc12aug.py b/mmsegmentation/configs/upernet/upernet_r50_512x512_20k_voc12aug.py index 95f5c09..1cd491f 100644 --- a/mmsegmentation/configs/upernet/upernet_r50_512x512_20k_voc12aug.py +++ b/mmsegmentation/configs/upernet/upernet_r50_512x512_20k_voc12aug.py @@ -1,7 +1,7 @@ _base_ = [ - '../_base_/models/upernet_r50.py', - '../_base_/datasets/pascal_voc12_aug.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_20k.py' + "../_base_/models/upernet_r50.py", + "../_base_/datasets/pascal_voc12_aug.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_20k.py", ] -model = dict( - decode_head=dict(num_classes=21), auxiliary_head=dict(num_classes=21)) +model = dict(decode_head=dict(num_classes=21), auxiliary_head=dict(num_classes=21)) diff --git a/mmsegmentation/configs/upernet/upernet_r50_512x512_40k_voc12aug.py b/mmsegmentation/configs/upernet/upernet_r50_512x512_40k_voc12aug.py index 9621fd1..b644546 100644 --- a/mmsegmentation/configs/upernet/upernet_r50_512x512_40k_voc12aug.py +++ b/mmsegmentation/configs/upernet/upernet_r50_512x512_40k_voc12aug.py @@ -1,7 +1,7 @@ _base_ = [ - '../_base_/models/upernet_r50.py', - '../_base_/datasets/pascal_voc12_aug.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_40k.py' + "../_base_/models/upernet_r50.py", + "../_base_/datasets/pascal_voc12_aug.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] -model = dict( - decode_head=dict(num_classes=21), auxiliary_head=dict(num_classes=21)) +model = dict(decode_head=dict(num_classes=21), auxiliary_head=dict(num_classes=21)) diff --git a/mmsegmentation/configs/upernet/upernet_r50_512x512_80k_ade20k.py b/mmsegmentation/configs/upernet/upernet_r50_512x512_80k_ade20k.py index f561e30..565e6de 100644 --- a/mmsegmentation/configs/upernet/upernet_r50_512x512_80k_ade20k.py +++ b/mmsegmentation/configs/upernet/upernet_r50_512x512_80k_ade20k.py @@ -1,6 +1,7 @@ _base_ = [ - '../_base_/models/upernet_r50.py', '../_base_/datasets/ade20k.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' + "../_base_/models/upernet_r50.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_80k.py", ] -model = dict( - decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150)) +model = dict(decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150)) diff --git a/mmsegmentation/configs/upernet/upernet_r50_769x769_40k_cityscapes.py b/mmsegmentation/configs/upernet/upernet_r50_769x769_40k_cityscapes.py index 89b18aa..b4a26e0 100644 --- a/mmsegmentation/configs/upernet/upernet_r50_769x769_40k_cityscapes.py +++ b/mmsegmentation/configs/upernet/upernet_r50_769x769_40k_cityscapes.py @@ -1,9 +1,11 @@ _base_ = [ - '../_base_/models/upernet_r50.py', - '../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_40k.py' + "../_base_/models/upernet_r50.py", + "../_base_/datasets/cityscapes_769x769.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_40k.py", ] model = dict( decode_head=dict(align_corners=True), auxiliary_head=dict(align_corners=True), - test_cfg=dict(mode='slide', crop_size=(769, 769), stride=(513, 513))) + test_cfg=dict(mode="slide", crop_size=(769, 769), stride=(513, 513)), +) diff --git a/mmsegmentation/configs/upernet/upernet_r50_769x769_80k_cityscapes.py b/mmsegmentation/configs/upernet/upernet_r50_769x769_80k_cityscapes.py index 29af98f..7864ad6 100644 --- a/mmsegmentation/configs/upernet/upernet_r50_769x769_80k_cityscapes.py +++ b/mmsegmentation/configs/upernet/upernet_r50_769x769_80k_cityscapes.py @@ -1,9 +1,11 @@ _base_ = [ - '../_base_/models/upernet_r50.py', - '../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_80k.py' + "../_base_/models/upernet_r50.py", + "../_base_/datasets/cityscapes_769x769.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_80k.py", ] model = dict( decode_head=dict(align_corners=True), auxiliary_head=dict(align_corners=True), - test_cfg=dict(mode='slide', crop_size=(769, 769), stride=(513, 513))) + test_cfg=dict(mode="slide", crop_size=(769, 769), stride=(513, 513)), +) diff --git a/mmsegmentation/configs/vit/upernet_deit-b16_512x512_160k_ade20k.py b/mmsegmentation/configs/vit/upernet_deit-b16_512x512_160k_ade20k.py index 68f4bd4..4f35cf7 100644 --- a/mmsegmentation/configs/vit/upernet_deit-b16_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/vit/upernet_deit-b16_512x512_160k_ade20k.py @@ -1,6 +1,7 @@ -_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' +_base_ = "./upernet_vit-b16_mln_512x512_160k_ade20k.py" model = dict( - pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth', + pretrained="pretrain/deit_base_patch16_224-b5f2ef4d.pth", backbone=dict(drop_path_rate=0.1), - neck=None) + neck=None, +) diff --git a/mmsegmentation/configs/vit/upernet_deit-b16_512x512_80k_ade20k.py b/mmsegmentation/configs/vit/upernet_deit-b16_512x512_80k_ade20k.py index 7204826..30c040e 100644 --- a/mmsegmentation/configs/vit/upernet_deit-b16_512x512_80k_ade20k.py +++ b/mmsegmentation/configs/vit/upernet_deit-b16_512x512_80k_ade20k.py @@ -1,6 +1,7 @@ -_base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py' +_base_ = "./upernet_vit-b16_mln_512x512_80k_ade20k.py" model = dict( - pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth', + pretrained="pretrain/deit_base_patch16_224-b5f2ef4d.pth", backbone=dict(drop_path_rate=0.1), - neck=None) + neck=None, +) diff --git a/mmsegmentation/configs/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k.py b/mmsegmentation/configs/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k.py index 32909ff..5bd3f9c 100644 --- a/mmsegmentation/configs/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k.py @@ -1,5 +1,6 @@ -_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' +_base_ = "./upernet_vit-b16_mln_512x512_160k_ade20k.py" model = dict( - pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth', - backbone=dict(drop_path_rate=0.1, final_norm=True)) + pretrained="pretrain/deit_base_patch16_224-b5f2ef4d.pth", + backbone=dict(drop_path_rate=0.1, final_norm=True), +) diff --git a/mmsegmentation/configs/vit/upernet_deit-b16_mln_512x512_160k_ade20k.py b/mmsegmentation/configs/vit/upernet_deit-b16_mln_512x512_160k_ade20k.py index 4abefe8..9b09f8a 100644 --- a/mmsegmentation/configs/vit/upernet_deit-b16_mln_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/vit/upernet_deit-b16_mln_512x512_160k_ade20k.py @@ -1,6 +1,6 @@ -_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' +_base_ = "./upernet_vit-b16_mln_512x512_160k_ade20k.py" model = dict( - pretrained='pretrain/deit_base_patch16_224-b5f2ef4d.pth', + pretrained="pretrain/deit_base_patch16_224-b5f2ef4d.pth", backbone=dict(drop_path_rate=0.1), ) diff --git a/mmsegmentation/configs/vit/upernet_deit-s16_512x512_160k_ade20k.py b/mmsegmentation/configs/vit/upernet_deit-s16_512x512_160k_ade20k.py index 290ff19..2a05804 100644 --- a/mmsegmentation/configs/vit/upernet_deit-s16_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/vit/upernet_deit-s16_512x512_160k_ade20k.py @@ -1,8 +1,9 @@ -_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' +_base_ = "./upernet_vit-b16_mln_512x512_160k_ade20k.py" model = dict( - pretrained='pretrain/deit_small_patch16_224-cd65a155.pth', + pretrained="pretrain/deit_small_patch16_224-cd65a155.pth", backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1), decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]), neck=None, - auxiliary_head=dict(num_classes=150, in_channels=384)) + auxiliary_head=dict(num_classes=150, in_channels=384), +) diff --git a/mmsegmentation/configs/vit/upernet_deit-s16_512x512_80k_ade20k.py b/mmsegmentation/configs/vit/upernet_deit-s16_512x512_80k_ade20k.py index 605d264..885bd3d 100644 --- a/mmsegmentation/configs/vit/upernet_deit-s16_512x512_80k_ade20k.py +++ b/mmsegmentation/configs/vit/upernet_deit-s16_512x512_80k_ade20k.py @@ -1,8 +1,9 @@ -_base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py' +_base_ = "./upernet_vit-b16_mln_512x512_80k_ade20k.py" model = dict( - pretrained='pretrain/deit_small_patch16_224-cd65a155.pth', + pretrained="pretrain/deit_small_patch16_224-cd65a155.pth", backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1), decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]), neck=None, - auxiliary_head=dict(num_classes=150, in_channels=384)) + auxiliary_head=dict(num_classes=150, in_channels=384), +) diff --git a/mmsegmentation/configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py b/mmsegmentation/configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py index ef743a2..756ad5b 100644 --- a/mmsegmentation/configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py @@ -1,9 +1,9 @@ -_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' +_base_ = "./upernet_vit-b16_mln_512x512_160k_ade20k.py" model = dict( - pretrained='pretrain/deit_small_patch16_224-cd65a155.pth', - backbone=dict( - num_heads=6, embed_dims=384, drop_path_rate=0.1, final_norm=True), + pretrained="pretrain/deit_small_patch16_224-cd65a155.pth", + backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1, final_norm=True), decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]), neck=dict(in_channels=[384, 384, 384, 384], out_channels=384), - auxiliary_head=dict(num_classes=150, in_channels=384)) + auxiliary_head=dict(num_classes=150, in_channels=384), +) diff --git a/mmsegmentation/configs/vit/upernet_deit-s16_mln_512x512_160k_ade20k.py b/mmsegmentation/configs/vit/upernet_deit-s16_mln_512x512_160k_ade20k.py index 069cab7..1c5850f 100644 --- a/mmsegmentation/configs/vit/upernet_deit-s16_mln_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/vit/upernet_deit-s16_mln_512x512_160k_ade20k.py @@ -1,8 +1,9 @@ -_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py' +_base_ = "./upernet_vit-b16_mln_512x512_160k_ade20k.py" model = dict( - pretrained='pretrain/deit_small_patch16_224-cd65a155.pth', + pretrained="pretrain/deit_small_patch16_224-cd65a155.pth", backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1), decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]), neck=dict(in_channels=[384, 384, 384, 384], out_channels=384), - auxiliary_head=dict(num_classes=150, in_channels=384)) + auxiliary_head=dict(num_classes=150, in_channels=384), +) diff --git a/mmsegmentation/configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py b/mmsegmentation/configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py index 51eeda0..d23b83c 100644 --- a/mmsegmentation/configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py @@ -1,39 +1,44 @@ _base_ = [ - '../_base_/models/upernet_vit-b16_ln_mln.py', - '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_160k.py' + "../_base_/models/upernet_vit-b16_ln_mln.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] model = dict( - pretrained='pretrain/vit_base_patch16_224.pth', + pretrained="pretrain/vit_base_patch16_224.pth", backbone=dict(drop_path_rate=0.1, final_norm=True), decode_head=dict(num_classes=150), - auxiliary_head=dict(num_classes=150)) + auxiliary_head=dict(num_classes=150), +) # AdamW optimizer, no weight decay for position embedding & layer norm # in backbone optimizer = dict( _delete_=True, - type='AdamW', + type="AdamW", lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, paramwise_cfg=dict( custom_keys={ - 'pos_embed': dict(decay_mult=0.), - 'cls_token': dict(decay_mult=0.), - 'norm': dict(decay_mult=0.) - })) + "pos_embed": dict(decay_mult=0.0), + "cls_token": dict(decay_mult=0.0), + "norm": dict(decay_mult=0.0), + } + ), +) lr_config = dict( _delete_=True, - policy='poly', - warmup='linear', + policy="poly", + warmup="linear", warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, - by_epoch=False) + by_epoch=False, +) # By default, models are trained on 8 GPUs with 2 images per GPU data = dict(samples_per_gpu=2) diff --git a/mmsegmentation/configs/vit/upernet_vit-b16_mln_512x512_160k_ade20k.py b/mmsegmentation/configs/vit/upernet_vit-b16_mln_512x512_160k_ade20k.py index 5b148d7..8853e33 100644 --- a/mmsegmentation/configs/vit/upernet_vit-b16_mln_512x512_160k_ade20k.py +++ b/mmsegmentation/configs/vit/upernet_vit-b16_mln_512x512_160k_ade20k.py @@ -1,38 +1,43 @@ _base_ = [ - '../_base_/models/upernet_vit-b16_ln_mln.py', - '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_160k.py' + "../_base_/models/upernet_vit-b16_ln_mln.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_160k.py", ] model = dict( - pretrained='pretrain/vit_base_patch16_224.pth', + pretrained="pretrain/vit_base_patch16_224.pth", decode_head=dict(num_classes=150), - auxiliary_head=dict(num_classes=150)) + auxiliary_head=dict(num_classes=150), +) # AdamW optimizer, no weight decay for position embedding & layer norm # in backbone optimizer = dict( _delete_=True, - type='AdamW', + type="AdamW", lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, paramwise_cfg=dict( custom_keys={ - 'pos_embed': dict(decay_mult=0.), - 'cls_token': dict(decay_mult=0.), - 'norm': dict(decay_mult=0.) - })) + "pos_embed": dict(decay_mult=0.0), + "cls_token": dict(decay_mult=0.0), + "norm": dict(decay_mult=0.0), + } + ), +) lr_config = dict( _delete_=True, - policy='poly', - warmup='linear', + policy="poly", + warmup="linear", warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, - by_epoch=False) + by_epoch=False, +) # By default, models are trained on 8 GPUs with 2 images per GPU data = dict(samples_per_gpu=2) diff --git a/mmsegmentation/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py b/mmsegmentation/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py index f893500..c0d9fc5 100644 --- a/mmsegmentation/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py +++ b/mmsegmentation/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py @@ -1,38 +1,43 @@ _base_ = [ - '../_base_/models/upernet_vit-b16_ln_mln.py', - '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', - '../_base_/schedules/schedule_80k.py' + "../_base_/models/upernet_vit-b16_ln_mln.py", + "../_base_/datasets/ade20k.py", + "../_base_/default_runtime.py", + "../_base_/schedules/schedule_80k.py", ] model = dict( - pretrained='pretrain/vit_base_patch16_224.pth', + pretrained="pretrain/vit_base_patch16_224.pth", decode_head=dict(num_classes=150), - auxiliary_head=dict(num_classes=150)) + auxiliary_head=dict(num_classes=150), +) # AdamW optimizer, no weight decay for position embedding & layer norm # in backbone optimizer = dict( _delete_=True, - type='AdamW', + type="AdamW", lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, paramwise_cfg=dict( custom_keys={ - 'pos_embed': dict(decay_mult=0.), - 'cls_token': dict(decay_mult=0.), - 'norm': dict(decay_mult=0.) - })) + "pos_embed": dict(decay_mult=0.0), + "cls_token": dict(decay_mult=0.0), + "norm": dict(decay_mult=0.0), + } + ), +) lr_config = dict( _delete_=True, - policy='poly', - warmup='linear', + policy="poly", + warmup="linear", warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, - by_epoch=False) + by_epoch=False, +) # By default, models are trained on 8 GPUs with 2 images per GPU data = dict(samples_per_gpu=2) diff --git a/mmsegmentation/demo/image_demo.py b/mmsegmentation/demo/image_demo.py index 87d6d6c..2902f57 100644 --- a/mmsegmentation/demo/image_demo.py +++ b/mmsegmentation/demo/image_demo.py @@ -7,21 +7,22 @@ def main(): parser = ArgumentParser() - parser.add_argument('img', help='Image file') - parser.add_argument('config', help='Config file') - parser.add_argument('checkpoint', help='Checkpoint file') - parser.add_argument('--out-file', default=None, help='Path to output file') + parser.add_argument("img", help="Image file") + parser.add_argument("config", help="Config file") + parser.add_argument("checkpoint", help="Checkpoint file") + parser.add_argument("--out-file", default=None, help="Path to output file") + parser.add_argument("--device", default="cuda:0", help="Device used for inference") parser.add_argument( - '--device', default='cuda:0', help='Device used for inference') + "--palette", + default="cityscapes", + help="Color palette used for segmentation map", + ) parser.add_argument( - '--palette', - default='cityscapes', - help='Color palette used for segmentation map') - parser.add_argument( - '--opacity', + "--opacity", type=float, default=0.5, - help='Opacity of painted segmentation map. In (0, 1] range.') + help="Opacity of painted segmentation map. In (0, 1] range.", + ) args = parser.parse_args() # build the model from a config file and a checkpoint file @@ -35,8 +36,9 @@ def main(): result, get_palette(args.palette), opacity=args.opacity, - out_file=args.out_file) + out_file=args.out_file, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/demo/video_demo.py b/mmsegmentation/demo/video_demo.py index eb4fd69..27a4019 100644 --- a/mmsegmentation/demo/video_demo.py +++ b/mmsegmentation/demo/video_demo.py @@ -9,54 +9,52 @@ def main(): parser = ArgumentParser() - parser.add_argument('video', help='Video file or webcam id') - parser.add_argument('config', help='Config file') - parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument("video", help="Video file or webcam id") + parser.add_argument("config", help="Config file") + parser.add_argument("checkpoint", help="Checkpoint file") + parser.add_argument("--device", default="cuda:0", help="Device used for inference") parser.add_argument( - '--device', default='cuda:0', help='Device used for inference') + "--palette", + default="cityscapes", + help="Color palette used for segmentation map", + ) parser.add_argument( - '--palette', - default='cityscapes', - help='Color palette used for segmentation map') + "--show", action="store_true", help="Whether to show draw result" + ) parser.add_argument( - '--show', action='store_true', help='Whether to show draw result') + "--show-wait-time", default=1, type=int, help="Wait time after imshow" + ) parser.add_argument( - '--show-wait-time', default=1, type=int, help='Wait time after imshow') + "--output-file", default=None, type=str, help="Output video file path" + ) parser.add_argument( - '--output-file', default=None, type=str, help='Output video file path') + "--output-fourcc", default="MJPG", type=str, help="Fourcc of the output video" + ) parser.add_argument( - '--output-fourcc', - default='MJPG', - type=str, - help='Fourcc of the output video') + "--output-fps", default=-1, type=int, help="FPS of the output video" + ) parser.add_argument( - '--output-fps', default=-1, type=int, help='FPS of the output video') + "--output-height", default=-1, type=int, help="Frame height of the output video" + ) parser.add_argument( - '--output-height', - default=-1, - type=int, - help='Frame height of the output video') + "--output-width", default=-1, type=int, help="Frame width of the output video" + ) parser.add_argument( - '--output-width', - default=-1, - type=int, - help='Frame width of the output video') - parser.add_argument( - '--opacity', + "--opacity", type=float, default=0.5, - help='Opacity of painted segmentation map. In (0, 1] range.') + help="Opacity of painted segmentation map. In (0, 1] range.", + ) args = parser.parse_args() - assert args.show or args.output_file, \ - 'At least one output should be enabled.' + assert args.show or args.output_file, "At least one output should be enabled." # build the model from a config file and a checkpoint file model = init_segmentor(args.config, args.checkpoint, device=args.device) # build input video cap = cv2.VideoCapture(args.video) - assert (cap.isOpened()) + assert cap.isOpened() input_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) input_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) input_fps = cap.get(cv2.CAP_PROP_FPS) @@ -68,12 +66,13 @@ def main(): if args.output_file is not None: fourcc = cv2.VideoWriter_fourcc(*args.output_fourcc) output_fps = args.output_fps if args.output_fps > 0 else input_fps - output_height = args.output_height if args.output_height > 0 else int( - input_height) - output_width = args.output_width if args.output_width > 0 else int( - input_width) - writer = cv2.VideoWriter(args.output_file, fourcc, output_fps, - (output_width, output_height), True) + output_height = ( + args.output_height if args.output_height > 0 else int(input_height) + ) + output_width = args.output_width if args.output_width > 0 else int(input_width) + writer = cv2.VideoWriter( + args.output_file, fourcc, output_fps, (output_width, output_height), True + ) # start looping try: @@ -91,16 +90,18 @@ def main(): result, palette=get_palette(args.palette), show=False, - opacity=args.opacity) + opacity=args.opacity, + ) if args.show: - cv2.imshow('video_demo', draw_img) + cv2.imshow("video_demo", draw_img) cv2.waitKey(args.show_wait_time) if writer: - if draw_img.shape[0] != output_height or draw_img.shape[ - 1] != output_width: - draw_img = cv2.resize(draw_img, - (output_width, output_height)) + if ( + draw_img.shape[0] != output_height + or draw_img.shape[1] != output_width + ): + draw_img = cv2.resize(draw_img, (output_width, output_height)) writer.write(draw_img) finally: if writer: @@ -108,5 +109,5 @@ def main(): cap.release() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/docs/en/conf.py b/mmsegmentation/docs/en/conf.py index cd2113d..5803603 100644 --- a/mmsegmentation/docs/en/conf.py +++ b/mmsegmentation/docs/en/conf.py @@ -17,20 +17,20 @@ import pytorch_sphinx_theme -sys.path.insert(0, os.path.abspath('../../')) +sys.path.insert(0, os.path.abspath("../../")) # -- Project information ----------------------------------------------------- -project = 'MMSegmentation' -copyright = '2020-2021, OpenMMLab' -author = 'MMSegmentation Authors' -version_file = '../../mmseg/version.py' +project = "MMSegmentation" +copyright = "2020-2021, OpenMMLab" +author = "MMSegmentation Authors" +version_file = "../../mmseg/version.py" def get_version(): - with open(version_file, 'r') as f: - exec(compile(f.read(), version_file, 'exec')) - return locals()['__version__'] + with open(version_file) as f: + exec(compile(f.read(), version_file, "exec")) + return locals()["__version__"] # The full version, including alpha/beta/rc tags @@ -42,36 +42,38 @@ def get_version(): # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', - 'sphinx_markdown_tables', 'sphinx_copybutton', 'myst_parser' + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx_markdown_tables", + "sphinx_copybutton", + "myst_parser", ] -autodoc_mock_imports = [ - 'matplotlib', 'pycocotools', 'mmseg.version', 'mmcv.ops' -] +autodoc_mock_imports = ["matplotlib", "pycocotools", "mmseg.version", "mmcv.ops"] # Ignore >>> when copying code -copybutton_prompt_text = r'>>> |\.\.\. ' +copybutton_prompt_text = r">>> |\.\.\. " copybutton_prompt_is_regexp = True # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', + ".rst": "restructuredtext", + ".md": "markdown", } # The master toctree document. -master_doc = 'index' +master_doc = "index" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- @@ -79,56 +81,48 @@ def get_version(): # a list of builtin themes. # # html_theme = 'sphinx_rtd_theme' -html_theme = 'pytorch_sphinx_theme' +html_theme = "pytorch_sphinx_theme" html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] html_theme_options = { - 'logo_url': - 'https://mmsegmentation.readthedocs.io/en/latest/', - 'menu': [ - { - 'name': - 'Tutorial', - 'url': - 'https://github.com/open-mmlab/mmsegmentation/blob/master/' - 'demo/MMSegmentation_Tutorial.ipynb' - }, + "logo_url": "https://mmsegmentation.readthedocs.io/en/latest/", + "menu": [ { - 'name': 'GitHub', - 'url': 'https://github.com/open-mmlab/mmsegmentation' + "name": "Tutorial", + "url": "https://github.com/open-mmlab/mmsegmentation/blob/master/" + "demo/MMSegmentation_Tutorial.ipynb", }, + {"name": "GitHub", "url": "https://github.com/open-mmlab/mmsegmentation"}, { - 'name': - 'Upstream', - 'children': [ + "name": "Upstream", + "children": [ { - 'name': 'MMCV', - 'url': 'https://github.com/open-mmlab/mmcv', - 'description': 'Foundational library for computer vision' + "name": "MMCV", + "url": "https://github.com/open-mmlab/mmcv", + "description": "Foundational library for computer vision", }, - ] + ], }, ], # Specify the language of shared menu - 'menu_lang': - 'en' + "menu_lang": "en", } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] -html_css_files = ['css/readthedocs.css'] +html_static_path = ["_static"] +html_css_files = ["css/readthedocs.css"] # Enable ::: for my_st -myst_enable_extensions = ['colon_fence'] +myst_enable_extensions = ["colon_fence"] myst_heading_anchors = 3 -language = 'en' +language = "en" def builder_inited_handler(app): - subprocess.run(['./stat.py']) + subprocess.run(["./stat.py"]) def setup(app): - app.connect('builder-inited', builder_inited_handler) + app.connect("builder-inited", builder_inited_handler) diff --git a/mmsegmentation/docs/en/stat.py b/mmsegmentation/docs/en/stat.py index 1398a70..f9a3273 100755 --- a/mmsegmentation/docs/en/stat.py +++ b/mmsegmentation/docs/en/stat.py @@ -7,34 +7,34 @@ import numpy as np -url_prefix = 'https://github.com/open-mmlab/mmsegmentation/blob/master/' +url_prefix = "https://github.com/open-mmlab/mmsegmentation/blob/master/" -files = sorted(glob.glob('../../configs/*/README.md')) +files = sorted(glob.glob("../../configs/*/README.md")) stats = [] titles = [] num_ckpts = 0 for f in files: - url = osp.dirname(f.replace('../../', url_prefix)) + url = osp.dirname(f.replace("../../", url_prefix)) - with open(f, 'r') as content_file: + with open(f) as content_file: content = content_file.read() - title = content.split('\n')[0].replace('#', '').strip() - ckpts = set(x.lower().strip() - for x in re.findall(r'https?://download.*\.pth', content) - if 'mmsegmentation' in x) + title = content.split("\n")[0].replace("#", "").strip() + ckpts = { + x.lower().strip() + for x in re.findall(r"https?://download.*\.pth", content) + if "mmsegmentation" in x + } if len(ckpts) == 0: continue - _papertype = [ - x for x in re.findall(r'', content) - ] + _papertype = [x for x in re.findall(r"", content)] assert len(_papertype) > 0 papertype = _papertype[0] - paper = set([(papertype, title)]) + paper = {(papertype, title)} titles.append(title) num_ckpts += len(ckpts) @@ -44,12 +44,10 @@ stats.append((paper, ckpts, statsmsg)) allpapers = func.reduce(lambda a, b: a.union(b), [p for p, _, _ in stats]) -msglist = '\n'.join(x for _, _, x in stats) +msglist = "\n".join(x for _, _, x in stats) -papertypes, papercounts = np.unique([t for t, _ in allpapers], - return_counts=True) -countstr = '\n'.join( - [f' - {t}: {c}' for t, c in zip(papertypes, papercounts)]) +papertypes, papercounts = np.unique([t for t, _ in allpapers], return_counts=True) +countstr = "\n".join([f" - {t}: {c}" for t, c in zip(papertypes, papercounts)]) modelzoo = f""" # Model Zoo Statistics @@ -61,5 +59,5 @@ {msglist} """ -with open('modelzoo_statistics.md', 'w') as f: +with open("modelzoo_statistics.md", "w") as f: f.write(modelzoo) diff --git a/mmsegmentation/docs/zh_cn/conf.py b/mmsegmentation/docs/zh_cn/conf.py index 4dec48d..4dfacca 100644 --- a/mmsegmentation/docs/zh_cn/conf.py +++ b/mmsegmentation/docs/zh_cn/conf.py @@ -17,20 +17,20 @@ import pytorch_sphinx_theme -sys.path.insert(0, os.path.abspath('../../')) +sys.path.insert(0, os.path.abspath("../../")) # -- Project information ----------------------------------------------------- -project = 'MMSegmentation' -copyright = '2020-2021, OpenMMLab' -author = 'MMSegmentation Authors' -version_file = '../../mmseg/version.py' +project = "MMSegmentation" +copyright = "2020-2021, OpenMMLab" +author = "MMSegmentation Authors" +version_file = "../../mmseg/version.py" def get_version(): - with open(version_file, 'r') as f: - exec(compile(f.read(), version_file, 'exec')) - return locals()['__version__'] + with open(version_file) as f: + exec(compile(f.read(), version_file, "exec")) + return locals()["__version__"] # The full version, including alpha/beta/rc tags @@ -42,36 +42,38 @@ def get_version(): # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', - 'sphinx_markdown_tables', 'sphinx_copybutton', 'myst_parser' + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx_markdown_tables", + "sphinx_copybutton", + "myst_parser", ] -autodoc_mock_imports = [ - 'matplotlib', 'pycocotools', 'mmseg.version', 'mmcv.ops' -] +autodoc_mock_imports = ["matplotlib", "pycocotools", "mmseg.version", "mmcv.ops"] # Ignore >>> when copying code -copybutton_prompt_text = r'>>> |\.\.\. ' +copybutton_prompt_text = r">>> |\.\.\. " copybutton_prompt_is_regexp = True # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', + ".rst": "restructuredtext", + ".md": "markdown", } # The master toctree document. -master_doc = 'index' +master_doc = "index" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- @@ -79,56 +81,48 @@ def get_version(): # a list of builtin themes. # # html_theme = 'sphinx_rtd_theme' -html_theme = 'pytorch_sphinx_theme' +html_theme = "pytorch_sphinx_theme" html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] html_theme_options = { - 'logo_url': - 'https://mmsegmentation.readthedocs.io/zh-CN/latest/', - 'menu': [ - { - 'name': - '教程', - 'url': - 'https://github.com/open-mmlab/mmsegmentation/blob/master/' - 'demo/MMSegmentation_Tutorial.ipynb' - }, + "logo_url": "https://mmsegmentation.readthedocs.io/zh-CN/latest/", + "menu": [ { - 'name': 'GitHub', - 'url': 'https://github.com/open-mmlab/mmsegmentation' + "name": "教程", + "url": "https://github.com/open-mmlab/mmsegmentation/blob/master/" + "demo/MMSegmentation_Tutorial.ipynb", }, + {"name": "GitHub", "url": "https://github.com/open-mmlab/mmsegmentation"}, { - 'name': - '上游库', - 'children': [ + "name": "上游库", + "children": [ { - 'name': 'MMCV', - 'url': 'https://github.com/open-mmlab/mmcv', - 'description': '基础视觉库' + "name": "MMCV", + "url": "https://github.com/open-mmlab/mmcv", + "description": "基础视觉库", }, - ] + ], }, ], # Specify the language of shared menu - 'menu_lang': - 'cn', + "menu_lang": "cn", } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] -html_css_files = ['css/readthedocs.css'] +html_static_path = ["_static"] +html_css_files = ["css/readthedocs.css"] # Enable ::: for my_st -myst_enable_extensions = ['colon_fence'] +myst_enable_extensions = ["colon_fence"] myst_heading_anchors = 3 -language = 'zh-CN' +language = "zh-CN" def builder_inited_handler(app): - subprocess.run(['./stat.py']) + subprocess.run(["./stat.py"]) def setup(app): - app.connect('builder-inited', builder_inited_handler) + app.connect("builder-inited", builder_inited_handler) diff --git a/mmsegmentation/docs/zh_cn/stat.py b/mmsegmentation/docs/zh_cn/stat.py index b3a1d73..5beb577 100755 --- a/mmsegmentation/docs/zh_cn/stat.py +++ b/mmsegmentation/docs/zh_cn/stat.py @@ -7,34 +7,34 @@ import numpy as np -url_prefix = 'https://github.com/open-mmlab/mmsegmentation/blob/master/' +url_prefix = "https://github.com/open-mmlab/mmsegmentation/blob/master/" -files = sorted(glob.glob('../../configs/*/README.md')) +files = sorted(glob.glob("../../configs/*/README.md")) stats = [] titles = [] num_ckpts = 0 for f in files: - url = osp.dirname(f.replace('../../', url_prefix)) + url = osp.dirname(f.replace("../../", url_prefix)) - with open(f, 'r') as content_file: + with open(f) as content_file: content = content_file.read() - title = content.split('\n')[0].replace('#', '').strip() - ckpts = set(x.lower().strip() - for x in re.findall(r'https?://download.*\.pth', content) - if 'mmsegmentation' in x) + title = content.split("\n")[0].replace("#", "").strip() + ckpts = { + x.lower().strip() + for x in re.findall(r"https?://download.*\.pth", content) + if "mmsegmentation" in x + } if len(ckpts) == 0: continue - _papertype = [ - x for x in re.findall(r'', content) - ] + _papertype = [x for x in re.findall(r"", content)] assert len(_papertype) > 0 papertype = _papertype[0] - paper = set([(papertype, title)]) + paper = {(papertype, title)} titles.append(title) num_ckpts += len(ckpts) @@ -44,12 +44,10 @@ stats.append((paper, ckpts, statsmsg)) allpapers = func.reduce(lambda a, b: a.union(b), [p for p, _, _ in stats]) -msglist = '\n'.join(x for _, _, x in stats) +msglist = "\n".join(x for _, _, x in stats) -papertypes, papercounts = np.unique([t for t, _ in allpapers], - return_counts=True) -countstr = '\n'.join( - [f' - {t}: {c}' for t, c in zip(papertypes, papercounts)]) +papertypes, papercounts = np.unique([t for t, _ in allpapers], return_counts=True) +countstr = "\n".join([f" - {t}: {c}" for t, c in zip(papertypes, papercounts)]) modelzoo = f""" # 模型库统计数据 @@ -61,5 +59,5 @@ {msglist} """ -with open('modelzoo_statistics.md', 'w') as f: +with open("modelzoo_statistics.md", "w") as f: f.write(modelzoo) diff --git a/mmsegmentation/mmseg/__init__.py b/mmsegmentation/mmseg/__init__.py index c28bf4e..66bda18 100644 --- a/mmsegmentation/mmseg/__init__.py +++ b/mmsegmentation/mmseg/__init__.py @@ -6,8 +6,8 @@ from .version import __version__, version_info -MMCV_MIN = '1.3.13' -MMCV_MAX = '1.8.0' +MMCV_MIN = "1.3.13" +MMCV_MAX = "1.8.0" def digit_version(version_str: str, length: int = 4): @@ -24,19 +24,21 @@ def digit_version(version_str: str, length: int = 4): tuple[int]: The version info in digits (integers). """ version = parse(version_str) - assert version.release, f'failed to parse version {version_str}' + assert version.release, f"failed to parse version {version_str}" release = list(version.release) release = release[:length] if len(release) < length: release = release + [0] * (length - len(release)) if version.is_prerelease: - mapping = {'a': -3, 'b': -2, 'rc': -1} + mapping = {"a": -3, "b": -2, "rc": -1} val = -4 # version.pre can be None if version.pre: if version.pre[0] not in mapping: - warnings.warn(f'unknown prerelease version {version.pre[0]}, ' - 'version checking may go wrong') + warnings.warn( + f"unknown prerelease version {version.pre[0]}, " + "version checking may go wrong" + ) else: val = mapping[version.pre[0]] release.extend([val, version.pre[-1]]) @@ -55,8 +57,9 @@ def digit_version(version_str: str, length: int = 4): mmcv_version = digit_version(mmcv.__version__) -assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \ - f'MMCV=={mmcv.__version__} is used but incompatible. ' \ - f'Please install mmcv>={mmcv_min_version}, <{mmcv_max_version}.' +assert mmcv_min_version <= mmcv_version < mmcv_max_version, ( + f"MMCV=={mmcv.__version__} is used but incompatible. " + f"Please install mmcv>={mmcv_min_version}, <{mmcv_max_version}." +) -__all__ = ['__version__', 'version_info', 'digit_version'] +__all__ = ["__version__", "version_info", "digit_version"] diff --git a/mmsegmentation/mmseg/apis/__init__.py b/mmsegmentation/mmseg/apis/__init__.py index c688180..28f78c4 100644 --- a/mmsegmentation/mmseg/apis/__init__.py +++ b/mmsegmentation/mmseg/apis/__init__.py @@ -1,11 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. from .inference import inference_segmentor, init_segmentor, show_result_pyplot from .test import multi_gpu_test, single_gpu_test -from .train import (get_root_logger, init_random_seed, set_random_seed, - train_segmentor) +from .train import get_root_logger, init_random_seed, set_random_seed, train_segmentor __all__ = [ - 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', - 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', - 'show_result_pyplot', 'init_random_seed' + "get_root_logger", + "set_random_seed", + "train_segmentor", + "init_segmentor", + "inference_segmentor", + "multi_gpu_test", + "single_gpu_test", + "show_result_pyplot", + "init_random_seed", ] diff --git a/mmsegmentation/mmseg/apis/inference.py b/mmsegmentation/mmseg/apis/inference.py index 5bbe666..24ecae9 100644 --- a/mmsegmentation/mmseg/apis/inference.py +++ b/mmsegmentation/mmseg/apis/inference.py @@ -9,7 +9,7 @@ from mmseg.models import build_segmentor -def init_segmentor(config, checkpoint=None, device='cuda:0'): +def init_segmentor(config, checkpoint=None, device="cuda:0"): """Initialize a segmentor from config file. Args: @@ -25,15 +25,17 @@ def init_segmentor(config, checkpoint=None, device='cuda:0'): if isinstance(config, str): config = mmcv.Config.fromfile(config) elif not isinstance(config, mmcv.Config): - raise TypeError('config must be a filename or Config object, ' - 'but got {}'.format(type(config))) + raise TypeError( + "config must be a filename or Config object, " + "but got {}".format(type(config)) + ) config.model.pretrained = None config.model.train_cfg = None - model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) + model = build_segmentor(config.model, test_cfg=config.get("test_cfg")) if checkpoint is not None: - checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') - model.CLASSES = checkpoint['meta']['CLASSES'] - model.PALETTE = checkpoint['meta']['PALETTE'] + checkpoint = load_checkpoint(model, checkpoint, map_location="cpu") + model.CLASSES = checkpoint["meta"]["CLASSES"] + model.PALETTE = checkpoint["meta"]["PALETTE"] model.cfg = config # save the config in the model for convenience model.to(device) model.eval() @@ -54,16 +56,16 @@ def __call__(self, results): dict: ``results`` will be returned containing loaded image. """ - if isinstance(results['img'], str): - results['filename'] = results['img'] - results['ori_filename'] = results['img'] + if isinstance(results["img"], str): + results["filename"] = results["img"] + results["ori_filename"] = results["img"] else: - results['filename'] = None - results['ori_filename'] = None - img = mmcv.imread(results['img']) - results['img'] = img - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + results["filename"] = None + results["ori_filename"] = None + img = mmcv.imread(results["img"]) + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape return results @@ -95,7 +97,7 @@ def inference_segmentor(model, imgs): # scatter to specified GPU data = scatter(data, [device])[0] else: - data['img_metas'] = [i.data[0] for i in data['img_metas']] + data["img_metas"] = [i.data[0] for i in data["img_metas"]] # forward the model with torch.no_grad(): @@ -103,15 +105,17 @@ def inference_segmentor(model, imgs): return result -def show_result_pyplot(model, - img, - result, - palette=None, - fig_size=(15, 10), - opacity=0.5, - title='', - block=True, - out_file=None): +def show_result_pyplot( + model, + img, + result, + palette=None, + fig_size=(15, 10), + opacity=0.5, + title="", + block=True, + out_file=None, +): """Visualize the segmentation results on the image. Args: @@ -132,10 +136,9 @@ def show_result_pyplot(model, out_file (str or None): The path to write the image. Default: None. """ - if hasattr(model, 'module'): + if hasattr(model, "module"): model = model.module - img = model.show_result( - img, result, palette=palette, show=False, opacity=opacity) + img = model.show_result(img, result, palette=palette, show=False, opacity=opacity) plt.figure(figsize=fig_size) plt.imshow(mmcv.bgr2rgb(img)) plt.title(title) diff --git a/mmsegmentation/mmseg/apis/test.py b/mmsegmentation/mmseg/apis/test.py index cc4fcc9..8c06126 100644 --- a/mmsegmentation/mmseg/apis/test.py +++ b/mmsegmentation/mmseg/apis/test.py @@ -26,20 +26,23 @@ def np2tmp(array, temp_file_name=None, tmpdir=None): if temp_file_name is None: temp_file_name = tempfile.NamedTemporaryFile( - suffix='.npy', delete=False, dir=tmpdir).name + suffix=".npy", delete=False, dir=tmpdir + ).name np.save(temp_file_name, array) return temp_file_name -def single_gpu_test(model, - data_loader, - show=False, - out_dir=None, - efficient_test=False, - opacity=0.5, - pre_eval=False, - format_only=False, - format_args={}): +def single_gpu_test( + model, + data_loader, + show=False, + out_dir=None, + efficient_test=False, + opacity=0.5, + pre_eval=False, + format_only=False, + format_args={}, +): """Test with single GPU by progressive mode. Args: @@ -66,14 +69,16 @@ def single_gpu_test(model, """ if efficient_test: warnings.warn( - 'DeprecationWarning: ``efficient_test`` will be deprecated, the ' - 'evaluation is CPU memory friendly with pre_eval=True') - mmcv.mkdir_or_exist('.efficient_test') + "DeprecationWarning: ``efficient_test`` will be deprecated, the " + "evaluation is CPU memory friendly with pre_eval=True" + ) + mmcv.mkdir_or_exist(".efficient_test") # when none of them is set true, return segmentation results as # a list of np.array. - assert [efficient_test, pre_eval, format_only].count(True) <= 1, \ - '``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \ - 'exclusive, only one of them could be true .' + assert [efficient_test, pre_eval, format_only].count(True) <= 1, ( + "``efficient_test``, ``pre_eval`` and ``format_only`` are mutually " + "exclusive, only one of them could be true ." + ) model.eval() results = [] @@ -91,20 +96,20 @@ def single_gpu_test(model, result = model(return_loss=False, **data) if show or out_dir: - img_tensor = data['img'][0] - img_metas = data['img_metas'][0].data[0] - imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) + img_tensor = data["img"][0] + img_metas = data["img_metas"][0].data[0] + imgs = tensor2imgs(img_tensor, **img_metas[0]["img_norm_cfg"]) assert len(imgs) == len(img_metas) for img, img_meta in zip(imgs, img_metas): - h, w, _ = img_meta['img_shape'] + h, w, _ = img_meta["img_shape"] img_show = img[:h, :w, :] - ori_h, ori_w = img_meta['ori_shape'][:-1] + ori_h, ori_w = img_meta["ori_shape"][:-1] img_show = mmcv.imresize(img_show, (ori_w, ori_h)) if out_dir: - out_file = osp.join(out_dir, img_meta['ori_filename']) + out_file = osp.join(out_dir, img_meta["ori_filename"]) else: out_file = None @@ -114,14 +119,16 @@ def single_gpu_test(model, palette=dataset.PALETTE, show=show, out_file=out_file, - opacity=opacity) + opacity=opacity, + ) if efficient_test: - result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] + result = [np2tmp(_, tmpdir=".efficient_test") for _ in result] if format_only: result = dataset.format_results( - result, indices=batch_indices, **format_args) + result, indices=batch_indices, **format_args + ) if pre_eval: # TODO: adapt samples_per_gpu > 1. # only samples_per_gpu=1 valid now @@ -137,14 +144,16 @@ def single_gpu_test(model, return results -def multi_gpu_test(model, - data_loader, - tmpdir=None, - gpu_collect=False, - efficient_test=False, - pre_eval=False, - format_only=False, - format_args={}): +def multi_gpu_test( + model, + data_loader, + tmpdir=None, + gpu_collect=False, + efficient_test=False, + pre_eval=False, + format_only=False, + format_args={}, +): """Test model with multiple gpus by progressive mode. This method tests model with multiple gpus and collects the results @@ -177,14 +186,16 @@ def multi_gpu_test(model, """ if efficient_test: warnings.warn( - 'DeprecationWarning: ``efficient_test`` will be deprecated, the ' - 'evaluation is CPU memory friendly with pre_eval=True') - mmcv.mkdir_or_exist('.efficient_test') + "DeprecationWarning: ``efficient_test`` will be deprecated, the " + "evaluation is CPU memory friendly with pre_eval=True" + ) + mmcv.mkdir_or_exist(".efficient_test") # when none of them is set true, return segmentation results as # a list of np.array. - assert [efficient_test, pre_eval, format_only].count(True) <= 1, \ - '``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \ - 'exclusive, only one of them could be true .' + assert [efficient_test, pre_eval, format_only].count(True) <= 1, ( + "``efficient_test``, ``pre_eval`` and ``format_only`` are mutually " + "exclusive, only one of them could be true ." + ) model.eval() results = [] @@ -208,11 +219,12 @@ def multi_gpu_test(model, result = model(return_loss=False, rescale=True, **data) if efficient_test: - result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] + result = [np2tmp(_, tmpdir=".efficient_test") for _ in result] if format_only: result = dataset.format_results( - result, indices=batch_indices, **format_args) + result, indices=batch_indices, **format_args + ) if pre_eval: # TODO: adapt samples_per_gpu > 1. # only samples_per_gpu=1 valid now diff --git a/mmsegmentation/mmseg/apis/train.py b/mmsegmentation/mmseg/apis/train.py index be8e422..90dcfe3 100644 --- a/mmsegmentation/mmseg/apis/train.py +++ b/mmsegmentation/mmseg/apis/train.py @@ -7,18 +7,22 @@ import numpy as np import torch import torch.distributed as dist -from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, - build_runner, get_dist_info) +from mmcv.runner import ( + HOOKS, + DistSamplerSeedHook, + EpochBasedRunner, + build_runner, + get_dist_info, +) from mmcv.utils import build_from_cfg from mmseg import digit_version from mmseg.core import DistEvalHook, EvalHook, build_optimizer from mmseg.datasets import build_dataloader, build_dataset -from mmseg.utils import (build_ddp, build_dp, find_latest_checkpoint, - get_root_logger) +from mmseg.utils import build_ddp, build_dp, find_latest_checkpoint, get_root_logger -def init_random_seed(seed=None, device='cuda'): +def init_random_seed(seed=None, device="cuda"): """Initialize random seed. If the seed is not set, the seed will be automatically randomized, @@ -68,13 +72,9 @@ def set_random_seed(seed, deterministic=False): torch.backends.cudnn.benchmark = False -def train_segmentor(model, - dataset, - cfg, - distributed=False, - validate=False, - timestamp=None, - meta=None): +def train_segmentor( + model, dataset, cfg, distributed=False, validate=False, timestamp=None, meta=None +): """Launch segmentor training.""" logger = get_root_logger(cfg.log_level) @@ -86,45 +86,58 @@ def train_segmentor(model, num_gpus=len(cfg.gpu_ids), dist=distributed, seed=cfg.seed, - drop_last=True) + drop_last=True, + ) # The overall dataloader settings - loader_cfg.update({ - k: v - for k, v in cfg.data.items() if k not in [ - 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', - 'test_dataloader' - ] - }) + loader_cfg.update( + { + k: v + for k, v in cfg.data.items() + if k + not in [ + "train", + "val", + "test", + "train_dataloader", + "val_dataloader", + "test_dataloader", + ] + } + ) # The specific dataloader settings - train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})} + train_loader_cfg = {**loader_cfg, **cfg.data.get("train_dataloader", {})} data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] # put model on devices if distributed: - find_unused_parameters = cfg.get('find_unused_parameters', False) + find_unused_parameters = cfg.get("find_unused_parameters", False) # Sets the `find_unused_parameters` parameter in # DDP wrapper model = build_ddp( model, cfg.device, - device_ids=[int(os.environ['LOCAL_RANK'])], + device_ids=[int(os.environ["LOCAL_RANK"])], broadcast_buffers=False, - find_unused_parameters=find_unused_parameters) + find_unused_parameters=find_unused_parameters, + ) else: if not torch.cuda.is_available(): - assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \ - 'Please use MMCV >= 1.4.4 for CPU training!' + assert digit_version(mmcv.__version__) >= digit_version( + "1.4.4" + ), "Please use MMCV >= 1.4.4 for CPU training!" model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids) # build runner optimizer = build_optimizer(model, cfg.optimizer) - if cfg.get('runner') is None: - cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} + if cfg.get("runner") is None: + cfg.runner = {"type": "IterBasedRunner", "max_iters": cfg.total_iters} warnings.warn( - 'config is now expected to have a `runner` section, ' - 'please set `runner` in your config.', UserWarning) + "config is now expected to have a `runner` section, " + "please set `runner` in your config.", + UserWarning, + ) runner = build_runner( cfg.runner, @@ -134,12 +147,18 @@ def train_segmentor(model, optimizer=optimizer, work_dir=cfg.work_dir, logger=logger, - meta=meta)) + meta=meta, + ), + ) # register hooks - runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, - cfg.checkpoint_config, cfg.log_config, - cfg.get('momentum_config', None)) + runner.register_training_hooks( + cfg.lr_config, + cfg.optimizer_config, + cfg.checkpoint_config, + cfg.log_config, + cfg.get("momentum_config", None), + ) if distributed: # when distributed training by epoch, using`DistSamplerSeedHook` to set # the different seed to distributed sampler for each epoch, it will @@ -156,34 +175,35 @@ def train_segmentor(model, # The specific dataloader settings val_loader_cfg = { **loader_cfg, - 'samples_per_gpu': 1, - 'shuffle': False, # Not shuffle by default - **cfg.data.get('val_dataloader', {}), + "samples_per_gpu": 1, + "shuffle": False, # Not shuffle by default + **cfg.data.get("val_dataloader", {}), } val_dataloader = build_dataloader(val_dataset, **val_loader_cfg) - eval_cfg = cfg.get('evaluation', {}) - eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' + eval_cfg = cfg.get("evaluation", {}) + eval_cfg["by_epoch"] = cfg.runner["type"] != "IterBasedRunner" eval_hook = DistEvalHook if distributed else EvalHook # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'. - runner.register_hook( - eval_hook(val_dataloader, **eval_cfg), priority='LOW') + runner.register_hook(eval_hook(val_dataloader, **eval_cfg), priority="LOW") # user-defined hooks - if cfg.get('custom_hooks', None): + if cfg.get("custom_hooks", None): custom_hooks = cfg.custom_hooks - assert isinstance(custom_hooks, list), \ - f'custom_hooks expect list type, but got {type(custom_hooks)}' + assert isinstance( + custom_hooks, list + ), f"custom_hooks expect list type, but got {type(custom_hooks)}" for hook_cfg in cfg.custom_hooks: - assert isinstance(hook_cfg, dict), \ - 'Each item in custom_hooks expects dict type, but got ' \ - f'{type(hook_cfg)}' + assert isinstance(hook_cfg, dict), ( + "Each item in custom_hooks expects dict type, but got " + f"{type(hook_cfg)}" + ) hook_cfg = hook_cfg.copy() - priority = hook_cfg.pop('priority', 'NORMAL') + priority = hook_cfg.pop("priority", "NORMAL") hook = build_from_cfg(hook_cfg, HOOKS) runner.register_hook(hook, priority=priority) - if cfg.resume_from is None and cfg.get('auto_resume'): + if cfg.resume_from is None and cfg.get("auto_resume"): resume_from = find_latest_checkpoint(cfg.work_dir) if resume_from is not None: cfg.resume_from = resume_from diff --git a/mmsegmentation/mmseg/core/__init__.py b/mmsegmentation/mmseg/core/__init__.py index 82f2422..80ba449 100644 --- a/mmsegmentation/mmseg/core/__init__.py +++ b/mmsegmentation/mmseg/core/__init__.py @@ -1,12 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .builder import (OPTIMIZER_BUILDERS, build_optimizer, - build_optimizer_constructor) +from .builder import OPTIMIZER_BUILDERS, build_optimizer, build_optimizer_constructor from .evaluation import * # noqa: F401, F403 from .hook import * # noqa: F401, F403 from .optimizers import * # noqa: F401, F403 from .seg import * # noqa: F401, F403 from .utils import * # noqa: F401, F403 -__all__ = [ - 'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor' -] +__all__ = ["OPTIMIZER_BUILDERS", "build_optimizer", "build_optimizer_constructor"] diff --git a/mmsegmentation/mmseg/core/builder.py b/mmsegmentation/mmseg/core/builder.py index 406dd9b..59732a7 100644 --- a/mmsegmentation/mmseg/core/builder.py +++ b/mmsegmentation/mmseg/core/builder.py @@ -4,30 +4,32 @@ from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS from mmcv.utils import Registry, build_from_cfg -OPTIMIZER_BUILDERS = Registry( - 'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS) +OPTIMIZER_BUILDERS = Registry("optimizer builder", parent=MMCV_OPTIMIZER_BUILDERS) def build_optimizer_constructor(cfg): - constructor_type = cfg.get('type') + constructor_type = cfg.get("type") if constructor_type in OPTIMIZER_BUILDERS: return build_from_cfg(cfg, OPTIMIZER_BUILDERS) elif constructor_type in MMCV_OPTIMIZER_BUILDERS: return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS) else: - raise KeyError(f'{constructor_type} is not registered ' - 'in the optimizer builder registry.') + raise KeyError( + f"{constructor_type} is not registered " + "in the optimizer builder registry." + ) def build_optimizer(model, cfg): optimizer_cfg = copy.deepcopy(cfg) - constructor_type = optimizer_cfg.pop('constructor', - 'DefaultOptimizerConstructor') - paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) + constructor_type = optimizer_cfg.pop("constructor", "DefaultOptimizerConstructor") + paramwise_cfg = optimizer_cfg.pop("paramwise_cfg", None) optim_constructor = build_optimizer_constructor( dict( type=constructor_type, optimizer_cfg=optimizer_cfg, - paramwise_cfg=paramwise_cfg)) + paramwise_cfg=paramwise_cfg, + ) + ) optimizer = optim_constructor(model) return optimizer diff --git a/mmsegmentation/mmseg/core/evaluation/__init__.py b/mmsegmentation/mmseg/core/evaluation/__init__.py index 3d16d17..9e7a1c2 100644 --- a/mmsegmentation/mmseg/core/evaluation/__init__.py +++ b/mmsegmentation/mmseg/core/evaluation/__init__.py @@ -1,11 +1,24 @@ # Copyright (c) OpenMMLab. All rights reserved. from .class_names import get_classes, get_palette from .eval_hooks import DistEvalHook, EvalHook -from .metrics import (eval_metrics, intersect_and_union, mean_dice, - mean_fscore, mean_iou, pre_eval_to_metrics) +from .metrics import ( + eval_metrics, + intersect_and_union, + mean_dice, + mean_fscore, + mean_iou, + pre_eval_to_metrics, +) __all__ = [ - 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore', - 'eval_metrics', 'get_classes', 'get_palette', 'pre_eval_to_metrics', - 'intersect_and_union' + "EvalHook", + "DistEvalHook", + "mean_dice", + "mean_iou", + "mean_fscore", + "eval_metrics", + "get_classes", + "get_palette", + "pre_eval_to_metrics", + "intersect_and_union", ] diff --git a/mmsegmentation/mmseg/core/evaluation/class_names.py b/mmsegmentation/mmseg/core/evaluation/class_names.py index e3bff62..c29d266 100644 --- a/mmsegmentation/mmseg/core/evaluation/class_names.py +++ b/mmsegmentation/mmseg/core/evaluation/class_names.py @@ -5,259 +5,894 @@ def cityscapes_classes(): """Cityscapes class names for external use.""" return [ - 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', - 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', - 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', - 'bicycle' + "road", + "sidewalk", + "building", + "wall", + "fence", + "pole", + "traffic light", + "traffic sign", + "vegetation", + "terrain", + "sky", + "person", + "rider", + "car", + "truck", + "bus", + "train", + "motorcycle", + "bicycle", ] def ade_classes(): """ADE20K class names for external use.""" return [ - 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', - 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', - 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', - 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', - 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', - 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', - 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', - 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', - 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', - 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', - 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', - 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', - 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', - 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', - 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', - 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', - 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', - 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', - 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', - 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', - 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', - 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', - 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', - 'clock', 'flag' + "wall", + "building", + "sky", + "floor", + "tree", + "ceiling", + "road", + "bed ", + "windowpane", + "grass", + "cabinet", + "sidewalk", + "person", + "earth", + "door", + "table", + "mountain", + "plant", + "curtain", + "chair", + "car", + "water", + "painting", + "sofa", + "shelf", + "house", + "sea", + "mirror", + "rug", + "field", + "armchair", + "seat", + "fence", + "desk", + "rock", + "wardrobe", + "lamp", + "bathtub", + "railing", + "cushion", + "base", + "box", + "column", + "signboard", + "chest of drawers", + "counter", + "sand", + "sink", + "skyscraper", + "fireplace", + "refrigerator", + "grandstand", + "path", + "stairs", + "runway", + "case", + "pool table", + "pillow", + "screen door", + "stairway", + "river", + "bridge", + "bookcase", + "blind", + "coffee table", + "toilet", + "flower", + "book", + "hill", + "bench", + "countertop", + "stove", + "palm", + "kitchen island", + "computer", + "swivel chair", + "boat", + "bar", + "arcade machine", + "hovel", + "bus", + "towel", + "light", + "truck", + "tower", + "chandelier", + "awning", + "streetlight", + "booth", + "television receiver", + "airplane", + "dirt track", + "apparel", + "pole", + "land", + "bannister", + "escalator", + "ottoman", + "bottle", + "buffet", + "poster", + "stage", + "van", + "ship", + "fountain", + "conveyer belt", + "canopy", + "washer", + "plaything", + "swimming pool", + "stool", + "barrel", + "basket", + "waterfall", + "tent", + "bag", + "minibike", + "cradle", + "oven", + "ball", + "food", + "step", + "tank", + "trade name", + "microwave", + "pot", + "animal", + "bicycle", + "lake", + "dishwasher", + "screen", + "blanket", + "sculpture", + "hood", + "sconce", + "vase", + "traffic light", + "tray", + "ashcan", + "fan", + "pier", + "crt screen", + "plate", + "monitor", + "bulletin board", + "shower", + "radiator", + "glass", + "clock", + "flag", ] def voc_classes(): """Pascal VOC class names for external use.""" return [ - 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', - 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', - 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', - 'tvmonitor' + "background", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", ] def cocostuff_classes(): """CocoStuff class names for external use.""" return [ - 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', - 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', - 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', - 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', - 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', - 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', - 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', - 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', - 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', - 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', - 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', - 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', - 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', - 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', - 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', - 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', - 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', - 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower', - 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel', - 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal', - 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'paper', - 'pavement', 'pillow', 'plant-other', 'plastic', 'platform', - 'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof', - 'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper', - 'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other', - 'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable', - 'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel', - 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', - 'window-blind', 'window-other', 'wood' + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "dining table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", + "banner", + "blanket", + "branch", + "bridge", + "building-other", + "bush", + "cabinet", + "cage", + "cardboard", + "carpet", + "ceiling-other", + "ceiling-tile", + "cloth", + "clothes", + "clouds", + "counter", + "cupboard", + "curtain", + "desk-stuff", + "dirt", + "door-stuff", + "fence", + "floor-marble", + "floor-other", + "floor-stone", + "floor-tile", + "floor-wood", + "flower", + "fog", + "food-other", + "fruit", + "furniture-other", + "grass", + "gravel", + "ground-other", + "hill", + "house", + "leaves", + "light", + "mat", + "metal", + "mirror-stuff", + "moss", + "mountain", + "mud", + "napkin", + "net", + "paper", + "pavement", + "pillow", + "plant-other", + "plastic", + "platform", + "playingfield", + "railing", + "railroad", + "river", + "road", + "rock", + "roof", + "rug", + "salad", + "sand", + "sea", + "shelf", + "sky-other", + "skyscraper", + "snow", + "solid-other", + "stairs", + "stone", + "straw", + "structural-other", + "table", + "tent", + "textile-other", + "towel", + "tree", + "vegetable", + "wall-brick", + "wall-concrete", + "wall-other", + "wall-panel", + "wall-stone", + "wall-tile", + "wall-wood", + "water-other", + "waterdrops", + "window-blind", + "window-other", + "wood", ] def loveda_classes(): """LoveDA class names for external use.""" return [ - 'background', 'building', 'road', 'water', 'barren', 'forest', - 'agricultural' + "background", + "building", + "road", + "water", + "barren", + "forest", + "agricultural", ] def potsdam_classes(): """Potsdam class names for external use.""" return [ - 'impervious_surface', 'building', 'low_vegetation', 'tree', 'car', - 'clutter' + "impervious_surface", + "building", + "low_vegetation", + "tree", + "car", + "clutter", ] def vaihingen_classes(): """Vaihingen class names for external use.""" return [ - 'impervious_surface', 'building', 'low_vegetation', 'tree', 'car', - 'clutter' + "impervious_surface", + "building", + "low_vegetation", + "tree", + "car", + "clutter", ] def isaid_classes(): """iSAID class names for external use.""" return [ - 'background', 'ship', 'store_tank', 'baseball_diamond', 'tennis_court', - 'basketball_court', 'Ground_Track_Field', 'Bridge', 'Large_Vehicle', - 'Small_Vehicle', 'Helicopter', 'Swimming_pool', 'Roundabout', - 'Soccer_ball_field', 'plane', 'Harbor' + "background", + "ship", + "store_tank", + "baseball_diamond", + "tennis_court", + "basketball_court", + "Ground_Track_Field", + "Bridge", + "Large_Vehicle", + "Small_Vehicle", + "Helicopter", + "Swimming_pool", + "Roundabout", + "Soccer_ball_field", + "plane", + "Harbor", ] def stare_classes(): """stare class names for external use.""" - return ['background', 'vessel'] + return ["background", "vessel"] def cityscapes_palette(): """Cityscapes palette for external use.""" - return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], - [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], - [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], - [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], - [0, 0, 230], [119, 11, 32]] + return [ + [128, 64, 128], + [244, 35, 232], + [70, 70, 70], + [102, 102, 156], + [190, 153, 153], + [153, 153, 153], + [250, 170, 30], + [220, 220, 0], + [107, 142, 35], + [152, 251, 152], + [70, 130, 180], + [220, 20, 60], + [255, 0, 0], + [0, 0, 142], + [0, 0, 70], + [0, 60, 100], + [0, 80, 100], + [0, 0, 230], + [119, 11, 32], + ] def ade_palette(): """ADE20K palette for external use.""" - return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], - [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], - [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], - [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], - [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], - [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], - [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], - [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], - [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], - [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], - [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], - [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], - [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], - [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], - [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], - [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], - [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], - [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], - [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], - [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], - [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], - [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], - [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], - [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], - [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], - [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], - [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], - [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], - [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], - [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], - [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], - [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], - [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], - [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], - [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], - [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], - [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], - [102, 255, 0], [92, 0, 255]] + return [ + [120, 120, 120], + [180, 120, 120], + [6, 230, 230], + [80, 50, 50], + [4, 200, 3], + [120, 120, 80], + [140, 140, 140], + [204, 5, 255], + [230, 230, 230], + [4, 250, 7], + [224, 5, 255], + [235, 255, 7], + [150, 5, 61], + [120, 120, 70], + [8, 255, 51], + [255, 6, 82], + [143, 255, 140], + [204, 255, 4], + [255, 51, 7], + [204, 70, 3], + [0, 102, 200], + [61, 230, 250], + [255, 6, 51], + [11, 102, 255], + [255, 7, 71], + [255, 9, 224], + [9, 7, 230], + [220, 220, 220], + [255, 9, 92], + [112, 9, 255], + [8, 255, 214], + [7, 255, 224], + [255, 184, 6], + [10, 255, 71], + [255, 41, 10], + [7, 255, 255], + [224, 255, 8], + [102, 8, 255], + [255, 61, 6], + [255, 194, 7], + [255, 122, 8], + [0, 255, 20], + [255, 8, 41], + [255, 5, 153], + [6, 51, 255], + [235, 12, 255], + [160, 150, 20], + [0, 163, 255], + [140, 140, 140], + [250, 10, 15], + [20, 255, 0], + [31, 255, 0], + [255, 31, 0], + [255, 224, 0], + [153, 255, 0], + [0, 0, 255], + [255, 71, 0], + [0, 235, 255], + [0, 173, 255], + [31, 0, 255], + [11, 200, 200], + [255, 82, 0], + [0, 255, 245], + [0, 61, 255], + [0, 255, 112], + [0, 255, 133], + [255, 0, 0], + [255, 163, 0], + [255, 102, 0], + [194, 255, 0], + [0, 143, 255], + [51, 255, 0], + [0, 82, 255], + [0, 255, 41], + [0, 255, 173], + [10, 0, 255], + [173, 255, 0], + [0, 255, 153], + [255, 92, 0], + [255, 0, 255], + [255, 0, 245], + [255, 0, 102], + [255, 173, 0], + [255, 0, 20], + [255, 184, 184], + [0, 31, 255], + [0, 255, 61], + [0, 71, 255], + [255, 0, 204], + [0, 255, 194], + [0, 255, 82], + [0, 10, 255], + [0, 112, 255], + [51, 0, 255], + [0, 194, 255], + [0, 122, 255], + [0, 255, 163], + [255, 153, 0], + [0, 255, 10], + [255, 112, 0], + [143, 255, 0], + [82, 0, 255], + [163, 255, 0], + [255, 235, 0], + [8, 184, 170], + [133, 0, 255], + [0, 255, 92], + [184, 0, 255], + [255, 0, 31], + [0, 184, 255], + [0, 214, 255], + [255, 0, 112], + [92, 255, 0], + [0, 224, 255], + [112, 224, 255], + [70, 184, 160], + [163, 0, 255], + [153, 0, 255], + [71, 255, 0], + [255, 0, 163], + [255, 204, 0], + [255, 0, 143], + [0, 255, 235], + [133, 255, 0], + [255, 0, 235], + [245, 0, 255], + [255, 0, 122], + [255, 245, 0], + [10, 190, 212], + [214, 255, 0], + [0, 204, 255], + [20, 0, 255], + [255, 255, 0], + [0, 153, 255], + [0, 41, 255], + [0, 255, 204], + [41, 0, 255], + [41, 255, 0], + [173, 0, 255], + [0, 245, 255], + [71, 0, 255], + [122, 0, 255], + [0, 255, 184], + [0, 92, 255], + [184, 255, 0], + [0, 133, 255], + [255, 214, 0], + [25, 194, 194], + [102, 255, 0], + [92, 0, 255], + ] def voc_palette(): """Pascal VOC palette for external use.""" - return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], - [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], - [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], - [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], - [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] + return [ + [0, 0, 0], + [128, 0, 0], + [0, 128, 0], + [128, 128, 0], + [0, 0, 128], + [128, 0, 128], + [0, 128, 128], + [128, 128, 128], + [64, 0, 0], + [192, 0, 0], + [64, 128, 0], + [192, 128, 0], + [64, 0, 128], + [192, 0, 128], + [64, 128, 128], + [192, 128, 128], + [0, 64, 0], + [128, 64, 0], + [0, 192, 0], + [128, 192, 0], + [0, 64, 128], + ] def cocostuff_palette(): """CocoStuff palette for external use.""" - return [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], - [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64], - [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], - [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192], - [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], - [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], - [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], [0, 32, 0], - [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], - [192, 128, 32], [128, 96, 128], [0, 0, 128], [64, 0, 32], - [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128], - [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], - [192, 0, 32], [128, 96, 0], [128, 0, 192], [0, 128, 32], - [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0], - [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], - [128, 128, 32], [192, 32, 128], [0, 64, 192], [0, 0, 32], - [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], - [128, 192, 192], [0, 0, 160], [192, 160, 128], [128, 192, 0], - [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], - [64, 160, 0], [0, 64, 0], [192, 128, 224], [64, 32, 0], - [0, 192, 128], [64, 128, 224], [192, 160, 0], [0, 192, 0], - [192, 128, 96], [192, 96, 128], [0, 64, 128], [64, 0, 96], - [64, 224, 128], [128, 64, 0], [192, 0, 224], [64, 96, 128], - [128, 192, 128], [64, 0, 224], [192, 224, 128], [128, 192, 64], - [192, 0, 96], [192, 96, 0], [128, 64, 192], [0, 128, 96], - [0, 224, 0], [64, 64, 64], [128, 128, 224], [0, 96, 0], - [64, 192, 192], [0, 128, 224], [128, 224, 0], [64, 192, 64], - [128, 128, 96], [128, 32, 128], [64, 0, 192], [0, 64, 96], - [0, 160, 128], [192, 0, 64], [128, 64, 224], [0, 32, 128], - [192, 128, 192], [0, 64, 224], [128, 160, 128], [192, 128, 0], - [128, 64, 32], [128, 32, 64], [192, 0, 128], [64, 192, 32], - [0, 160, 64], [64, 0, 0], [192, 192, 160], [0, 32, 64], - [64, 128, 128], [64, 192, 160], [128, 160, 64], [64, 128, 0], - [192, 192, 32], [128, 96, 192], [64, 0, 128], [64, 64, 32], - [0, 224, 192], [192, 0, 0], [192, 64, 160], [0, 96, 192], - [192, 128, 128], [64, 64, 160], [128, 224, 192], [192, 128, 64], - [192, 64, 32], [128, 96, 64], [192, 0, 192], [0, 192, 32], - [64, 224, 64], [64, 0, 64], [128, 192, 160], [64, 96, 64], - [64, 128, 192], [0, 192, 160], [192, 224, 64], [64, 128, 64], - [128, 192, 32], [192, 32, 192], [64, 64, 192], [0, 64, 32], - [64, 160, 192], [192, 64, 64], [128, 64, 160], [64, 32, 192], - [192, 192, 192], [0, 64, 160], [192, 160, 192], [192, 192, 0], - [128, 64, 96], [192, 32, 64], [192, 64, 128], [64, 192, 96], - [64, 160, 64], [64, 64, 0]] + return [ + [0, 192, 64], + [0, 192, 64], + [0, 64, 96], + [128, 192, 192], + [0, 64, 64], + [0, 192, 224], + [0, 192, 192], + [128, 192, 64], + [0, 192, 96], + [128, 192, 64], + [128, 32, 192], + [0, 0, 224], + [0, 0, 64], + [0, 160, 192], + [128, 0, 96], + [128, 0, 192], + [0, 32, 192], + [128, 128, 224], + [0, 0, 192], + [128, 160, 192], + [128, 128, 0], + [128, 0, 32], + [128, 32, 0], + [128, 0, 128], + [64, 128, 32], + [0, 160, 0], + [0, 0, 0], + [192, 128, 160], + [0, 32, 0], + [0, 128, 128], + [64, 128, 160], + [128, 160, 0], + [0, 128, 0], + [192, 128, 32], + [128, 96, 128], + [0, 0, 128], + [64, 0, 32], + [0, 224, 128], + [128, 0, 0], + [192, 0, 160], + [0, 96, 128], + [128, 128, 128], + [64, 0, 160], + [128, 224, 128], + [128, 128, 64], + [192, 0, 32], + [128, 96, 0], + [128, 0, 192], + [0, 128, 32], + [64, 224, 0], + [0, 0, 64], + [128, 128, 160], + [64, 96, 0], + [0, 128, 192], + [0, 128, 160], + [192, 224, 0], + [0, 128, 64], + [128, 128, 32], + [192, 32, 128], + [0, 64, 192], + [0, 0, 32], + [64, 160, 128], + [128, 64, 64], + [128, 0, 160], + [64, 32, 128], + [128, 192, 192], + [0, 0, 160], + [192, 160, 128], + [128, 192, 0], + [128, 0, 96], + [192, 32, 0], + [128, 64, 128], + [64, 128, 96], + [64, 160, 0], + [0, 64, 0], + [192, 128, 224], + [64, 32, 0], + [0, 192, 128], + [64, 128, 224], + [192, 160, 0], + [0, 192, 0], + [192, 128, 96], + [192, 96, 128], + [0, 64, 128], + [64, 0, 96], + [64, 224, 128], + [128, 64, 0], + [192, 0, 224], + [64, 96, 128], + [128, 192, 128], + [64, 0, 224], + [192, 224, 128], + [128, 192, 64], + [192, 0, 96], + [192, 96, 0], + [128, 64, 192], + [0, 128, 96], + [0, 224, 0], + [64, 64, 64], + [128, 128, 224], + [0, 96, 0], + [64, 192, 192], + [0, 128, 224], + [128, 224, 0], + [64, 192, 64], + [128, 128, 96], + [128, 32, 128], + [64, 0, 192], + [0, 64, 96], + [0, 160, 128], + [192, 0, 64], + [128, 64, 224], + [0, 32, 128], + [192, 128, 192], + [0, 64, 224], + [128, 160, 128], + [192, 128, 0], + [128, 64, 32], + [128, 32, 64], + [192, 0, 128], + [64, 192, 32], + [0, 160, 64], + [64, 0, 0], + [192, 192, 160], + [0, 32, 64], + [64, 128, 128], + [64, 192, 160], + [128, 160, 64], + [64, 128, 0], + [192, 192, 32], + [128, 96, 192], + [64, 0, 128], + [64, 64, 32], + [0, 224, 192], + [192, 0, 0], + [192, 64, 160], + [0, 96, 192], + [192, 128, 128], + [64, 64, 160], + [128, 224, 192], + [192, 128, 64], + [192, 64, 32], + [128, 96, 64], + [192, 0, 192], + [0, 192, 32], + [64, 224, 64], + [64, 0, 64], + [128, 192, 160], + [64, 96, 64], + [64, 128, 192], + [0, 192, 160], + [192, 224, 64], + [64, 128, 64], + [128, 192, 32], + [192, 32, 192], + [64, 64, 192], + [0, 64, 32], + [64, 160, 192], + [192, 64, 64], + [128, 64, 160], + [64, 32, 192], + [192, 192, 192], + [0, 64, 160], + [192, 160, 192], + [192, 192, 0], + [128, 64, 96], + [192, 32, 64], + [192, 64, 128], + [64, 192, 96], + [64, 160, 64], + [64, 64, 0], + ] def loveda_palette(): """LoveDA palette for external use.""" - return [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], - [159, 129, 183], [0, 255, 0], [255, 195, 128]] + return [ + [255, 255, 255], + [255, 0, 0], + [255, 255, 0], + [0, 0, 255], + [159, 129, 183], + [0, 255, 0], + [255, 195, 128], + ] def potsdam_palette(): """Potsdam palette for external use.""" - return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], - [255, 255, 0], [255, 0, 0]] + return [ + [255, 255, 255], + [0, 0, 255], + [0, 255, 255], + [0, 255, 0], + [255, 255, 0], + [255, 0, 0], + ] def vaihingen_palette(): """Vaihingen palette for external use.""" - return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], - [255, 255, 0], [255, 0, 0]] + return [ + [255, 255, 255], + [0, 0, 255], + [0, 255, 255], + [0, 255, 0], + [255, 255, 0], + [255, 0, 0], + ] def isaid_palette(): """iSAID palette for external use.""" - return [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127], - [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, - 127], [0, 0, 127], - [0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 127, 191], - [0, 127, 255], [0, 100, 155]] + return [ + [0, 0, 0], + [0, 0, 63], + [0, 63, 63], + [0, 63, 0], + [0, 63, 127], + [0, 63, 191], + [0, 63, 255], + [0, 127, 63], + [0, 127, 127], + [0, 0, 127], + [0, 0, 191], + [0, 0, 255], + [0, 191, 127], + [0, 127, 191], + [0, 127, 255], + [0, 100, 155], + ] def stare_palette(): @@ -266,19 +901,25 @@ def stare_palette(): dataset_aliases = { - 'cityscapes': ['cityscapes'], - 'ade': ['ade', 'ade20k'], - 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'], - 'loveda': ['loveda'], - 'potsdam': ['potsdam'], - 'vaihingen': ['vaihingen'], - 'cocostuff': [ - 'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff', - 'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k', - 'coco_stuff164k' + "cityscapes": ["cityscapes"], + "ade": ["ade", "ade20k"], + "voc": ["voc", "pascal_voc", "voc12", "voc12aug"], + "loveda": ["loveda"], + "potsdam": ["potsdam"], + "vaihingen": ["vaihingen"], + "cocostuff": [ + "cocostuff", + "cocostuff10k", + "cocostuff164k", + "coco-stuff", + "coco-stuff10k", + "coco-stuff164k", + "coco_stuff", + "coco_stuff10k", + "coco_stuff164k", ], - 'isaid': ['isaid', 'iSAID'], - 'stare': ['stare', 'STARE'] + "isaid": ["isaid", "iSAID"], + "stare": ["stare", "STARE"], } @@ -291,11 +932,11 @@ def get_classes(dataset): if mmcv.is_str(dataset): if dataset in alias2name: - labels = eval(alias2name[dataset] + '_classes()') + labels = eval(alias2name[dataset] + "_classes()") else: - raise ValueError(f'Unrecognized dataset: {dataset}') + raise ValueError(f"Unrecognized dataset: {dataset}") else: - raise TypeError(f'dataset must a str, but got {type(dataset)}') + raise TypeError(f"dataset must a str, but got {type(dataset)}") return labels @@ -308,9 +949,9 @@ def get_palette(dataset): if mmcv.is_str(dataset): if dataset in alias2name: - labels = eval(alias2name[dataset] + '_palette()') + labels = eval(alias2name[dataset] + "_palette()") else: - raise ValueError(f'Unrecognized dataset: {dataset}') + raise ValueError(f"Unrecognized dataset: {dataset}") else: - raise TypeError(f'dataset must a str, but got {type(dataset)}') + raise TypeError(f"dataset must a str, but got {type(dataset)}") return labels diff --git a/mmsegmentation/mmseg/core/evaluation/eval_hooks.py b/mmsegmentation/mmseg/core/evaluation/eval_hooks.py index 8f2be57..585304e 100644 --- a/mmsegmentation/mmseg/core/evaluation/eval_hooks.py +++ b/mmsegmentation/mmseg/core/evaluation/eval_hooks.py @@ -23,24 +23,22 @@ class EvalHook(_EvalHook): list: The prediction results. """ - greater_keys = ['mIoU', 'mAcc', 'aAcc'] - - def __init__(self, - *args, - by_epoch=False, - efficient_test=False, - pre_eval=False, - **kwargs): + greater_keys = ["mIoU", "mAcc", "aAcc"] + + def __init__( + self, *args, by_epoch=False, efficient_test=False, pre_eval=False, **kwargs + ): super().__init__(*args, by_epoch=by_epoch, **kwargs) self.pre_eval = pre_eval self.latest_results = None if efficient_test: warnings.warn( - 'DeprecationWarning: ``efficient_test`` for evaluation hook ' - 'is deprecated, the evaluation hook is CPU memory friendly ' - 'with ``pre_eval=True`` as argument for ``single_gpu_test()`` ' - 'function') + "DeprecationWarning: ``efficient_test`` for evaluation hook " + "is deprecated, the evaluation hook is CPU memory friendly " + "with ``pre_eval=True`` as argument for ``single_gpu_test()`` " + "function" + ) def _do_evaluate(self, runner): """perform evaluation and save ckpt.""" @@ -48,11 +46,13 @@ def _do_evaluate(self, runner): return from mmseg.apis import single_gpu_test + results = single_gpu_test( - runner.model, self.dataloader, show=False, pre_eval=self.pre_eval) + runner.model, self.dataloader, show=False, pre_eval=self.pre_eval + ) self.latest_results = results runner.log_buffer.clear() - runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + runner.log_buffer.output["eval_iter_num"] = len(self.dataloader) key_score = self.evaluate(runner, results) if self.save_best: self._save_ckpt(runner, key_score) @@ -73,23 +73,21 @@ class DistEvalHook(_DistEvalHook): list: The prediction results. """ - greater_keys = ['mIoU', 'mAcc', 'aAcc'] + greater_keys = ["mIoU", "mAcc", "aAcc"] - def __init__(self, - *args, - by_epoch=False, - efficient_test=False, - pre_eval=False, - **kwargs): + def __init__( + self, *args, by_epoch=False, efficient_test=False, pre_eval=False, **kwargs + ): super().__init__(*args, by_epoch=by_epoch, **kwargs) self.pre_eval = pre_eval self.latest_results = None if efficient_test: warnings.warn( - 'DeprecationWarning: ``efficient_test`` for evaluation hook ' - 'is deprecated, the evaluation hook is CPU memory friendly ' - 'with ``pre_eval=True`` as argument for ``multi_gpu_test()`` ' - 'function') + "DeprecationWarning: ``efficient_test`` for evaluation hook " + "is deprecated, the evaluation hook is CPU memory friendly " + "with ``pre_eval=True`` as argument for ``multi_gpu_test()`` " + "function" + ) def _do_evaluate(self, runner): """perform evaluation and save ckpt.""" @@ -101,8 +99,7 @@ def _do_evaluate(self, runner): if self.broadcast_bn_buffer: model = runner.model for name, module in model.named_modules(): - if isinstance(module, - _BatchNorm) and module.track_running_stats: + if isinstance(module, _BatchNorm) and module.track_running_stats: dist.broadcast(module.running_var, 0) dist.broadcast(module.running_mean, 0) @@ -111,21 +108,23 @@ def _do_evaluate(self, runner): tmpdir = self.tmpdir if tmpdir is None: - tmpdir = osp.join(runner.work_dir, '.eval_hook') + tmpdir = osp.join(runner.work_dir, ".eval_hook") from mmseg.apis import multi_gpu_test + results = multi_gpu_test( runner.model, self.dataloader, tmpdir=tmpdir, gpu_collect=self.gpu_collect, - pre_eval=self.pre_eval) + pre_eval=self.pre_eval, + ) self.latest_results = results runner.log_buffer.clear() if runner.rank == 0: - print('\n') - runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + print("\n") + runner.log_buffer.output["eval_iter_num"] = len(self.dataloader) key_score = self.evaluate(runner, results) if self.save_best: diff --git a/mmsegmentation/mmseg/core/evaluation/metrics.py b/mmsegmentation/mmseg/core/evaluation/metrics.py index 31be596..cd454e6 100644 --- a/mmsegmentation/mmseg/core/evaluation/metrics.py +++ b/mmsegmentation/mmseg/core/evaluation/metrics.py @@ -18,17 +18,18 @@ def f_score(precision, recall, beta=1): Returns: [torch.tensor]: The f-score value. """ - score = (1 + beta**2) * (precision * recall) / ( - (beta**2 * precision) + recall) + score = (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall) return score -def intersect_and_union(pred_label, - label, - num_classes, - ignore_index, - label_map=dict(), - reduce_zero_label=False): +def intersect_and_union( + pred_label, + label, + num_classes, + ignore_index, + label_map=dict(), + reduce_zero_label=False, +): """Calculate intersection and Union. Args: @@ -55,11 +56,10 @@ def intersect_and_union(pred_label, if isinstance(pred_label, str): pred_label = torch.from_numpy(np.load(pred_label)) else: - pred_label = torch.from_numpy((pred_label)) + pred_label = torch.from_numpy(pred_label) if isinstance(label, str): - label = torch.from_numpy( - mmcv.imread(label, flag='unchanged', backend='pillow')) + label = torch.from_numpy(mmcv.imread(label, flag="unchanged", backend="pillow")) else: label = torch.from_numpy(label) @@ -72,27 +72,32 @@ def intersect_and_union(pred_label, label = label - 1 label[label == 254] = 255 - mask = (label != ignore_index) + mask = label != ignore_index pred_label = pred_label[mask] label = label[mask] intersect = pred_label[pred_label == label] area_intersect = torch.histc( - intersect.float(), bins=(num_classes), min=0, max=num_classes - 1) + intersect.float(), bins=(num_classes), min=0, max=num_classes - 1 + ) area_pred_label = torch.histc( - pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1) + pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1 + ) area_label = torch.histc( - label.float(), bins=(num_classes), min=0, max=num_classes - 1) + label.float(), bins=(num_classes), min=0, max=num_classes - 1 + ) area_union = area_pred_label + area_label - area_intersect return area_intersect, area_union, area_pred_label, area_label -def total_intersect_and_union(results, - gt_seg_maps, - num_classes, - ignore_index, - label_map=dict(), - reduce_zero_label=False): +def total_intersect_and_union( + results, + gt_seg_maps, + num_classes, + ignore_index, + label_map=dict(), + reduce_zero_label=False, +): """Calculate Total Intersection and Union. Args: @@ -113,30 +118,35 @@ def total_intersect_and_union(results, ndarray: The prediction histogram on all classes. ndarray: The ground truth histogram on all classes. """ - total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64) - total_area_union = torch.zeros((num_classes, ), dtype=torch.float64) - total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64) - total_area_label = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_intersect = torch.zeros((num_classes,), dtype=torch.float64) + total_area_union = torch.zeros((num_classes,), dtype=torch.float64) + total_area_pred_label = torch.zeros((num_classes,), dtype=torch.float64) + total_area_label = torch.zeros((num_classes,), dtype=torch.float64) for result, gt_seg_map in zip(results, gt_seg_maps): - area_intersect, area_union, area_pred_label, area_label = \ - intersect_and_union( - result, gt_seg_map, num_classes, ignore_index, - label_map, reduce_zero_label) + area_intersect, area_union, area_pred_label, area_label = intersect_and_union( + result, gt_seg_map, num_classes, ignore_index, label_map, reduce_zero_label + ) total_area_intersect += area_intersect total_area_union += area_union total_area_pred_label += area_pred_label total_area_label += area_label - return total_area_intersect, total_area_union, total_area_pred_label, \ - total_area_label - - -def mean_iou(results, - gt_seg_maps, - num_classes, - ignore_index, - nan_to_num=None, - label_map=dict(), - reduce_zero_label=False): + return ( + total_area_intersect, + total_area_union, + total_area_pred_label, + total_area_label, + ) + + +def mean_iou( + results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False, +): """Calculate Mean Intersection and Union (mIoU) Args: @@ -162,20 +172,23 @@ def mean_iou(results, gt_seg_maps=gt_seg_maps, num_classes=num_classes, ignore_index=ignore_index, - metrics=['mIoU'], + metrics=["mIoU"], nan_to_num=nan_to_num, label_map=label_map, - reduce_zero_label=reduce_zero_label) + reduce_zero_label=reduce_zero_label, + ) return iou_result -def mean_dice(results, - gt_seg_maps, - num_classes, - ignore_index, - nan_to_num=None, - label_map=dict(), - reduce_zero_label=False): +def mean_dice( + results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False, +): """Calculate Mean Dice (mDice) Args: @@ -202,21 +215,24 @@ def mean_dice(results, gt_seg_maps=gt_seg_maps, num_classes=num_classes, ignore_index=ignore_index, - metrics=['mDice'], + metrics=["mDice"], nan_to_num=nan_to_num, label_map=label_map, - reduce_zero_label=reduce_zero_label) + reduce_zero_label=reduce_zero_label, + ) return dice_result -def mean_fscore(results, - gt_seg_maps, - num_classes, - ignore_index, - nan_to_num=None, - label_map=dict(), - reduce_zero_label=False, - beta=1): +def mean_fscore( + results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False, + beta=1, +): """Calculate Mean F-Score (mFscore) Args: @@ -246,23 +262,26 @@ def mean_fscore(results, gt_seg_maps=gt_seg_maps, num_classes=num_classes, ignore_index=ignore_index, - metrics=['mFscore'], + metrics=["mFscore"], nan_to_num=nan_to_num, label_map=label_map, reduce_zero_label=reduce_zero_label, - beta=beta) + beta=beta, + ) return fscore_result -def eval_metrics(results, - gt_seg_maps, - num_classes, - ignore_index, - metrics=['mIoU'], - nan_to_num=None, - label_map=dict(), - reduce_zero_label=False, - beta=1): +def eval_metrics( + results, + gt_seg_maps, + num_classes, + ignore_index, + metrics=["mIoU"], + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False, + beta=1, +): """Calculate evaluation metrics Args: results (list[ndarray] | list[str]): List of prediction segmentation @@ -282,22 +301,28 @@ def eval_metrics(results, ndarray: Per category evaluation metrics, shape (num_classes, ). """ - total_area_intersect, total_area_union, total_area_pred_label, \ - total_area_label = total_intersect_and_union( - results, gt_seg_maps, num_classes, ignore_index, label_map, - reduce_zero_label) - ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union, - total_area_pred_label, - total_area_label, metrics, nan_to_num, - beta) + ( + total_area_intersect, + total_area_union, + total_area_pred_label, + total_area_label, + ) = total_intersect_and_union( + results, gt_seg_maps, num_classes, ignore_index, label_map, reduce_zero_label + ) + ret_metrics = total_area_to_metrics( + total_area_intersect, + total_area_union, + total_area_pred_label, + total_area_label, + metrics, + nan_to_num, + beta, + ) return ret_metrics -def pre_eval_to_metrics(pre_eval_results, - metrics=['mIoU'], - nan_to_num=None, - beta=1): +def pre_eval_to_metrics(pre_eval_results, metrics=["mIoU"], nan_to_num=None, beta=1): """Convert pre-eval results to metrics. Args: @@ -323,21 +348,28 @@ def pre_eval_to_metrics(pre_eval_results, total_area_pred_label = sum(pre_eval_results[2]) total_area_label = sum(pre_eval_results[3]) - ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union, - total_area_pred_label, - total_area_label, metrics, nan_to_num, - beta) + ret_metrics = total_area_to_metrics( + total_area_intersect, + total_area_union, + total_area_pred_label, + total_area_label, + metrics, + nan_to_num, + beta, + ) return ret_metrics -def total_area_to_metrics(total_area_intersect, - total_area_union, - total_area_pred_label, - total_area_label, - metrics=['mIoU'], - nan_to_num=None, - beta=1): +def total_area_to_metrics( + total_area_intersect, + total_area_union, + total_area_pred_label, + total_area_label, + metrics=["mIoU"], + nan_to_num=None, + beta=1, +): """Calculate evaluation metrics Args: total_area_intersect (ndarray): The intersection of prediction and @@ -357,40 +389,39 @@ def total_area_to_metrics(total_area_intersect, """ if isinstance(metrics, str): metrics = [metrics] - allowed_metrics = ['mIoU', 'mDice', 'mFscore'] + allowed_metrics = ["mIoU", "mDice", "mFscore"] if not set(metrics).issubset(set(allowed_metrics)): - raise KeyError('metrics {} is not supported'.format(metrics)) + raise KeyError(f"metrics {metrics} is not supported") all_acc = total_area_intersect.sum() / total_area_label.sum() - ret_metrics = OrderedDict({'aAcc': all_acc}) + ret_metrics = OrderedDict({"aAcc": all_acc}) for metric in metrics: - if metric == 'mIoU': + if metric == "mIoU": iou = total_area_intersect / total_area_union acc = total_area_intersect / total_area_label - ret_metrics['IoU'] = iou - ret_metrics['Acc'] = acc - elif metric == 'mDice': - dice = 2 * total_area_intersect / ( - total_area_pred_label + total_area_label) + ret_metrics["IoU"] = iou + ret_metrics["Acc"] = acc + elif metric == "mDice": + dice = 2 * total_area_intersect / (total_area_pred_label + total_area_label) acc = total_area_intersect / total_area_label - ret_metrics['Dice'] = dice - ret_metrics['Acc'] = acc - elif metric == 'mFscore': + ret_metrics["Dice"] = dice + ret_metrics["Acc"] = acc + elif metric == "mFscore": precision = total_area_intersect / total_area_pred_label recall = total_area_intersect / total_area_label f_value = torch.tensor( - [f_score(x[0], x[1], beta) for x in zip(precision, recall)]) - ret_metrics['Fscore'] = f_value - ret_metrics['Precision'] = precision - ret_metrics['Recall'] = recall - - ret_metrics = { - metric: value.numpy() - for metric, value in ret_metrics.items() - } + [f_score(x[0], x[1], beta) for x in zip(precision, recall)] + ) + ret_metrics["Fscore"] = f_value + ret_metrics["Precision"] = precision + ret_metrics["Recall"] = recall + + ret_metrics = {metric: value.numpy() for metric, value in ret_metrics.items()} if nan_to_num is not None: - ret_metrics = OrderedDict({ - metric: np.nan_to_num(metric_value, nan=nan_to_num) - for metric, metric_value in ret_metrics.items() - }) + ret_metrics = OrderedDict( + { + metric: np.nan_to_num(metric_value, nan=nan_to_num) + for metric, metric_value in ret_metrics.items() + } + ) return ret_metrics diff --git a/mmsegmentation/mmseg/core/hook/__init__.py b/mmsegmentation/mmseg/core/hook/__init__.py index 02fe93d..b954166 100644 --- a/mmsegmentation/mmseg/core/hook/__init__.py +++ b/mmsegmentation/mmseg/core/hook/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .wandblogger_hook import MMSegWandbHook -__all__ = ['MMSegWandbHook'] +__all__ = ["MMSegWandbHook"] diff --git a/mmsegmentation/mmseg/core/hook/wandblogger_hook.py b/mmsegmentation/mmseg/core/hook/wandblogger_hook.py index b35c526..d9df217 100644 --- a/mmsegmentation/mmseg/core/hook/wandblogger_hook.py +++ b/mmsegmentation/mmseg/core/hook/wandblogger_hook.py @@ -83,27 +83,28 @@ class MMSegWandbHook(WandbLoggerHook): Default: 100 """ - def __init__(self, - init_kwargs=None, - interval=50, - log_checkpoint=False, - log_checkpoint_metadata=False, - num_eval_images=100, - **kwargs): - super(MMSegWandbHook, self).__init__(init_kwargs, interval, **kwargs) + def __init__( + self, + init_kwargs=None, + interval=50, + log_checkpoint=False, + log_checkpoint_metadata=False, + num_eval_images=100, + **kwargs, + ): + super().__init__(init_kwargs, interval, **kwargs) self.log_checkpoint = log_checkpoint - self.log_checkpoint_metadata = ( - log_checkpoint and log_checkpoint_metadata) + self.log_checkpoint_metadata = log_checkpoint and log_checkpoint_metadata self.num_eval_images = num_eval_images - self.log_evaluation = (num_eval_images > 0) + self.log_evaluation = num_eval_images > 0 self.ckpt_hook: CheckpointHook = None self.eval_hook: EvalHook = None self.test_fn = None @master_only def before_run(self, runner): - super(MMSegWandbHook, self).before_run(runner) + super().before_run(runner) # Check if EvalHook and CheckpointHook are available. for hook in runner.hooks: @@ -111,10 +112,12 @@ def before_run(self, runner): self.ckpt_hook = hook if isinstance(hook, EvalHook): from mmseg.apis import single_gpu_test + self.eval_hook = hook self.test_fn = single_gpu_test if isinstance(hook, DistEvalHook): from mmseg.apis import multi_gpu_test + self.eval_hook = hook self.test_fn = multi_gpu_test @@ -124,8 +127,9 @@ def before_run(self, runner): self.log_checkpoint = False self.log_checkpoint_metadata = False runner.logger.warning( - 'To log checkpoint in MMSegWandbHook, `CheckpointHook` is' - 'required, please check hooks in the runner.') + "To log checkpoint in MMSegWandbHook, `CheckpointHook` is" + "required, please check hooks in the runner." + ) else: self.ckpt_interval = self.ckpt_hook.interval @@ -135,10 +139,11 @@ def before_run(self, runner): self.log_evaluation = False self.log_checkpoint_metadata = False runner.logger.warning( - 'To log evaluation or checkpoint metadata in ' - 'MMSegWandbHook, `EvalHook` or `DistEvalHook` in mmseg ' - 'is required, please check whether the validation ' - 'is enabled.') + "To log evaluation or checkpoint metadata in " + "MMSegWandbHook, `EvalHook` or `DistEvalHook` in mmseg " + "is required, please check whether the validation " + "is enabled." + ) else: self.eval_interval = self.eval_hook.interval self.val_dataset = self.eval_hook.dataloader.dataset @@ -146,18 +151,20 @@ def before_run(self, runner): if self.num_eval_images > len(self.val_dataset): self.num_eval_images = len(self.val_dataset) runner.logger.warning( - f'The num_eval_images ({self.num_eval_images}) is ' - 'greater than the total number of validation samples ' - f'({len(self.val_dataset)}). The complete validation ' - 'dataset will be logged.') + f"The num_eval_images ({self.num_eval_images}) is " + "greater than the total number of validation samples " + f"({len(self.val_dataset)}). The complete validation " + "dataset will be logged." + ) # Check conditions to log checkpoint metadata if self.log_checkpoint_metadata: - assert self.ckpt_interval % self.eval_interval == 0, \ - 'To log checkpoint metadata in MMSegWandbHook, the interval ' \ - f'of checkpoint saving ({self.ckpt_interval}) should be ' \ - 'divisible by the interval of evaluation ' \ - f'({self.eval_interval}).' + assert self.ckpt_interval % self.eval_interval == 0, ( + "To log checkpoint metadata in MMSegWandbHook, the interval " + f"of checkpoint saving ({self.ckpt_interval}) should be " + "divisible by the interval of evaluation " + f"({self.eval_interval})." + ) # Initialize evaluation table if self.log_evaluation: @@ -171,14 +178,14 @@ def before_run(self, runner): # for the reason of this double-layered structure, refer to # https://github.com/open-mmlab/mmdetection/issues/8145#issuecomment-1345343076 def after_train_iter(self, runner): - if self.get_mode(runner) == 'train': + if self.get_mode(runner) == "train": # An ugly patch. The iter-based eval hook will call the # `after_train_iter` method of all logger hooks before evaluation. # Use this trick to skip that call. # Don't call super method at first, it will clear the log_buffer - return super(MMSegWandbHook, self).after_train_iter(runner) + return super().after_train_iter(runner) else: - super(MMSegWandbHook, self).after_train_iter(runner) + super().after_train_iter(runner) self._after_train_iter(runner) @master_only @@ -187,19 +194,17 @@ def _after_train_iter(self, runner): return # Save checkpoint and metadata - if (self.log_checkpoint - and self.every_n_iters(runner, self.ckpt_interval) - or (self.ckpt_hook.save_last and self.is_last_iter(runner))): + if ( + self.log_checkpoint + and self.every_n_iters(runner, self.ckpt_interval) + or (self.ckpt_hook.save_last and self.is_last_iter(runner)) + ): if self.log_checkpoint_metadata and self.eval_hook: - metadata = { - 'iter': runner.iter + 1, - **self._get_eval_results() - } + metadata = {"iter": runner.iter + 1, **self._get_eval_results()} else: metadata = None - aliases = [f'iter_{runner.iter+1}', 'latest'] - model_path = osp.join(self.ckpt_hook.out_dir, - f'iter_{runner.iter+1}.pth') + aliases = [f"iter_{runner.iter+1}", "latest"] + model_path = osp.join(self.ckpt_hook.out_dir, f"iter_{runner.iter+1}.pth") self._log_ckpt_as_artifact(model_path, aliases, metadata) # Save prediction table @@ -228,7 +233,8 @@ def _log_ckpt_as_artifact(self, model_path, aliases, metadata=None): metadata (dict, optional): Metadata associated with this artifact. """ model_artifact = self.wandb.Artifact( - f'run_{self.wandb.run.id}_model', type='model', metadata=metadata) + f"run_{self.wandb.run.id}_model", type="model", metadata=metadata + ) model_artifact.add_file(model_path) self.wandb.log_artifact(model_artifact, aliases=aliases) @@ -236,22 +242,24 @@ def _get_eval_results(self): """Get model evaluation results.""" results = self.eval_hook.latest_results eval_results = self.val_dataset.evaluate( - results, logger='silent', **self.eval_hook.eval_kwargs) + results, logger="silent", **self.eval_hook.eval_kwargs + ) return eval_results def _init_data_table(self): """Initialize the W&B Tables for validation data.""" - columns = ['image_name', 'image'] + columns = ["image_name", "image"] self.data_table = self.wandb.Table(columns=columns) def _init_pred_table(self): """Initialize the W&B Tables for model evaluation.""" - columns = ['image_name', 'ground_truth', 'prediction'] + columns = ["image_name", "ground_truth", "prediction"] self.eval_table = self.wandb.Table(columns=columns) def _add_ground_truth(self, runner): # Get image loading pipeline from mmseg.datasets.pipelines import LoadImageFromFile + img_loader = None for t in self.val_dataset.pipeline.transforms: if isinstance(t, LoadImageFromFile): @@ -260,8 +268,8 @@ def _add_ground_truth(self, runner): if img_loader is None: self.log_evaluation = False runner.logger.warning( - 'LoadImageFromFile is required to add images ' - 'to W&B Tables.') + "LoadImageFromFile is required to add images " "to W&B Tables." + ) return # Select the images to be logged. @@ -269,23 +277,23 @@ def _add_ground_truth(self, runner): # Set seed so that same validation set is logged each time. np.random.seed(42) np.random.shuffle(self.eval_image_indexs) - self.eval_image_indexs = self.eval_image_indexs[:self.num_eval_images] + self.eval_image_indexs = self.eval_image_indexs[: self.num_eval_images] classes = self.val_dataset.CLASSES self.class_id_to_label = {id: name for id, name in enumerate(classes)} - self.class_set = self.wandb.Classes([{ - 'id': id, - 'name': name - } for id, name in self.class_id_to_label.items()]) + self.class_set = self.wandb.Classes( + [{"id": id, "name": name} for id, name in self.class_id_to_label.items()] + ) for idx in self.eval_image_indexs: img_info = self.val_dataset.img_infos[idx] - image_name = img_info['filename'] + image_name = img_info["filename"] # Get image and convert from BGR to RGB img_meta = img_loader( - dict(img_info=img_info, img_prefix=self.val_dataset.img_dir)) - image = mmcv.bgr2rgb(img_meta['img']) + dict(img_info=img_info, img_prefix=self.val_dataset.img_dir) + ) + image = mmcv.bgr2rgb(img_meta["img"]) # Get segmentation mask seg_mask = self.val_dataset.get_gt_seg_map_by_idx(idx) @@ -293,21 +301,22 @@ def _add_ground_truth(self, runner): wandb_masks = None if seg_mask.ndim == 2: wandb_masks = { - 'ground_truth': { - 'mask_data': seg_mask, - 'class_labels': self.class_id_to_label + "ground_truth": { + "mask_data": seg_mask, + "class_labels": self.class_id_to_label, } } # Log a row to the data table. self.data_table.add_data( image_name, - self.wandb.Image( - image, masks=wandb_masks, classes=self.class_set)) + self.wandb.Image(image, masks=wandb_masks, classes=self.class_set), + ) else: runner.logger.warning( - f'The segmentation mask is {seg_mask.ndim}D which ' - 'is not supported by W&B.') + f"The segmentation mask is {seg_mask.ndim}D which " + "is not supported by W&B." + ) self.log_evaluation = False return @@ -322,9 +331,9 @@ def _log_predictions(self, results, runner): if pred_mask.ndim == 2: wandb_masks = { - 'prediction': { - 'mask_data': pred_mask, - 'class_labels': self.class_id_to_label + "prediction": { + "mask_data": pred_mask, + "class_labels": self.class_id_to_label, } } @@ -335,11 +344,14 @@ def _log_predictions(self, results, runner): self.wandb.Image( self.data_table_ref.data[ndx][1], masks=wandb_masks, - classes=self.class_set)) + classes=self.class_set, + ), + ) else: runner.logger.warning( - 'The predictio segmentation mask is ' - f'{pred_mask.ndim}D which is not supported by W&B.') + "The predictio segmentation mask is " + f"{pred_mask.ndim}D which is not supported by W&B." + ) self.log_evaluation = False return @@ -350,13 +362,13 @@ def _log_data_table(self): This allows the data to be uploaded just once. """ - data_artifact = self.wandb.Artifact('val', type='dataset') - data_artifact.add(self.data_table, 'val_data') + data_artifact = self.wandb.Artifact("val", type="dataset") + data_artifact.add(self.data_table, "val_data") self.wandb.run.use_artifact(data_artifact) data_artifact.wait() - self.data_table_ref = data_artifact.get('val_data') + self.data_table_ref = data_artifact.get("val_data") def _log_eval_table(self, iter): """Log the W&B Tables for model evaluation. @@ -365,6 +377,7 @@ def _log_eval_table(self, iter): to compare models at different intervals interactively. """ pred_artifact = self.wandb.Artifact( - f'run_{self.wandb.run.id}_pred', type='evaluation') - pred_artifact.add(self.eval_table, 'eval_data') + f"run_{self.wandb.run.id}_pred", type="evaluation" + ) + pred_artifact.add(self.eval_table, "eval_data") self.wandb.run.log_artifact(pred_artifact) diff --git a/mmsegmentation/mmseg/core/optimizers/__init__.py b/mmsegmentation/mmseg/core/optimizers/__init__.py index 4fbf4ec..811bdcb 100644 --- a/mmsegmentation/mmseg/core/optimizers/__init__.py +++ b/mmsegmentation/mmseg/core/optimizers/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .layer_decay_optimizer_constructor import ( - LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) + LayerDecayOptimizerConstructor, + LearningRateDecayOptimizerConstructor, +) -__all__ = [ - 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor' -] +__all__ = ["LearningRateDecayOptimizerConstructor", "LayerDecayOptimizerConstructor"] diff --git a/mmsegmentation/mmseg/core/optimizers/layer_decay_optimizer_constructor.py b/mmsegmentation/mmseg/core/optimizers/layer_decay_optimizer_constructor.py index 2b6b8ff..9f261a1 100644 --- a/mmsegmentation/mmseg/core/optimizers/layer_decay_optimizer_constructor.py +++ b/mmsegmentation/mmseg/core/optimizers/layer_decay_optimizer_constructor.py @@ -21,11 +21,10 @@ def get_layer_id_for_convnext(var_name, max_layer_id): ``LearningRateDecayOptimizerConstructor``. """ - if var_name in ('backbone.cls_token', 'backbone.mask_token', - 'backbone.pos_embed'): + if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"): return 0 - elif var_name.startswith('backbone.downsample_layers'): - stage_id = int(var_name.split('.')[2]) + elif var_name.startswith("backbone.downsample_layers"): + stage_id = int(var_name.split(".")[2]) if stage_id == 0: layer_id = 0 elif stage_id == 1: @@ -35,9 +34,9 @@ def get_layer_id_for_convnext(var_name, max_layer_id): elif stage_id == 3: layer_id = max_layer_id return layer_id - elif var_name.startswith('backbone.stages'): - stage_id = int(var_name.split('.')[2]) - block_id = int(var_name.split('.')[3]) + elif var_name.startswith("backbone.stages"): + stage_id = int(var_name.split(".")[2]) + block_id = int(var_name.split(".")[3]) if stage_id == 0: layer_id = 1 elif stage_id == 1: @@ -64,13 +63,12 @@ def get_stage_id_for_convnext(var_name, max_stage_id): ``LearningRateDecayOptimizerConstructor``. """ - if var_name in ('backbone.cls_token', 'backbone.mask_token', - 'backbone.pos_embed'): + if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"): return 0 - elif var_name.startswith('backbone.downsample_layers'): + elif var_name.startswith("backbone.downsample_layers"): return 0 - elif var_name.startswith('backbone.stages'): - stage_id = int(var_name.split('.')[2]) + elif var_name.startswith("backbone.stages"): + stage_id = int(var_name.split(".")[2]) return stage_id + 1 else: return max_stage_id - 1 @@ -87,13 +85,12 @@ def get_layer_id_for_vit(var_name, max_layer_id): int: Returns the layer id of the key. """ - if var_name in ('backbone.cls_token', 'backbone.mask_token', - 'backbone.pos_embed'): + if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"): return 0 - elif var_name.startswith('backbone.patch_embed'): + elif var_name.startswith("backbone.patch_embed"): return 0 - elif var_name.startswith('backbone.layers'): - layer_id = int(var_name.split('.')[2]) + elif var_name.startswith("backbone.layers"): + layer_id = int(var_name.split(".")[2]) return layer_id + 1 else: return max_layer_id - 1 @@ -121,67 +118,75 @@ def add_params(self, params, module, **kwargs): logger = get_root_logger() parameter_groups = {} - logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}') - num_layers = self.paramwise_cfg.get('num_layers') + 2 - decay_rate = self.paramwise_cfg.get('decay_rate') - decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise') - logger.info('Build LearningRateDecayOptimizerConstructor ' - f'{decay_type} {decay_rate} - {num_layers}') + logger.info(f"self.paramwise_cfg is {self.paramwise_cfg}") + num_layers = self.paramwise_cfg.get("num_layers") + 2 + decay_rate = self.paramwise_cfg.get("decay_rate") + decay_type = self.paramwise_cfg.get("decay_type", "layer_wise") + logger.info( + "Build LearningRateDecayOptimizerConstructor " + f"{decay_type} {decay_rate} - {num_layers}" + ) weight_decay = self.base_wd for name, param in module.named_parameters(): if not param.requires_grad: continue # frozen weights - if len(param.shape) == 1 or name.endswith('.bias') or name in ( - 'pos_embed', 'cls_token'): - group_name = 'no_decay' - this_weight_decay = 0. + if ( + len(param.shape) == 1 + or name.endswith(".bias") + or name in ("pos_embed", "cls_token") + ): + group_name = "no_decay" + this_weight_decay = 0.0 else: - group_name = 'decay' + group_name = "decay" this_weight_decay = weight_decay - if 'layer_wise' in decay_type: - if 'ConvNeXt' in module.backbone.__class__.__name__: + if "layer_wise" in decay_type: + if "ConvNeXt" in module.backbone.__class__.__name__: layer_id = get_layer_id_for_convnext( - name, self.paramwise_cfg.get('num_layers')) - logger.info(f'set param {name} as id {layer_id}') - elif 'BEiT' in module.backbone.__class__.__name__ or \ - 'MAE' in module.backbone.__class__.__name__: + name, self.paramwise_cfg.get("num_layers") + ) + logger.info(f"set param {name} as id {layer_id}") + elif ( + "BEiT" in module.backbone.__class__.__name__ + or "MAE" in module.backbone.__class__.__name__ + ): layer_id = get_layer_id_for_vit(name, num_layers) - logger.info(f'set param {name} as id {layer_id}') + logger.info(f"set param {name} as id {layer_id}") else: raise NotImplementedError() - elif decay_type == 'stage_wise': - if 'ConvNeXt' in module.backbone.__class__.__name__: + elif decay_type == "stage_wise": + if "ConvNeXt" in module.backbone.__class__.__name__: layer_id = get_stage_id_for_convnext(name, num_layers) - logger.info(f'set param {name} as id {layer_id}') + logger.info(f"set param {name} as id {layer_id}") else: raise NotImplementedError() - group_name = f'layer_{layer_id}_{group_name}' + group_name = f"layer_{layer_id}_{group_name}" if group_name not in parameter_groups: - scale = decay_rate**(num_layers - layer_id - 1) + scale = decay_rate ** (num_layers - layer_id - 1) parameter_groups[group_name] = { - 'weight_decay': this_weight_decay, - 'params': [], - 'param_names': [], - 'lr_scale': scale, - 'group_name': group_name, - 'lr': scale * self.base_lr, + "weight_decay": this_weight_decay, + "params": [], + "param_names": [], + "lr_scale": scale, + "group_name": group_name, + "lr": scale * self.base_lr, } - parameter_groups[group_name]['params'].append(param) - parameter_groups[group_name]['param_names'].append(name) + parameter_groups[group_name]["params"].append(param) + parameter_groups[group_name]["param_names"].append(name) rank, _ = get_dist_info() if rank == 0: to_display = {} for key in parameter_groups: to_display[key] = { - 'param_names': parameter_groups[key]['param_names'], - 'lr_scale': parameter_groups[key]['lr_scale'], - 'lr': parameter_groups[key]['lr'], - 'weight_decay': parameter_groups[key]['weight_decay'], + "param_names": parameter_groups[key]["param_names"], + "lr_scale": parameter_groups[key]["lr_scale"], + "lr": parameter_groups[key]["lr"], + "weight_decay": parameter_groups[key]["weight_decay"], } - logger.info(f'Param groups = {json.dumps(to_display, indent=2)}') + logger.info(f"Param groups = {json.dumps(to_display, indent=2)}") params.extend(parameter_groups.values()) @@ -195,14 +200,17 @@ class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor): """ def __init__(self, optimizer_cfg, paramwise_cfg): - warnings.warn('DeprecationWarning: Original ' - 'LayerDecayOptimizerConstructor of BEiT ' - 'will be deprecated. Please use ' - 'LearningRateDecayOptimizerConstructor instead, ' - 'and set decay_type = layer_wise_vit in paramwise_cfg.') - paramwise_cfg.update({'decay_type': 'layer_wise_vit'}) - warnings.warn('DeprecationWarning: Layer_decay_rate will ' - 'be deleted, please use decay_rate instead.') - paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate') - super(LayerDecayOptimizerConstructor, - self).__init__(optimizer_cfg, paramwise_cfg) + warnings.warn( + "DeprecationWarning: Original " + "LayerDecayOptimizerConstructor of BEiT " + "will be deprecated. Please use " + "LearningRateDecayOptimizerConstructor instead, " + "and set decay_type = layer_wise_vit in paramwise_cfg." + ) + paramwise_cfg.update({"decay_type": "layer_wise_vit"}) + warnings.warn( + "DeprecationWarning: Layer_decay_rate will " + "be deleted, please use decay_rate instead." + ) + paramwise_cfg["decay_rate"] = paramwise_cfg.pop("layer_decay_rate") + super().__init__(optimizer_cfg, paramwise_cfg) diff --git a/mmsegmentation/mmseg/core/seg/__init__.py b/mmsegmentation/mmseg/core/seg/__init__.py index 5206b96..5bdc6c3 100644 --- a/mmsegmentation/mmseg/core/seg/__init__.py +++ b/mmsegmentation/mmseg/core/seg/__init__.py @@ -2,4 +2,4 @@ from .builder import build_pixel_sampler from .sampler import BasePixelSampler, OHEMPixelSampler -__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] +__all__ = ["build_pixel_sampler", "BasePixelSampler", "OHEMPixelSampler"] diff --git a/mmsegmentation/mmseg/core/seg/builder.py b/mmsegmentation/mmseg/core/seg/builder.py index 1cecd34..525364e 100644 --- a/mmsegmentation/mmseg/core/seg/builder.py +++ b/mmsegmentation/mmseg/core/seg/builder.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmcv.utils import Registry, build_from_cfg -PIXEL_SAMPLERS = Registry('pixel sampler') +PIXEL_SAMPLERS = Registry("pixel sampler") def build_pixel_sampler(cfg, **default_args): diff --git a/mmsegmentation/mmseg/core/seg/sampler/__init__.py b/mmsegmentation/mmseg/core/seg/sampler/__init__.py index 5a76485..3caa926 100644 --- a/mmsegmentation/mmseg/core/seg/sampler/__init__.py +++ b/mmsegmentation/mmseg/core/seg/sampler/__init__.py @@ -2,4 +2,4 @@ from .base_pixel_sampler import BasePixelSampler from .ohem_pixel_sampler import OHEMPixelSampler -__all__ = ['BasePixelSampler', 'OHEMPixelSampler'] +__all__ = ["BasePixelSampler", "OHEMPixelSampler"] diff --git a/mmsegmentation/mmseg/core/seg/sampler/ohem_pixel_sampler.py b/mmsegmentation/mmseg/core/seg/sampler/ohem_pixel_sampler.py index 833a287..223cfda 100644 --- a/mmsegmentation/mmseg/core/seg/sampler/ohem_pixel_sampler.py +++ b/mmsegmentation/mmseg/core/seg/sampler/ohem_pixel_sampler.py @@ -23,7 +23,7 @@ class OHEMPixelSampler(BasePixelSampler): """ def __init__(self, context, thresh=None, min_kept=100000): - super(OHEMPixelSampler, self).__init__() + super().__init__() self.context = context assert min_kept > 1 self.thresh = thresh @@ -56,12 +56,11 @@ def sample(self, seg_logit, seg_label): sort_prob, sort_indices = seg_prob[valid_mask].sort() if sort_prob.numel() > 0: - min_threshold = sort_prob[min(batch_kept, - sort_prob.numel() - 1)] + min_threshold = sort_prob[min(batch_kept, sort_prob.numel() - 1)] else: min_threshold = 0.0 threshold = max(min_threshold, self.thresh) - valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. + valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.0 else: if not isinstance(self.context.loss_decode, nn.ModuleList): losses_decode = [self.context.loss_decode] @@ -74,11 +73,12 @@ def sample(self, seg_logit, seg_label): seg_label, weight=None, ignore_index=self.context.ignore_index, - reduction_override='none') + reduction_override="none", + ) # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa _, sort_indices = losses[valid_mask].sort(descending=True) - valid_seg_weight[sort_indices[:batch_kept]] = 1. + valid_seg_weight[sort_indices[:batch_kept]] = 1.0 seg_weight[valid_mask] = valid_seg_weight diff --git a/mmsegmentation/mmseg/core/utils/__init__.py b/mmsegmentation/mmseg/core/utils/__init__.py index 2888289..b9d7ca8 100644 --- a/mmsegmentation/mmseg/core/utils/__init__.py +++ b/mmsegmentation/mmseg/core/utils/__init__.py @@ -2,4 +2,4 @@ from .dist_util import check_dist_init, sync_random_seed from .misc import add_prefix -__all__ = ['add_prefix', 'check_dist_init', 'sync_random_seed'] +__all__ = ["add_prefix", "check_dist_init", "sync_random_seed"] diff --git a/mmsegmentation/mmseg/core/utils/dist_util.py b/mmsegmentation/mmseg/core/utils/dist_util.py index b328851..4511444 100644 --- a/mmsegmentation/mmseg/core/utils/dist_util.py +++ b/mmsegmentation/mmseg/core/utils/dist_util.py @@ -9,7 +9,7 @@ def check_dist_init(): return dist.is_available() and dist.is_initialized() -def sync_random_seed(seed=None, device='cuda'): +def sync_random_seed(seed=None, device="cuda"): """Make sure different ranks share the same seed. All workers must call this function, otherwise it will deadlock. This method is generally used in `DistributedSampler`, because the seed should be identical across all diff --git a/mmsegmentation/mmseg/core/utils/misc.py b/mmsegmentation/mmseg/core/utils/misc.py index 282bb8d..44d3cfd 100644 --- a/mmsegmentation/mmseg/core/utils/misc.py +++ b/mmsegmentation/mmseg/core/utils/misc.py @@ -13,6 +13,6 @@ def add_prefix(inputs, prefix): outputs = dict() for name, value in inputs.items(): - outputs[f'{prefix}.{name}'] = value + outputs[f"{prefix}.{name}"] = value return outputs diff --git a/mmsegmentation/mmseg/datasets/__init__.py b/mmsegmentation/mmseg/datasets/__init__.py index 4281180..bad2c9a 100644 --- a/mmsegmentation/mmseg/datasets/__init__.py +++ b/mmsegmentation/mmseg/datasets/__init__.py @@ -7,8 +7,7 @@ from .coco_trash import COCOTrashDataset from .custom import CustomDataset from .dark_zurich import DarkZurichDataset -from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset, - RepeatDataset) +from .dataset_wrappers import ConcatDataset, MultiImageMixDataset, RepeatDataset from .drive import DRIVEDataset from .face import FaceOccludedDataset from .hrf import HRFDataset @@ -22,12 +21,30 @@ from .voc import PascalVOCDataset __all__ = [ - 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', - 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset', - 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', - 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', - 'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset', - 'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset', - 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', 'FaceOccludedDataset', - 'COCOTrashDataset' + "CustomDataset", + "build_dataloader", + "ConcatDataset", + "RepeatDataset", + "DATASETS", + "build_dataset", + "PIPELINES", + "CityscapesDataset", + "PascalVOCDataset", + "ADE20KDataset", + "PascalContextDataset", + "PascalContextDataset59", + "ChaseDB1Dataset", + "DRIVEDataset", + "HRFDataset", + "STAREDataset", + "DarkZurichDataset", + "NightDrivingDataset", + "COCOStuffDataset", + "LoveDADataset", + "MultiImageMixDataset", + "iSAIDDataset", + "ISPRSDataset", + "PotsdamDataset", + "FaceOccludedDataset", + "COCOTrashDataset", ] diff --git a/mmsegmentation/mmseg/datasets/ade.py b/mmsegmentation/mmseg/datasets/ade.py index db94ceb..f4a65c8 100644 --- a/mmsegmentation/mmseg/datasets/ade.py +++ b/mmsegmentation/mmseg/datasets/ade.py @@ -18,77 +18,317 @@ class ADE20KDataset(CustomDataset): The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to '.png'. """ + CLASSES = ( - 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', - 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', - 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', - 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', - 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', - 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', - 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', - 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', - 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', - 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', - 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', - 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', - 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', - 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', - 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', - 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', - 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', - 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', - 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', - 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', - 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', - 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', - 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', - 'clock', 'flag') - - PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], - [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], - [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], - [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], - [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], - [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], - [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], - [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], - [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], - [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], - [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], - [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], - [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], - [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], - [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], - [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], - [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], - [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], - [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], - [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], - [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], - [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], - [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], - [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], - [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], - [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], - [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], - [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], - [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], - [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], - [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], - [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], - [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], - [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], - [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], - [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], - [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], - [102, 255, 0], [92, 0, 255]] + "wall", + "building", + "sky", + "floor", + "tree", + "ceiling", + "road", + "bed ", + "windowpane", + "grass", + "cabinet", + "sidewalk", + "person", + "earth", + "door", + "table", + "mountain", + "plant", + "curtain", + "chair", + "car", + "water", + "painting", + "sofa", + "shelf", + "house", + "sea", + "mirror", + "rug", + "field", + "armchair", + "seat", + "fence", + "desk", + "rock", + "wardrobe", + "lamp", + "bathtub", + "railing", + "cushion", + "base", + "box", + "column", + "signboard", + "chest of drawers", + "counter", + "sand", + "sink", + "skyscraper", + "fireplace", + "refrigerator", + "grandstand", + "path", + "stairs", + "runway", + "case", + "pool table", + "pillow", + "screen door", + "stairway", + "river", + "bridge", + "bookcase", + "blind", + "coffee table", + "toilet", + "flower", + "book", + "hill", + "bench", + "countertop", + "stove", + "palm", + "kitchen island", + "computer", + "swivel chair", + "boat", + "bar", + "arcade machine", + "hovel", + "bus", + "towel", + "light", + "truck", + "tower", + "chandelier", + "awning", + "streetlight", + "booth", + "television receiver", + "airplane", + "dirt track", + "apparel", + "pole", + "land", + "bannister", + "escalator", + "ottoman", + "bottle", + "buffet", + "poster", + "stage", + "van", + "ship", + "fountain", + "conveyer belt", + "canopy", + "washer", + "plaything", + "swimming pool", + "stool", + "barrel", + "basket", + "waterfall", + "tent", + "bag", + "minibike", + "cradle", + "oven", + "ball", + "food", + "step", + "tank", + "trade name", + "microwave", + "pot", + "animal", + "bicycle", + "lake", + "dishwasher", + "screen", + "blanket", + "sculpture", + "hood", + "sconce", + "vase", + "traffic light", + "tray", + "ashcan", + "fan", + "pier", + "crt screen", + "plate", + "monitor", + "bulletin board", + "shower", + "radiator", + "glass", + "clock", + "flag", + ) + + PALETTE = [ + [120, 120, 120], + [180, 120, 120], + [6, 230, 230], + [80, 50, 50], + [4, 200, 3], + [120, 120, 80], + [140, 140, 140], + [204, 5, 255], + [230, 230, 230], + [4, 250, 7], + [224, 5, 255], + [235, 255, 7], + [150, 5, 61], + [120, 120, 70], + [8, 255, 51], + [255, 6, 82], + [143, 255, 140], + [204, 255, 4], + [255, 51, 7], + [204, 70, 3], + [0, 102, 200], + [61, 230, 250], + [255, 6, 51], + [11, 102, 255], + [255, 7, 71], + [255, 9, 224], + [9, 7, 230], + [220, 220, 220], + [255, 9, 92], + [112, 9, 255], + [8, 255, 214], + [7, 255, 224], + [255, 184, 6], + [10, 255, 71], + [255, 41, 10], + [7, 255, 255], + [224, 255, 8], + [102, 8, 255], + [255, 61, 6], + [255, 194, 7], + [255, 122, 8], + [0, 255, 20], + [255, 8, 41], + [255, 5, 153], + [6, 51, 255], + [235, 12, 255], + [160, 150, 20], + [0, 163, 255], + [140, 140, 140], + [250, 10, 15], + [20, 255, 0], + [31, 255, 0], + [255, 31, 0], + [255, 224, 0], + [153, 255, 0], + [0, 0, 255], + [255, 71, 0], + [0, 235, 255], + [0, 173, 255], + [31, 0, 255], + [11, 200, 200], + [255, 82, 0], + [0, 255, 245], + [0, 61, 255], + [0, 255, 112], + [0, 255, 133], + [255, 0, 0], + [255, 163, 0], + [255, 102, 0], + [194, 255, 0], + [0, 143, 255], + [51, 255, 0], + [0, 82, 255], + [0, 255, 41], + [0, 255, 173], + [10, 0, 255], + [173, 255, 0], + [0, 255, 153], + [255, 92, 0], + [255, 0, 255], + [255, 0, 245], + [255, 0, 102], + [255, 173, 0], + [255, 0, 20], + [255, 184, 184], + [0, 31, 255], + [0, 255, 61], + [0, 71, 255], + [255, 0, 204], + [0, 255, 194], + [0, 255, 82], + [0, 10, 255], + [0, 112, 255], + [51, 0, 255], + [0, 194, 255], + [0, 122, 255], + [0, 255, 163], + [255, 153, 0], + [0, 255, 10], + [255, 112, 0], + [143, 255, 0], + [82, 0, 255], + [163, 255, 0], + [255, 235, 0], + [8, 184, 170], + [133, 0, 255], + [0, 255, 92], + [184, 0, 255], + [255, 0, 31], + [0, 184, 255], + [0, 214, 255], + [255, 0, 112], + [92, 255, 0], + [0, 224, 255], + [112, 224, 255], + [70, 184, 160], + [163, 0, 255], + [153, 0, 255], + [71, 255, 0], + [255, 0, 163], + [255, 204, 0], + [255, 0, 143], + [0, 255, 235], + [133, 255, 0], + [255, 0, 235], + [245, 0, 255], + [255, 0, 122], + [255, 245, 0], + [10, 190, 212], + [214, 255, 0], + [0, 204, 255], + [20, 0, 255], + [255, 255, 0], + [0, 153, 255], + [0, 41, 255], + [0, 255, 204], + [41, 0, 255], + [41, 255, 0], + [173, 0, 255], + [0, 245, 255], + [71, 0, 255], + [122, 0, 255], + [0, 255, 184], + [0, 92, 255], + [184, 255, 0], + [0, 133, 255], + [255, 214, 0], + [25, 194, 194], + [102, 255, 0], + [92, 0, 255], + ] def __init__(self, **kwargs): - super(ADE20KDataset, self).__init__( - img_suffix='.jpg', - seg_map_suffix='.png', - reduce_zero_label=True, - **kwargs) + super().__init__( + img_suffix=".jpg", seg_map_suffix=".png", reduce_zero_label=True, **kwargs + ) def results2img(self, results, imgfile_prefix, to_label_id, indices=None): """Write the segmentation results to images. @@ -115,11 +355,10 @@ def results2img(self, results, imgfile_prefix, to_label_id, indices=None): mmcv.mkdir_or_exist(imgfile_prefix) result_files = [] for result, idx in zip(results, indices): - - filename = self.img_infos[idx]['filename'] + filename = self.img_infos[idx]["filename"] basename = osp.splitext(osp.basename(filename))[0] - png_filename = osp.join(imgfile_prefix, f'{basename}.png') + png_filename = osp.join(imgfile_prefix, f"{basename}.png") # The index range of official requirement is from 0 to 150. # But the index range of output is from 0 to 149. @@ -132,11 +371,7 @@ def results2img(self, results, imgfile_prefix, to_label_id, indices=None): return result_files - def format_results(self, - results, - imgfile_prefix, - to_label_id=True, - indices=None): + def format_results(self, results, imgfile_prefix, to_label_id=True, indices=None): """Format the results into dir (standard format for ade20k evaluation). Args: @@ -159,9 +394,8 @@ def format_results(self, if indices is None: indices = list(range(len(self))) - assert isinstance(results, list), 'results must be a list.' - assert isinstance(indices, list), 'indices must be a list.' + assert isinstance(results, list), "results must be a list." + assert isinstance(indices, list), "indices must be a list." - result_files = self.results2img(results, imgfile_prefix, to_label_id, - indices) + result_files = self.results2img(results, imgfile_prefix, to_label_id, indices) return result_files diff --git a/mmsegmentation/mmseg/datasets/builder.py b/mmsegmentation/mmseg/datasets/builder.py index 49ee633..6b2020c 100644 --- a/mmsegmentation/mmseg/datasets/builder.py +++ b/mmsegmentation/mmseg/datasets/builder.py @@ -13,27 +13,29 @@ from .samplers import DistributedSampler -if platform.system() != 'Windows': +if platform.system() != "Windows": # https://github.com/pytorch/pytorch/issues/973 import resource + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) base_soft_limit = rlimit[0] hard_limit = rlimit[1] soft_limit = min(max(4096, base_soft_limit), hard_limit) resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) -DATASETS = Registry('dataset') -PIPELINES = Registry('pipeline') +DATASETS = Registry("dataset") +PIPELINES = Registry("pipeline") def _concat_dataset(cfg, default_args=None): """Build :obj:`ConcatDataset by.""" from .dataset_wrappers import ConcatDataset - img_dir = cfg['img_dir'] - ann_dir = cfg.get('ann_dir', None) - split = cfg.get('split', None) + + img_dir = cfg["img_dir"] + ann_dir = cfg.get("ann_dir", None) + split = cfg.get("split", None) # pop 'separate_eval' since it is not a valid key for common datasets. - separate_eval = cfg.pop('separate_eval', True) + separate_eval = cfg.pop("separate_eval", True) num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1 if ann_dir is not None: num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1 @@ -54,11 +56,11 @@ def _concat_dataset(cfg, default_args=None): for i in range(num_dset): data_cfg = copy.deepcopy(cfg) if isinstance(img_dir, (list, tuple)): - data_cfg['img_dir'] = img_dir[i] + data_cfg["img_dir"] = img_dir[i] if isinstance(ann_dir, (list, tuple)): - data_cfg['ann_dir'] = ann_dir[i] + data_cfg["ann_dir"] = ann_dir[i] if isinstance(split, (list, tuple)): - data_cfg['split'] = split[i] + data_cfg["split"] = split[i] datasets.append(build_dataset(data_cfg, default_args)) return ConcatDataset(datasets, separate_eval) @@ -66,20 +68,22 @@ def _concat_dataset(cfg, default_args=None): def build_dataset(cfg, default_args=None): """Build datasets.""" - from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset, - RepeatDataset) + from .dataset_wrappers import ConcatDataset, MultiImageMixDataset, RepeatDataset + if isinstance(cfg, (list, tuple)): dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) - elif cfg['type'] == 'RepeatDataset': + elif cfg["type"] == "RepeatDataset": dataset = RepeatDataset( - build_dataset(cfg['dataset'], default_args), cfg['times']) - elif cfg['type'] == 'MultiImageMixDataset': + build_dataset(cfg["dataset"], default_args), cfg["times"] + ) + elif cfg["type"] == "MultiImageMixDataset": cp_cfg = copy.deepcopy(cfg) - cp_cfg['dataset'] = build_dataset(cp_cfg['dataset']) - cp_cfg.pop('type') + cp_cfg["dataset"] = build_dataset(cp_cfg["dataset"]) + cp_cfg.pop("type") dataset = MultiImageMixDataset(**cp_cfg) - elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance( - cfg.get('split', None), (list, tuple)): + elif isinstance(cfg.get("img_dir"), (list, tuple)) or isinstance( + cfg.get("split", None), (list, tuple) + ): dataset = _concat_dataset(cfg, default_args) else: dataset = build_from_cfg(cfg, DATASETS, default_args) @@ -87,17 +91,19 @@ def build_dataset(cfg, default_args=None): return dataset -def build_dataloader(dataset, - samples_per_gpu, - workers_per_gpu, - num_gpus=1, - dist=True, - shuffle=True, - seed=None, - drop_last=False, - pin_memory=True, - persistent_workers=True, - **kwargs): +def build_dataloader( + dataset, + samples_per_gpu, + workers_per_gpu, + num_gpus=1, + dist=True, + shuffle=True, + seed=None, + drop_last=False, + pin_memory=True, + persistent_workers=True, + **kwargs, +): """Build PyTorch DataLoader. In distributed training, each GPU/process has a dataloader. @@ -131,7 +137,8 @@ def build_dataloader(dataset, rank, world_size = get_dist_info() if dist and not isinstance(dataset, IterableDataset): sampler = DistributedSampler( - dataset, world_size, rank, shuffle=shuffle, seed=seed) + dataset, world_size, rank, shuffle=shuffle, seed=seed + ) shuffle = False batch_size = samples_per_gpu num_workers = workers_per_gpu @@ -145,11 +152,13 @@ def build_dataloader(dataset, batch_size = num_gpus * samples_per_gpu num_workers = num_gpus * workers_per_gpu - init_fn = partial( - worker_init_fn, num_workers=num_workers, rank=rank, - seed=seed) if seed is not None else None + init_fn = ( + partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) + if seed is not None + else None + ) - if digit_version(torch.__version__) >= digit_version('1.8.0'): + if digit_version(torch.__version__) >= digit_version("1.8.0"): data_loader = DataLoader( dataset, batch_size=batch_size, @@ -161,7 +170,8 @@ def build_dataloader(dataset, worker_init_fn=init_fn, drop_last=drop_last, persistent_workers=persistent_workers, - **kwargs) + **kwargs, + ) else: data_loader = DataLoader( dataset, @@ -173,7 +183,8 @@ def build_dataloader(dataset, shuffle=shuffle, worker_init_fn=init_fn, drop_last=drop_last, - **kwargs) + **kwargs, + ) return data_loader diff --git a/mmsegmentation/mmseg/datasets/chase_db1.py b/mmsegmentation/mmseg/datasets/chase_db1.py index 5cdc8d8..617f151 100644 --- a/mmsegmentation/mmseg/datasets/chase_db1.py +++ b/mmsegmentation/mmseg/datasets/chase_db1.py @@ -14,14 +14,15 @@ class ChaseDB1Dataset(CustomDataset): '_1stHO.png'. """ - CLASSES = ('background', 'vessel') + CLASSES = ("background", "vessel") PALETTE = [[120, 120, 120], [6, 230, 230]] def __init__(self, **kwargs): - super(ChaseDB1Dataset, self).__init__( - img_suffix='.png', - seg_map_suffix='_1stHO.png', + super().__init__( + img_suffix=".png", + seg_map_suffix="_1stHO.png", reduce_zero_label=False, - **kwargs) + **kwargs, + ) assert self.file_client.exists(self.img_dir) diff --git a/mmsegmentation/mmseg/datasets/cityscapes.py b/mmsegmentation/mmseg/datasets/cityscapes.py index ed633d0..ad90899 100644 --- a/mmsegmentation/mmseg/datasets/cityscapes.py +++ b/mmsegmentation/mmseg/datasets/cityscapes.py @@ -18,23 +18,57 @@ class CityscapesDataset(CustomDataset): fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. """ - CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', - 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', - 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', - 'bicycle') - - PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], - [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], - [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], - [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], - [0, 80, 100], [0, 0, 230], [119, 11, 32]] - - def __init__(self, - img_suffix='_leftImg8bit.png', - seg_map_suffix='_gtFine_labelTrainIds.png', - **kwargs): - super(CityscapesDataset, self).__init__( - img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) + CLASSES = ( + "road", + "sidewalk", + "building", + "wall", + "fence", + "pole", + "traffic light", + "traffic sign", + "vegetation", + "terrain", + "sky", + "person", + "rider", + "car", + "truck", + "bus", + "train", + "motorcycle", + "bicycle", + ) + + PALETTE = [ + [128, 64, 128], + [244, 35, 232], + [70, 70, 70], + [102, 102, 156], + [190, 153, 153], + [153, 153, 153], + [250, 170, 30], + [220, 220, 0], + [107, 142, 35], + [152, 251, 152], + [70, 130, 180], + [220, 20, 60], + [255, 0, 0], + [0, 0, 142], + [0, 0, 70], + [0, 60, 100], + [0, 80, 100], + [0, 0, 230], + [119, 11, 32], + ] + + def __init__( + self, + img_suffix="_leftImg8bit.png", + seg_map_suffix="_gtFine_labelTrainIds.png", + **kwargs, + ): + super().__init__(img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) @staticmethod def _convert_to_label_id(result): @@ -42,6 +76,7 @@ def _convert_to_label_id(result): if isinstance(result, str): result = np.load(result) import cityscapesscripts.helpers.labels as CSLabels + result_copy = result.copy() for trainId, label in CSLabels.trainId2label.items(): result_copy[result == trainId] = label.id @@ -75,13 +110,14 @@ def results2img(self, results, imgfile_prefix, to_label_id, indices=None): for result, idx in zip(results, indices): if to_label_id: result = self._convert_to_label_id(result) - filename = self.img_infos[idx]['filename'] + filename = self.img_infos[idx]["filename"] basename = osp.splitext(osp.basename(filename))[0] - png_filename = osp.join(imgfile_prefix, f'{basename}.png') + png_filename = osp.join(imgfile_prefix, f"{basename}.png") - output = Image.fromarray(result.astype(np.uint8)).convert('P') + output = Image.fromarray(result.astype(np.uint8)).convert("P") import cityscapesscripts.helpers.labels as CSLabels + palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8) for label_id, label in CSLabels.id2label.items(): palette[label_id] = label.color @@ -92,11 +128,7 @@ def results2img(self, results, imgfile_prefix, to_label_id, indices=None): return result_files - def format_results(self, - results, - imgfile_prefix, - to_label_id=True, - indices=None): + def format_results(self, results, imgfile_prefix, to_label_id=True, indices=None): """Format the results into dir (standard format for Cityscapes evaluation). @@ -119,19 +151,14 @@ def format_results(self, if indices is None: indices = list(range(len(self))) - assert isinstance(results, list), 'results must be a list.' - assert isinstance(indices, list), 'indices must be a list.' + assert isinstance(results, list), "results must be a list." + assert isinstance(indices, list), "indices must be a list." - result_files = self.results2img(results, imgfile_prefix, to_label_id, - indices) + result_files = self.results2img(results, imgfile_prefix, to_label_id, indices) return result_files - def evaluate(self, - results, - metric='mIoU', - logger=None, - imgfile_prefix=None): + def evaluate(self, results, metric="mIoU", logger=None, imgfile_prefix=None): """Evaluation in Cityscapes/default protocol. Args: @@ -155,14 +182,13 @@ def evaluate(self, eval_results = dict() metrics = metric.copy() if isinstance(metric, list) else [metric] - if 'cityscapes' in metrics: + if "cityscapes" in metrics: eval_results.update( - self._evaluate_cityscapes(results, logger, imgfile_prefix)) - metrics.remove('cityscapes') + self._evaluate_cityscapes(results, logger, imgfile_prefix) + ) + metrics.remove("cityscapes") if len(metrics) > 0: - eval_results.update( - super(CityscapesDataset, - self).evaluate(results, metrics, logger)) + eval_results.update(super().evaluate(results, metrics, logger)) return eval_results @@ -181,17 +207,19 @@ def _evaluate_cityscapes(self, results, logger, imgfile_prefix): try: import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa except ImportError: - raise ImportError('Please run "pip install cityscapesscripts" to ' - 'install cityscapesscripts first.') - msg = 'Evaluating in Cityscapes style' + raise ImportError( + 'Please run "pip install cityscapesscripts" to ' + "install cityscapesscripts first." + ) + msg = "Evaluating in Cityscapes style" if logger is None: - msg = '\n' + msg + msg = "\n" + msg print_log(msg, logger=logger) result_dir = imgfile_prefix eval_results = dict() - print_log(f'Evaluating results under {result_dir} ...', logger=logger) + print_log(f"Evaluating results under {result_dir} ...", logger=logger) CSEval.args.evalInstLevelScore = True CSEval.args.predictionPath = osp.abspath(result_dir) @@ -204,11 +232,13 @@ def _evaluate_cityscapes(self, results, logger, imgfile_prefix): # when evaluating with official cityscapesscripts, # **_gtFine_labelIds.png is used for seg_map in mmcv.scandir( - self.ann_dir, 'gtFine_labelIds.png', recursive=True): + self.ann_dir, "gtFine_labelIds.png", recursive=True + ): seg_map_list.append(osp.join(self.ann_dir, seg_map)) pred_list.append(CSEval.getPrediction(CSEval.args, seg_map)) eval_results.update( - CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args)) + CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args) + ) return eval_results diff --git a/mmsegmentation/mmseg/datasets/coco_stuff.py b/mmsegmentation/mmseg/datasets/coco_stuff.py index 24d0895..56b4901 100644 --- a/mmsegmentation/mmseg/datasets/coco_stuff.py +++ b/mmsegmentation/mmseg/datasets/coco_stuff.py @@ -14,81 +14,356 @@ class COCOStuffDataset(CustomDataset): 10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg', and ``seg_map_suffix`` is fixed to '.png'. """ + CLASSES = ( - 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', - 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', - 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', - 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', - 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', - 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', - 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', - 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', - 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', - 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', - 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', - 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', - 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', - 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', - 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', - 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', - 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', - 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', - 'flower', 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', - 'gravel', 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', - 'metal', 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', - 'paper', 'pavement', 'pillow', 'plant-other', 'plastic', 'platform', - 'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof', - 'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper', - 'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other', - 'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable', - 'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel', - 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', - 'window-blind', 'window-other', 'wood') + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "dining table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", + "banner", + "blanket", + "branch", + "bridge", + "building-other", + "bush", + "cabinet", + "cage", + "cardboard", + "carpet", + "ceiling-other", + "ceiling-tile", + "cloth", + "clothes", + "clouds", + "counter", + "cupboard", + "curtain", + "desk-stuff", + "dirt", + "door-stuff", + "fence", + "floor-marble", + "floor-other", + "floor-stone", + "floor-tile", + "floor-wood", + "flower", + "fog", + "food-other", + "fruit", + "furniture-other", + "grass", + "gravel", + "ground-other", + "hill", + "house", + "leaves", + "light", + "mat", + "metal", + "mirror-stuff", + "moss", + "mountain", + "mud", + "napkin", + "net", + "paper", + "pavement", + "pillow", + "plant-other", + "plastic", + "platform", + "playingfield", + "railing", + "railroad", + "river", + "road", + "rock", + "roof", + "rug", + "salad", + "sand", + "sea", + "shelf", + "sky-other", + "skyscraper", + "snow", + "solid-other", + "stairs", + "stone", + "straw", + "structural-other", + "table", + "tent", + "textile-other", + "towel", + "tree", + "vegetable", + "wall-brick", + "wall-concrete", + "wall-other", + "wall-panel", + "wall-stone", + "wall-tile", + "wall-wood", + "water-other", + "waterdrops", + "window-blind", + "window-other", + "wood", + ) - PALETTE = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], - [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64], - [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], - [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192], - [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], - [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], - [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], - [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], - [0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128], - [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], - [0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128], - [128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192], - [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], - [64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0], - [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192], - [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], - [64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128], - [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], - [64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224], - [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0], - [0, 192, 0], [192, 128, 96], [192, 96, 128], [0, 64, 128], - [64, 0, 96], [64, 224, 128], [128, 64, 0], [192, 0, 224], - [64, 96, 128], [128, 192, 128], [64, 0, 224], [192, 224, 128], - [128, 192, 64], [192, 0, 96], [192, 96, 0], [128, 64, 192], - [0, 128, 96], [0, 224, 0], [64, 64, 64], [128, 128, 224], - [0, 96, 0], [64, 192, 192], [0, 128, 224], [128, 224, 0], - [64, 192, 64], [128, 128, 96], [128, 32, 128], [64, 0, 192], - [0, 64, 96], [0, 160, 128], [192, 0, 64], [128, 64, 224], - [0, 32, 128], [192, 128, 192], [0, 64, 224], [128, 160, 128], - [192, 128, 0], [128, 64, 32], [128, 32, 64], [192, 0, 128], - [64, 192, 32], [0, 160, 64], [64, 0, 0], [192, 192, 160], - [0, 32, 64], [64, 128, 128], [64, 192, 160], [128, 160, 64], - [64, 128, 0], [192, 192, 32], [128, 96, 192], [64, 0, 128], - [64, 64, 32], [0, 224, 192], [192, 0, 0], [192, 64, 160], - [0, 96, 192], [192, 128, 128], [64, 64, 160], [128, 224, 192], - [192, 128, 64], [192, 64, 32], [128, 96, 64], [192, 0, 192], - [0, 192, 32], [64, 224, 64], [64, 0, 64], [128, 192, 160], - [64, 96, 64], [64, 128, 192], [0, 192, 160], [192, 224, 64], - [64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192], - [0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160], - [64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192], - [192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128], - [64, 192, 96], [64, 160, 64], [64, 64, 0]] + PALETTE = [ + [0, 192, 64], + [0, 192, 64], + [0, 64, 96], + [128, 192, 192], + [0, 64, 64], + [0, 192, 224], + [0, 192, 192], + [128, 192, 64], + [0, 192, 96], + [128, 192, 64], + [128, 32, 192], + [0, 0, 224], + [0, 0, 64], + [0, 160, 192], + [128, 0, 96], + [128, 0, 192], + [0, 32, 192], + [128, 128, 224], + [0, 0, 192], + [128, 160, 192], + [128, 128, 0], + [128, 0, 32], + [128, 32, 0], + [128, 0, 128], + [64, 128, 32], + [0, 160, 0], + [0, 0, 0], + [192, 128, 160], + [0, 32, 0], + [0, 128, 128], + [64, 128, 160], + [128, 160, 0], + [0, 128, 0], + [192, 128, 32], + [128, 96, 128], + [0, 0, 128], + [64, 0, 32], + [0, 224, 128], + [128, 0, 0], + [192, 0, 160], + [0, 96, 128], + [128, 128, 128], + [64, 0, 160], + [128, 224, 128], + [128, 128, 64], + [192, 0, 32], + [128, 96, 0], + [128, 0, 192], + [0, 128, 32], + [64, 224, 0], + [0, 0, 64], + [128, 128, 160], + [64, 96, 0], + [0, 128, 192], + [0, 128, 160], + [192, 224, 0], + [0, 128, 64], + [128, 128, 32], + [192, 32, 128], + [0, 64, 192], + [0, 0, 32], + [64, 160, 128], + [128, 64, 64], + [128, 0, 160], + [64, 32, 128], + [128, 192, 192], + [0, 0, 160], + [192, 160, 128], + [128, 192, 0], + [128, 0, 96], + [192, 32, 0], + [128, 64, 128], + [64, 128, 96], + [64, 160, 0], + [0, 64, 0], + [192, 128, 224], + [64, 32, 0], + [0, 192, 128], + [64, 128, 224], + [192, 160, 0], + [0, 192, 0], + [192, 128, 96], + [192, 96, 128], + [0, 64, 128], + [64, 0, 96], + [64, 224, 128], + [128, 64, 0], + [192, 0, 224], + [64, 96, 128], + [128, 192, 128], + [64, 0, 224], + [192, 224, 128], + [128, 192, 64], + [192, 0, 96], + [192, 96, 0], + [128, 64, 192], + [0, 128, 96], + [0, 224, 0], + [64, 64, 64], + [128, 128, 224], + [0, 96, 0], + [64, 192, 192], + [0, 128, 224], + [128, 224, 0], + [64, 192, 64], + [128, 128, 96], + [128, 32, 128], + [64, 0, 192], + [0, 64, 96], + [0, 160, 128], + [192, 0, 64], + [128, 64, 224], + [0, 32, 128], + [192, 128, 192], + [0, 64, 224], + [128, 160, 128], + [192, 128, 0], + [128, 64, 32], + [128, 32, 64], + [192, 0, 128], + [64, 192, 32], + [0, 160, 64], + [64, 0, 0], + [192, 192, 160], + [0, 32, 64], + [64, 128, 128], + [64, 192, 160], + [128, 160, 64], + [64, 128, 0], + [192, 192, 32], + [128, 96, 192], + [64, 0, 128], + [64, 64, 32], + [0, 224, 192], + [192, 0, 0], + [192, 64, 160], + [0, 96, 192], + [192, 128, 128], + [64, 64, 160], + [128, 224, 192], + [192, 128, 64], + [192, 64, 32], + [128, 96, 64], + [192, 0, 192], + [0, 192, 32], + [64, 224, 64], + [64, 0, 64], + [128, 192, 160], + [64, 96, 64], + [64, 128, 192], + [0, 192, 160], + [192, 224, 64], + [64, 128, 64], + [128, 192, 32], + [192, 32, 192], + [64, 64, 192], + [0, 64, 32], + [64, 160, 192], + [192, 64, 64], + [128, 64, 160], + [64, 32, 192], + [192, 192, 192], + [0, 64, 160], + [192, 160, 192], + [192, 192, 0], + [128, 64, 96], + [192, 32, 64], + [192, 64, 128], + [64, 192, 96], + [64, 160, 64], + [64, 64, 0], + ] def __init__(self, **kwargs): - super(COCOStuffDataset, self).__init__( - img_suffix='.jpg', seg_map_suffix='_labelTrainIds.png', **kwargs) + super().__init__( + img_suffix=".jpg", seg_map_suffix="_labelTrainIds.png", **kwargs + ) diff --git a/mmsegmentation/mmseg/datasets/coco_trash.py b/mmsegmentation/mmseg/datasets/coco_trash.py index 383c9a3..0dcc71b 100644 --- a/mmsegmentation/mmseg/datasets/coco_trash.py +++ b/mmsegmentation/mmseg/datasets/coco_trash.py @@ -8,16 +8,31 @@ class COCOTrashDataset(CustomDataset): """COCO-Trash dataset.""" CLASSES = ( - "Background", "General trash", "Paper", "Paper pack", "Metal", "Glass", - "Plastic", "Styrofoam", "Plastic bag", "Battery", "Clothing", + "Background", + "General trash", + "Paper", + "Paper pack", + "Metal", + "Glass", + "Plastic", + "Styrofoam", + "Plastic bag", + "Battery", + "Clothing", ) PALETTE = ( - (128, 224, 128), (128, 62, 62), (30, 142, 30), (192, 0, 0), (50, 50, 160), - (0, 224, 224), (0, 0, 224), (192, 224, 0), (192, 224, 224), (192, 96, 0), + (128, 224, 128), + (128, 62, 62), + (30, 142, 30), + (192, 0, 0), + (50, 50, 160), + (0, 224, 224), + (0, 0, 224), + (192, 224, 0), + (192, 224, 224), + (192, 96, 0), (0, 224, 0), ) def __init__(self, **kwargs) -> None: - super(COCOTrashDataset, self).__init__( - img_suffix=".jpg", seg_map_suffix=".png", **kwargs - ) + super().__init__(img_suffix=".jpg", seg_map_suffix=".png", **kwargs) diff --git a/mmsegmentation/mmseg/datasets/custom.py b/mmsegmentation/mmseg/datasets/custom.py index 4615d41..96d0cf4 100644 --- a/mmsegmentation/mmseg/datasets/custom.py +++ b/mmsegmentation/mmseg/datasets/custom.py @@ -77,21 +77,23 @@ class CustomDataset(Dataset): PALETTE = None - def __init__(self, - pipeline, - img_dir, - img_suffix='.jpg', - ann_dir=None, - seg_map_suffix='.png', - split=None, - data_root=None, - test_mode=False, - ignore_index=255, - reduce_zero_label=False, - classes=None, - palette=None, - gt_seg_map_loader_cfg=None, - file_client_args=dict(backend='disk')): + def __init__( + self, + pipeline, + img_dir, + img_suffix=".jpg", + ann_dir=None, + seg_map_suffix=".png", + split=None, + data_root=None, + test_mode=False, + ignore_index=255, + reduce_zero_label=False, + classes=None, + palette=None, + gt_seg_map_loader_cfg=None, + file_client_args=dict(backend="disk"), + ): self.pipeline = Compose(pipeline) self.img_dir = img_dir self.img_suffix = img_suffix @@ -103,18 +105,20 @@ def __init__(self, self.ignore_index = ignore_index self.reduce_zero_label = reduce_zero_label self.label_map = None - self.CLASSES, self.PALETTE = self.get_classes_and_palette( - classes, palette) - self.gt_seg_map_loader = LoadAnnotations( - ) if gt_seg_map_loader_cfg is None else LoadAnnotations( - **gt_seg_map_loader_cfg) + self.CLASSES, self.PALETTE = self.get_classes_and_palette(classes, palette) + self.gt_seg_map_loader = ( + LoadAnnotations() + if gt_seg_map_loader_cfg is None + else LoadAnnotations(**gt_seg_map_loader_cfg) + ) self.file_client_args = file_client_args self.file_client = mmcv.FileClient.infer_client(self.file_client_args) if test_mode: - assert self.CLASSES is not None, \ - '`cls.CLASSES` or `classes` should be specified when testing' + assert ( + self.CLASSES is not None + ), "`cls.CLASSES` or `classes` should be specified when testing" # join paths if data_root is specified if self.data_root is not None: @@ -126,16 +130,15 @@ def __init__(self, self.split = osp.join(self.data_root, self.split) # load annotations - self.img_infos = self.load_annotations(self.img_dir, self.img_suffix, - self.ann_dir, - self.seg_map_suffix, self.split) + self.img_infos = self.load_annotations( + self.img_dir, self.img_suffix, self.ann_dir, self.seg_map_suffix, self.split + ) def __len__(self): """Total number of samples of data.""" return len(self.img_infos) - def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, - split): + def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, split): """Load annotation from directory. Args: @@ -153,29 +156,26 @@ def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, img_infos = [] if split is not None: - lines = mmcv.list_from_file( - split, file_client_args=self.file_client_args) + lines = mmcv.list_from_file(split, file_client_args=self.file_client_args) for line in lines: img_name = line.strip() img_info = dict(filename=img_name + img_suffix) if ann_dir is not None: seg_map = img_name + seg_map_suffix - img_info['ann'] = dict(seg_map=seg_map) + img_info["ann"] = dict(seg_map=seg_map) img_infos.append(img_info) else: for img in self.file_client.list_dir_or_file( - dir_path=img_dir, - list_dir=False, - suffix=img_suffix, - recursive=True): + dir_path=img_dir, list_dir=False, suffix=img_suffix, recursive=True + ): img_info = dict(filename=img) if ann_dir is not None: seg_map = img.replace(img_suffix, seg_map_suffix) - img_info['ann'] = dict(seg_map=seg_map) + img_info["ann"] = dict(seg_map=seg_map) img_infos.append(img_info) - img_infos = sorted(img_infos, key=lambda x: x['filename']) + img_infos = sorted(img_infos, key=lambda x: x["filename"]) - print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) + print_log(f"Loaded {len(img_infos)} images", logger=get_root_logger()) return img_infos def get_ann_info(self, idx): @@ -188,15 +188,15 @@ def get_ann_info(self, idx): dict: Annotation info of specified index. """ - return self.img_infos[idx]['ann'] + return self.img_infos[idx]["ann"] def pre_pipeline(self, results): """Prepare results dict for pipeline.""" - results['seg_fields'] = [] - results['img_prefix'] = self.img_dir - results['seg_prefix'] = self.ann_dir + results["seg_fields"] = [] + results["img_prefix"] = self.img_dir + results["seg_prefix"] = self.ann_dir if self.custom_classes: - results['label_map'] = self.label_map + results["label_map"] = self.label_map def __getitem__(self, idx): """Get training/test data after pipeline. @@ -257,22 +257,23 @@ def get_gt_seg_map_by_idx(self, index): results = dict(ann_info=ann_info) self.pre_pipeline(results) self.gt_seg_map_loader(results) - return results['gt_semantic_seg'] + return results["gt_semantic_seg"] def get_gt_seg_maps(self, efficient_test=None): """Get ground truth segmentation maps for evaluation.""" if efficient_test is not None: warnings.warn( - 'DeprecationWarning: ``efficient_test`` has been deprecated ' - 'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory ' - 'friendly by default. ') + "DeprecationWarning: ``efficient_test`` has been deprecated " + "since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory " + "friendly by default. " + ) for idx in range(len(self)): ann_info = self.get_ann_info(idx) results = dict(ann_info=ann_info) self.pre_pipeline(results) self.gt_seg_map_loader(results) - yield results['gt_semantic_seg'] + yield results["gt_semantic_seg"] def pre_eval(self, preds, indices): """Collect eval result from each iteration. @@ -309,7 +310,9 @@ def pre_eval(self, preds, indices): # https://github.com/open-mmlab/mmsegmentation/issues/1415 # for more ditails label_map=dict(), - reduce_zero_label=self.reduce_zero_label)) + reduce_zero_label=self.reduce_zero_label, + ) + ) return pre_eval_results @@ -337,11 +340,11 @@ def get_classes_and_palette(self, classes=None, palette=None): elif isinstance(classes, (tuple, list)): class_names = classes else: - raise ValueError(f'Unsupported type {type(classes)} of classes.') + raise ValueError(f"Unsupported type {type(classes)} of classes.") if self.CLASSES: if not set(class_names).issubset(self.CLASSES): - raise ValueError('classes is not a subset of CLASSES.') + raise ValueError("classes is not a subset of CLASSES.") # dictionary, its keys are the old label ids and its values # are the new label ids. @@ -358,12 +361,10 @@ def get_classes_and_palette(self, classes=None, palette=None): return class_names, palette def get_palette_for_custom_classes(self, class_names, palette=None): - if self.label_map is not None: # return subset of palette palette = [] - for old_id, new_id in sorted( - self.label_map.items(), key=lambda x: x[1]): + for old_id, new_id in sorted(self.label_map.items(), key=lambda x: x[1]): if new_id != -1: palette.append(self.PALETTE[old_id]) palette = type(self.PALETTE)(palette) @@ -385,12 +386,7 @@ def get_palette_for_custom_classes(self, class_names, palette=None): return palette - def evaluate(self, - results, - metric='mIoU', - logger=None, - gt_seg_maps=None, - **kwargs): + def evaluate(self, results, metric="mIoU", logger=None, gt_seg_maps=None, **kwargs): """Evaluate the dataset. Args: @@ -409,14 +405,13 @@ def evaluate(self, """ if isinstance(metric, str): metric = [metric] - allowed_metrics = ['mIoU', 'mDice', 'mFscore'] + allowed_metrics = ["mIoU", "mDice", "mFscore"] if not set(metric).issubset(set(allowed_metrics)): - raise KeyError('metric {} is not supported'.format(metric)) + raise KeyError(f"metric {metric} is not supported") eval_results = {} # test a list of files - if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( - results, str): + if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(results, str): if gt_seg_maps is None: gt_seg_maps = self.get_gt_seg_maps() num_classes = len(self.CLASSES) @@ -427,7 +422,8 @@ def evaluate(self, self.ignore_index, metric, label_map=dict(), - reduce_zero_label=self.reduce_zero_label) + reduce_zero_label=self.reduce_zero_label, + ) # test a list of pre_eval_results else: ret_metrics = pre_eval_to_metrics(results, metric) @@ -439,19 +435,23 @@ def evaluate(self, class_names = self.CLASSES # summary table - ret_metrics_summary = OrderedDict({ - ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) - for ret_metric, ret_metric_value in ret_metrics.items() - }) + ret_metrics_summary = OrderedDict( + { + ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + } + ) # each class table - ret_metrics.pop('aAcc', None) - ret_metrics_class = OrderedDict({ - ret_metric: np.round(ret_metric_value * 100, 2) - for ret_metric, ret_metric_value in ret_metrics.items() - }) - ret_metrics_class.update({'Class': class_names}) - ret_metrics_class.move_to_end('Class', last=False) + ret_metrics.pop("aAcc", None) + ret_metrics_class = OrderedDict( + { + ret_metric: np.round(ret_metric_value * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + } + ) + ret_metrics_class.update({"Class": class_names}) + ret_metrics_class.move_to_end("Class", last=False) # for logger class_table_data = PrettyTable() @@ -460,28 +460,30 @@ def evaluate(self, summary_table_data = PrettyTable() for key, val in ret_metrics_summary.items(): - if key == 'aAcc': + if key == "aAcc": summary_table_data.add_column(key, [val]) else: - summary_table_data.add_column('m' + key, [val]) + summary_table_data.add_column("m" + key, [val]) - print_log('per class results:', logger) - print_log('\n' + class_table_data.get_string(), logger=logger) - print_log('Summary:', logger) - print_log('\n' + summary_table_data.get_string(), logger=logger) + print_log("per class results:", logger) + print_log("\n" + class_table_data.get_string(), logger=logger) + print_log("Summary:", logger) + print_log("\n" + summary_table_data.get_string(), logger=logger) # each metric dict for key, value in ret_metrics_summary.items(): - if key == 'aAcc': + if key == "aAcc": eval_results[key] = value / 100.0 else: - eval_results['m' + key] = value / 100.0 + eval_results["m" + key] = value / 100.0 - ret_metrics_class.pop('Class', None) + ret_metrics_class.pop("Class", None) for key, value in ret_metrics_class.items(): - eval_results.update({ - key + '.' + str(name): value[idx] / 100.0 - for idx, name in enumerate(class_names) - }) + eval_results.update( + { + key + "." + str(name): value[idx] / 100.0 + for idx, name in enumerate(class_names) + } + ) return eval_results diff --git a/mmsegmentation/mmseg/datasets/dark_zurich.py b/mmsegmentation/mmseg/datasets/dark_zurich.py index 0b6fda6..59a6e53 100644 --- a/mmsegmentation/mmseg/datasets/dark_zurich.py +++ b/mmsegmentation/mmseg/datasets/dark_zurich.py @@ -9,6 +9,5 @@ class DarkZurichDataset(CityscapesDataset): def __init__(self, **kwargs): super().__init__( - img_suffix='_rgb_anon.png', - seg_map_suffix='_gt_labelTrainIds.png', - **kwargs) + img_suffix="_rgb_anon.png", seg_map_suffix="_gt_labelTrainIds.png", **kwargs + ) diff --git a/mmsegmentation/mmseg/datasets/dataset_wrappers.py b/mmsegmentation/mmseg/datasets/dataset_wrappers.py index 1fb089f..324bddf 100644 --- a/mmsegmentation/mmseg/datasets/dataset_wrappers.py +++ b/mmsegmentation/mmseg/datasets/dataset_wrappers.py @@ -27,17 +27,18 @@ class ConcatDataset(_ConcatDataset): """ def __init__(self, datasets, separate_eval=True): - super(ConcatDataset, self).__init__(datasets) + super().__init__(datasets) self.CLASSES = datasets[0].CLASSES self.PALETTE = datasets[0].PALETTE self.separate_eval = separate_eval - assert separate_eval in [True, False], \ - f'separate_eval can only be True or False,' \ - f'but get {separate_eval}' + assert separate_eval in [True, False], ( + f"separate_eval can only be True or False," f"but get {separate_eval}" + ) if any([isinstance(ds, CityscapesDataset) for ds in datasets]): raise NotImplementedError( - 'Evaluating ConcatDataset containing CityscapesDataset' - 'is not supported!') + "Evaluating ConcatDataset containing CityscapesDataset" + "is not supported!" + ) def evaluate(self, results, logger=None, **kwargs): """Evaluate the results. @@ -54,53 +55,60 @@ def evaluate(self, results, logger=None, **kwargs): or each separate dataset if `self.separate_eval=True`. """ - assert len(results) == self.cumulative_sizes[-1], \ - ('Dataset and results have different sizes: ' - f'{self.cumulative_sizes[-1]} v.s. {len(results)}') + assert len(results) == self.cumulative_sizes[-1], ( + "Dataset and results have different sizes: " + f"{self.cumulative_sizes[-1]} v.s. {len(results)}" + ) # Check whether all the datasets support evaluation for dataset in self.datasets: - assert hasattr(dataset, 'evaluate'), \ - f'{type(dataset)} does not implement evaluate function' + assert hasattr( + dataset, "evaluate" + ), f"{type(dataset)} does not implement evaluate function" if self.separate_eval: dataset_idx = -1 total_eval_results = dict() for size, dataset in zip(self.cumulative_sizes, self.datasets): - start_idx = 0 if dataset_idx == -1 else \ - self.cumulative_sizes[dataset_idx] + start_idx = ( + 0 if dataset_idx == -1 else self.cumulative_sizes[dataset_idx] + ) end_idx = self.cumulative_sizes[dataset_idx + 1] results_per_dataset = results[start_idx:end_idx] print_log( - f'\nEvaluateing {dataset.img_dir} with ' - f'{len(results_per_dataset)} images now', - logger=logger) + f"\nEvaluateing {dataset.img_dir} with " + f"{len(results_per_dataset)} images now", + logger=logger, + ) eval_results_per_dataset = dataset.evaluate( - results_per_dataset, logger=logger, **kwargs) + results_per_dataset, logger=logger, **kwargs + ) dataset_idx += 1 for k, v in eval_results_per_dataset.items(): - total_eval_results.update({f'{dataset_idx}_{k}': v}) + total_eval_results.update({f"{dataset_idx}_{k}": v}) return total_eval_results - if len(set([type(ds) for ds in self.datasets])) != 1: + if len({type(ds) for ds in self.datasets}) != 1: raise NotImplementedError( - 'All the datasets should have same types when ' - 'self.separate_eval=False') + "All the datasets should have same types when " + "self.separate_eval=False" + ) else: - if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( - results, str): + if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(results, str): # merge the generators of gt_seg_maps gt_seg_maps = chain( - *[dataset.get_gt_seg_maps() for dataset in self.datasets]) + *[dataset.get_gt_seg_maps() for dataset in self.datasets] + ) else: # if the results are `pre_eval` results, # we do not need gt_seg_maps to evaluate gt_seg_maps = None eval_results = self.datasets[0].evaluate( - results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs) + results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs + ) return eval_results def get_dataset_idx_and_sample_idx(self, indice): @@ -117,7 +125,8 @@ def get_dataset_idx_and_sample_idx(self, indice): if indice < 0: if -indice > len(self): raise ValueError( - 'absolute value of index should not exceed dataset length') + "absolute value of index should not exceed dataset length" + ) indice = len(self) + indice dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice) if dataset_idx == 0: @@ -131,18 +140,18 @@ def format_results(self, results, imgfile_prefix, indices=None, **kwargs): if indices is None: indices = list(range(len(self))) - assert isinstance(results, list), 'results must be a list.' - assert isinstance(indices, list), 'indices must be a list.' + assert isinstance(results, list), "results must be a list." + assert isinstance(indices, list), "indices must be a list." ret_res = [] for i, indice in enumerate(indices): - dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx( - indice) + dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(indice) res = self.datasets[dataset_idx].format_results( [results[i]], - imgfile_prefix + f'/{dataset_idx}', + imgfile_prefix + f"/{dataset_idx}", indices=[sample_idx], - **kwargs) + **kwargs, + ) ret_res.append(res) return sum(ret_res, []) @@ -155,15 +164,14 @@ def pre_eval(self, preds, indices): preds = [preds] ret_res = [] for i, indice in enumerate(indices): - dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx( - indice) + dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(indice) res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx) ret_res.append(res) return sum(ret_res, []) @DATASETS.register_module() -class RepeatDataset(object): +class RepeatDataset: """A wrapper of repeated dataset. The length of repeated dataset will be `times` larger than the original @@ -214,21 +222,20 @@ class MultiImageMixDataset: def __init__(self, dataset, pipeline, skip_type_keys=None): assert isinstance(pipeline, collections.abc.Sequence) if skip_type_keys is not None: - assert all([ - isinstance(skip_type_key, str) - for skip_type_key in skip_type_keys - ]) + assert all( + [isinstance(skip_type_key, str) for skip_type_key in skip_type_keys] + ) self._skip_type_keys = skip_type_keys self.pipeline = [] self.pipeline_types = [] for transform in pipeline: if isinstance(transform, dict): - self.pipeline_types.append(transform['type']) + self.pipeline_types.append(transform["type"]) transform = build_from_cfg(transform, PIPELINES) self.pipeline.append(transform) else: - raise TypeError('pipeline must be a dict') + raise TypeError("pipeline must be a dict") self.dataset = dataset self.CLASSES = dataset.CLASSES @@ -240,25 +247,24 @@ def __len__(self): def __getitem__(self, idx): results = copy.deepcopy(self.dataset[idx]) - for (transform, transform_type) in zip(self.pipeline, - self.pipeline_types): - if self._skip_type_keys is not None and \ - transform_type in self._skip_type_keys: + for transform, transform_type in zip(self.pipeline, self.pipeline_types): + if ( + self._skip_type_keys is not None + and transform_type in self._skip_type_keys + ): continue - if hasattr(transform, 'get_indexes'): + if hasattr(transform, "get_indexes"): indexes = transform.get_indexes(self.dataset) if not isinstance(indexes, collections.abc.Sequence): indexes = [indexes] - mix_results = [ - copy.deepcopy(self.dataset[index]) for index in indexes - ] - results['mix_results'] = mix_results + mix_results = [copy.deepcopy(self.dataset[index]) for index in indexes] + results["mix_results"] = mix_results results = transform(results) - if 'mix_results' in results: - results.pop('mix_results') + if "mix_results" in results: + results.pop("mix_results") return results @@ -271,7 +277,5 @@ def update_skip_type_keys(self, skip_type_keys): skip_type_keys (list[str], optional): Sequence of type string to be skip pipeline. """ - assert all([ - isinstance(skip_type_key, str) for skip_type_key in skip_type_keys - ]) + assert all([isinstance(skip_type_key, str) for skip_type_key in skip_type_keys]) self._skip_type_keys = skip_type_keys diff --git a/mmsegmentation/mmseg/datasets/drive.py b/mmsegmentation/mmseg/datasets/drive.py index d44fb0d..1f984eb 100644 --- a/mmsegmentation/mmseg/datasets/drive.py +++ b/mmsegmentation/mmseg/datasets/drive.py @@ -14,14 +14,15 @@ class DRIVEDataset(CustomDataset): '_manual1.png'. """ - CLASSES = ('background', 'vessel') + CLASSES = ("background", "vessel") PALETTE = [[120, 120, 120], [6, 230, 230]] def __init__(self, **kwargs): - super(DRIVEDataset, self).__init__( - img_suffix='.png', - seg_map_suffix='_manual1.png', + super().__init__( + img_suffix=".png", + seg_map_suffix="_manual1.png", reduce_zero_label=False, - **kwargs) + **kwargs, + ) assert self.file_client.exists(self.img_dir) diff --git a/mmsegmentation/mmseg/datasets/face.py b/mmsegmentation/mmseg/datasets/face.py index cbc2345..cfbf72e 100755 --- a/mmsegmentation/mmseg/datasets/face.py +++ b/mmsegmentation/mmseg/datasets/face.py @@ -13,11 +13,12 @@ class FaceOccludedDataset(CustomDataset): split (str): Split txt file for Pascal VOC. """ - CLASSES = ('background', 'face') + CLASSES = ("background", "face") PALETTE = [[0, 0, 0], [128, 0, 0]] def __init__(self, split, **kwargs): - super(FaceOccludedDataset, self).__init__( - img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) + super().__init__( + img_suffix=".jpg", seg_map_suffix=".png", split=split, **kwargs + ) assert osp.exists(self.img_dir) and self.split is not None diff --git a/mmsegmentation/mmseg/datasets/hrf.py b/mmsegmentation/mmseg/datasets/hrf.py index cf3ea8d..3400cf6 100644 --- a/mmsegmentation/mmseg/datasets/hrf.py +++ b/mmsegmentation/mmseg/datasets/hrf.py @@ -14,14 +14,12 @@ class HRFDataset(CustomDataset): '.png'. """ - CLASSES = ('background', 'vessel') + CLASSES = ("background", "vessel") PALETTE = [[120, 120, 120], [6, 230, 230]] def __init__(self, **kwargs): - super(HRFDataset, self).__init__( - img_suffix='.png', - seg_map_suffix='.png', - reduce_zero_label=False, - **kwargs) + super().__init__( + img_suffix=".png", seg_map_suffix=".png", reduce_zero_label=False, **kwargs + ) assert self.file_client.exists(self.img_dir) diff --git a/mmsegmentation/mmseg/datasets/isaid.py b/mmsegmentation/mmseg/datasets/isaid.py index db24f93..1a9dd98 100644 --- a/mmsegmentation/mmseg/datasets/isaid.py +++ b/mmsegmentation/mmseg/datasets/isaid.py @@ -10,38 +10,60 @@ @DATASETS.register_module() class iSAIDDataset(CustomDataset): - """ iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images + """iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images In segmentation map annotation for iSAID dataset, which is included in 16 categories. ``reduce_zero_label`` is fixed to False. The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to '_manual1.png'. """ - CLASSES = ('background', 'ship', 'store_tank', 'baseball_diamond', - 'tennis_court', 'basketball_court', 'Ground_Track_Field', - 'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter', - 'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane', - 'Harbor') + CLASSES = ( + "background", + "ship", + "store_tank", + "baseball_diamond", + "tennis_court", + "basketball_court", + "Ground_Track_Field", + "Bridge", + "Large_Vehicle", + "Small_Vehicle", + "Helicopter", + "Swimming_pool", + "Roundabout", + "Soccer_ball_field", + "plane", + "Harbor", + ) - PALETTE = [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127], - [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127], - [0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127], - [0, 127, 191], [0, 127, 255], [0, 100, 155]] + PALETTE = [ + [0, 0, 0], + [0, 0, 63], + [0, 63, 63], + [0, 63, 0], + [0, 63, 127], + [0, 63, 191], + [0, 63, 255], + [0, 127, 63], + [0, 127, 127], + [0, 0, 127], + [0, 0, 191], + [0, 0, 255], + [0, 191, 127], + [0, 127, 191], + [0, 127, 255], + [0, 100, 155], + ] def __init__(self, **kwargs): - super(iSAIDDataset, self).__init__( - img_suffix='.png', - seg_map_suffix='.png', - ignore_index=255, - **kwargs) + super().__init__( + img_suffix=".png", seg_map_suffix=".png", ignore_index=255, **kwargs + ) assert self.file_client.exists(self.img_dir) - def load_annotations(self, - img_dir, - img_suffix, - ann_dir, - seg_map_suffix=None, - split=None): + def load_annotations( + self, img_dir, img_suffix, ann_dir, seg_map_suffix=None, split=None + ): """Load annotation from directory. Args: @@ -64,9 +86,9 @@ def load_annotations(self, name = line.strip() img_info = dict(filename=name + img_suffix) if ann_dir is not None: - ann_name = name + '_instance_color_RGB' + ann_name = name + "_instance_color_RGB" seg_map = ann_name + seg_map_suffix - img_info['ann'] = dict(seg_map=seg_map) + img_info["ann"] = dict(seg_map=seg_map) img_infos.append(img_info) else: for img in mmcv.scandir(img_dir, img_suffix, recursive=True): @@ -74,9 +96,10 @@ def load_annotations(self, if ann_dir is not None: seg_img = img seg_map = seg_img.replace( - img_suffix, '_instance_color_RGB' + seg_map_suffix) - img_info['ann'] = dict(seg_map=seg_map) + img_suffix, "_instance_color_RGB" + seg_map_suffix + ) + img_info["ann"] = dict(seg_map=seg_map) img_infos.append(img_info) - print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) + print_log(f"Loaded {len(img_infos)} images", logger=get_root_logger()) return img_infos diff --git a/mmsegmentation/mmseg/datasets/isprs.py b/mmsegmentation/mmseg/datasets/isprs.py index 5f23e1a..31bed19 100644 --- a/mmsegmentation/mmseg/datasets/isprs.py +++ b/mmsegmentation/mmseg/datasets/isprs.py @@ -11,15 +11,26 @@ class ISPRSDataset(CustomDataset): ``reduce_zero_label`` should be set to True. The ``img_suffix`` and ``seg_map_suffix`` are both fixed to '.png'. """ - CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree', - 'car', 'clutter') - PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], - [255, 255, 0], [255, 0, 0]] + CLASSES = ( + "impervious_surface", + "building", + "low_vegetation", + "tree", + "car", + "clutter", + ) + + PALETTE = [ + [255, 255, 255], + [0, 0, 255], + [0, 255, 255], + [0, 255, 0], + [255, 255, 0], + [255, 0, 0], + ] def __init__(self, **kwargs): - super(ISPRSDataset, self).__init__( - img_suffix='.png', - seg_map_suffix='.png', - reduce_zero_label=True, - **kwargs) + super().__init__( + img_suffix=".png", seg_map_suffix=".png", reduce_zero_label=True, **kwargs + ) diff --git a/mmsegmentation/mmseg/datasets/loveda.py b/mmsegmentation/mmseg/datasets/loveda.py index 90d654f..6cbeb24 100644 --- a/mmsegmentation/mmseg/datasets/loveda.py +++ b/mmsegmentation/mmseg/datasets/loveda.py @@ -17,18 +17,31 @@ class LoveDADataset(CustomDataset): ``reduce_zero_label`` should be set to True. The ``img_suffix`` and ``seg_map_suffix`` are both fixed to '.png'. """ - CLASSES = ('background', 'building', 'road', 'water', 'barren', 'forest', - 'agricultural') - PALETTE = [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], - [159, 129, 183], [0, 255, 0], [255, 195, 128]] + CLASSES = ( + "background", + "building", + "road", + "water", + "barren", + "forest", + "agricultural", + ) + + PALETTE = [ + [255, 255, 255], + [255, 0, 0], + [255, 255, 0], + [0, 0, 255], + [159, 129, 183], + [0, 255, 0], + [255, 195, 128], + ] def __init__(self, **kwargs): - super(LoveDADataset, self).__init__( - img_suffix='.png', - seg_map_suffix='.png', - reduce_zero_label=True, - **kwargs) + super().__init__( + img_suffix=".png", seg_map_suffix=".png", reduce_zero_label=True, **kwargs + ) def results2img(self, results, imgfile_prefix, indices=None): """Write the segmentation results to images. @@ -51,11 +64,10 @@ def results2img(self, results, imgfile_prefix, indices=None): mmcv.mkdir_or_exist(imgfile_prefix) result_files = [] for result, idx in zip(results, indices): - - filename = self.img_infos[idx]['filename'] + filename = self.img_infos[idx]["filename"] basename = osp.splitext(osp.basename(filename))[0] - png_filename = osp.join(imgfile_prefix, f'{basename}.png') + png_filename = osp.join(imgfile_prefix, f"{basename}.png") # The index range of official requirement is from 0 to 6. output = Image.fromarray(result.astype(np.uint8)) @@ -84,8 +96,8 @@ def format_results(self, results, imgfile_prefix, indices=None): if indices is None: indices = list(range(len(self))) - assert isinstance(results, list), 'results must be a list.' - assert isinstance(indices, list), 'indices must be a list.' + assert isinstance(results, list), "results must be a list." + assert isinstance(indices, list), "indices must be a list." result_files = self.results2img(results, imgfile_prefix, indices) diff --git a/mmsegmentation/mmseg/datasets/night_driving.py b/mmsegmentation/mmseg/datasets/night_driving.py index 6620586..a00bf53 100644 --- a/mmsegmentation/mmseg/datasets/night_driving.py +++ b/mmsegmentation/mmseg/datasets/night_driving.py @@ -9,6 +9,7 @@ class NightDrivingDataset(CityscapesDataset): def __init__(self, **kwargs): super().__init__( - img_suffix='_leftImg8bit.png', - seg_map_suffix='_gtCoarse_labelTrainIds.png', - **kwargs) + img_suffix="_leftImg8bit.png", + seg_map_suffix="_gtCoarse_labelTrainIds.png", + **kwargs, + ) diff --git a/mmsegmentation/mmseg/datasets/pascal_context.py b/mmsegmentation/mmseg/datasets/pascal_context.py index 20285d8..fcfebcb 100644 --- a/mmsegmentation/mmseg/datasets/pascal_context.py +++ b/mmsegmentation/mmseg/datasets/pascal_context.py @@ -17,40 +17,140 @@ class PascalContextDataset(CustomDataset): split (str): Split txt file for PascalContext. """ - CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', - 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus', - 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', - 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', - 'floor', 'flower', 'food', 'grass', 'ground', 'horse', - 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', - 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep', - 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', - 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', - 'window', 'wood') + CLASSES = ( + "background", + "aeroplane", + "bag", + "bed", + "bedclothes", + "bench", + "bicycle", + "bird", + "boat", + "book", + "bottle", + "building", + "bus", + "cabinet", + "car", + "cat", + "ceiling", + "chair", + "cloth", + "computer", + "cow", + "cup", + "curtain", + "dog", + "door", + "fence", + "floor", + "flower", + "food", + "grass", + "ground", + "horse", + "keyboard", + "light", + "motorbike", + "mountain", + "mouse", + "person", + "plate", + "platform", + "pottedplant", + "road", + "rock", + "sheep", + "shelves", + "sidewalk", + "sign", + "sky", + "snow", + "sofa", + "table", + "track", + "train", + "tree", + "truck", + "tvmonitor", + "wall", + "water", + "window", + "wood", + ) - PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], - [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], - [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], - [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], - [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], - [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], - [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], - [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], - [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], - [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], - [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], - [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], - [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], - [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], - [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]] + PALETTE = [ + [120, 120, 120], + [180, 120, 120], + [6, 230, 230], + [80, 50, 50], + [4, 200, 3], + [120, 120, 80], + [140, 140, 140], + [204, 5, 255], + [230, 230, 230], + [4, 250, 7], + [224, 5, 255], + [235, 255, 7], + [150, 5, 61], + [120, 120, 70], + [8, 255, 51], + [255, 6, 82], + [143, 255, 140], + [204, 255, 4], + [255, 51, 7], + [204, 70, 3], + [0, 102, 200], + [61, 230, 250], + [255, 6, 51], + [11, 102, 255], + [255, 7, 71], + [255, 9, 224], + [9, 7, 230], + [220, 220, 220], + [255, 9, 92], + [112, 9, 255], + [8, 255, 214], + [7, 255, 224], + [255, 184, 6], + [10, 255, 71], + [255, 41, 10], + [7, 255, 255], + [224, 255, 8], + [102, 8, 255], + [255, 61, 6], + [255, 194, 7], + [255, 122, 8], + [0, 255, 20], + [255, 8, 41], + [255, 5, 153], + [6, 51, 255], + [235, 12, 255], + [160, 150, 20], + [0, 163, 255], + [140, 140, 140], + [250, 10, 15], + [20, 255, 0], + [31, 255, 0], + [255, 31, 0], + [255, 224, 0], + [153, 255, 0], + [0, 0, 255], + [255, 71, 0], + [0, 235, 255], + [0, 173, 255], + [31, 0, 255], + ] def __init__(self, split, **kwargs): - super(PascalContextDataset, self).__init__( - img_suffix='.jpg', - seg_map_suffix='.png', + super().__init__( + img_suffix=".jpg", + seg_map_suffix=".png", split=split, reduce_zero_label=False, - **kwargs) + **kwargs, + ) assert self.file_client.exists(self.img_dir) and self.split is not None @@ -67,37 +167,136 @@ class PascalContextDataset59(CustomDataset): split (str): Split txt file for PascalContext. """ - CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', - 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', - 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow', - 'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower', - 'food', 'grass', 'ground', 'horse', 'keyboard', 'light', - 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform', - 'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk', - 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train', - 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood') + CLASSES = ( + "aeroplane", + "bag", + "bed", + "bedclothes", + "bench", + "bicycle", + "bird", + "boat", + "book", + "bottle", + "building", + "bus", + "cabinet", + "car", + "cat", + "ceiling", + "chair", + "cloth", + "computer", + "cow", + "cup", + "curtain", + "dog", + "door", + "fence", + "floor", + "flower", + "food", + "grass", + "ground", + "horse", + "keyboard", + "light", + "motorbike", + "mountain", + "mouse", + "person", + "plate", + "platform", + "pottedplant", + "road", + "rock", + "sheep", + "shelves", + "sidewalk", + "sign", + "sky", + "snow", + "sofa", + "table", + "track", + "train", + "tree", + "truck", + "tvmonitor", + "wall", + "water", + "window", + "wood", + ) - PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], - [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], - [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], - [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], - [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], - [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], - [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], - [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], - [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], - [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], - [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], - [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], - [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], - [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], - [0, 235, 255], [0, 173, 255], [31, 0, 255]] + PALETTE = [ + [180, 120, 120], + [6, 230, 230], + [80, 50, 50], + [4, 200, 3], + [120, 120, 80], + [140, 140, 140], + [204, 5, 255], + [230, 230, 230], + [4, 250, 7], + [224, 5, 255], + [235, 255, 7], + [150, 5, 61], + [120, 120, 70], + [8, 255, 51], + [255, 6, 82], + [143, 255, 140], + [204, 255, 4], + [255, 51, 7], + [204, 70, 3], + [0, 102, 200], + [61, 230, 250], + [255, 6, 51], + [11, 102, 255], + [255, 7, 71], + [255, 9, 224], + [9, 7, 230], + [220, 220, 220], + [255, 9, 92], + [112, 9, 255], + [8, 255, 214], + [7, 255, 224], + [255, 184, 6], + [10, 255, 71], + [255, 41, 10], + [7, 255, 255], + [224, 255, 8], + [102, 8, 255], + [255, 61, 6], + [255, 194, 7], + [255, 122, 8], + [0, 255, 20], + [255, 8, 41], + [255, 5, 153], + [6, 51, 255], + [235, 12, 255], + [160, 150, 20], + [0, 163, 255], + [140, 140, 140], + [250, 10, 15], + [20, 255, 0], + [31, 255, 0], + [255, 31, 0], + [255, 224, 0], + [153, 255, 0], + [0, 0, 255], + [255, 71, 0], + [0, 235, 255], + [0, 173, 255], + [31, 0, 255], + ] def __init__(self, split, **kwargs): - super(PascalContextDataset59, self).__init__( - img_suffix='.jpg', - seg_map_suffix='.png', + super().__init__( + img_suffix=".jpg", + seg_map_suffix=".png", split=split, reduce_zero_label=True, - **kwargs) + **kwargs, + ) assert self.file_client.exists(self.img_dir) and self.split is not None diff --git a/mmsegmentation/mmseg/datasets/pipelines/__init__.py b/mmsegmentation/mmseg/datasets/pipelines/__init__.py index 8256a6f..e8747ef 100644 --- a/mmsegmentation/mmseg/datasets/pipelines/__init__.py +++ b/mmsegmentation/mmseg/datasets/pipelines/__init__.py @@ -1,19 +1,55 @@ # Copyright (c) OpenMMLab. All rights reserved. from .compose import Compose -from .formatting import (Collect, ImageToTensor, ToDataContainer, ToTensor, - Transpose, to_tensor) +from .formatting import ( + Collect, + ImageToTensor, + ToDataContainer, + ToTensor, + Transpose, + to_tensor, +) from .loading import LoadAnnotations, LoadImageFromFile from .test_time_aug import MultiScaleFlipAug -from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, - PhotoMetricDistortion, RandomCrop, RandomCutOut, - RandomFlip, RandomMosaic, RandomRotate, Rerange, - Resize, RGB2Gray, SegRescale) +from .transforms import ( + CLAHE, + AdjustGamma, + Normalize, + Pad, + PhotoMetricDistortion, + RandomCrop, + RandomCutOut, + RandomFlip, + RandomMosaic, + RandomRotate, + Rerange, + Resize, + RGB2Gray, + SegRescale, +) __all__ = [ - 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', - 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', - 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', - 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', - 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut', - 'RandomMosaic' + "Compose", + "to_tensor", + "ToTensor", + "ImageToTensor", + "ToDataContainer", + "Transpose", + "Collect", + "LoadAnnotations", + "LoadImageFromFile", + "MultiScaleFlipAug", + "Resize", + "RandomFlip", + "Pad", + "RandomCrop", + "Normalize", + "SegRescale", + "PhotoMetricDistortion", + "RandomRotate", + "AdjustGamma", + "CLAHE", + "Rerange", + "RGB2Gray", + "RandomCutOut", + "RandomMosaic", ] diff --git a/mmsegmentation/mmseg/datasets/pipelines/compose.py b/mmsegmentation/mmseg/datasets/pipelines/compose.py index 30280c1..343e519 100644 --- a/mmsegmentation/mmseg/datasets/pipelines/compose.py +++ b/mmsegmentation/mmseg/datasets/pipelines/compose.py @@ -7,7 +7,7 @@ @PIPELINES.register_module() -class Compose(object): +class Compose: """Compose multiple transforms sequentially. Args: @@ -25,7 +25,7 @@ def __init__(self, transforms): elif callable(transform): self.transforms.append(transform) else: - raise TypeError('transform must be callable or a dict') + raise TypeError("transform must be callable or a dict") def __call__(self, data): """Call function to apply transforms sequentially. @@ -44,9 +44,9 @@ def __call__(self, data): return data def __repr__(self): - format_string = self.__class__.__name__ + '(' + format_string = self.__class__.__name__ + "(" for t in self.transforms: - format_string += '\n' - format_string += f' {t}' - format_string += '\n)' + format_string += "\n" + format_string += f" {t}" + format_string += "\n)" return format_string diff --git a/mmsegmentation/mmseg/datasets/pipelines/formating.py b/mmsegmentation/mmseg/datasets/pipelines/formating.py index f6e53bf..17abc83 100644 --- a/mmsegmentation/mmseg/datasets/pipelines/formating.py +++ b/mmsegmentation/mmseg/datasets/pipelines/formating.py @@ -2,8 +2,8 @@ # flake8: noqa import warnings -from .formatting import * - -warnings.warn('DeprecationWarning: mmseg.datasets.pipelines.formating will be ' - 'deprecated in 2021, please replace it with ' - 'mmseg.datasets.pipelines.formatting.') +warnings.warn( + "DeprecationWarning: mmseg.datasets.pipelines.formating will be " + "deprecated in 2021, please replace it with " + "mmseg.datasets.pipelines.formatting." +) diff --git a/mmsegmentation/mmseg/datasets/pipelines/formatting.py b/mmsegmentation/mmseg/datasets/pipelines/formatting.py index 4e057c1..0aa0bdd 100644 --- a/mmsegmentation/mmseg/datasets/pipelines/formatting.py +++ b/mmsegmentation/mmseg/datasets/pipelines/formatting.py @@ -31,11 +31,11 @@ def to_tensor(data): elif isinstance(data, float): return torch.FloatTensor([data]) else: - raise TypeError(f'type {type(data)} cannot be converted to tensor.') + raise TypeError(f"type {type(data)} cannot be converted to tensor.") @PIPELINES.register_module() -class ToTensor(object): +class ToTensor: """Convert some results to :obj:`torch.Tensor` by given keys. Args: @@ -61,11 +61,11 @@ def __call__(self, results): return results def __repr__(self): - return self.__class__.__name__ + f'(keys={self.keys})' + return self.__class__.__name__ + f"(keys={self.keys})" @PIPELINES.register_module() -class ImageToTensor(object): +class ImageToTensor: """Convert image to :obj:`torch.Tensor` by given keys. The dimension order of input image is (H, W, C). The pipeline will convert @@ -99,11 +99,11 @@ def __call__(self, results): return results def __repr__(self): - return self.__class__.__name__ + f'(keys={self.keys})' + return self.__class__.__name__ + f"(keys={self.keys})" @PIPELINES.register_module() -class Transpose(object): +class Transpose: """Transpose some results by given keys. Args: @@ -132,12 +132,11 @@ def __call__(self, results): return results def __repr__(self): - return self.__class__.__name__ + \ - f'(keys={self.keys}, order={self.order})' + return self.__class__.__name__ + f"(keys={self.keys}, order={self.order})" @PIPELINES.register_module() -class ToDataContainer(object): +class ToDataContainer: """Convert results to :obj:`mmcv.DataContainer` by given fields. Args: @@ -148,9 +147,9 @@ class ToDataContainer(object): dict(key='gt_semantic_seg'))``. """ - def __init__(self, - fields=(dict(key='img', - stack=True), dict(key='gt_semantic_seg'))): + def __init__( + self, fields=(dict(key="img", stack=True), dict(key="gt_semantic_seg")) + ): self.fields = fields def __call__(self, results): @@ -167,16 +166,16 @@ def __call__(self, results): for field in self.fields: field = field.copy() - key = field.pop('key') + key = field.pop("key") results[key] = DC(results[key], **field) return results def __repr__(self): - return self.__class__.__name__ + f'(fields={self.fields})' + return self.__class__.__name__ + f"(fields={self.fields})" @PIPELINES.register_module() -class DefaultFormatBundle(object): +class DefaultFormatBundle: """Default formatting bundle. It simplifies the pipeline of formatting common fields, including "img" @@ -198,18 +197,18 @@ def __call__(self, results): default bundle. """ - if 'img' in results: - img = results['img'] + if "img" in results: + img = results["img"] if len(img.shape) < 3: img = np.expand_dims(img, -1) img = np.ascontiguousarray(img.transpose(2, 0, 1)) - results['img'] = DC(to_tensor(img), stack=True) - if 'gt_semantic_seg' in results: + results["img"] = DC(to_tensor(img), stack=True) + if "gt_semantic_seg" in results: # convert to long - results['gt_semantic_seg'] = DC( - to_tensor(results['gt_semantic_seg'][None, - ...].astype(np.int64)), - stack=True) + results["gt_semantic_seg"] = DC( + to_tensor(results["gt_semantic_seg"][None, ...].astype(np.int64)), + stack=True, + ) return results def __repr__(self): @@ -217,7 +216,7 @@ def __repr__(self): @PIPELINES.register_module() -class Collect(object): +class Collect: """Collect data from the loader relevant to the specific task. This is usually the last stage of the data loader pipeline. Typically keys @@ -254,11 +253,21 @@ class Collect(object): ``flip_direction``, ``img_norm_cfg``) """ - def __init__(self, - keys, - meta_keys=('filename', 'ori_filename', 'ori_shape', - 'img_shape', 'pad_shape', 'scale_factor', 'flip', - 'flip_direction', 'img_norm_cfg')): + def __init__( + self, + keys, + meta_keys=( + "filename", + "ori_filename", + "ori_shape", + "img_shape", + "pad_shape", + "scale_factor", + "flip", + "flip_direction", + "img_norm_cfg", + ), + ): self.keys = keys self.meta_keys = meta_keys @@ -279,11 +288,12 @@ def __call__(self, results): img_meta = {} for key in self.meta_keys: img_meta[key] = results[key] - data['img_metas'] = DC(img_meta, cpu_only=True) + data["img_metas"] = DC(img_meta, cpu_only=True) for key in self.keys: data[key] = results[key] return data def __repr__(self): - return self.__class__.__name__ + \ - f'(keys={self.keys}, meta_keys={self.meta_keys})' + return ( + self.__class__.__name__ + f"(keys={self.keys}, meta_keys={self.meta_keys})" + ) diff --git a/mmsegmentation/mmseg/datasets/pipelines/loading.py b/mmsegmentation/mmseg/datasets/pipelines/loading.py index 572e434..3a8ca44 100644 --- a/mmsegmentation/mmseg/datasets/pipelines/loading.py +++ b/mmsegmentation/mmseg/datasets/pipelines/loading.py @@ -8,7 +8,7 @@ @PIPELINES.register_module() -class LoadImageFromFile(object): +class LoadImageFromFile: """Load an image from file. Required keys are "img_prefix" and "img_info" (a dict that must contain the @@ -29,11 +29,13 @@ class LoadImageFromFile(object): 'cv2' """ - def __init__(self, - to_float32=False, - color_type='color', - file_client_args=dict(backend='disk'), - imdecode_backend='cv2'): + def __init__( + self, + to_float32=False, + color_type="color", + file_client_args=dict(backend="disk"), + imdecode_backend="cv2", + ): self.to_float32 = to_float32 self.color_type = color_type self.file_client_args = file_client_args.copy() @@ -53,42 +55,43 @@ def __call__(self, results): if self.file_client is None: self.file_client = mmcv.FileClient(**self.file_client_args) - if results.get('img_prefix') is not None: - filename = osp.join(results['img_prefix'], - results['img_info']['filename']) + if results.get("img_prefix") is not None: + filename = osp.join(results["img_prefix"], results["img_info"]["filename"]) else: - filename = results['img_info']['filename'] + filename = results["img_info"]["filename"] img_bytes = self.file_client.get(filename) img = mmcv.imfrombytes( - img_bytes, flag=self.color_type, backend=self.imdecode_backend) + img_bytes, flag=self.color_type, backend=self.imdecode_backend + ) if self.to_float32: img = img.astype(np.float32) - results['filename'] = filename - results['ori_filename'] = results['img_info']['filename'] - results['img'] = img - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + results["filename"] = filename + results["ori_filename"] = results["img_info"]["filename"] + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 num_channels = 1 if len(img.shape) < 3 else img.shape[2] - results['img_norm_cfg'] = dict( + results["img_norm_cfg"] = dict( mean=np.zeros(num_channels, dtype=np.float32), std=np.ones(num_channels, dtype=np.float32), - to_rgb=False) + to_rgb=False, + ) return results def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(to_float32={self.to_float32},' + repr_str += f"(to_float32={self.to_float32}," repr_str += f"color_type='{self.color_type}'," repr_str += f"imdecode_backend='{self.imdecode_backend}')" return repr_str @PIPELINES.register_module() -class LoadAnnotations(object): +class LoadAnnotations: """Load annotations for semantic segmentation. Args: @@ -102,10 +105,12 @@ class LoadAnnotations(object): 'pillow' """ - def __init__(self, - reduce_zero_label=False, - file_client_args=dict(backend='disk'), - imdecode_backend='pillow'): + def __init__( + self, + reduce_zero_label=False, + file_client_args=dict(backend="disk"), + imdecode_backend="pillow", + ): self.reduce_zero_label = reduce_zero_label self.file_client_args = file_client_args.copy() self.file_client = None @@ -124,22 +129,23 @@ def __call__(self, results): if self.file_client is None: self.file_client = mmcv.FileClient(**self.file_client_args) - if results.get('seg_prefix', None) is not None: - filename = osp.join(results['seg_prefix'], - results['ann_info']['seg_map']) + if results.get("seg_prefix", None) is not None: + filename = osp.join(results["seg_prefix"], results["ann_info"]["seg_map"]) else: - filename = results['ann_info']['seg_map'] + filename = results["ann_info"]["seg_map"] img_bytes = self.file_client.get(filename) - gt_semantic_seg = mmcv.imfrombytes( - img_bytes, flag='unchanged', - backend=self.imdecode_backend).squeeze().astype(np.uint8) + gt_semantic_seg = ( + mmcv.imfrombytes(img_bytes, flag="unchanged", backend=self.imdecode_backend) + .squeeze() + .astype(np.uint8) + ) # modify if custom classes - if results.get('label_map', None) is not None: + if results.get("label_map", None) is not None: # Add deep copy to solve bug of repeatedly # replace `gt_semantic_seg`, which is reported in # https://github.com/open-mmlab/mmsegmentation/pull/1445/ gt_semantic_seg_copy = gt_semantic_seg.copy() - for old_id, new_id in results['label_map'].items(): + for old_id, new_id in results["label_map"].items(): gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id # reduce zero_label if self.reduce_zero_label: @@ -147,12 +153,12 @@ def __call__(self, results): gt_semantic_seg[gt_semantic_seg == 0] = 255 gt_semantic_seg = gt_semantic_seg - 1 gt_semantic_seg[gt_semantic_seg == 254] = 255 - results['gt_semantic_seg'] = gt_semantic_seg - results['seg_fields'].append('gt_semantic_seg') + results["gt_semantic_seg"] = gt_semantic_seg + results["seg_fields"].append("gt_semantic_seg") return results def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(reduce_zero_label={self.reduce_zero_label},' + repr_str += f"(reduce_zero_label={self.reduce_zero_label}," repr_str += f"imdecode_backend='{self.imdecode_backend}')" return repr_str diff --git a/mmsegmentation/mmseg/datasets/pipelines/test_time_aug.py b/mmsegmentation/mmseg/datasets/pipelines/test_time_aug.py index 4964087..b4ef940 100644 --- a/mmsegmentation/mmseg/datasets/pipelines/test_time_aug.py +++ b/mmsegmentation/mmseg/datasets/pipelines/test_time_aug.py @@ -8,7 +8,7 @@ @PIPELINES.register_module() -class MultiScaleFlipAug(object): +class MultiScaleFlipAug: """Test-time augmentation with multiple scales and flipping. An example configuration is as followed: @@ -51,53 +51,49 @@ class MultiScaleFlipAug(object): It has no effect when flip == False. Default: "horizontal". """ - def __init__(self, - transforms, - img_scale, - img_ratios=None, - flip=False, - flip_direction='horizontal'): + def __init__( + self, + transforms, + img_scale, + img_ratios=None, + flip=False, + flip_direction="horizontal", + ): if flip: - trans_index = { - key['type']: index - for index, key in enumerate(transforms) - } - if 'RandomFlip' in trans_index and 'Pad' in trans_index: - assert trans_index['RandomFlip'] < trans_index['Pad'], \ - 'Pad must be executed after RandomFlip when flip is True' + trans_index = {key["type"]: index for index, key in enumerate(transforms)} + if "RandomFlip" in trans_index and "Pad" in trans_index: + assert ( + trans_index["RandomFlip"] < trans_index["Pad"] + ), "Pad must be executed after RandomFlip when flip is True" self.transforms = Compose(transforms) if img_ratios is not None: - img_ratios = img_ratios if isinstance(img_ratios, - list) else [img_ratios] + img_ratios = img_ratios if isinstance(img_ratios, list) else [img_ratios] assert mmcv.is_list_of(img_ratios, float) if img_scale is None: # mode 1: given img_scale=None and a range of image ratio self.img_scale = None assert mmcv.is_list_of(img_ratios, float) - elif isinstance(img_scale, tuple) and mmcv.is_list_of( - img_ratios, float): + elif isinstance(img_scale, tuple) and mmcv.is_list_of(img_ratios, float): assert len(img_scale) == 2 # mode 2: given a scale and a range of image ratio - self.img_scale = [(int(img_scale[0] * ratio), - int(img_scale[1] * ratio)) - for ratio in img_ratios] + self.img_scale = [ + (int(img_scale[0] * ratio), int(img_scale[1] * ratio)) + for ratio in img_ratios + ] else: # mode 3: given multiple scales - self.img_scale = img_scale if isinstance(img_scale, - list) else [img_scale] + self.img_scale = img_scale if isinstance(img_scale, list) else [img_scale] assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None self.flip = flip self.img_ratios = img_ratios - self.flip_direction = flip_direction if isinstance( - flip_direction, list) else [flip_direction] + self.flip_direction = ( + flip_direction if isinstance(flip_direction, list) else [flip_direction] + ) assert mmcv.is_list_of(self.flip_direction, str) - if not self.flip and self.flip_direction != ['horizontal']: - warnings.warn( - 'flip_direction has no effect when flip is set to False') - if (self.flip - and not any([t['type'] == 'RandomFlip' for t in transforms])): - warnings.warn( - 'flip has no effect when RandomFlip is not in transforms') + if not self.flip and self.flip_direction != ["horizontal"]: + warnings.warn("flip_direction has no effect when flip is set to False") + if self.flip and not any([t["type"] == "RandomFlip" for t in transforms]): + warnings.warn("flip has no effect when RandomFlip is not in transforms") def __call__(self, results): """Call function to apply test time augment transforms on results. @@ -112,9 +108,8 @@ def __call__(self, results): aug_data = [] if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float): - h, w = results['img'].shape[:2] - img_scale = [(int(w * ratio), int(h * ratio)) - for ratio in self.img_ratios] + h, w = results["img"].shape[:2] + img_scale = [(int(w * ratio), int(h * ratio)) for ratio in self.img_ratios] else: img_scale = self.img_scale flip_aug = [False, True] if self.flip else [False] @@ -122,9 +117,9 @@ def __call__(self, results): for flip in flip_aug: for direction in self.flip_direction: _results = results.copy() - _results['scale'] = scale - _results['flip'] = flip - _results['flip_direction'] = direction + _results["scale"] = scale + _results["flip"] = flip + _results["flip_direction"] = direction data = self.transforms(_results) aug_data.append(data) # list of dict to dict of list @@ -136,7 +131,7 @@ def __call__(self, results): def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(transforms={self.transforms}, ' - repr_str += f'img_scale={self.img_scale}, flip={self.flip})' - repr_str += f'flip_direction={self.flip_direction}' + repr_str += f"(transforms={self.transforms}, " + repr_str += f"img_scale={self.img_scale}, flip={self.flip})" + repr_str += f"flip_direction={self.flip_direction}" return repr_str diff --git a/mmsegmentation/mmseg/datasets/pipelines/transforms.py b/mmsegmentation/mmseg/datasets/pipelines/transforms.py index 5673b64..25ad26c 100644 --- a/mmsegmentation/mmseg/datasets/pipelines/transforms.py +++ b/mmsegmentation/mmseg/datasets/pipelines/transforms.py @@ -10,7 +10,7 @@ @PIPELINES.register_module() -class ResizeToMultiple(object): +class ResizeToMultiple: """Resize images & seg to multiple of divisor. Args: @@ -35,39 +35,39 @@ def __call__(self, results): dict: Resized results, 'img_shape', 'pad_shape' keys are updated. """ # Align image to multiple of size divisor. - img = results['img'] + img = results["img"] img = mmcv.imresize_to_multiple( img, self.size_divisor, scale_factor=1, - interpolation=self.interpolation - if self.interpolation else 'bilinear') + interpolation=self.interpolation if self.interpolation else "bilinear", + ) - results['img'] = img - results['img_shape'] = img.shape - results['pad_shape'] = img.shape + results["img"] = img + results["img_shape"] = img.shape + results["pad_shape"] = img.shape # Align segmentation map to multiple of size divisor. - for key in results.get('seg_fields', []): + for key in results.get("seg_fields", []): gt_seg = results[key] gt_seg = mmcv.imresize_to_multiple( - gt_seg, - self.size_divisor, - scale_factor=1, - interpolation='nearest') + gt_seg, self.size_divisor, scale_factor=1, interpolation="nearest" + ) results[key] = gt_seg return results def __repr__(self): repr_str = self.__class__.__name__ - repr_str += (f'(size_divisor={self.size_divisor}, ' - f'interpolation={self.interpolation})') + repr_str += ( + f"(size_divisor={self.size_divisor}, " + f"interpolation={self.interpolation})" + ) return repr_str @PIPELINES.register_module() -class Resize(object): +class Resize: """Resize images & seg. This transform resizes the input image to some scale. If the input dict @@ -106,12 +106,14 @@ class Resize(object): bigger than the crop size in ``slide_inference``. Default: None """ - def __init__(self, - img_scale=None, - multiscale_mode='range', - ratio_range=None, - keep_ratio=True, - min_size=None): + def __init__( + self, + img_scale=None, + multiscale_mode="range", + ratio_range=None, + keep_ratio=True, + min_size=None, + ): if img_scale is None: self.img_scale = None else: @@ -127,7 +129,7 @@ def __init__(self, assert self.img_scale is None or len(self.img_scale) == 1 else: # mode 3 and 4: given multiple scales or a range of scales - assert multiscale_mode in ['value', 'range'] + assert multiscale_mode in ["value", "range"] self.multiscale_mode = multiscale_mode self.ratio_range = ratio_range @@ -170,12 +172,8 @@ def random_sample(img_scales): assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 img_scale_long = [max(s) for s in img_scales] img_scale_short = [min(s) for s in img_scales] - long_edge = np.random.randint( - min(img_scale_long), - max(img_scale_long) + 1) - short_edge = np.random.randint( - min(img_scale_short), - max(img_scale_short) + 1) + long_edge = np.random.randint(min(img_scale_long), max(img_scale_long) + 1) + short_edge = np.random.randint(min(img_scale_short), max(img_scale_short) + 1) img_scale = (long_edge, short_edge) return img_scale, None @@ -226,23 +224,23 @@ def _random_scale(self, results): if self.ratio_range is not None: if self.img_scale is None: - h, w = results['img'].shape[:2] - scale, scale_idx = self.random_sample_ratio((w, h), - self.ratio_range) + h, w = results["img"].shape[:2] + scale, scale_idx = self.random_sample_ratio((w, h), self.ratio_range) else: scale, scale_idx = self.random_sample_ratio( - self.img_scale[0], self.ratio_range) + self.img_scale[0], self.ratio_range + ) elif len(self.img_scale) == 1: scale, scale_idx = self.img_scale[0], 0 - elif self.multiscale_mode == 'range': + elif self.multiscale_mode == "range": scale, scale_idx = self.random_sample(self.img_scale) - elif self.multiscale_mode == 'value': + elif self.multiscale_mode == "value": scale, scale_idx = self.random_select(self.img_scale) else: raise NotImplementedError - results['scale'] = scale - results['scale_idx'] = scale_idx + results["scale"] = scale + results["scale_idx"] = scale_idx def _resize_img(self, results): """Resize images with ``results['scale']``.""" @@ -252,46 +250,49 @@ def _resize_img(self, results): # shape of images is (min_size, min_size, 3). 'min_size' # with tuple type will be supported, i.e. the width and # height are not equal. - if min(results['scale']) < self.min_size: + if min(results["scale"]) < self.min_size: new_short = self.min_size else: - new_short = min(results['scale']) + new_short = min(results["scale"]) - h, w = results['img'].shape[:2] + h, w = results["img"].shape[:2] if h > w: new_h, new_w = new_short * h / w, new_short else: new_h, new_w = new_short, new_short * w / h - results['scale'] = (new_h, new_w) + results["scale"] = (new_h, new_w) img, scale_factor = mmcv.imrescale( - results['img'], results['scale'], return_scale=True) + results["img"], results["scale"], return_scale=True + ) # the w_scale and h_scale has minor difference # a real fix should be done in the mmcv.imrescale in the future new_h, new_w = img.shape[:2] - h, w = results['img'].shape[:2] + h, w = results["img"].shape[:2] w_scale = new_w / w h_scale = new_h / h else: img, w_scale, h_scale = mmcv.imresize( - results['img'], results['scale'], return_scale=True) - scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], - dtype=np.float32) - results['img'] = img - results['img_shape'] = img.shape - results['pad_shape'] = img.shape # in case that there is no padding - results['scale_factor'] = scale_factor - results['keep_ratio'] = self.keep_ratio + results["img"], results["scale"], return_scale=True + ) + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], dtype=np.float32) + results["img"] = img + results["img_shape"] = img.shape + results["pad_shape"] = img.shape # in case that there is no padding + results["scale_factor"] = scale_factor + results["keep_ratio"] = self.keep_ratio def _resize_seg(self, results): """Resize semantic segmentation map with ``results['scale']``.""" - for key in results.get('seg_fields', []): + for key in results.get("seg_fields", []): if self.keep_ratio: gt_seg = mmcv.imrescale( - results[key], results['scale'], interpolation='nearest') + results[key], results["scale"], interpolation="nearest" + ) else: gt_seg = mmcv.imresize( - results[key], results['scale'], interpolation='nearest') + results[key], results["scale"], interpolation="nearest" + ) results[key] = gt_seg def __call__(self, results): @@ -306,7 +307,7 @@ def __call__(self, results): 'keep_ratio' keys are added into result dict. """ - if 'scale' not in results: + if "scale" not in results: self._random_scale(results) self._resize_img(results) self._resize_seg(results) @@ -314,15 +315,17 @@ def __call__(self, results): def __repr__(self): repr_str = self.__class__.__name__ - repr_str += (f'(img_scale={self.img_scale}, ' - f'multiscale_mode={self.multiscale_mode}, ' - f'ratio_range={self.ratio_range}, ' - f'keep_ratio={self.keep_ratio})') + repr_str += ( + f"(img_scale={self.img_scale}, " + f"multiscale_mode={self.multiscale_mode}, " + f"ratio_range={self.ratio_range}, " + f"keep_ratio={self.keep_ratio})" + ) return repr_str @PIPELINES.register_module() -class RandomFlip(object): +class RandomFlip: """Flip the image & seg. If the input dict contains the key "flip", then the flag will be used, @@ -335,13 +338,13 @@ class RandomFlip(object): 'horizontal' and 'vertical'. Default: 'horizontal'. """ - @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip') - def __init__(self, prob=None, direction='horizontal'): + @deprecated_api_warning({"flip_ratio": "prob"}, cls_name="RandomFlip") + def __init__(self, prob=None, direction="horizontal"): self.prob = prob self.direction = direction if prob is not None: assert prob >= 0 and prob <= 1 - assert direction in ['horizontal', 'vertical'] + assert direction in ["horizontal", "vertical"] def __call__(self, results): """Call function to flip bounding boxes, masks, semantic segmentation @@ -355,29 +358,31 @@ def __call__(self, results): result dict. """ - if 'flip' not in results: + if "flip" not in results: flip = True if np.random.rand() < self.prob else False - results['flip'] = flip - if 'flip_direction' not in results: - results['flip_direction'] = self.direction - if results['flip']: + results["flip"] = flip + if "flip_direction" not in results: + results["flip_direction"] = self.direction + if results["flip"]: # flip image - results['img'] = mmcv.imflip( - results['img'], direction=results['flip_direction']) + results["img"] = mmcv.imflip( + results["img"], direction=results["flip_direction"] + ) # flip segs - for key in results.get('seg_fields', []): + for key in results.get("seg_fields", []): # use copy() to make numpy stride positive results[key] = mmcv.imflip( - results[key], direction=results['flip_direction']).copy() + results[key], direction=results["flip_direction"] + ).copy() return results def __repr__(self): - return self.__class__.__name__ + f'(prob={self.prob})' + return self.__class__.__name__ + f"(prob={self.prob})" @PIPELINES.register_module() -class Pad(object): +class Pad: """Pad the image & mask. There are two padding modes: (1) pad to a fixed size and (2) pad to the @@ -392,11 +397,7 @@ class Pad(object): Default: 255. """ - def __init__(self, - size=None, - size_divisor=None, - pad_val=0, - seg_pad_val=255): + def __init__(self, size=None, size_divisor=None, pad_val=0, seg_pad_val=255): self.size = size self.size_divisor = size_divisor self.pad_val = pad_val @@ -409,22 +410,23 @@ def _pad_img(self, results): """Pad images according to ``self.size``.""" if self.size is not None: padded_img = mmcv.impad( - results['img'], shape=self.size, pad_val=self.pad_val) + results["img"], shape=self.size, pad_val=self.pad_val + ) elif self.size_divisor is not None: padded_img = mmcv.impad_to_multiple( - results['img'], self.size_divisor, pad_val=self.pad_val) - results['img'] = padded_img - results['pad_shape'] = padded_img.shape - results['pad_fixed_size'] = self.size - results['pad_size_divisor'] = self.size_divisor + results["img"], self.size_divisor, pad_val=self.pad_val + ) + results["img"] = padded_img + results["pad_shape"] = padded_img.shape + results["pad_fixed_size"] = self.size + results["pad_size_divisor"] = self.size_divisor def _pad_seg(self, results): """Pad masks according to ``results['pad_shape']``.""" - for key in results.get('seg_fields', []): + for key in results.get("seg_fields", []): results[key] = mmcv.impad( - results[key], - shape=results['pad_shape'][:2], - pad_val=self.seg_pad_val) + results[key], shape=results["pad_shape"][:2], pad_val=self.seg_pad_val + ) def __call__(self, results): """Call function to pad images, masks, semantic segmentation maps. @@ -442,13 +444,15 @@ def __call__(self, results): def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \ - f'pad_val={self.pad_val})' + repr_str += ( + f"(size={self.size}, size_divisor={self.size_divisor}, " + f"pad_val={self.pad_val})" + ) return repr_str @PIPELINES.register_module() -class Normalize(object): +class Normalize: """Normalize the image. Added key is "img_norm_cfg". @@ -476,21 +480,20 @@ def __call__(self, results): result dict. """ - results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std, - self.to_rgb) - results['img_norm_cfg'] = dict( - mean=self.mean, std=self.std, to_rgb=self.to_rgb) + results["img"] = mmcv.imnormalize( + results["img"], self.mean, self.std, self.to_rgb + ) + results["img_norm_cfg"] = dict(mean=self.mean, std=self.std, to_rgb=self.to_rgb) return results def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \ - f'{self.to_rgb})' + repr_str += f"(mean={self.mean}, std={self.std}, to_rgb=" f"{self.to_rgb})" return repr_str @PIPELINES.register_module() -class Rerange(object): +class Rerange: """Rerange the image pixel value. Args: @@ -516,7 +519,7 @@ def __call__(self, results): dict: Reranged results. """ - img = results['img'] + img = results["img"] img_min_value = np.min(img) img_max_value = np.max(img) @@ -525,18 +528,18 @@ def __call__(self, results): img = (img - img_min_value) / (img_max_value - img_min_value) # rerange to [min_value, max_value] img = img * (self.max_value - self.min_value) + self.min_value - results['img'] = img + results["img"] = img return results def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(min_value={self.min_value}, max_value={self.max_value})' + repr_str += f"(min_value={self.min_value}, max_value={self.max_value})" return repr_str @PIPELINES.register_module() -class CLAHE(object): +class CLAHE: """Use CLAHE method to process the image. See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. @@ -566,22 +569,25 @@ def __call__(self, results): dict: Processed results. """ - for i in range(results['img'].shape[2]): - results['img'][:, :, i] = mmcv.clahe( - np.array(results['img'][:, :, i], dtype=np.uint8), - self.clip_limit, self.tile_grid_size) + for i in range(results["img"].shape[2]): + results["img"][:, :, i] = mmcv.clahe( + np.array(results["img"][:, :, i], dtype=np.uint8), + self.clip_limit, + self.tile_grid_size, + ) return results def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(clip_limit={self.clip_limit}, '\ - f'tile_grid_size={self.tile_grid_size})' + repr_str += ( + f"(clip_limit={self.clip_limit}, " f"tile_grid_size={self.tile_grid_size})" + ) return repr_str @PIPELINES.register_module() -class RandomCrop(object): +class RandomCrop: """Random crop the image & seg. Args: @@ -590,7 +596,7 @@ class RandomCrop(object): occupy. """ - def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255): + def __init__(self, crop_size, cat_max_ratio=1.0, ignore_index=255): assert crop_size[0] > 0 and crop_size[1] > 0 self.crop_size = crop_size self.cat_max_ratio = cat_max_ratio @@ -624,37 +630,36 @@ def __call__(self, results): updated according to crop size. """ - img = results['img'] + img = results["img"] crop_bbox = self.get_crop_bbox(img) - if self.cat_max_ratio < 1.: + if self.cat_max_ratio < 1.0: # Repeat 10 times for _ in range(10): - seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox) + seg_temp = self.crop(results["gt_semantic_seg"], crop_bbox) labels, cnt = np.unique(seg_temp, return_counts=True) cnt = cnt[labels != self.ignore_index] - if len(cnt) > 1 and np.max(cnt) / np.sum( - cnt) < self.cat_max_ratio: + if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < self.cat_max_ratio: break crop_bbox = self.get_crop_bbox(img) # crop the image img = self.crop(img, crop_bbox) img_shape = img.shape - results['img'] = img - results['img_shape'] = img_shape + results["img"] = img + results["img_shape"] = img_shape # crop semantic seg - for key in results.get('seg_fields', []): + for key in results.get("seg_fields", []): results[key] = self.crop(results[key], crop_bbox) return results def __repr__(self): - return self.__class__.__name__ + f'(crop_size={self.crop_size})' + return self.__class__.__name__ + f"(crop_size={self.crop_size})" @PIPELINES.register_module() -class RandomRotate(object): +class RandomRotate: """Rotate the image & seg. Args: @@ -672,22 +677,19 @@ class RandomRotate(object): rotated image. Default: False """ - def __init__(self, - prob, - degree, - pad_val=0, - seg_pad_val=255, - center=None, - auto_bound=False): + def __init__( + self, prob, degree, pad_val=0, seg_pad_val=255, center=None, auto_bound=False + ): self.prob = prob assert prob >= 0 and prob <= 1 if isinstance(degree, (float, int)): - assert degree > 0, f'degree {degree} should be positive' + assert degree > 0, f"degree {degree} should be positive" self.degree = (-degree, degree) else: self.degree = degree - assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ - f'tuple of (min, max)' + assert len(self.degree) == 2, ( + f"degree {self.degree} should be a " f"tuple of (min, max)" + ) self.pal_val = pad_val self.seg_pad_val = seg_pad_val self.center = center @@ -707,37 +709,41 @@ def __call__(self, results): degree = np.random.uniform(min(*self.degree), max(*self.degree)) if rotate: # rotate image - results['img'] = mmcv.imrotate( - results['img'], + results["img"] = mmcv.imrotate( + results["img"], angle=degree, border_value=self.pal_val, center=self.center, - auto_bound=self.auto_bound) + auto_bound=self.auto_bound, + ) # rotate segs - for key in results.get('seg_fields', []): + for key in results.get("seg_fields", []): results[key] = mmcv.imrotate( results[key], angle=degree, border_value=self.seg_pad_val, center=self.center, auto_bound=self.auto_bound, - interpolation='nearest') + interpolation="nearest", + ) return results def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(prob={self.prob}, ' \ - f'degree={self.degree}, ' \ - f'pad_val={self.pal_val}, ' \ - f'seg_pad_val={self.seg_pad_val}, ' \ - f'center={self.center}, ' \ - f'auto_bound={self.auto_bound})' + repr_str += ( + f"(prob={self.prob}, " + f"degree={self.degree}, " + f"pad_val={self.pal_val}, " + f"seg_pad_val={self.seg_pad_val}, " + f"center={self.center}, " + f"auto_bound={self.auto_bound})" + ) return repr_str @PIPELINES.register_module() -class RGB2Gray(object): +class RGB2Gray: """Convert RGB image to grayscale image. This transform calculate the weighted mean of input image channels with @@ -769,7 +775,7 @@ def __call__(self, results): Returns: dict: Result dict with grayscale image. """ - img = results['img'] + img = results["img"] assert len(img.shape) == 3 assert img.shape[2] == len(self.weights) weights = np.array(self.weights).reshape((1, 1, -1)) @@ -779,20 +785,19 @@ def __call__(self, results): else: img = img.repeat(self.out_channels, axis=2) - results['img'] = img - results['img_shape'] = img.shape + results["img"] = img + results["img_shape"] = img.shape return results def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(out_channels={self.out_channels}, ' \ - f'weights={self.weights})' + repr_str += f"(out_channels={self.out_channels}, " f"weights={self.weights})" return repr_str @PIPELINES.register_module() -class AdjustGamma(object): +class AdjustGamma: """Using gamma correction to process the image. Args: @@ -805,8 +810,9 @@ def __init__(self, gamma=1.0): assert gamma > 0 self.gamma = gamma inv_gamma = 1.0 / gamma - self.table = np.array([(i / 255.0)**inv_gamma * 255 - for i in np.arange(256)]).astype('uint8') + self.table = np.array( + [(i / 255.0) ** inv_gamma * 255 for i in np.arange(256)] + ).astype("uint8") def __call__(self, results): """Call function to process the image with gamma correction. @@ -818,17 +824,18 @@ def __call__(self, results): dict: Processed results. """ - results['img'] = mmcv.lut_transform( - np.array(results['img'], dtype=np.uint8), self.table) + results["img"] = mmcv.lut_transform( + np.array(results["img"], dtype=np.uint8), self.table + ) return results def __repr__(self): - return self.__class__.__name__ + f'(gamma={self.gamma})' + return self.__class__.__name__ + f"(gamma={self.gamma})" @PIPELINES.register_module() -class SegRescale(object): +class SegRescale: """Rescale semantic segmentation maps. Args: @@ -847,18 +854,19 @@ def __call__(self, results): Returns: dict: Result dict with semantic segmentation map scaled. """ - for key in results.get('seg_fields', []): + for key in results.get("seg_fields", []): if self.scale_factor != 1: results[key] = mmcv.imrescale( - results[key], self.scale_factor, interpolation='nearest') + results[key], self.scale_factor, interpolation="nearest" + ) return results def __repr__(self): - return self.__class__.__name__ + f'(scale_factor={self.scale_factor})' + return self.__class__.__name__ + f"(scale_factor={self.scale_factor})" @PIPELINES.register_module() -class PhotoMetricDistortion(object): +class PhotoMetricDistortion: """Apply photometric distortion to image sequentially, every transformation is applied with a probability of 0.5. The position of random contrast is in second or second to last. @@ -878,11 +886,13 @@ class PhotoMetricDistortion(object): hue_delta (int): delta of hue. """ - def __init__(self, - brightness_delta=32, - contrast_range=(0.5, 1.5), - saturation_range=(0.5, 1.5), - hue_delta=18): + def __init__( + self, + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18, + ): self.brightness_delta = brightness_delta self.contrast_lower, self.contrast_upper = contrast_range self.saturation_lower, self.saturation_upper = saturation_range @@ -898,17 +908,16 @@ def brightness(self, img): """Brightness distortion.""" if random.randint(2): return self.convert( - img, - beta=random.uniform(-self.brightness_delta, - self.brightness_delta)) + img, beta=random.uniform(-self.brightness_delta, self.brightness_delta) + ) return img def contrast(self, img): """Contrast distortion.""" if random.randint(2): return self.convert( - img, - alpha=random.uniform(self.contrast_lower, self.contrast_upper)) + img, alpha=random.uniform(self.contrast_lower, self.contrast_upper) + ) return img def saturation(self, img): @@ -917,8 +926,8 @@ def saturation(self, img): img = mmcv.bgr2hsv(img) img[:, :, 1] = self.convert( img[:, :, 1], - alpha=random.uniform(self.saturation_lower, - self.saturation_upper)) + alpha=random.uniform(self.saturation_lower, self.saturation_upper), + ) img = mmcv.hsv2bgr(img) return img @@ -926,9 +935,10 @@ def hue(self, img): """Hue distortion.""" if random.randint(2): img = mmcv.bgr2hsv(img) - img[:, :, - 0] = (img[:, :, 0].astype(int) + - random.randint(-self.hue_delta, self.hue_delta)) % 180 + img[:, :, 0] = ( + img[:, :, 0].astype(int) + + random.randint(-self.hue_delta, self.hue_delta) + ) % 180 img = mmcv.hsv2bgr(img) return img @@ -942,7 +952,7 @@ def __call__(self, results): dict: Result dict with images distorted. """ - img = results['img'] + img = results["img"] # random brightness img = self.brightness(img) @@ -962,22 +972,24 @@ def __call__(self, results): if mode == 0: img = self.contrast(img) - results['img'] = img + results["img"] = img return results def __repr__(self): repr_str = self.__class__.__name__ - repr_str += (f'(brightness_delta={self.brightness_delta}, ' - f'contrast_range=({self.contrast_lower}, ' - f'{self.contrast_upper}), ' - f'saturation_range=({self.saturation_lower}, ' - f'{self.saturation_upper}), ' - f'hue_delta={self.hue_delta})') + repr_str += ( + f"(brightness_delta={self.brightness_delta}, " + f"contrast_range=({self.contrast_lower}, " + f"{self.contrast_upper}), " + f"saturation_range=({self.saturation_lower}, " + f"{self.saturation_upper}), " + f"hue_delta={self.hue_delta})" + ) return repr_str @PIPELINES.register_module() -class RandomCutOut(object): +class RandomCutOut: """CutOut operation. Randomly drop some regions of image used in @@ -1002,26 +1014,30 @@ class RandomCutOut(object): If seg_fill_in is None, skip. Default: None. """ - def __init__(self, - prob, - n_holes, - cutout_shape=None, - cutout_ratio=None, - fill_in=(0, 0, 0), - seg_fill_in=None): - + def __init__( + self, + prob, + n_holes, + cutout_shape=None, + cutout_ratio=None, + fill_in=(0, 0, 0), + seg_fill_in=None, + ): assert 0 <= prob and prob <= 1 - assert (cutout_shape is None) ^ (cutout_ratio is None), \ - 'Either cutout_shape or cutout_ratio should be specified.' - assert (isinstance(cutout_shape, (list, tuple)) - or isinstance(cutout_ratio, (list, tuple))) + assert (cutout_shape is None) ^ ( + cutout_ratio is None + ), "Either cutout_shape or cutout_ratio should be specified." + assert isinstance(cutout_shape, (list, tuple)) or isinstance( + cutout_ratio, (list, tuple) + ) if isinstance(n_holes, tuple): assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1] else: n_holes = (n_holes, n_holes) if seg_fill_in is not None: - assert (isinstance(seg_fill_in, int) and 0 <= seg_fill_in - and seg_fill_in <= 255) + assert ( + isinstance(seg_fill_in, int) and 0 <= seg_fill_in and seg_fill_in <= 255 + ) self.prob = prob self.n_holes = n_holes self.fill_in = fill_in @@ -1035,7 +1051,7 @@ def __call__(self, results): """Call function to drop some regions of image.""" cutout = True if np.random.rand() < self.prob else False if cutout: - h, w, c = results['img'].shape + h, w, c = results["img"].shape n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) for _ in range(n_holes): x1 = np.random.randint(0, w) @@ -1049,27 +1065,30 @@ def __call__(self, results): x2 = np.clip(x1 + cutout_w, 0, w) y2 = np.clip(y1 + cutout_h, 0, h) - results['img'][y1:y2, x1:x2, :] = self.fill_in + results["img"][y1:y2, x1:x2, :] = self.fill_in if self.seg_fill_in is not None: - for key in results.get('seg_fields', []): + for key in results.get("seg_fields", []): results[key][y1:y2, x1:x2] = self.seg_fill_in return results def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(prob={self.prob}, ' - repr_str += f'n_holes={self.n_holes}, ' - repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio - else f'cutout_shape={self.candidates}, ') - repr_str += f'fill_in={self.fill_in}, ' - repr_str += f'seg_fill_in={self.seg_fill_in})' + repr_str += f"(prob={self.prob}, " + repr_str += f"n_holes={self.n_holes}, " + repr_str += ( + f"cutout_ratio={self.candidates}, " + if self.with_ratio + else f"cutout_shape={self.candidates}, " + ) + repr_str += f"fill_in={self.fill_in}, " + repr_str += f"seg_fill_in={self.seg_fill_in})" return repr_str @PIPELINES.register_module() -class RandomMosaic(object): +class RandomMosaic: """Mosaic augmentation. Given 4 images, mosaic transform combines them into one output image. The output image is composed of the parts from each sub- image. @@ -1111,12 +1130,14 @@ class RandomMosaic(object): seg_pad_val (int): Pad value of segmentation map. Default: 255. """ - def __init__(self, - prob, - img_scale=(640, 640), - center_ratio_range=(0.5, 1.5), - pad_val=0, - seg_pad_val=255): + def __init__( + self, + prob, + img_scale=(640, 640), + center_ratio_range=(0.5, 1.5), + pad_val=0, + seg_pad_val=255, + ): assert 0 <= prob and prob <= 1 assert isinstance(img_scale, tuple) self.prob = prob @@ -1163,52 +1184,57 @@ def _mosaic_transform_img(self, results): dict: Updated result dict. """ - assert 'mix_results' in results - if len(results['img'].shape) == 3: + assert "mix_results" in results + if len(results["img"].shape) == 3: mosaic_img = np.full( (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3), self.pad_val, - dtype=results['img'].dtype) + dtype=results["img"].dtype, + ) else: mosaic_img = np.full( (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), self.pad_val, - dtype=results['img'].dtype) + dtype=results["img"].dtype, + ) # mosaic center x, y self.center_x = int( - random.uniform(*self.center_ratio_range) * self.img_scale[1]) + random.uniform(*self.center_ratio_range) * self.img_scale[1] + ) self.center_y = int( - random.uniform(*self.center_ratio_range) * self.img_scale[0]) + random.uniform(*self.center_ratio_range) * self.img_scale[0] + ) center_position = (self.center_x, self.center_y) - loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + loc_strs = ("top_left", "top_right", "bottom_left", "bottom_right") for i, loc in enumerate(loc_strs): - if loc == 'top_left': + if loc == "top_left": result_patch = copy.deepcopy(results) else: - result_patch = copy.deepcopy(results['mix_results'][i - 1]) + result_patch = copy.deepcopy(results["mix_results"][i - 1]) - img_i = result_patch['img'] + img_i = result_patch["img"] h_i, w_i = img_i.shape[:2] # keep_ratio resize - scale_ratio_i = min(self.img_scale[0] / h_i, - self.img_scale[1] / w_i) + scale_ratio_i = min(self.img_scale[0] / h_i, self.img_scale[1] / w_i) img_i = mmcv.imresize( - img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i))) + img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)) + ) # compute the combine parameters paste_coord, crop_coord = self._mosaic_combine( - loc, center_position, img_i.shape[:2][::-1]) + loc, center_position, img_i.shape[:2][::-1] + ) x1_p, y1_p, x2_p, y2_p = paste_coord x1_c, y1_c, x2_c, y2_c = crop_coord # crop and paste image mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c] - results['img'] = mosaic_img - results['img_shape'] = mosaic_img.shape - results['ori_shape'] = mosaic_img.shape + results["img"] = mosaic_img + results["img_shape"] = mosaic_img.shape + results["ori_shape"] = mosaic_img.shape return results @@ -1222,42 +1248,43 @@ def _mosaic_transform_seg(self, results): dict: Updated result dict. """ - assert 'mix_results' in results - for key in results.get('seg_fields', []): + assert "mix_results" in results + for key in results.get("seg_fields", []): mosaic_seg = np.full( (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), self.seg_pad_val, - dtype=results[key].dtype) + dtype=results[key].dtype, + ) # mosaic center x, y center_position = (self.center_x, self.center_y) - loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + loc_strs = ("top_left", "top_right", "bottom_left", "bottom_right") for i, loc in enumerate(loc_strs): - if loc == 'top_left': + if loc == "top_left": result_patch = copy.deepcopy(results) else: - result_patch = copy.deepcopy(results['mix_results'][i - 1]) + result_patch = copy.deepcopy(results["mix_results"][i - 1]) gt_seg_i = result_patch[key] h_i, w_i = gt_seg_i.shape[:2] # keep_ratio resize - scale_ratio_i = min(self.img_scale[0] / h_i, - self.img_scale[1] / w_i) + scale_ratio_i = min(self.img_scale[0] / h_i, self.img_scale[1] / w_i) gt_seg_i = mmcv.imresize( gt_seg_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)), - interpolation='nearest') + interpolation="nearest", + ) # compute the combine parameters paste_coord, crop_coord = self._mosaic_combine( - loc, center_position, gt_seg_i.shape[:2][::-1]) + loc, center_position, gt_seg_i.shape[:2][::-1] + ) x1_p, y1_p, x2_p, y2_p = paste_coord x1_c, y1_c, x2_c, y2_c = crop_coord # crop and paste image - mosaic_seg[y1_p:y2_p, x1_p:x2_p] = gt_seg_i[y1_c:y2_c, - x1_c:x2_c] + mosaic_seg[y1_p:y2_p, x1_p:x2_p] = gt_seg_i[y1_c:y2_c, x1_c:x2_c] results[key] = mosaic_seg @@ -1281,55 +1308,75 @@ def _mosaic_combine(self, loc, center_position_xy, img_shape_wh): - crop_coord (tuple): crop corner coordinate in mosaic image. """ - assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right') - if loc == 'top_left': + assert loc in ("top_left", "top_right", "bottom_left", "bottom_right") + if loc == "top_left": # index0 to top left part of image - x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ - max(center_position_xy[1] - img_shape_wh[1], 0), \ - center_position_xy[0], \ - center_position_xy[1] - crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - ( - y2 - y1), img_shape_wh[0], img_shape_wh[1] - - elif loc == 'top_right': + x1, y1, x2, y2 = ( + max(center_position_xy[0] - img_shape_wh[0], 0), + max(center_position_xy[1] - img_shape_wh[1], 0), + center_position_xy[0], + center_position_xy[1], + ) + crop_coord = ( + img_shape_wh[0] - (x2 - x1), + img_shape_wh[1] - (y2 - y1), + img_shape_wh[0], + img_shape_wh[1], + ) + + elif loc == "top_right": # index1 to top right part of image - x1, y1, x2, y2 = center_position_xy[0], \ - max(center_position_xy[1] - img_shape_wh[1], 0), \ - min(center_position_xy[0] + img_shape_wh[0], - self.img_scale[1] * 2), \ - center_position_xy[1] - crop_coord = 0, img_shape_wh[1] - (y2 - y1), min( - img_shape_wh[0], x2 - x1), img_shape_wh[1] - - elif loc == 'bottom_left': + x1, y1, x2, y2 = ( + center_position_xy[0], + max(center_position_xy[1] - img_shape_wh[1], 0), + min(center_position_xy[0] + img_shape_wh[0], self.img_scale[1] * 2), + center_position_xy[1], + ) + crop_coord = ( + 0, + img_shape_wh[1] - (y2 - y1), + min(img_shape_wh[0], x2 - x1), + img_shape_wh[1], + ) + + elif loc == "bottom_left": # index2 to bottom left part of image - x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ - center_position_xy[1], \ - center_position_xy[0], \ - min(self.img_scale[0] * 2, center_position_xy[1] + - img_shape_wh[1]) - crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min( - y2 - y1, img_shape_wh[1]) + x1, y1, x2, y2 = ( + max(center_position_xy[0] - img_shape_wh[0], 0), + center_position_xy[1], + center_position_xy[0], + min(self.img_scale[0] * 2, center_position_xy[1] + img_shape_wh[1]), + ) + crop_coord = ( + img_shape_wh[0] - (x2 - x1), + 0, + img_shape_wh[0], + min(y2 - y1, img_shape_wh[1]), + ) else: # index3 to bottom right part of image - x1, y1, x2, y2 = center_position_xy[0], \ - center_position_xy[1], \ - min(center_position_xy[0] + img_shape_wh[0], - self.img_scale[1] * 2), \ - min(self.img_scale[0] * 2, center_position_xy[1] + - img_shape_wh[1]) - crop_coord = 0, 0, min(img_shape_wh[0], - x2 - x1), min(y2 - y1, img_shape_wh[1]) + x1, y1, x2, y2 = ( + center_position_xy[0], + center_position_xy[1], + min(center_position_xy[0] + img_shape_wh[0], self.img_scale[1] * 2), + min(self.img_scale[0] * 2, center_position_xy[1] + img_shape_wh[1]), + ) + crop_coord = ( + 0, + 0, + min(img_shape_wh[0], x2 - x1), + min(y2 - y1, img_shape_wh[1]), + ) paste_coord = x1, y1, x2, y2 return paste_coord, crop_coord def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(prob={self.prob}, ' - repr_str += f'img_scale={self.img_scale}, ' - repr_str += f'center_ratio_range={self.center_ratio_range}, ' - repr_str += f'pad_val={self.pad_val}, ' - repr_str += f'seg_pad_val={self.pad_val})' + repr_str += f"(prob={self.prob}, " + repr_str += f"img_scale={self.img_scale}, " + repr_str += f"center_ratio_range={self.center_ratio_range}, " + repr_str += f"pad_val={self.pad_val}, " + repr_str += f"seg_pad_val={self.pad_val})" return repr_str diff --git a/mmsegmentation/mmseg/datasets/potsdam.py b/mmsegmentation/mmseg/datasets/potsdam.py index 2986b8f..44dd4e7 100644 --- a/mmsegmentation/mmseg/datasets/potsdam.py +++ b/mmsegmentation/mmseg/datasets/potsdam.py @@ -11,15 +11,26 @@ class PotsdamDataset(CustomDataset): ``reduce_zero_label`` should be set to True. The ``img_suffix`` and ``seg_map_suffix`` are both fixed to '.png'. """ - CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree', - 'car', 'clutter') - PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], - [255, 255, 0], [255, 0, 0]] + CLASSES = ( + "impervious_surface", + "building", + "low_vegetation", + "tree", + "car", + "clutter", + ) + + PALETTE = [ + [255, 255, 255], + [0, 0, 255], + [0, 255, 255], + [0, 255, 0], + [255, 255, 0], + [255, 0, 0], + ] def __init__(self, **kwargs): - super(PotsdamDataset, self).__init__( - img_suffix='.png', - seg_map_suffix='.png', - reduce_zero_label=True, - **kwargs) + super().__init__( + img_suffix=".png", seg_map_suffix=".png", reduce_zero_label=True, **kwargs + ) diff --git a/mmsegmentation/mmseg/datasets/samplers/__init__.py b/mmsegmentation/mmseg/datasets/samplers/__init__.py index da09eff..c2c3d98 100644 --- a/mmsegmentation/mmseg/datasets/samplers/__init__.py +++ b/mmsegmentation/mmseg/datasets/samplers/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .distributed_sampler import DistributedSampler -__all__ = ['DistributedSampler'] +__all__ = ["DistributedSampler"] diff --git a/mmsegmentation/mmseg/datasets/samplers/distributed_sampler.py b/mmsegmentation/mmseg/datasets/samplers/distributed_sampler.py index 4f9bf35..ece8b57 100644 --- a/mmsegmentation/mmseg/datasets/samplers/distributed_sampler.py +++ b/mmsegmentation/mmseg/datasets/samplers/distributed_sampler.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from __future__ import division from typing import Iterator, Optional import torch @@ -27,14 +26,15 @@ class DistributedSampler(_DistributedSampler): processes in the distributed group. Default: ``0``. """ - def __init__(self, - dataset: Dataset, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, - shuffle: bool = True, - seed=0) -> None: - super().__init__( - dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed=0, + ) -> None: + super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) # In distributed sampling, different ranks should sample # non-overlapped data in the dataset. Therefore, this function @@ -47,8 +47,8 @@ def __init__(self, def __iter__(self) -> Iterator: """ - Yields: - Iterator: iterator of indices for rank. + Yields: + Iterator: iterator of indices for rank. """ # deterministically shuffle based on epoch if self.shuffle: @@ -63,11 +63,11 @@ def __iter__(self) -> Iterator: indices = torch.arange(len(self.dataset)).tolist() # add extra samples to make it evenly divisible - indices += indices[:(self.total_size - len(indices))] + indices += indices[: (self.total_size - len(indices))] assert len(indices) == self.total_size # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) diff --git a/mmsegmentation/mmseg/datasets/stare.py b/mmsegmentation/mmseg/datasets/stare.py index a24d1d9..2eec5ba 100644 --- a/mmsegmentation/mmseg/datasets/stare.py +++ b/mmsegmentation/mmseg/datasets/stare.py @@ -15,14 +15,15 @@ class STAREDataset(CustomDataset): '.ah.png'. """ - CLASSES = ('background', 'vessel') + CLASSES = ("background", "vessel") PALETTE = [[120, 120, 120], [6, 230, 230]] def __init__(self, **kwargs): - super(STAREDataset, self).__init__( - img_suffix='.png', - seg_map_suffix='.ah.png', + super().__init__( + img_suffix=".png", + seg_map_suffix=".ah.png", reduce_zero_label=False, - **kwargs) + **kwargs, + ) assert osp.exists(self.img_dir) diff --git a/mmsegmentation/mmseg/datasets/voc.py b/mmsegmentation/mmseg/datasets/voc.py index 3cec9e3..149bf37 100644 --- a/mmsegmentation/mmseg/datasets/voc.py +++ b/mmsegmentation/mmseg/datasets/voc.py @@ -13,18 +13,56 @@ class PascalVOCDataset(CustomDataset): split (str): Split txt file for Pascal VOC. """ - CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', - 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', - 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', - 'train', 'tvmonitor') + CLASSES = ( + "background", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", + ) - PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], - [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], - [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], - [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], - [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] + PALETTE = [ + [0, 0, 0], + [128, 0, 0], + [0, 128, 0], + [128, 128, 0], + [0, 0, 128], + [128, 0, 128], + [0, 128, 128], + [128, 128, 128], + [64, 0, 0], + [192, 0, 0], + [64, 128, 0], + [192, 128, 0], + [64, 0, 128], + [192, 0, 128], + [64, 128, 128], + [192, 128, 128], + [0, 64, 0], + [128, 64, 0], + [0, 192, 0], + [128, 192, 0], + [0, 64, 128], + ] def __init__(self, split, **kwargs): - super(PascalVOCDataset, self).__init__( - img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) + super().__init__( + img_suffix=".jpg", seg_map_suffix=".png", split=split, **kwargs + ) assert osp.exists(self.img_dir) and self.split is not None diff --git a/mmsegmentation/mmseg/models/__init__.py b/mmsegmentation/mmseg/models/__init__.py index 87d8108..05765bc 100644 --- a/mmsegmentation/mmseg/models/__init__.py +++ b/mmsegmentation/mmseg/models/__init__.py @@ -1,13 +1,27 @@ # Copyright (c) OpenMMLab. All rights reserved. from .backbones import * # noqa: F401,F403 -from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, - build_head, build_loss, build_segmentor) +from .builder import ( + BACKBONES, + HEADS, + LOSSES, + SEGMENTORS, + build_backbone, + build_head, + build_loss, + build_segmentor, +) from .decode_heads import * # noqa: F401,F403 from .losses import * # noqa: F401,F403 from .necks import * # noqa: F401,F403 from .segmentors import * # noqa: F401,F403 __all__ = [ - 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', - 'build_head', 'build_loss', 'build_segmentor' + "BACKBONES", + "HEADS", + "LOSSES", + "SEGMENTORS", + "build_backbone", + "build_head", + "build_loss", + "build_segmentor", ] diff --git a/mmsegmentation/mmseg/models/backbones/__init__.py b/mmsegmentation/mmseg/models/backbones/__init__.py index bda42bb..4e686e2 100644 --- a/mmsegmentation/mmseg/models/backbones/__init__.py +++ b/mmsegmentation/mmseg/models/backbones/__init__.py @@ -22,9 +22,29 @@ from .vit import VisionTransformer __all__ = [ - 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', - 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', - 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', - 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', - 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE' + "ResNet", + "ResNetV1c", + "ResNetV1d", + "ResNeXt", + "HRNet", + "FastSCNN", + "ResNeSt", + "MobileNetV2", + "UNet", + "CGNet", + "MobileNetV3", + "VisionTransformer", + "SwinTransformer", + "MixVisionTransformer", + "BiSeNetV1", + "BiSeNetV2", + "ICNet", + "TIMMBackbone", + "ERFNet", + "PCPVT", + "SVT", + "STDCNet", + "STDCContextPathNet", + "BEiT", + "MAE", ] diff --git a/mmsegmentation/mmseg/models/backbones/beit.py b/mmsegmentation/mmseg/models/backbones/beit.py index fade601..949beca 100644 --- a/mmsegmentation/mmseg/models/backbones/beit.py +++ b/mmsegmentation/mmseg/models/backbones/beit.py @@ -7,8 +7,7 @@ import torch.nn.functional as F from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.drop import build_dropout -from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, - trunc_normal_) +from mmcv.cnn.utils.weight_init import constant_init, kaiming_init, trunc_normal_ from mmcv.runner import BaseModule, ModuleList, _load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.utils import _pair as to_2tuple @@ -45,16 +44,18 @@ class BEiTAttention(BaseModule): Default: None. """ - def __init__(self, - embed_dims, - num_heads, - window_size, - bias='qv_bias', - qk_scale=None, - attn_drop_rate=0., - proj_drop_rate=0., - init_cfg=None, - **kwargs): + def __init__( + self, + embed_dims, + num_heads, + window_size, + bias="qv_bias", + qk_scale=None, + attn_drop_rate=0.0, + proj_drop_rate=0.0, + init_cfg=None, + **kwargs, + ): super().__init__(init_cfg=init_cfg) self.embed_dims = embed_dims self.num_heads = num_heads @@ -63,7 +64,7 @@ def __init__(self, self.scale = qk_scale or head_embed_dims**-0.5 qkv_bias = bias - if bias == 'qv_bias': + if bias == "qv_bias": self._init_qv_bias() qkv_bias = False @@ -85,7 +86,8 @@ def _init_rel_pos_embedding(self): self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3 # relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH) self.relative_position_bias_table = nn.Parameter( - torch.zeros(self.num_relative_distance, self.num_heads)) + torch.zeros(self.num_relative_distance, self.num_heads) + ) # get pair-wise relative position index for # each token inside the window @@ -95,8 +97,7 @@ def _init_rel_pos_embedding(self): coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # coords_flatten shape is (2, Wh*Ww) coords_flatten = torch.flatten(coords, 1) - relative_coords = ( - coords_flatten[:, :, None] - coords_flatten[:, None, :]) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # relative_coords shape is (Wh*Ww, Wh*Ww, 2) relative_coords = relative_coords.permute(1, 2, 0).contiguous() # shift to start from 0 @@ -104,15 +105,15 @@ def _init_rel_pos_embedding(self): relative_coords[:, :, 1] += Ww - 1 relative_coords[:, :, 0] *= 2 * Ww - 1 relative_position_index = torch.zeros( - size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype) + size=(Wh * Ww + 1,) * 2, dtype=relative_coords.dtype + ) # relative_position_index shape is (Wh*Ww, Wh*Ww) relative_position_index[1:, 1:] = relative_coords.sum(-1) relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 - self.register_buffer('relative_position_index', - relative_position_index) + self.register_buffer("relative_position_index", relative_position_index) def init_weights(self): trunc_normal_(self.relative_position_bias_table, std=0.02) @@ -124,7 +125,7 @@ def forward(self, x): """ B, N, C = x.shape - if self.bias == 'qv_bias': + if self.bias == "qv_bias": k_bias = torch.zeros_like(self.v_bias, requires_grad=False) qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias)) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) @@ -134,15 +135,16 @@ def forward(self, x): qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale - attn = (q @ k.transpose(-2, -1)) + attn = q @ k.transpose(-2, -1) if self.relative_position_bias_table is not None: Wh = self.window_size[0] Ww = self.window_size[1] relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1)].view( - Wh * Ww + 1, Wh * Ww + 1, -1) + self.relative_position_index.view(-1) + ].view(Wh * Ww + 1, Wh * Ww + 1, -1) relative_position_bias = relative_position_bias.permute( - 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) @@ -178,45 +180,51 @@ class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer): and FFN with learnable scaling. Default: None. """ - def __init__(self, - embed_dims, - num_heads, - feedforward_channels, - attn_drop_rate=0., - drop_path_rate=0., - num_fcs=2, - bias='qv_bias', - act_cfg=dict(type='GELU'), - norm_cfg=dict(type='LN'), - window_size=None, - attn_cfg=dict(), - ffn_cfg=dict(add_identity=False), - init_values=None): + def __init__( + self, + embed_dims, + num_heads, + feedforward_channels, + attn_drop_rate=0.0, + drop_path_rate=0.0, + num_fcs=2, + bias="qv_bias", + act_cfg=dict(type="GELU"), + norm_cfg=dict(type="LN"), + window_size=None, + attn_cfg=dict(), + ffn_cfg=dict(add_identity=False), + init_values=None, + ): attn_cfg.update(dict(window_size=window_size, qk_scale=None)) - super(BEiTTransformerEncoderLayer, self).__init__( + super().__init__( embed_dims=embed_dims, num_heads=num_heads, feedforward_channels=feedforward_channels, attn_drop_rate=attn_drop_rate, - drop_path_rate=0., - drop_rate=0., + drop_path_rate=0.0, + drop_rate=0.0, num_fcs=num_fcs, qkv_bias=bias, act_cfg=act_cfg, norm_cfg=norm_cfg, attn_cfg=attn_cfg, - ffn_cfg=ffn_cfg) + ffn_cfg=ffn_cfg, + ) # NOTE: drop path for stochastic depth, we shall see if # this is better than dropout here - dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate) - self.drop_path = build_dropout( - dropout_layer) if dropout_layer else nn.Identity() + dropout_layer = dict(type="DropPath", drop_prob=drop_path_rate) + self.drop_path = ( + build_dropout(dropout_layer) if dropout_layer else nn.Identity() + ) self.gamma_1 = nn.Parameter( - init_values * torch.ones((embed_dims)), requires_grad=True) + init_values * torch.ones(embed_dims), requires_grad=True + ) self.gamma_2 = nn.Parameter( - init_values * torch.ones((embed_dims)), requires_grad=True) + init_values * torch.ones(embed_dims), requires_grad=True + ) def build_attn(self, attn_cfg): self.attn = BEiTAttention(**attn_cfg) @@ -266,45 +274,51 @@ class BEiT(BaseModule): Default: None. """ - def __init__(self, - img_size=224, - patch_size=16, - in_channels=3, - embed_dims=768, - num_layers=12, - num_heads=12, - mlp_ratio=4, - out_indices=-1, - qv_bias=True, - attn_drop_rate=0., - drop_path_rate=0., - norm_cfg=dict(type='LN'), - act_cfg=dict(type='GELU'), - patch_norm=False, - final_norm=False, - num_fcs=2, - norm_eval=False, - pretrained=None, - init_values=0.1, - init_cfg=None): - super(BEiT, self).__init__(init_cfg=init_cfg) + def __init__( + self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + qv_bias=True, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_cfg=dict(type="LN"), + act_cfg=dict(type="GELU"), + patch_norm=False, + final_norm=False, + num_fcs=2, + norm_eval=False, + pretrained=None, + init_values=0.1, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) if isinstance(img_size, int): img_size = to_2tuple(img_size) elif isinstance(img_size, tuple): if len(img_size) == 1: img_size = to_2tuple(img_size[0]) - assert len(img_size) == 2, \ - f'The size of image should have length 1 or 2, ' \ - f'but got {len(img_size)}' - - assert not (init_cfg and pretrained), \ - 'init_cfg and pretrained cannot be set at the same time' + assert len(img_size) == 2, ( + f"The size of image should have length 1 or 2, " + f"but got {len(img_size)}" + ) + + assert not ( + init_cfg and pretrained + ), "init_cfg and pretrained cannot be set at the same time" if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is deprecated, ' - 'please use "init_cfg" instead') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + warnings.warn( + "DeprecationWarning: pretrained is deprecated, " + 'please use "init_cfg" instead' + ) + self.init_cfg = dict(type="Pretrained", checkpoint=pretrained) elif pretrained is not None: - raise TypeError('pretrained must be a str or None') + raise TypeError("pretrained must be a str or None") self.in_channels = in_channels self.img_size = img_size @@ -323,8 +337,7 @@ def __init__(self, self.norm_cfg = norm_cfg self.patch_norm = patch_norm self.init_values = init_values - self.window_size = (img_size[0] // patch_size, - img_size[1] // patch_size) + self.window_size = (img_size[0] // patch_size, img_size[1] // patch_size) self.patch_shape = self.window_size self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) @@ -338,12 +351,11 @@ def __init__(self, elif isinstance(out_indices, list) or isinstance(out_indices, tuple): self.out_indices = out_indices else: - raise TypeError('out_indices must be type of int, list or tuple') + raise TypeError("out_indices must be type of int, list or tuple") self.final_norm = final_norm if final_norm: - self.norm1_name, norm1 = build_norm_layer( - norm_cfg, embed_dims, postfix=1) + self.norm1_name, norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) def _build_patch_embedding(self): @@ -351,19 +363,19 @@ def _build_patch_embedding(self): self.patch_embed = PatchEmbed( in_channels=self.in_channels, embed_dims=self.embed_dims, - conv_type='Conv2d', + conv_type="Conv2d", kernel_size=self.patch_size, stride=self.patch_size, padding=0, norm_cfg=self.norm_cfg if self.patch_norm else None, - init_cfg=None) + init_cfg=None, + ) def _build_layers(self): """Build transformer encoding layers.""" dpr = [ - x.item() - for x in torch.linspace(0, self.drop_path_rate, self.num_layers) + x.item() for x in torch.linspace(0, self.drop_path_rate, self.num_layers) ] self.layers = ModuleList() for i in range(self.num_layers): @@ -375,18 +387,19 @@ def _build_layers(self): attn_drop_rate=self.attn_drop_rate, drop_path_rate=dpr[i], num_fcs=self.num_fcs, - bias='qv_bias' if self.qv_bias else False, + bias="qv_bias" if self.qv_bias else False, act_cfg=self.act_cfg, norm_cfg=self.norm_cfg, window_size=self.window_size, - init_values=self.init_values)) + init_values=self.init_values, + ) + ) @property def norm1(self): return getattr(self, self.norm1_name) - def _geometric_sequence_interpolation(self, src_size, dst_size, sequence, - num): + def _geometric_sequence_interpolation(self, src_size, dst_size, sequence, num): """Get new sequence via geometric sequence interpolation. Args: @@ -419,7 +432,7 @@ def geometric_progression(a, r, n): cur = 1 for i in range(src_size // 2): dis.append(cur) - cur += q**(i + 1) + cur += q ** (i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis y = r_ids + [0] + dis @@ -430,9 +443,10 @@ def geometric_progression(a, r, n): new_sequence = [] for i in range(num): z = sequence[:, i].view(src_size, src_size).float().numpy() - f = interpolate.interp2d(x, y, z, kind='cubic') + f = interpolate.interp2d(x, y, z, kind="cubic") new_sequence.append( - torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(sequence)) + torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(sequence) + ) new_sequence = torch.cat(new_sequence, dim=-1) return new_sequence @@ -449,19 +463,19 @@ def resize_rel_pos_embed(self, checkpoint): state_dict (dict): Interpolate the relative pos_embed weights in the pre-train model to the current model size. """ - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] else: state_dict = checkpoint all_keys = list(state_dict.keys()) for key in all_keys: - if 'relative_position_index' in key: + if "relative_position_index" in key: state_dict.pop(key) # In order to keep the center of pos_bias as consistent as # possible after interpolation, and vice versa in the edge # area, the geometric sequence interpolation method is adopted. - if 'relative_position_bias_table' in key: + if "relative_position_bias_table" in key: rel_pos_bias = state_dict[key] src_num_pos, num_attn_heads = rel_pos_bias.size() dst_num_pos, _ = self.state_dict()[key].size() @@ -469,27 +483,28 @@ def resize_rel_pos_embed(self, checkpoint): if dst_patch_shape[0] != dst_patch_shape[1]: raise NotImplementedError() # Count the number of extra tokens. - num_extra_tokens = dst_num_pos - ( - dst_patch_shape[0] * 2 - 1) * ( - dst_patch_shape[1] * 2 - 1) - src_size = int((src_num_pos - num_extra_tokens)**0.5) - dst_size = int((dst_num_pos - num_extra_tokens)**0.5) + num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * ( + dst_patch_shape[1] * 2 - 1 + ) + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) if src_size != dst_size: extra_tokens = rel_pos_bias[-num_extra_tokens:, :] rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] new_rel_pos_bias = self._geometric_sequence_interpolation( - src_size, dst_size, rel_pos_bias, num_attn_heads) + src_size, dst_size, rel_pos_bias, num_attn_heads + ) new_rel_pos_bias = torch.cat( - (new_rel_pos_bias, extra_tokens), dim=0) + (new_rel_pos_bias, extra_tokens), dim=0 + ) state_dict[key] = new_rel_pos_bias return state_dict def init_weights(self): - def _init_weights(m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -498,33 +513,36 @@ def _init_weights(m): self.apply(_init_weights) - if (isinstance(self.init_cfg, dict) - and self.init_cfg.get('type') == 'Pretrained'): + if ( + isinstance(self.init_cfg, dict) + and self.init_cfg.get("type") == "Pretrained" + ): logger = get_root_logger() checkpoint = _load_checkpoint( - self.init_cfg['checkpoint'], logger=logger, map_location='cpu') + self.init_cfg["checkpoint"], logger=logger, map_location="cpu" + ) state_dict = self.resize_rel_pos_embed(checkpoint) self.load_state_dict(state_dict, False) elif self.init_cfg is not None: - super(BEiT, self).init_weights() + super().init_weights() else: # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 # Copyright 2019 Ross Wightman # Licensed under the Apache License, Version 2.0 (the "License") - trunc_normal_(self.cls_token, std=.02) + trunc_normal_(self.cls_token, std=0.02) for n, m in self.named_modules(): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if m.bias is not None: - if 'ffn' in n: - nn.init.normal_(m.bias, mean=0., std=1e-6) + if "ffn" in n: + nn.init.normal_(m.bias, mean=0.0, std=1e-6) else: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): - kaiming_init(m, mode='fan_in', bias=0.) + kaiming_init(m, mode="fan_in", bias=0.0) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): - constant_init(m, val=1.0, bias=0.) + constant_init(m, val=1.0, bias=0.0) def forward(self, inputs): B = inputs.shape[0] @@ -545,14 +563,17 @@ def forward(self, inputs): # Remove class token and reshape token for decoder head out = x[:, 1:] B, _, C = out.shape - out = out.reshape(B, hw_shape[0], hw_shape[1], - C).permute(0, 3, 1, 2).contiguous() + out = ( + out.reshape(B, hw_shape[0], hw_shape[1], C) + .permute(0, 3, 1, 2) + .contiguous() + ) outs.append(out) return tuple(outs) def train(self, mode=True): - super(BEiT, self).train(mode) + super().train(mode) if mode and self.norm_eval: for m in self.modules(): if isinstance(m, nn.LayerNorm): diff --git a/mmsegmentation/mmseg/models/backbones/bisenetv1.py b/mmsegmentation/mmseg/models/backbones/bisenetv1.py index 4beb7b3..5d25ff6 100644 --- a/mmsegmentation/mmseg/models/backbones/bisenetv1.py +++ b/mmsegmentation/mmseg/models/backbones/bisenetv1.py @@ -22,20 +22,24 @@ class SpatialPath(BaseModule): x (torch.Tensor): Feature map for Feature Fusion Module. """ - def __init__(self, - in_channels=3, - num_channels=(64, 64, 64, 128), - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - init_cfg=None): - super(SpatialPath, self).__init__(init_cfg=init_cfg) - assert len(num_channels) == 4, 'Length of input channels \ - of Spatial Path must be 4!' + def __init__( + self, + in_channels=3, + num_channels=(64, 64, 64, 128), + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + assert ( + len(num_channels) == 4 + ), "Length of input channels \ + of Spatial Path must be 4!" self.layers = [] for i in range(len(num_channels)): - layer_name = f'layer{i + 1}' + layer_name = f"layer{i + 1}" self.layers.append(layer_name) if i == 0: self.add_module( @@ -48,7 +52,9 @@ def __init__(self, padding=3, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ), + ) elif i == len(num_channels) - 1: self.add_module( layer_name, @@ -60,7 +66,9 @@ def __init__(self, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ), + ) else: self.add_module( layer_name, @@ -72,7 +80,9 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ), + ) def forward(self, x): for i, layer_name in enumerate(self.layers): @@ -91,14 +101,16 @@ class AttentionRefinementModule(BaseModule): x_out (torch.Tensor): Feature map of Attention Refinement Module. """ - def __init__(self, - in_channels, - out_channel, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - init_cfg=None): - super(AttentionRefinementModule, self).__init__(init_cfg=init_cfg) + def __init__( + self, + in_channels, + out_channel, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.conv_layer = ConvModule( in_channels=in_channels, out_channels=out_channel, @@ -107,7 +119,8 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) self.atten_conv_layer = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), ConvModule( @@ -117,7 +130,10 @@ def __init__(self, bias=False, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=None), nn.Sigmoid()) + act_cfg=None, + ), + nn.Sigmoid(), + ) def forward(self, x): x = self.conv_layer(x) @@ -144,25 +160,27 @@ class ContextPath(BaseModule): Fusion Module and Auxiliary Head. """ - def __init__(self, - backbone_cfg, - context_channels=(128, 256, 512), - align_corners=False, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - init_cfg=None): - super(ContextPath, self).__init__(init_cfg=init_cfg) - assert len(context_channels) == 3, 'Length of input channels \ - of Context Path must be 3!' + def __init__( + self, + backbone_cfg, + context_channels=(128, 256, 512), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + assert ( + len(context_channels) == 3 + ), "Length of input channels \ + of Context Path must be 3!" self.backbone = build_backbone(backbone_cfg) self.align_corners = align_corners - self.arm16 = AttentionRefinementModule(context_channels[1], - context_channels[0]) - self.arm32 = AttentionRefinementModule(context_channels[2], - context_channels[0]) + self.arm16 = AttentionRefinementModule(context_channels[1], context_channels[0]) + self.arm32 = AttentionRefinementModule(context_channels[2], context_channels[0]) self.conv_head32 = ConvModule( in_channels=context_channels[0], out_channels=context_channels[0], @@ -171,7 +189,8 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) self.conv_head16 = ConvModule( in_channels=context_channels[0], out_channels=context_channels[0], @@ -180,7 +199,8 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) self.gap_conv = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), ConvModule( @@ -191,7 +211,9 @@ def __init__(self, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ), + ) def forward(self, x): x_4, x_8, x_16, x_32 = self.backbone(x) @@ -199,12 +221,12 @@ def forward(self, x): x_32_arm = self.arm32(x_32) x_32_sum = x_32_arm + x_gap - x_32_up = resize(input=x_32_sum, size=x_16.shape[2:], mode='nearest') + x_32_up = resize(input=x_32_sum, size=x_16.shape[2:], mode="nearest") x_32_up = self.conv_head32(x_32_up) x_16_arm = self.arm16(x_16) x_16_sum = x_16_arm + x_32_up - x_16_up = resize(input=x_16_sum, size=x_8.shape[2:], mode='nearest') + x_16_up = resize(input=x_16_sum, size=x_8.shape[2:], mode="nearest") x_16_up = self.conv_head16(x_16_up) return x_16_up, x_32_up @@ -221,14 +243,16 @@ class FeatureFusionModule(BaseModule): x_out (torch.Tensor): Feature map of Feature Fusion Module. """ - def __init__(self, - in_channels, - out_channels, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - init_cfg=None): - super(FeatureFusionModule, self).__init__(init_cfg=init_cfg) + def __init__( + self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.conv1 = ConvModule( in_channels=in_channels, out_channels=out_channels, @@ -237,7 +261,8 @@ def __init__(self, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) self.gap = nn.AdaptiveAvgPool2d((1, 1)) self.conv_atten = nn.Sequential( ConvModule( @@ -249,7 +274,10 @@ def __init__(self, bias=False, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg), nn.Sigmoid()) + act_cfg=act_cfg, + ), + nn.Sigmoid(), + ) def forward(self, x_sp, x_cp): x_concat = torch.cat([x_sp, x_cp], dim=1) @@ -291,30 +319,36 @@ class BiSeNetV1(BaseModule): Default: 256. """ - def __init__(self, - backbone_cfg, - in_channels=3, - spatial_channels=(64, 64, 64, 128), - context_channels=(128, 256, 512), - out_indices=(0, 1, 2), - align_corners=False, - out_channels=256, - conv_cfg=None, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='ReLU'), - init_cfg=None): - - super(BiSeNetV1, self).__init__(init_cfg=init_cfg) - assert len(spatial_channels) == 4, 'Length of input channels \ - of Spatial Path must be 4!' - - assert len(context_channels) == 3, 'Length of input channels \ - of Context Path must be 3!' + def __init__( + self, + backbone_cfg, + in_channels=3, + spatial_channels=(64, 64, 64, 128), + context_channels=(128, 256, 512), + out_indices=(0, 1, 2), + align_corners=False, + out_channels=256, + conv_cfg=None, + norm_cfg=dict(type="BN", requires_grad=True), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + assert ( + len(spatial_channels) == 4 + ), "Length of input channels \ + of Spatial Path must be 4!" + + assert ( + len(context_channels) == 3 + ), "Length of input channels \ + of Context Path must be 3!" self.out_indices = out_indices self.align_corners = align_corners - self.context_path = ContextPath(backbone_cfg, context_channels, - self.align_corners) + self.context_path = ContextPath( + backbone_cfg, context_channels, self.align_corners + ) self.spatial_path = SpatialPath(in_channels, spatial_channels) self.ffm = FeatureFusionModule(context_channels[1], out_channels) self.conv_cfg = conv_cfg diff --git a/mmsegmentation/mmseg/models/backbones/bisenetv2.py b/mmsegmentation/mmseg/models/backbones/bisenetv2.py index d908b32..c97928c 100644 --- a/mmsegmentation/mmseg/models/backbones/bisenetv2.py +++ b/mmsegmentation/mmseg/models/backbones/bisenetv2.py @@ -1,8 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn -from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, - build_activation_layer, build_norm_layer) +from mmcv.cnn import ( + ConvModule, + DepthwiseSeparableConvModule, + build_activation_layer, + build_norm_layer, +) from mmcv.runner import BaseModule from mmseg.ops import resize @@ -30,14 +34,16 @@ class DetailBranch(BaseModule): x (torch.Tensor): Feature map of Detail Branch. """ - def __init__(self, - detail_channels=(64, 64, 128), - in_channels=3, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - init_cfg=None): - super(DetailBranch, self).__init__(init_cfg=init_cfg) + def __init__( + self, + detail_channels=(64, 64, 128), + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) detail_branch = [] for i in range(len(detail_channels)): if i == 0: @@ -51,7 +57,8 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg), + act_cfg=act_cfg, + ), ConvModule( in_channels=detail_channels[i], out_channels=detail_channels[i], @@ -60,7 +67,10 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg))) + act_cfg=act_cfg, + ), + ) + ) else: detail_branch.append( nn.Sequential( @@ -72,7 +82,8 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg), + act_cfg=act_cfg, + ), ConvModule( in_channels=detail_channels[i], out_channels=detail_channels[i], @@ -81,7 +92,8 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg), + act_cfg=act_cfg, + ), ConvModule( in_channels=detail_channels[i], out_channels=detail_channels[i], @@ -90,7 +102,10 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg))) + act_cfg=act_cfg, + ), + ) + ) self.detail_branch = nn.ModuleList(detail_branch) def forward(self, x): @@ -119,14 +134,16 @@ class StemBlock(BaseModule): x (torch.Tensor): First feature map in Semantic Branch. """ - def __init__(self, - in_channels=3, - out_channels=16, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - init_cfg=None): - super(StemBlock, self).__init__(init_cfg=init_cfg) + def __init__( + self, + in_channels=3, + out_channels=16, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.conv_first = ConvModule( in_channels=in_channels, @@ -136,7 +153,8 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) self.convs = nn.Sequential( ConvModule( in_channels=out_channels, @@ -146,7 +164,8 @@ def __init__(self, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg), + act_cfg=act_cfg, + ), ConvModule( in_channels=out_channels // 2, out_channels=out_channels, @@ -155,9 +174,10 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) - self.pool = nn.MaxPool2d( - kernel_size=3, stride=2, padding=1, ceil_mode=False) + act_cfg=act_cfg, + ), + ) + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False) self.fuse_last = ConvModule( in_channels=out_channels * 2, out_channels=out_channels, @@ -166,7 +186,8 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) def forward(self, x): x = self.conv_first(x) @@ -198,16 +219,18 @@ class GELayer(BaseModule): Semantic Branch. """ - def __init__(self, - in_channels, - out_channels, - exp_ratio=6, - stride=1, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - init_cfg=None): - super(GELayer, self).__init__(init_cfg=init_cfg) + def __init__( + self, + in_channels, + out_channels, + exp_ratio=6, + stride=1, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) mid_channel = in_channels * exp_ratio self.conv1 = ConvModule( in_channels=in_channels, @@ -217,7 +240,8 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) if stride == 1: self.dwconv = nn.Sequential( # ReLU in ConvModule not shown in paper @@ -230,7 +254,9 @@ def __init__(self, groups=in_channels, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ) + ) self.shortcut = None else: self.dwconv = nn.Sequential( @@ -244,7 +270,8 @@ def __init__(self, bias=False, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=None), + act_cfg=None, + ), # ReLU in ConvModule not shown in paper ConvModule( in_channels=mid_channel, @@ -255,7 +282,8 @@ def __init__(self, groups=mid_channel, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg), + act_cfg=act_cfg, + ), ) self.shortcut = nn.Sequential( DepthwiseSeparableConvModule( @@ -268,7 +296,8 @@ def __init__(self, dw_act_cfg=None, pw_norm_cfg=norm_cfg, pw_act_cfg=None, - )) + ) + ) self.conv2 = nn.Sequential( ConvModule( @@ -281,7 +310,8 @@ def __init__(self, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None, - )) + ) + ) self.act = build_activation_layer(act_cfg) @@ -319,19 +349,22 @@ class CEBlock(BaseModule): x (torch.Tensor): Last feature map in Semantic Branch. """ - def __init__(self, - in_channels=3, - out_channels=16, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - init_cfg=None): - super(CEBlock, self).__init__(init_cfg=init_cfg) + def __init__( + self, + in_channels=3, + out_channels=16, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.in_channels = in_channels self.out_channels = out_channels self.gap = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), - build_norm_layer(norm_cfg, self.in_channels)[1]) + build_norm_layer(norm_cfg, self.in_channels)[1], + ) self.conv_gap = ConvModule( in_channels=self.in_channels, out_channels=self.out_channels, @@ -340,7 +373,8 @@ def __init__(self, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) # Note: in paper here is naive conv2d, no bn-relu self.conv_last = ConvModule( in_channels=self.out_channels, @@ -350,7 +384,8 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) def forward(self, x): identity = x @@ -380,46 +415,60 @@ class SemanticBranch(BaseModule): Guided Aggregation Layer. """ - def __init__(self, - semantic_channels=(16, 32, 64, 128), - in_channels=3, - exp_ratio=6, - init_cfg=None): - super(SemanticBranch, self).__init__(init_cfg=init_cfg) + def __init__( + self, + semantic_channels=(16, 32, 64, 128), + in_channels=3, + exp_ratio=6, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.in_channels = in_channels self.semantic_channels = semantic_channels self.semantic_stages = [] for i in range(len(semantic_channels)): - stage_name = f'stage{i + 1}' + stage_name = f"stage{i + 1}" self.semantic_stages.append(stage_name) if i == 0: self.add_module( - stage_name, - StemBlock(self.in_channels, semantic_channels[i])) + stage_name, StemBlock(self.in_channels, semantic_channels[i]) + ) elif i == (len(semantic_channels) - 1): self.add_module( stage_name, nn.Sequential( - GELayer(semantic_channels[i - 1], semantic_channels[i], - exp_ratio, 2), - GELayer(semantic_channels[i], semantic_channels[i], - exp_ratio, 1), - GELayer(semantic_channels[i], semantic_channels[i], - exp_ratio, 1), - GELayer(semantic_channels[i], semantic_channels[i], - exp_ratio, 1))) + GELayer( + semantic_channels[i - 1], semantic_channels[i], exp_ratio, 2 + ), + GELayer( + semantic_channels[i], semantic_channels[i], exp_ratio, 1 + ), + GELayer( + semantic_channels[i], semantic_channels[i], exp_ratio, 1 + ), + GELayer( + semantic_channels[i], semantic_channels[i], exp_ratio, 1 + ), + ), + ) else: self.add_module( stage_name, nn.Sequential( - GELayer(semantic_channels[i - 1], semantic_channels[i], - exp_ratio, 2), - GELayer(semantic_channels[i], semantic_channels[i], - exp_ratio, 1))) - - self.add_module(f'stage{len(semantic_channels)}_CEBlock', - CEBlock(semantic_channels[-1], semantic_channels[-1])) - self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock') + GELayer( + semantic_channels[i - 1], semantic_channels[i], exp_ratio, 2 + ), + GELayer( + semantic_channels[i], semantic_channels[i], exp_ratio, 1 + ), + ), + ) + + self.add_module( + f"stage{len(semantic_channels)}_CEBlock", + CEBlock(semantic_channels[-1], semantic_channels[-1]), + ) + self.semantic_stages.append(f"stage{len(semantic_channels)}_CEBlock") def forward(self, x): semantic_outs = [] @@ -451,14 +500,16 @@ class BGALayer(BaseModule): output (torch.Tensor): Output feature map for Segment heads. """ - def __init__(self, - out_channels=128, - align_corners=False, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - init_cfg=None): - super(BGALayer, self).__init__(init_cfg=init_cfg) + def __init__( + self, + out_channels=128, + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.out_channels = out_channels self.align_corners = align_corners self.detail_dwconv = nn.Sequential( @@ -472,7 +523,8 @@ def __init__(self, dw_act_cfg=None, pw_norm_cfg=None, pw_act_cfg=None, - )) + ) + ) self.detail_down = nn.Sequential( ConvModule( in_channels=self.out_channels, @@ -483,8 +535,10 @@ def __init__(self, bias=False, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=None), - nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)) + act_cfg=None, + ), + nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False), + ) self.semantic_conv = nn.Sequential( ConvModule( in_channels=self.out_channels, @@ -495,7 +549,9 @@ def __init__(self, bias=False, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=None)) + act_cfg=None, + ) + ) self.semantic_dwconv = nn.Sequential( DepthwiseSeparableConvModule( in_channels=self.out_channels, @@ -507,7 +563,8 @@ def __init__(self, dw_act_cfg=None, pw_norm_cfg=None, pw_act_cfg=None, - )) + ) + ) self.conv = ConvModule( in_channels=self.out_channels, out_channels=self.out_channels, @@ -528,15 +585,17 @@ def forward(self, x_d, x_s): semantic_conv = resize( input=semantic_conv, size=detail_dwconv.shape[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv) fuse_2 = detail_down * torch.sigmoid(semantic_dwconv) fuse_2 = resize( input=fuse_2, size=fuse_1.shape[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) output = self.conv(fuse_1 + fuse_2) return output @@ -576,25 +635,26 @@ class BiSeNetV2(BaseModule): Default: None. """ - def __init__(self, - in_channels=3, - detail_channels=(64, 64, 128), - semantic_channels=(16, 32, 64, 128), - semantic_expansion_ratio=6, - bga_channels=128, - out_indices=(0, 1, 2, 3, 4), - align_corners=False, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - init_cfg=None): + def __init__( + self, + in_channels=3, + detail_channels=(64, 64, 128), + semantic_channels=(16, 32, 64, 128), + semantic_expansion_ratio=6, + bga_channels=128, + out_indices=(0, 1, 2, 3, 4), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): if init_cfg is None: init_cfg = [ - dict(type='Kaiming', layer='Conv2d'), - dict( - type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + dict(type="Kaiming", layer="Conv2d"), + dict(type="Constant", val=1, layer=["_BatchNorm", "GroupNorm"]), ] - super(BiSeNetV2, self).__init__(init_cfg=init_cfg) + super().__init__(init_cfg=init_cfg) self.in_channels = in_channels self.out_indices = out_indices self.detail_channels = detail_channels @@ -607,9 +667,9 @@ def __init__(self, self.act_cfg = act_cfg self.detail = DetailBranch(self.detail_channels, self.in_channels) - self.semantic = SemanticBranch(self.semantic_channels, - self.in_channels, - self.semantic_expansion_ratio) + self.semantic = SemanticBranch( + self.semantic_channels, self.in_channels, self.semantic_expansion_ratio + ) self.bga = BGALayer(self.bga_channels, self.align_corners) def forward(self, x): diff --git a/mmsegmentation/mmseg/models/backbones/cgnet.py b/mmsegmentation/mmseg/models/backbones/cgnet.py index 168194c..d5d7ef3 100644 --- a/mmsegmentation/mmseg/models/backbones/cgnet.py +++ b/mmsegmentation/mmseg/models/backbones/cgnet.py @@ -25,18 +25,20 @@ class GlobalContextExtractor(nn.Module): """ def __init__(self, channel, reduction=16, with_cp=False): - super(GlobalContextExtractor, self).__init__() + super().__init__() self.channel = channel self.reduction = reduction assert reduction >= 1 and channel >= reduction self.with_cp = with_cp self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( - nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), - nn.Linear(channel // reduction, channel), nn.Sigmoid()) + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid(), + ) def forward(self, x): - def _inner_forward(x): num_batch, num_channel = x.size()[:2] y = self.avg_pool(x).view(num_batch, num_channel) @@ -76,24 +78,26 @@ class ContextGuidedBlock(nn.Module): memory while slowing down the training speed. Default: False. """ - def __init__(self, - in_channels, - out_channels, - dilation=2, - reduction=16, - skip_connect=True, - downsample=False, - conv_cfg=None, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='PReLU'), - with_cp=False): - super(ContextGuidedBlock, self).__init__() + def __init__( + self, + in_channels, + out_channels, + dilation=2, + reduction=16, + skip_connect=True, + downsample=False, + conv_cfg=None, + norm_cfg=dict(type="BN", requires_grad=True), + act_cfg=dict(type="PReLU"), + with_cp=False, + ): + super().__init__() self.with_cp = with_cp self.downsample = downsample channels = out_channels if downsample else out_channels // 2 - if 'type' in act_cfg and act_cfg['type'] == 'PReLU': - act_cfg['num_parameters'] = channels + if "type" in act_cfg and act_cfg["type"] == "PReLU": + act_cfg["num_parameters"] = channels kernel_size = 3 if downsample else 1 stride = 2 if downsample else 1 padding = (kernel_size - 1) // 2 @@ -106,7 +110,8 @@ def __init__(self, padding, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) self.f_loc = build_conv_layer( conv_cfg, @@ -115,7 +120,8 @@ def __init__(self, kernel_size=3, padding=1, groups=channels, - bias=False) + bias=False, + ) self.f_sur = build_conv_layer( conv_cfg, channels, @@ -124,24 +130,21 @@ def __init__(self, padding=dilation, groups=channels, dilation=dilation, - bias=False) + bias=False, + ) self.bn = build_norm_layer(norm_cfg, 2 * channels)[1] self.activate = nn.PReLU(2 * channels) if downsample: self.bottleneck = build_conv_layer( - conv_cfg, - 2 * channels, - out_channels, - kernel_size=1, - bias=False) + conv_cfg, 2 * channels, out_channels, kernel_size=1, bias=False + ) self.skip_connect = skip_connect and not downsample self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp) def forward(self, x): - def _inner_forward(x): out = self.conv1x1(x) loc = self.f_loc(out) @@ -172,7 +175,7 @@ class InputInjection(nn.Module): """Downsampling module for CGNet.""" def __init__(self, num_downsampling): - super(InputInjection, self).__init__() + super().__init__() self.pool = nn.ModuleList() for i in range(num_downsampling): self.pool.append(nn.AvgPool2d(3, stride=2, padding=1)) @@ -216,45 +219,45 @@ class CGNet(BaseModule): Default: None """ - def __init__(self, - in_channels=3, - num_channels=(32, 64, 128), - num_blocks=(3, 21), - dilations=(2, 4), - reductions=(8, 16), - conv_cfg=None, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='PReLU'), - norm_eval=False, - with_cp=False, - pretrained=None, - init_cfg=None): - - super(CGNet, self).__init__(init_cfg) - - assert not (init_cfg and pretrained), \ - 'init_cfg and pretrained cannot be setting at the same time' + def __init__( + self, + in_channels=3, + num_channels=(32, 64, 128), + num_blocks=(3, 21), + dilations=(2, 4), + reductions=(8, 16), + conv_cfg=None, + norm_cfg=dict(type="BN", requires_grad=True), + act_cfg=dict(type="PReLU"), + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None, + ): + super().__init__(init_cfg) + + assert not ( + init_cfg and pretrained + ), "init_cfg and pretrained cannot be setting at the same time" if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is a deprecated, ' - 'please use "init_cfg" instead') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + warnings.warn( + "DeprecationWarning: pretrained is a deprecated, " + 'please use "init_cfg" instead' + ) + self.init_cfg = dict(type="Pretrained", checkpoint=pretrained) elif pretrained is None: if init_cfg is None: self.init_cfg = [ - dict(type='Kaiming', layer=['Conv2d', 'Linear']), - dict( - type='Constant', - val=1, - layer=['_BatchNorm', 'GroupNorm']), - dict(type='Constant', val=0, layer='PReLU') + dict(type="Kaiming", layer=["Conv2d", "Linear"]), + dict(type="Constant", val=1, layer=["_BatchNorm", "GroupNorm"]), + dict(type="Constant", val=0, layer="PReLU"), ] else: - raise TypeError('pretrained must be a str or None') + raise TypeError("pretrained must be a str or None") self.in_channels = in_channels self.num_channels = num_channels - assert isinstance(self.num_channels, tuple) and len( - self.num_channels) == 3 + assert isinstance(self.num_channels, tuple) and len(self.num_channels) == 3 self.num_blocks = num_blocks assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2 self.dilations = dilations @@ -264,8 +267,8 @@ def __init__(self, self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg - if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU': - self.act_cfg['num_parameters'] = num_channels[0] + if "type" in self.act_cfg and self.act_cfg["type"] == "PReLU": + self.act_cfg["num_parameters"] = num_channels[0] self.norm_eval = norm_eval self.with_cp = with_cp @@ -281,7 +284,9 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ) + ) cur_channels = num_channels[0] self.inject_2x = InputInjection(1) # down-sample for Input, factor=2 @@ -289,8 +294,8 @@ def __init__(self, cur_channels += in_channels self.norm_prelu_0 = nn.Sequential( - build_norm_layer(norm_cfg, cur_channels)[1], - nn.PReLU(cur_channels)) + build_norm_layer(norm_cfg, cur_channels)[1], nn.PReLU(cur_channels) + ) # stage 1 self.level1 = nn.ModuleList() @@ -305,12 +310,14 @@ def __init__(self, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, - with_cp=with_cp)) # CG block + with_cp=with_cp, + ) + ) # CG block cur_channels = 2 * num_channels[1] + in_channels self.norm_prelu_1 = nn.Sequential( - build_norm_layer(norm_cfg, cur_channels)[1], - nn.PReLU(cur_channels)) + build_norm_layer(norm_cfg, cur_channels)[1], nn.PReLU(cur_channels) + ) # stage 2 self.level2 = nn.ModuleList() @@ -325,12 +332,14 @@ def __init__(self, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, - with_cp=with_cp)) # CG block + with_cp=with_cp, + ) + ) # CG block cur_channels = 2 * num_channels[2] self.norm_prelu_2 = nn.Sequential( - build_norm_layer(norm_cfg, cur_channels)[1], - nn.PReLU(cur_channels)) + build_norm_layer(norm_cfg, cur_channels)[1], nn.PReLU(cur_channels) + ) def forward(self, x): output = [] @@ -364,7 +373,7 @@ def forward(self, x): def train(self, mode=True): """Convert the model into training mode will keeping the normalization layer freezed.""" - super(CGNet, self).train(mode) + super().train(mode) if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only diff --git a/mmsegmentation/mmseg/models/backbones/erfnet.py b/mmsegmentation/mmseg/models/backbones/erfnet.py index 8921c18..f31b3f1 100644 --- a/mmsegmentation/mmseg/models/backbones/erfnet.py +++ b/mmsegmentation/mmseg/models/backbones/erfnet.py @@ -28,14 +28,16 @@ class DownsamplerBlock(BaseModule): Default: None. """ - def __init__(self, - in_channels, - out_channels, - conv_cfg=None, - norm_cfg=dict(type='BN', eps=1e-3), - act_cfg=dict(type='ReLU'), - init_cfg=None): - super(DownsamplerBlock, self).__init__(init_cfg=init_cfg) + def __init__( + self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type="BN", eps=1e-3), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg @@ -46,7 +48,8 @@ def __init__(self, out_channels - in_channels, kernel_size=3, stride=2, - padding=1) + padding=1, + ) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] self.act = build_activation_layer(self.act_cfg) @@ -57,8 +60,9 @@ def forward(self, input): pool_out = resize( input=pool_out, size=conv_out.size()[2:], - mode='bilinear', - align_corners=False) + mode="bilinear", + align_corners=False, + ) output = torch.cat([conv_out, pool_out], 1) output = self.bn(output) output = self.act(output) @@ -86,16 +90,18 @@ class NonBottleneck1d(BaseModule): Default: None. """ - def __init__(self, - channels, - drop_rate=0, - dilation=1, - num_conv_layer=2, - conv_cfg=None, - norm_cfg=dict(type='BN', eps=1e-3), - act_cfg=dict(type='ReLU'), - init_cfg=None): - super(NonBottleneck1d, self).__init__(init_cfg=init_cfg) + def __init__( + self, + channels, + drop_rate=0, + dilation=1, + num_conv_layer=2, + conv_cfg=None, + norm_cfg=dict(type="BN", eps=1e-3), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg @@ -118,7 +124,9 @@ def __init__(self, stride=1, padding=first_conv_padding, bias=True, - dilation=first_conv_dilation)) + dilation=first_conv_dilation, + ) + ) self.convs_layers.append(self.act) self.convs_layers.append( build_conv_layer( @@ -129,9 +137,10 @@ def __init__(self, stride=1, padding=second_conv_padding, bias=True, - dilation=second_conv_dilation)) - self.convs_layers.append( - build_norm_layer(self.norm_cfg, channels)[1]) + dilation=second_conv_dilation, + ) + ) + self.convs_layers.append(build_norm_layer(self.norm_cfg, channels)[1]) if conv_layer == 0: self.convs_layers.append(self.act) else: @@ -161,14 +170,16 @@ class UpsamplerBlock(BaseModule): Default: None. """ - def __init__(self, - in_channels, - out_channels, - conv_cfg=None, - norm_cfg=dict(type='BN', eps=1e-3), - act_cfg=dict(type='ReLU'), - init_cfg=None): - super(UpsamplerBlock, self).__init__(init_cfg=init_cfg) + def __init__( + self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type="BN", eps=1e-3), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg @@ -180,7 +191,8 @@ def __init__(self, stride=2, padding=1, output_padding=1, - bias=True) + bias=True, + ) self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] self.act = build_activation_layer(self.act_cfg) @@ -227,46 +239,53 @@ class ERFNet(BaseModule): Default 0.1. """ - def __init__(self, - in_channels=3, - enc_downsample_channels=(16, 64, 128), - enc_stage_non_bottlenecks=(5, 8), - enc_non_bottleneck_dilations=(2, 4, 8, 16), - enc_non_bottleneck_channels=(64, 128), - dec_upsample_channels=(64, 16), - dec_stages_non_bottleneck=(2, 2), - dec_non_bottleneck_channels=(64, 16), - dropout_ratio=0.1, - conv_cfg=None, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='ReLU'), - init_cfg=None): - - super(ERFNet, self).__init__(init_cfg=init_cfg) - assert len(enc_downsample_channels) \ - == len(dec_upsample_channels)+1, 'Number of downsample\ + def __init__( + self, + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_stage_non_bottlenecks=(5, 8), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128), + dec_upsample_channels=(64, 16), + dec_stages_non_bottleneck=(2, 2), + dec_non_bottleneck_channels=(64, 16), + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=dict(type="BN", requires_grad=True), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + assert ( + len(enc_downsample_channels) == len(dec_upsample_channels) + 1 + ), "Number of downsample\ block of encoder does not \ - match number of upsample block of decoder!' - assert len(enc_downsample_channels) \ - == len(enc_stage_non_bottlenecks)+1, 'Number of \ + match number of upsample block of decoder!" + assert ( + len(enc_downsample_channels) == len(enc_stage_non_bottlenecks) + 1 + ), "Number of \ downsample block of encoder does not match \ - number of Non-bottleneck block of encoder!' - assert len(enc_downsample_channels) \ - == len(enc_non_bottleneck_channels)+1, 'Number of \ + number of Non-bottleneck block of encoder!" + assert ( + len(enc_downsample_channels) == len(enc_non_bottleneck_channels) + 1 + ), "Number of \ downsample block of encoder does not match \ - number of channels of Non-bottleneck block of encoder!' - assert enc_stage_non_bottlenecks[-1] \ - % len(enc_non_bottleneck_dilations) == 0, 'Number of \ + number of channels of Non-bottleneck block of encoder!" + assert ( + enc_stage_non_bottlenecks[-1] % len(enc_non_bottleneck_dilations) == 0 + ), "Number of \ Non-bottleneck block of encoder does not match \ - number of Non-bottleneck block of encoder!' - assert len(dec_upsample_channels) \ - == len(dec_stages_non_bottleneck), 'Number of \ + number of Non-bottleneck block of encoder!" + assert len(dec_upsample_channels) == len( + dec_stages_non_bottleneck + ), "Number of \ upsample block of decoder does not match \ - number of Non-bottleneck block of decoder!' - assert len(dec_stages_non_bottleneck) \ - == len(dec_non_bottleneck_channels), 'Number of \ + number of Non-bottleneck block of decoder!" + assert len(dec_stages_non_bottleneck) == len( + dec_non_bottleneck_channels + ), "Number of \ Non-bottleneck block of decoder does not match \ - number of channels of Non-bottleneck block of decoder!' + number of channels of Non-bottleneck block of decoder!" self.in_channels = in_channels self.enc_downsample_channels = enc_downsample_channels @@ -286,40 +305,53 @@ def __init__(self, self.act_cfg = act_cfg self.encoder.append( - DownsamplerBlock(self.in_channels, enc_downsample_channels[0])) + DownsamplerBlock(self.in_channels, enc_downsample_channels[0]) + ) for i in range(len(enc_downsample_channels) - 1): self.encoder.append( - DownsamplerBlock(enc_downsample_channels[i], - enc_downsample_channels[i + 1])) + DownsamplerBlock( + enc_downsample_channels[i], enc_downsample_channels[i + 1] + ) + ) # Last part of encoder is some dilated NonBottleneck1d blocks. if i == len(enc_downsample_channels) - 2: - iteration_times = int(enc_stage_non_bottlenecks[-1] / - len(enc_non_bottleneck_dilations)) + iteration_times = int( + enc_stage_non_bottlenecks[-1] / len(enc_non_bottleneck_dilations) + ) for j in range(iteration_times): for k in range(len(enc_non_bottleneck_dilations)): self.encoder.append( - NonBottleneck1d(enc_downsample_channels[-1], - self.dropout_ratio, - enc_non_bottleneck_dilations[k])) + NonBottleneck1d( + enc_downsample_channels[-1], + self.dropout_ratio, + enc_non_bottleneck_dilations[k], + ) + ) else: for j in range(enc_stage_non_bottlenecks[i]): self.encoder.append( - NonBottleneck1d(enc_downsample_channels[i + 1], - self.dropout_ratio)) + NonBottleneck1d( + enc_downsample_channels[i + 1], self.dropout_ratio + ) + ) for i in range(len(dec_upsample_channels)): if i == 0: self.decoder.append( - UpsamplerBlock(enc_downsample_channels[-1], - dec_non_bottleneck_channels[i])) + UpsamplerBlock( + enc_downsample_channels[-1], dec_non_bottleneck_channels[i] + ) + ) else: self.decoder.append( - UpsamplerBlock(dec_non_bottleneck_channels[i - 1], - dec_non_bottleneck_channels[i])) + UpsamplerBlock( + dec_non_bottleneck_channels[i - 1], + dec_non_bottleneck_channels[i], + ) + ) for j in range(dec_stages_non_bottleneck[i]): - self.decoder.append( - NonBottleneck1d(dec_non_bottleneck_channels[i])) + self.decoder.append(NonBottleneck1d(dec_non_bottleneck_channels[i])) def forward(self, x): for enc in self.encoder: diff --git a/mmsegmentation/mmseg/models/backbones/fast_scnn.py b/mmsegmentation/mmseg/models/backbones/fast_scnn.py index cbfbcaf..23201c3 100644 --- a/mmsegmentation/mmseg/models/backbones/fast_scnn.py +++ b/mmsegmentation/mmseg/models/backbones/fast_scnn.py @@ -29,15 +29,17 @@ class LearningToDownsample(nn.Module): as `act_cfg`. Default: None. """ - def __init__(self, - in_channels, - dw_channels, - out_channels, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - dw_act_cfg=None): - super(LearningToDownsample, self).__init__() + def __init__( + self, + in_channels, + dw_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + dw_act_cfg=None, + ): + super().__init__() self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg @@ -53,7 +55,8 @@ def __init__(self, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.dsconv1 = DepthwiseSeparableConvModule( dw_channels1, @@ -62,7 +65,8 @@ def __init__(self, stride=2, padding=1, norm_cfg=self.norm_cfg, - dw_act_cfg=self.dw_act_cfg) + dw_act_cfg=self.dw_act_cfg, + ) self.dsconv2 = DepthwiseSeparableConvModule( dw_channels2, @@ -71,7 +75,8 @@ def __init__(self, stride=2, padding=1, norm_cfg=self.norm_cfg, - dw_act_cfg=self.dw_act_cfg) + dw_act_cfg=self.dw_act_cfg, + ) def forward(self, x): x = self.conv(x) @@ -113,32 +118,42 @@ class GlobalFeatureExtractor(nn.Module): Default: False """ - def __init__(self, - in_channels=64, - block_channels=(64, 96, 128), - out_channels=128, - expand_ratio=6, - num_blocks=(3, 3, 3), - strides=(2, 2, 1), - pool_scales=(1, 2, 3, 6), - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - align_corners=False): - super(GlobalFeatureExtractor, self).__init__() + def __init__( + self, + in_channels=64, + block_channels=(64, 96, 128), + out_channels=128, + expand_ratio=6, + num_blocks=(3, 3, 3), + strides=(2, 2, 1), + pool_scales=(1, 2, 3, 6), + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + align_corners=False, + ): + super().__init__() self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg assert len(block_channels) == len(num_blocks) == 3 - self.bottleneck1 = self._make_layer(in_channels, block_channels[0], - num_blocks[0], strides[0], - expand_ratio) - self.bottleneck2 = self._make_layer(block_channels[0], - block_channels[1], num_blocks[1], - strides[1], expand_ratio) - self.bottleneck3 = self._make_layer(block_channels[1], - block_channels[2], num_blocks[2], - strides[2], expand_ratio) + self.bottleneck1 = self._make_layer( + in_channels, block_channels[0], num_blocks[0], strides[0], expand_ratio + ) + self.bottleneck2 = self._make_layer( + block_channels[0], + block_channels[1], + num_blocks[1], + strides[1], + expand_ratio, + ) + self.bottleneck3 = self._make_layer( + block_channels[1], + block_channels[2], + num_blocks[2], + strides[2], + expand_ratio, + ) self.ppm = PPM( pool_scales, block_channels[2], @@ -146,7 +161,8 @@ def __init__(self, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, - align_corners=align_corners) + align_corners=align_corners, + ) self.out = ConvModule( block_channels[2] * 2, @@ -155,14 +171,10 @@ def __init__(self, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) - def _make_layer(self, - in_channels, - out_channels, - blocks, - stride=1, - expand_ratio=6): + def _make_layer(self, in_channels, out_channels, blocks, stride=1, expand_ratio=6): layers = [ InvertedResidual( in_channels, @@ -170,7 +182,8 @@ def _make_layer(self, stride, expand_ratio, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) ] for i in range(1, blocks): layers.append( @@ -180,7 +193,9 @@ def _make_layer(self, 1, expand_ratio, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg)) + act_cfg=self.act_cfg, + ) + ) return nn.Sequential(*layers) def forward(self, x): @@ -212,16 +227,18 @@ class FeatureFusionModule(nn.Module): Default: False. """ - def __init__(self, - higher_in_channels, - lower_in_channels, - out_channels, - conv_cfg=None, - norm_cfg=dict(type='BN'), - dwconv_act_cfg=dict(type='ReLU'), - conv_act_cfg=None, - align_corners=False): - super(FeatureFusionModule, self).__init__() + def __init__( + self, + higher_in_channels, + lower_in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type="BN"), + dwconv_act_cfg=dict(type="ReLU"), + conv_act_cfg=None, + align_corners=False, + ): + super().__init__() self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.dwconv_act_cfg = dwconv_act_cfg @@ -235,14 +252,16 @@ def __init__(self, groups=out_channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.dwconv_act_cfg) + act_cfg=self.dwconv_act_cfg, + ) self.conv_lower_res = ConvModule( out_channels, out_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.conv_act_cfg) + act_cfg=self.conv_act_cfg, + ) self.conv_higher_res = ConvModule( higher_in_channels, @@ -250,7 +269,8 @@ def __init__(self, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.conv_act_cfg) + act_cfg=self.conv_act_cfg, + ) self.relu = nn.ReLU(True) @@ -258,8 +278,9 @@ def forward(self, higher_res_feature, lower_res_feature): lower_res_feature = resize( lower_res_feature, size=higher_res_feature.size()[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) lower_res_feature = self.dwconv(lower_res_feature) lower_res_feature = self.conv_lower_res(lower_res_feature) @@ -323,39 +344,43 @@ class FastSCNN(BaseModule): Default: None """ - def __init__(self, - in_channels=3, - downsample_dw_channels=(32, 48), - global_in_channels=64, - global_block_channels=(64, 96, 128), - global_block_strides=(2, 2, 1), - global_out_channels=128, - higher_in_channels=64, - lower_in_channels=128, - fusion_out_channels=128, - out_indices=(0, 1, 2), - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - align_corners=False, - dw_act_cfg=None, - init_cfg=None): - - super(FastSCNN, self).__init__(init_cfg) + def __init__( + self, + in_channels=3, + downsample_dw_channels=(32, 48), + global_in_channels=64, + global_block_channels=(64, 96, 128), + global_block_strides=(2, 2, 1), + global_out_channels=128, + higher_in_channels=64, + lower_in_channels=128, + fusion_out_channels=128, + out_indices=(0, 1, 2), + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + align_corners=False, + dw_act_cfg=None, + init_cfg=None, + ): + super().__init__(init_cfg) if init_cfg is None: self.init_cfg = [ - dict(type='Kaiming', layer='Conv2d'), - dict( - type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + dict(type="Kaiming", layer="Conv2d"), + dict(type="Constant", val=1, layer=["_BatchNorm", "GroupNorm"]), ] if global_in_channels != higher_in_channels: - raise AssertionError('Global Input Channels must be the same \ - with Higher Input Channels!') + raise AssertionError( + "Global Input Channels must be the same \ + with Higher Input Channels!" + ) elif global_out_channels != lower_in_channels: - raise AssertionError('Global Output Channels must be the same \ - with Lower Input Channels!') + raise AssertionError( + "Global Output Channels must be the same \ + with Lower Input Channels!" + ) self.in_channels = in_channels self.downsample_dw_channels1 = downsample_dw_channels[0] @@ -379,7 +404,8 @@ def __init__(self, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, - dw_act_cfg=dw_act_cfg) + dw_act_cfg=dw_act_cfg, + ) self.global_feature_extractor = GlobalFeatureExtractor( global_in_channels, global_block_channels, @@ -388,7 +414,8 @@ def __init__(self, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, - align_corners=self.align_corners) + align_corners=self.align_corners, + ) self.feature_fusion = FeatureFusionModule( higher_in_channels, lower_in_channels, @@ -396,13 +423,13 @@ def __init__(self, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, dwconv_act_cfg=self.act_cfg, - align_corners=self.align_corners) + align_corners=self.align_corners, + ) def forward(self, x): higher_res_features = self.learning_to_downsample(x) lower_res_features = self.global_feature_extractor(higher_res_features) - fusion_output = self.feature_fusion(higher_res_features, - lower_res_features) + fusion_output = self.feature_fusion(higher_res_features, lower_res_features) outs = [higher_res_features, lower_res_features, fusion_output] outs = [outs[i] for i in self.out_indices] diff --git a/mmsegmentation/mmseg/models/backbones/hrnet.py b/mmsegmentation/mmseg/models/backbones/hrnet.py index 90feadc..a941188 100644 --- a/mmsegmentation/mmseg/models/backbones/hrnet.py +++ b/mmsegmentation/mmseg/models/backbones/hrnet.py @@ -18,22 +18,23 @@ class HRModule(BaseModule): is in this module. """ - def __init__(self, - num_branches, - blocks, - num_blocks, - in_channels, - num_channels, - multiscale_output=True, - with_cp=False, - conv_cfg=None, - norm_cfg=dict(type='BN', requires_grad=True), - block_init_cfg=None, - init_cfg=None): - super(HRModule, self).__init__(init_cfg) + def __init__( + self, + num_branches, + blocks, + num_blocks, + in_channels, + num_channels, + multiscale_output=True, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type="BN", requires_grad=True), + block_init_cfg=None, + init_cfg=None, + ): + super().__init__(init_cfg) self.block_init_cfg = block_init_cfg - self._check_branches(num_branches, num_blocks, in_channels, - num_channels) + self._check_branches(num_branches, num_blocks, in_channels, num_channels) self.in_channels = in_channels self.num_branches = num_branches @@ -42,40 +43,41 @@ def __init__(self, self.norm_cfg = norm_cfg self.conv_cfg = conv_cfg self.with_cp = with_cp - self.branches = self._make_branches(num_branches, blocks, num_blocks, - num_channels) + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels + ) self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU(inplace=False) - def _check_branches(self, num_branches, num_blocks, in_channels, - num_channels): + def _check_branches(self, num_branches, num_blocks, in_channels, num_channels): """Check branches configuration.""" if num_branches != len(num_blocks): - error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \ - f'{len(num_blocks)})' + error_msg = ( + f"NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(" f"{len(num_blocks)})" + ) raise ValueError(error_msg) if num_branches != len(num_channels): - error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \ - f'{len(num_channels)})' + error_msg = ( + f"NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(" f"{len(num_channels)})" + ) raise ValueError(error_msg) if num_branches != len(in_channels): - error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \ - f'{len(in_channels)})' + error_msg = ( + f"NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(" + f"{len(in_channels)})" + ) raise ValueError(error_msg) - def _make_one_branch(self, - branch_index, - block, - num_blocks, - num_channels, - stride=1): + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): """Build one branch.""" downsample = None - if stride != 1 or \ - self.in_channels[branch_index] != \ - num_channels[branch_index] * block.expansion: + if ( + stride != 1 + or self.in_channels[branch_index] + != num_channels[branch_index] * block.expansion + ): downsample = nn.Sequential( build_conv_layer( self.conv_cfg, @@ -83,9 +85,12 @@ def _make_one_branch(self, num_channels[branch_index] * block.expansion, kernel_size=1, stride=stride, - bias=False), - build_norm_layer(self.norm_cfg, num_channels[branch_index] * - block.expansion)[1]) + bias=False, + ), + build_norm_layer( + self.norm_cfg, num_channels[branch_index] * block.expansion + )[1], + ) layers = [] layers.append( @@ -97,9 +102,10 @@ def _make_one_branch(self, with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg, - init_cfg=self.block_init_cfg)) - self.in_channels[branch_index] = \ - num_channels[branch_index] * block.expansion + init_cfg=self.block_init_cfg, + ) + ) + self.in_channels[branch_index] = num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): layers.append( block( @@ -108,7 +114,9 @@ def _make_one_branch(self, with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg, - init_cfg=self.block_init_cfg)) + init_cfg=self.block_init_cfg, + ) + ) return Sequential(*layers) @@ -117,8 +125,7 @@ def _make_branches(self, num_branches, block, num_blocks, num_channels): branches = [] for i in range(num_branches): - branches.append( - self._make_one_branch(i, block, num_blocks, num_channels)) + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) return ModuleList(branches) @@ -144,13 +151,17 @@ def _make_fuse_layers(self): kernel_size=1, stride=1, padding=0, - bias=False), + bias=False, + ), build_norm_layer(self.norm_cfg, in_channels[i])[1], # we set align_corners=False for HRNet Upsample( - scale_factor=2**(j - i), - mode='bilinear', - align_corners=False))) + scale_factor=2 ** (j - i), + mode="bilinear", + align_corners=False, + ), + ) + ) elif j == i: fuse_layer.append(None) else: @@ -166,9 +177,11 @@ def _make_fuse_layers(self): kernel_size=3, stride=2, padding=1, - bias=False), - build_norm_layer(self.norm_cfg, - in_channels[i])[1])) + bias=False, + ), + build_norm_layer(self.norm_cfg, in_channels[i])[1], + ) + ) else: conv_downsamples.append( nn.Sequential( @@ -179,10 +192,12 @@ def _make_fuse_layers(self): kernel_size=3, stride=2, padding=1, - bias=False), - build_norm_layer(self.norm_cfg, - in_channels[j])[1], - nn.ReLU(inplace=False))) + bias=False, + ), + build_norm_layer(self.norm_cfg, in_channels[j])[1], + nn.ReLU(inplace=False), + ) + ) fuse_layer.append(nn.Sequential(*conv_downsamples)) fuse_layers.append(nn.ModuleList(fuse_layer)) @@ -206,8 +221,9 @@ def forward(self, x): y = y + resize( self.fuse_layers[i][j](x[j]), size=x[i].shape[2:], - mode='bilinear', - align_corners=False) + mode="bilinear", + align_corners=False, + ) else: y += self.fuse_layers[i][j](x[j]) x_fuse.append(self.relu(y)) @@ -294,51 +310,59 @@ class HRNet(BaseModule): (1, 256, 1, 1) """ - blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} - - def __init__(self, - extra, - in_channels=3, - conv_cfg=None, - norm_cfg=dict(type='BN', requires_grad=True), - norm_eval=False, - with_cp=False, - frozen_stages=-1, - zero_init_residual=False, - multiscale_output=True, - pretrained=None, - init_cfg=None): - super(HRNet, self).__init__(init_cfg) + blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck} + + def __init__( + self, + extra, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type="BN", requires_grad=True), + norm_eval=False, + with_cp=False, + frozen_stages=-1, + zero_init_residual=False, + multiscale_output=True, + pretrained=None, + init_cfg=None, + ): + super().__init__(init_cfg) self.pretrained = pretrained self.zero_init_residual = zero_init_residual - assert not (init_cfg and pretrained), \ - 'init_cfg and pretrained cannot be setting at the same time' + assert not ( + init_cfg and pretrained + ), "init_cfg and pretrained cannot be setting at the same time" if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is deprecated, ' - 'please use "init_cfg" instead') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + warnings.warn( + "DeprecationWarning: pretrained is deprecated, " + 'please use "init_cfg" instead' + ) + self.init_cfg = dict(type="Pretrained", checkpoint=pretrained) elif pretrained is None: if init_cfg is None: self.init_cfg = [ - dict(type='Kaiming', layer='Conv2d'), - dict( - type='Constant', - val=1, - layer=['_BatchNorm', 'GroupNorm']) + dict(type="Kaiming", layer="Conv2d"), + dict(type="Constant", val=1, layer=["_BatchNorm", "GroupNorm"]), ] else: - raise TypeError('pretrained must be a str or None') + raise TypeError("pretrained must be a str or None") # Assert configurations of 4 stages are in extra - assert 'stage1' in extra and 'stage2' in extra \ - and 'stage3' in extra and 'stage4' in extra + assert ( + "stage1" in extra + and "stage2" in extra + and "stage3" in extra + and "stage4" in extra + ) # Assert whether the length of `num_blocks` and `num_channels` are # equal to `num_branches` for i in range(4): - cfg = extra[f'stage{i + 1}'] - assert len(cfg['num_blocks']) == cfg['num_branches'] and \ - len(cfg['num_channels']) == cfg['num_branches'] + cfg = extra[f"stage{i + 1}"] + assert ( + len(cfg["num_blocks"]) == cfg["num_branches"] + and len(cfg["num_channels"]) == cfg["num_branches"] + ) self.extra = extra self.conv_cfg = conv_cfg @@ -358,66 +382,64 @@ def __init__(self, kernel_size=3, stride=2, padding=1, - bias=False) + bias=False, + ) self.add_module(self.norm1_name, norm1) self.conv2 = build_conv_layer( - self.conv_cfg, - 64, - 64, - kernel_size=3, - stride=2, - padding=1, - bias=False) + self.conv_cfg, 64, 64, kernel_size=3, stride=2, padding=1, bias=False + ) self.add_module(self.norm2_name, norm2) self.relu = nn.ReLU(inplace=True) # stage 1 - self.stage1_cfg = self.extra['stage1'] - num_channels = self.stage1_cfg['num_channels'][0] - block_type = self.stage1_cfg['block'] - num_blocks = self.stage1_cfg['num_blocks'][0] + self.stage1_cfg = self.extra["stage1"] + num_channels = self.stage1_cfg["num_channels"][0] + block_type = self.stage1_cfg["block"] + num_blocks = self.stage1_cfg["num_blocks"][0] block = self.blocks_dict[block_type] stage1_out_channels = num_channels * block.expansion self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) # stage 2 - self.stage2_cfg = self.extra['stage2'] - num_channels = self.stage2_cfg['num_channels'] - block_type = self.stage2_cfg['block'] + self.stage2_cfg = self.extra["stage2"] + num_channels = self.stage2_cfg["num_channels"] + block_type = self.stage2_cfg["block"] block = self.blocks_dict[block_type] num_channels = [channel * block.expansion for channel in num_channels] - self.transition1 = self._make_transition_layer([stage1_out_channels], - num_channels) + self.transition1 = self._make_transition_layer( + [stage1_out_channels], num_channels + ) self.stage2, pre_stage_channels = self._make_stage( - self.stage2_cfg, num_channels) + self.stage2_cfg, num_channels + ) # stage 3 - self.stage3_cfg = self.extra['stage3'] - num_channels = self.stage3_cfg['num_channels'] - block_type = self.stage3_cfg['block'] + self.stage3_cfg = self.extra["stage3"] + num_channels = self.stage3_cfg["num_channels"] + block_type = self.stage3_cfg["block"] block = self.blocks_dict[block_type] num_channels = [channel * block.expansion for channel in num_channels] - self.transition2 = self._make_transition_layer(pre_stage_channels, - num_channels) + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) self.stage3, pre_stage_channels = self._make_stage( - self.stage3_cfg, num_channels) + self.stage3_cfg, num_channels + ) # stage 4 - self.stage4_cfg = self.extra['stage4'] - num_channels = self.stage4_cfg['num_channels'] - block_type = self.stage4_cfg['block'] + self.stage4_cfg = self.extra["stage4"] + num_channels = self.stage4_cfg["num_channels"] + block_type = self.stage4_cfg["block"] block = self.blocks_dict[block_type] num_channels = [channel * block.expansion for channel in num_channels] - self.transition3 = self._make_transition_layer(pre_stage_channels, - num_channels) + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) self.stage4, pre_stage_channels = self._make_stage( - self.stage4_cfg, num_channels, multiscale_output=multiscale_output) + self.stage4_cfg, num_channels, multiscale_output=multiscale_output + ) self._freeze_stages() @@ -431,8 +453,7 @@ def norm2(self): """nn.Module: the normalization layer named "norm2" """ return getattr(self, self.norm2_name) - def _make_transition_layer(self, num_channels_pre_layer, - num_channels_cur_layer): + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): """Make transition layer.""" num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) @@ -450,18 +471,25 @@ def _make_transition_layer(self, num_channels_pre_layer, kernel_size=3, stride=1, padding=1, - bias=False), - build_norm_layer(self.norm_cfg, - num_channels_cur_layer[i])[1], - nn.ReLU(inplace=True))) + bias=False, + ), + build_norm_layer(self.norm_cfg, num_channels_cur_layer[i])[ + 1 + ], + nn.ReLU(inplace=True), + ) + ) else: transition_layers.append(None) else: conv_downsamples = [] for j in range(i + 1 - num_branches_pre): in_channels = num_channels_pre_layer[-1] - out_channels = num_channels_cur_layer[i] \ - if j == i - num_branches_pre else in_channels + out_channels = ( + num_channels_cur_layer[i] + if j == i - num_branches_pre + else in_channels + ) conv_downsamples.append( nn.Sequential( build_conv_layer( @@ -471,9 +499,12 @@ def _make_transition_layer(self, num_channels_pre_layer, kernel_size=3, stride=2, padding=1, - bias=False), + bias=False, + ), build_norm_layer(self.norm_cfg, out_channels)[1], - nn.ReLU(inplace=True))) + nn.ReLU(inplace=True), + ) + ) transition_layers.append(nn.Sequential(*conv_downsamples)) return nn.ModuleList(transition_layers) @@ -489,19 +520,26 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1): planes * block.expansion, kernel_size=1, stride=stride, - bias=False), - build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) + bias=False, + ), + build_norm_layer(self.norm_cfg, planes * block.expansion)[1], + ) layers = [] block_init_cfg = None - if self.pretrained is None and not hasattr( - self, 'init_cfg') and self.zero_init_residual: + if ( + self.pretrained is None + and not hasattr(self, "init_cfg") + and self.zero_init_residual + ): if block is BasicBlock: block_init_cfg = dict( - type='Constant', val=0, override=dict(name='norm2')) + type="Constant", val=0, override=dict(name="norm2") + ) elif block is Bottleneck: block_init_cfg = dict( - type='Constant', val=0, override=dict(name='norm3')) + type="Constant", val=0, override=dict(name="norm3") + ) layers.append( block( @@ -512,7 +550,9 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1): with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg, - init_cfg=block_init_cfg)) + init_cfg=block_init_cfg, + ) + ) inplanes = planes * block.expansion for i in range(1, blocks): layers.append( @@ -522,28 +562,35 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1): with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg, - init_cfg=block_init_cfg)) + init_cfg=block_init_cfg, + ) + ) return Sequential(*layers) def _make_stage(self, layer_config, in_channels, multiscale_output=True): """Make each stage.""" - num_modules = layer_config['num_modules'] - num_branches = layer_config['num_branches'] - num_blocks = layer_config['num_blocks'] - num_channels = layer_config['num_channels'] - block = self.blocks_dict[layer_config['block']] + num_modules = layer_config["num_modules"] + num_branches = layer_config["num_branches"] + num_blocks = layer_config["num_blocks"] + num_channels = layer_config["num_channels"] + block = self.blocks_dict[layer_config["block"]] hr_modules = [] block_init_cfg = None - if self.pretrained is None and not hasattr( - self, 'init_cfg') and self.zero_init_residual: + if ( + self.pretrained is None + and not hasattr(self, "init_cfg") + and self.zero_init_residual + ): if block is BasicBlock: block_init_cfg = dict( - type='Constant', val=0, override=dict(name='norm2')) + type="Constant", val=0, override=dict(name="norm2") + ) elif block is Bottleneck: block_init_cfg = dict( - type='Constant', val=0, override=dict(name='norm3')) + type="Constant", val=0, override=dict(name="norm3") + ) for i in range(num_modules): # multi_scale_output is only used for the last module @@ -563,14 +610,15 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True): with_cp=self.with_cp, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg, - block_init_cfg=block_init_cfg)) + block_init_cfg=block_init_cfg, + ) + ) return Sequential(*hr_modules), in_channels def _freeze_stages(self): """Freeze stages param and norm stats.""" if self.frozen_stages >= 0: - self.norm1.eval() self.norm2.eval() for m in [self.conv1, self.norm1, self.conv2, self.norm2]: @@ -579,13 +627,13 @@ def _freeze_stages(self): for i in range(1, self.frozen_stages + 1): if i == 1: - m = getattr(self, f'layer{i}') - t = getattr(self, f'transition{i}') + m = getattr(self, f"layer{i}") + t = getattr(self, f"transition{i}") elif i == 4: - m = getattr(self, f'stage{i}') + m = getattr(self, f"stage{i}") else: - m = getattr(self, f'stage{i}') - t = getattr(self, f'transition{i}') + m = getattr(self, f"stage{i}") + t = getattr(self, f"transition{i}") m.eval() for param in m.parameters(): param.requires_grad = False @@ -605,7 +653,7 @@ def forward(self, x): x = self.layer1(x) x_list = [] - for i in range(self.stage2_cfg['num_branches']): + for i in range(self.stage2_cfg["num_branches"]): if self.transition1[i] is not None: x_list.append(self.transition1[i](x)) else: @@ -613,7 +661,7 @@ def forward(self, x): y_list = self.stage2(x_list) x_list = [] - for i in range(self.stage3_cfg['num_branches']): + for i in range(self.stage3_cfg["num_branches"]): if self.transition2[i] is not None: x_list.append(self.transition2[i](y_list[-1])) else: @@ -621,7 +669,7 @@ def forward(self, x): y_list = self.stage3(x_list) x_list = [] - for i in range(self.stage4_cfg['num_branches']): + for i in range(self.stage4_cfg["num_branches"]): if self.transition3[i] is not None: x_list.append(self.transition3[i](y_list[-1])) else: @@ -633,7 +681,7 @@ def forward(self, x): def train(self, mode=True): """Convert the model into training mode will keeping the normalization layer freezed.""" - super(HRNet, self).train(mode) + super().train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): diff --git a/mmsegmentation/mmseg/models/backbones/icnet.py b/mmsegmentation/mmseg/models/backbones/icnet.py index 6faaeab..6413a16 100644 --- a/mmsegmentation/mmseg/models/backbones/icnet.py +++ b/mmsegmentation/mmseg/models/backbones/icnet.py @@ -43,35 +43,38 @@ class ICNet(BaseModule): Default: None. """ - def __init__(self, - backbone_cfg, - in_channels=3, - layer_channels=(512, 2048), - light_branch_middle_channels=32, - psp_out_channels=512, - out_channels=(64, 256, 256), - pool_scales=(1, 2, 3, 6), - conv_cfg=None, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='ReLU'), - align_corners=False, - init_cfg=None): + def __init__( + self, + backbone_cfg, + in_channels=3, + layer_channels=(512, 2048), + light_branch_middle_channels=32, + psp_out_channels=512, + out_channels=(64, 256, 256), + pool_scales=(1, 2, 3, 6), + conv_cfg=None, + norm_cfg=dict(type="BN", requires_grad=True), + act_cfg=dict(type="ReLU"), + align_corners=False, + init_cfg=None, + ): if backbone_cfg is None: - raise TypeError('backbone_cfg must be passed from config file!') + raise TypeError("backbone_cfg must be passed from config file!") if init_cfg is None: init_cfg = [ - dict(type='Kaiming', mode='fan_out', layer='Conv2d'), - dict(type='Constant', val=1, layer='_BatchNorm'), - dict(type='Normal', mean=0.01, layer='Linear') + dict(type="Kaiming", mode="fan_out", layer="Conv2d"), + dict(type="Constant", val=1, layer="_BatchNorm"), + dict(type="Normal", mean=0.01, layer="Linear"), ] - super(ICNet, self).__init__(init_cfg=init_cfg) + super().__init__(init_cfg=init_cfg) self.align_corners = align_corners self.backbone = build_backbone(backbone_cfg) # Note: Default `ceil_mode` is false in nn.MaxPool2d, set # `ceil_mode=True` to keep information in the corner of feature map. self.backbone.maxpool = nn.MaxPool2d( - kernel_size=3, stride=2, padding=1, ceil_mode=True) + kernel_size=3, stride=2, padding=1, ceil_mode=True + ) self.psp_modules = PPM( pool_scales=pool_scales, @@ -80,7 +83,8 @@ def __init__(self, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, - align_corners=align_corners) + align_corners=align_corners, + ) self.psp_bottleneck = ConvModule( layer_channels[1] + len(pool_scales) * psp_out_channels, @@ -89,7 +93,8 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) self.conv_sub1 = nn.Sequential( ConvModule( @@ -99,7 +104,8 @@ def __init__(self, stride=2, padding=1, conv_cfg=conv_cfg, - norm_cfg=norm_cfg), + norm_cfg=norm_cfg, + ), ConvModule( in_channels=light_branch_middle_channels, out_channels=light_branch_middle_channels, @@ -107,7 +113,8 @@ def __init__(self, stride=2, padding=1, conv_cfg=conv_cfg, - norm_cfg=norm_cfg), + norm_cfg=norm_cfg, + ), ConvModule( in_channels=light_branch_middle_channels, out_channels=out_channels[0], @@ -115,21 +122,17 @@ def __init__(self, stride=2, padding=1, conv_cfg=conv_cfg, - norm_cfg=norm_cfg)) + norm_cfg=norm_cfg, + ), + ) self.conv_sub2 = ConvModule( - layer_channels[0], - out_channels[1], - 1, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg) + layer_channels[0], out_channels[1], 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg + ) self.conv_sub4 = ConvModule( - psp_out_channels, - out_channels[2], - 1, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg) + psp_out_channels, out_channels[2], 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg + ) def forward(self, x): output = [] @@ -139,10 +142,8 @@ def forward(self, x): # sub 2 x = resize( - x, - scale_factor=0.5, - mode='bilinear', - align_corners=self.align_corners) + x, scale_factor=0.5, mode="bilinear", align_corners=self.align_corners + ) x = self.backbone.stem(x) x = self.backbone.maxpool(x) x = self.backbone.layer1(x) @@ -151,10 +152,8 @@ def forward(self, x): # sub 4 x = resize( - x, - scale_factor=0.5, - mode='bilinear', - align_corners=self.align_corners) + x, scale_factor=0.5, mode="bilinear", align_corners=self.align_corners + ) x = self.backbone.layer3(x) x = self.backbone.layer4(x) psp_outs = self.psp_modules(x) + [x] diff --git a/mmsegmentation/mmseg/models/backbones/mae.py b/mmsegmentation/mmseg/models/backbones/mae.py index d3e8754..de51099 100644 --- a/mmsegmentation/mmseg/models/backbones/mae.py +++ b/mmsegmentation/mmseg/models/backbones/mae.py @@ -3,8 +3,7 @@ import torch import torch.nn as nn -from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, - trunc_normal_) +from mmcv.cnn.utils.weight_init import constant_init, kaiming_init, trunc_normal_ from mmcv.runner import ModuleList, _load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm @@ -28,8 +27,6 @@ def init_weights(self): # with `trunc_normal`, `init_weights` here does # nothing and just passes directly - pass - class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer): """Implements one encoder layer in Vision Transformer. @@ -80,27 +77,29 @@ class MAE(BEiT): Default: None. """ - def __init__(self, - img_size=224, - patch_size=16, - in_channels=3, - embed_dims=768, - num_layers=12, - num_heads=12, - mlp_ratio=4, - out_indices=-1, - attn_drop_rate=0., - drop_path_rate=0., - norm_cfg=dict(type='LN'), - act_cfg=dict(type='GELU'), - patch_norm=False, - final_norm=False, - num_fcs=2, - norm_eval=False, - pretrained=None, - init_values=0.1, - init_cfg=None): - super(MAE, self).__init__( + def __init__( + self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_cfg=dict(type="LN"), + act_cfg=dict(type="GELU"), + patch_norm=False, + final_norm=False, + num_fcs=2, + norm_eval=False, + pretrained=None, + init_values=0.1, + init_cfg=None, + ): + super().__init__( img_size=img_size, patch_size=patch_size, in_channels=in_channels, @@ -120,18 +119,17 @@ def __init__(self, norm_eval=norm_eval, pretrained=pretrained, init_values=init_values, - init_cfg=init_cfg) + init_cfg=init_cfg, + ) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) self.num_patches = self.patch_shape[0] * self.patch_shape[1] - self.pos_embed = nn.Parameter( - torch.zeros(1, self.num_patches + 1, embed_dims)) + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dims)) def _build_layers(self): dpr = [ - x.item() - for x in torch.linspace(0, self.drop_path_rate, self.num_layers) + x.item() for x in torch.linspace(0, self.drop_path_rate, self.num_layers) ] self.layers = ModuleList() for i in range(self.num_layers): @@ -147,7 +145,9 @@ def _build_layers(self): act_cfg=self.act_cfg, norm_cfg=self.norm_cfg, window_size=self.patch_shape, - init_values=self.init_values)) + init_values=self.init_values, + ) + ) def fix_init_weight(self): """Rescale the initialization according to layer id. @@ -165,10 +165,9 @@ def rescale(param, layer_id): rescale(layer.ffn.layers[1].weight.data, layer_id + 1) def init_weights(self): - def _init_weights(m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -178,43 +177,45 @@ def _init_weights(m): self.apply(_init_weights) self.fix_init_weight() - if (isinstance(self.init_cfg, dict) - and self.init_cfg.get('type') == 'Pretrained'): + if ( + isinstance(self.init_cfg, dict) + and self.init_cfg.get("type") == "Pretrained" + ): logger = get_root_logger() checkpoint = _load_checkpoint( - self.init_cfg['checkpoint'], logger=logger, map_location='cpu') + self.init_cfg["checkpoint"], logger=logger, map_location="cpu" + ) state_dict = self.resize_rel_pos_embed(checkpoint) state_dict = self.resize_abs_pos_embed(state_dict) self.load_state_dict(state_dict, False) elif self.init_cfg is not None: - super(MAE, self).init_weights() + super().init_weights() else: # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 # Copyright 2019 Ross Wightman # Licensed under the Apache License, Version 2.0 (the "License") - trunc_normal_(self.cls_token, std=.02) + trunc_normal_(self.cls_token, std=0.02) for n, m in self.named_modules(): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if m.bias is not None: - if 'ffn' in n: - nn.init.normal_(m.bias, mean=0., std=1e-6) + if "ffn" in n: + nn.init.normal_(m.bias, mean=0.0, std=1e-6) else: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): - kaiming_init(m, mode='fan_in', bias=0.) + kaiming_init(m, mode="fan_in", bias=0.0) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): - constant_init(m, val=1.0, bias=0.) + constant_init(m, val=1.0, bias=0.0) def resize_abs_pos_embed(self, state_dict): - if 'pos_embed' in state_dict: - pos_embed_checkpoint = state_dict['pos_embed'] + if "pos_embed" in state_dict: + pos_embed_checkpoint = state_dict["pos_embed"] embedding_size = pos_embed_checkpoint.shape[-1] num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches # height (== width) for the checkpoint position embedding - orig_size = int( - (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(self.num_patches**0.5) # class_token and dist_token are kept unchanged @@ -222,17 +223,18 @@ def resize_abs_pos_embed(self, state_dict): extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] - pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, - embedding_size).permute( - 0, 3, 1, 2) + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), - mode='bicubic', - align_corners=False) + mode="bicubic", + align_corners=False, + ) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) - state_dict['pos_embed'] = new_pos_embed + state_dict["pos_embed"] = new_pos_embed return state_dict def forward(self, inputs): @@ -254,8 +256,11 @@ def forward(self, inputs): if i in self.out_indices: out = x[:, 1:] B, _, C = out.shape - out = out.reshape(B, hw_shape[0], hw_shape[1], - C).permute(0, 3, 1, 2).contiguous() + out = ( + out.reshape(B, hw_shape[0], hw_shape[1], C) + .permute(0, 3, 1, 2) + .contiguous() + ) outs.append(out) return tuple(outs) diff --git a/mmsegmentation/mmseg/models/backbones/mit.py b/mmsegmentation/mmseg/models/backbones/mit.py index 4417cf1..65f1ea0 100644 --- a/mmsegmentation/mmseg/models/backbones/mit.py +++ b/mmsegmentation/mmseg/models/backbones/mit.py @@ -8,8 +8,7 @@ from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.transformer import MultiheadAttention -from mmcv.cnn.utils.weight_init import (constant_init, normal_init, - trunc_normal_init) +from mmcv.cnn.utils.weight_init import constant_init, normal_init, trunc_normal_init from mmcv.runner import BaseModule, ModuleList, Sequential from ..builder import BACKBONES @@ -37,14 +36,16 @@ class MixFFN(BaseModule): Default: None. """ - def __init__(self, - embed_dims, - feedforward_channels, - act_cfg=dict(type='GELU'), - ffn_drop=0., - dropout_layer=None, - init_cfg=None): - super(MixFFN, self).__init__(init_cfg) + def __init__( + self, + embed_dims, + feedforward_channels, + act_cfg=dict(type="GELU"), + ffn_drop=0.0, + dropout_layer=None, + init_cfg=None, + ): + super().__init__(init_cfg) self.embed_dims = embed_dims self.feedforward_channels = feedforward_channels @@ -57,7 +58,8 @@ def __init__(self, out_channels=feedforward_channels, kernel_size=1, stride=1, - bias=True) + bias=True, + ) # 3x3 depth wise conv to provide positional encode information pe_conv = Conv2d( in_channels=feedforward_channels, @@ -66,18 +68,21 @@ def __init__(self, stride=1, padding=(3 - 1) // 2, bias=True, - groups=feedforward_channels) + groups=feedforward_channels, + ) fc2 = Conv2d( in_channels=feedforward_channels, out_channels=in_channels, kernel_size=1, stride=1, - bias=True) + bias=True, + ) drop = nn.Dropout(ffn_drop) layers = [fc1, pe_conv, self.activate, drop, fc2, drop] self.layers = Sequential(*layers) - self.dropout_layer = build_dropout( - dropout_layer) if dropout_layer else torch.nn.Identity() + self.dropout_layer = ( + build_dropout(dropout_layer) if dropout_layer else torch.nn.Identity() + ) def forward(self, x, hw_shape, identity=None): out = nlc_to_nchw(x, hw_shape) @@ -114,17 +119,19 @@ class EfficientMultiheadAttention(MultiheadAttention): Attention of Segformer. Default: 1. """ - def __init__(self, - embed_dims, - num_heads, - attn_drop=0., - proj_drop=0., - dropout_layer=None, - init_cfg=None, - batch_first=True, - qkv_bias=False, - norm_cfg=dict(type='LN'), - sr_ratio=1): + def __init__( + self, + embed_dims, + num_heads, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + init_cfg=None, + batch_first=True, + qkv_bias=False, + norm_cfg=dict(type="LN"), + sr_ratio=1, + ): super().__init__( embed_dims, num_heads, @@ -133,7 +140,8 @@ def __init__(self, dropout_layer=dropout_layer, init_cfg=init_cfg, batch_first=batch_first, - bias=qkv_bias) + bias=qkv_bias, + ) self.sr_ratio = sr_ratio if sr_ratio > 1: @@ -141,21 +149,24 @@ def __init__(self, in_channels=embed_dims, out_channels=embed_dims, kernel_size=sr_ratio, - stride=sr_ratio) + stride=sr_ratio, + ) # The ret[0] of build_norm_layer is norm name. self.norm = build_norm_layer(norm_cfg, embed_dims)[1] # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa from mmseg import digit_version, mmcv_version - if mmcv_version < digit_version('1.3.17'): - warnings.warn('The legacy version of forward function in' - 'EfficientMultiheadAttention is deprecated in' - 'mmcv>=1.3.17 and will no longer support in the' - 'future. Please upgrade your mmcv.') + + if mmcv_version < digit_version("1.3.17"): + warnings.warn( + "The legacy version of forward function in" + "EfficientMultiheadAttention is deprecated in" + "mmcv>=1.3.17 and will no longer support in the" + "future. Please upgrade your mmcv." + ) self.forward = self.legacy_forward def forward(self, x, hw_shape, identity=None): - x_q = x if self.sr_ratio > 1: x_kv = nlc_to_nchw(x, hw_shape) @@ -240,20 +251,22 @@ class TransformerEncoderLayer(BaseModule): some memory while slowing down the training speed. Default: False. """ - def __init__(self, - embed_dims, - num_heads, - feedforward_channels, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - qkv_bias=True, - act_cfg=dict(type='GELU'), - norm_cfg=dict(type='LN'), - batch_first=True, - sr_ratio=1, - with_cp=False): - super(TransformerEncoderLayer, self).__init__() + def __init__( + self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + qkv_bias=True, + act_cfg=dict(type="GELU"), + norm_cfg=dict(type="LN"), + batch_first=True, + sr_ratio=1, + with_cp=False, + ): + super().__init__() # The ret[0] of build_norm_layer is norm name. self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] @@ -263,11 +276,12 @@ def __init__(self, num_heads=num_heads, attn_drop=attn_drop_rate, proj_drop=drop_rate, - dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate), batch_first=batch_first, qkv_bias=qkv_bias, norm_cfg=norm_cfg, - sr_ratio=sr_ratio) + sr_ratio=sr_ratio, + ) # The ret[0] of build_norm_layer is norm name. self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] @@ -276,13 +290,13 @@ def __init__(self, embed_dims=embed_dims, feedforward_channels=feedforward_channels, ffn_drop=drop_rate, - dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), - act_cfg=act_cfg) + dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate), + act_cfg=act_cfg, + ) self.with_cp = with_cp def forward(self, x, hw_shape): - def _inner_forward(x): x = self.attn(self.norm1(x), hw_shape, identity=x) x = self.ffn(self.norm2(x), hw_shape, identity=x) @@ -337,36 +351,41 @@ class MixVisionTransformer(BaseModule): some memory while slowing down the training speed. Default: False. """ - def __init__(self, - in_channels=3, - embed_dims=64, - num_stages=4, - num_layers=[3, 4, 6, 3], - num_heads=[1, 2, 4, 8], - patch_sizes=[7, 3, 3, 3], - strides=[4, 2, 2, 2], - sr_ratios=[8, 4, 2, 1], - out_indices=(0, 1, 2, 3), - mlp_ratio=4, - qkv_bias=True, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - act_cfg=dict(type='GELU'), - norm_cfg=dict(type='LN', eps=1e-6), - pretrained=None, - init_cfg=None, - with_cp=False): - super(MixVisionTransformer, self).__init__(init_cfg=init_cfg) - - assert not (init_cfg and pretrained), \ - 'init_cfg and pretrained cannot be set at the same time' + def __init__( + self, + in_channels=3, + embed_dims=64, + num_stages=4, + num_layers=[3, 4, 6, 3], + num_heads=[1, 2, 4, 8], + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratio=4, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + act_cfg=dict(type="GELU"), + norm_cfg=dict(type="LN", eps=1e-6), + pretrained=None, + init_cfg=None, + with_cp=False, + ): + super().__init__(init_cfg=init_cfg) + + assert not ( + init_cfg and pretrained + ), "init_cfg and pretrained cannot be set at the same time" if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is deprecated, ' - 'please use "init_cfg" instead') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + warnings.warn( + "DeprecationWarning: pretrained is deprecated, " + 'please use "init_cfg" instead' + ) + self.init_cfg = dict(type="Pretrained", checkpoint=pretrained) elif pretrained is not None: - raise TypeError('pretrained must be a str or None') + raise TypeError("pretrained must be a str or None") self.embed_dims = embed_dims self.num_stages = num_stages @@ -376,16 +395,21 @@ def __init__(self, self.strides = strides self.sr_ratios = sr_ratios self.with_cp = with_cp - assert num_stages == len(num_layers) == len(num_heads) \ - == len(patch_sizes) == len(strides) == len(sr_ratios) + assert ( + num_stages + == len(num_layers) + == len(num_heads) + == len(patch_sizes) + == len(strides) + == len(sr_ratios) + ) self.out_indices = out_indices assert max(out_indices) < self.num_stages # transformer encoder dpr = [ - x.item() - for x in torch.linspace(0, drop_path_rate, sum(num_layers)) + x.item() for x in torch.linspace(0, drop_path_rate, sum(num_layers)) ] # stochastic num_layer decay rule cur = 0 @@ -398,21 +422,26 @@ def __init__(self, kernel_size=patch_sizes[i], stride=strides[i], padding=patch_sizes[i] // 2, - norm_cfg=norm_cfg) - layer = ModuleList([ - TransformerEncoderLayer( - embed_dims=embed_dims_i, - num_heads=num_heads[i], - feedforward_channels=mlp_ratio * embed_dims_i, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=dpr[cur + idx], - qkv_bias=qkv_bias, - act_cfg=act_cfg, - norm_cfg=norm_cfg, - with_cp=with_cp, - sr_ratio=sr_ratios[i]) for idx in range(num_layer) - ]) + norm_cfg=norm_cfg, + ) + layer = ModuleList( + [ + TransformerEncoderLayer( + embed_dims=embed_dims_i, + num_heads=num_heads[i], + feedforward_channels=mlp_ratio * embed_dims_i, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[cur + idx], + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + sr_ratio=sr_ratios[i], + ) + for idx in range(num_layer) + ] + ) in_channels = embed_dims_i # The ret[0] of build_norm_layer is norm name. norm = build_norm_layer(norm_cfg, embed_dims_i)[1] @@ -423,17 +452,15 @@ def init_weights(self): if self.init_cfg is None: for m in self.modules(): if isinstance(m, nn.Linear): - trunc_normal_init(m, std=.02, bias=0.) + trunc_normal_init(m, std=0.02, bias=0.0) elif isinstance(m, nn.LayerNorm): - constant_init(m, val=1.0, bias=0.) + constant_init(m, val=1.0, bias=0.0) elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[ - 1] * m.out_channels + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups - normal_init( - m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + normal_init(m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) else: - super(MixVisionTransformer, self).init_weights() + super().init_weights() def forward(self, x): outs = [] diff --git a/mmsegmentation/mmseg/models/backbones/mobilenet_v2.py b/mmsegmentation/mmseg/models/backbones/mobilenet_v2.py index cbb9c6c..6621305 100644 --- a/mmsegmentation/mmseg/models/backbones/mobilenet_v2.py +++ b/mmsegmentation/mmseg/models/backbones/mobilenet_v2.py @@ -47,42 +47,51 @@ class MobileNetV2(BaseModule): # Parameters to build layers. 3 parameters are needed to construct a # layer, from left to right: expand_ratio, channel, num_blocks. - arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4], - [6, 96, 3], [6, 160, 3], [6, 320, 1]] - - def __init__(self, - widen_factor=1., - strides=(1, 2, 2, 2, 1, 2, 1), - dilations=(1, 1, 1, 1, 1, 1, 1), - out_indices=(1, 2, 4, 6), - frozen_stages=-1, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU6'), - norm_eval=False, - with_cp=False, - pretrained=None, - init_cfg=None): - super(MobileNetV2, self).__init__(init_cfg) + arch_settings = [ + [1, 16, 1], + [6, 24, 2], + [6, 32, 3], + [6, 64, 4], + [6, 96, 3], + [6, 160, 3], + [6, 320, 1], + ] + + def __init__( + self, + widen_factor=1.0, + strides=(1, 2, 2, 2, 1, 2, 1), + dilations=(1, 1, 1, 1, 1, 1, 1), + out_indices=(1, 2, 4, 6), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU6"), + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None, + ): + super().__init__(init_cfg) self.pretrained = pretrained - assert not (init_cfg and pretrained), \ - 'init_cfg and pretrained cannot be setting at the same time' + assert not ( + init_cfg and pretrained + ), "init_cfg and pretrained cannot be setting at the same time" if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is a deprecated, ' - 'please use "init_cfg" instead') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + warnings.warn( + "DeprecationWarning: pretrained is a deprecated, " + 'please use "init_cfg" instead' + ) + self.init_cfg = dict(type="Pretrained", checkpoint=pretrained) elif pretrained is None: if init_cfg is None: self.init_cfg = [ - dict(type='Kaiming', layer='Conv2d'), - dict( - type='Constant', - val=1, - layer=['_BatchNorm', 'GroupNorm']) + dict(type="Kaiming", layer="Conv2d"), + dict(type="Constant", val=1, layer=["_BatchNorm", "GroupNorm"]), ] else: - raise TypeError('pretrained must be a str or None') + raise TypeError("pretrained must be a str or None") self.widen_factor = widen_factor self.strides = strides @@ -91,12 +100,16 @@ def __init__(self, self.out_indices = out_indices for index in out_indices: if index not in range(0, 7): - raise ValueError('the item in out_indices must in ' - f'range(0, 7). But received {index}') + raise ValueError( + "the item in out_indices must in " + f"range(0, 7). But received {index}" + ) if frozen_stages not in range(-1, 7): - raise ValueError('frozen_stages must be in range(-1, 7). ' - f'But received {frozen_stages}') + raise ValueError( + "frozen_stages must be in range(-1, 7). " + f"But received {frozen_stages}" + ) self.out_indices = out_indices self.frozen_stages = frozen_stages self.conv_cfg = conv_cfg @@ -115,7 +128,8 @@ def __init__(self, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.layers = [] @@ -129,13 +143,13 @@ def __init__(self, num_blocks=num_blocks, stride=stride, dilation=dilation, - expand_ratio=expand_ratio) - layer_name = f'layer{i + 1}' + expand_ratio=expand_ratio, + ) + layer_name = f"layer{i + 1}" self.add_module(layer_name, inverted_res_layer) self.layers.append(layer_name) - def make_layer(self, out_channels, num_blocks, stride, dilation, - expand_ratio): + def make_layer(self, out_channels, num_blocks, stride, dilation, expand_ratio): """Stack InvertedResidual blocks to build a layer for MobileNetV2. Args: @@ -158,7 +172,9 @@ def make_layer(self, out_channels, num_blocks, stride, dilation, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, - with_cp=self.with_cp)) + with_cp=self.with_cp, + ) + ) self.in_channels = out_channels return nn.Sequential(*layers) @@ -183,13 +199,13 @@ def _freeze_stages(self): for param in self.conv1.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): - layer = getattr(self, f'layer{i}') + layer = getattr(self, f"layer{i}") layer.eval() for param in layer.parameters(): param.requires_grad = False def train(self, mode=True): - super(MobileNetV2, self).train(mode) + super().train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): diff --git a/mmsegmentation/mmseg/models/backbones/mobilenet_v3.py b/mmsegmentation/mmseg/models/backbones/mobilenet_v3.py index dd3d6eb..0670a43 100644 --- a/mmsegmentation/mmseg/models/backbones/mobilenet_v3.py +++ b/mmsegmentation/mmseg/models/backbones/mobilenet_v3.py @@ -39,68 +39,75 @@ class MobileNetV3(BaseModule): init_cfg (dict or list[dict], optional): Initialization config dict. Default: None """ + # Parameters to build each block: # [kernel size, mid channels, out channels, with_se, act type, stride] arch_settings = { - 'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4 - [3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8 - [3, 88, 24, False, 'ReLU', 1], - [5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16 - [5, 240, 40, True, 'HSwish', 1], - [5, 240, 40, True, 'HSwish', 1], - [5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16 - [5, 144, 48, True, 'HSwish', 1], - [5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32 - [5, 576, 96, True, 'HSwish', 1], - [5, 576, 96, True, 'HSwish', 1]], - 'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2 - [3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4 - [3, 72, 24, False, 'ReLU', 1], - [5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8 - [5, 120, 40, True, 'ReLU', 1], - [5, 120, 40, True, 'ReLU', 1], - [3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16 - [3, 200, 80, False, 'HSwish', 1], - [3, 184, 80, False, 'HSwish', 1], - [3, 184, 80, False, 'HSwish', 1], - [3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16 - [3, 672, 112, True, 'HSwish', 1], - [5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32 - [5, 960, 160, True, 'HSwish', 1], - [5, 960, 160, True, 'HSwish', 1]] + "small": [ + [3, 16, 16, True, "ReLU", 2], # block0 layer1 os=4 + [3, 72, 24, False, "ReLU", 2], # block1 layer2 os=8 + [3, 88, 24, False, "ReLU", 1], + [5, 96, 40, True, "HSwish", 2], # block2 layer4 os=16 + [5, 240, 40, True, "HSwish", 1], + [5, 240, 40, True, "HSwish", 1], + [5, 120, 48, True, "HSwish", 1], # block3 layer7 os=16 + [5, 144, 48, True, "HSwish", 1], + [5, 288, 96, True, "HSwish", 2], # block4 layer9 os=32 + [5, 576, 96, True, "HSwish", 1], + [5, 576, 96, True, "HSwish", 1], + ], + "large": [ + [3, 16, 16, False, "ReLU", 1], # block0 layer1 os=2 + [3, 64, 24, False, "ReLU", 2], # block1 layer2 os=4 + [3, 72, 24, False, "ReLU", 1], + [5, 72, 40, True, "ReLU", 2], # block2 layer4 os=8 + [5, 120, 40, True, "ReLU", 1], + [5, 120, 40, True, "ReLU", 1], + [3, 240, 80, False, "HSwish", 2], # block3 layer7 os=16 + [3, 200, 80, False, "HSwish", 1], + [3, 184, 80, False, "HSwish", 1], + [3, 184, 80, False, "HSwish", 1], + [3, 480, 112, True, "HSwish", 1], # block4 layer11 os=16 + [3, 672, 112, True, "HSwish", 1], + [5, 672, 160, True, "HSwish", 2], # block5 layer13 os=32 + [5, 960, 160, True, "HSwish", 1], + [5, 960, 160, True, "HSwish", 1], + ], } # yapf: disable - def __init__(self, - arch='small', - conv_cfg=None, - norm_cfg=dict(type='BN'), - out_indices=(0, 1, 12), - frozen_stages=-1, - reduction_factor=1, - norm_eval=False, - with_cp=False, - pretrained=None, - init_cfg=None): - super(MobileNetV3, self).__init__(init_cfg) + def __init__( + self, + arch="small", + conv_cfg=None, + norm_cfg=dict(type="BN"), + out_indices=(0, 1, 12), + frozen_stages=-1, + reduction_factor=1, + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None, + ): + super().__init__(init_cfg) self.pretrained = pretrained - assert not (init_cfg and pretrained), \ - 'init_cfg and pretrained cannot be setting at the same time' + assert not ( + init_cfg and pretrained + ), "init_cfg and pretrained cannot be setting at the same time" if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is a deprecated, ' - 'please use "init_cfg" instead') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + warnings.warn( + "DeprecationWarning: pretrained is a deprecated, " + 'please use "init_cfg" instead' + ) + self.init_cfg = dict(type="Pretrained", checkpoint=pretrained) elif pretrained is None: if init_cfg is None: self.init_cfg = [ - dict(type='Kaiming', layer='Conv2d'), - dict( - type='Constant', - val=1, - layer=['_BatchNorm', 'GroupNorm']) + dict(type="Kaiming", layer="Conv2d"), + dict(type="Constant", val=1, layer=["_BatchNorm", "GroupNorm"]), ] else: - raise TypeError('pretrained must be a str or None') + raise TypeError("pretrained must be a str or None") assert arch in self.arch_settings assert isinstance(reduction_factor, int) and reduction_factor > 0 @@ -108,14 +115,17 @@ def __init__(self, for index in out_indices: if index not in range(0, len(self.arch_settings[arch]) + 2): raise ValueError( - 'the item in out_indices must in ' - f'range(0, {len(self.arch_settings[arch])+2}). ' - f'But received {index}') + "the item in out_indices must in " + f"range(0, {len(self.arch_settings[arch])+2}). " + f"But received {index}" + ) if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2): - raise ValueError('frozen_stages must be in range(-1, ' - f'{len(self.arch_settings[arch])+2}). ' - f'But received {frozen_stages}') + raise ValueError( + "frozen_stages must be in range(-1, " + f"{len(self.arch_settings[arch])+2}). " + f"But received {frozen_stages}" + ) self.arch = arch self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg @@ -137,19 +147,18 @@ def _make_layer(self): kernel_size=3, stride=2, padding=1, - conv_cfg=dict(type='Conv2dAdaptivePadding'), + conv_cfg=dict(type="Conv2dAdaptivePadding"), norm_cfg=self.norm_cfg, - act_cfg=dict(type='HSwish')) - self.add_module('layer0', layer) - layers.append('layer0') + act_cfg=dict(type="HSwish"), + ) + self.add_module("layer0", layer) + layers.append("layer0") layer_setting = self.arch_settings[self.arch] for i, params in enumerate(layer_setting): - (kernel_size, mid_channels, out_channels, with_se, act, - stride) = params + (kernel_size, mid_channels, out_channels, with_se, act, stride) = params - if self.arch == 'large' and i >= 12 or self.arch == 'small' and \ - i >= 8: + if self.arch == "large" and i >= 12 or self.arch == "small" and i >= 8: mid_channels = mid_channels // self.reduction_factor out_channels = out_channels // self.reduction_factor @@ -157,8 +166,11 @@ def _make_layer(self): se_cfg = dict( channels=mid_channels, ratio=4, - act_cfg=(dict(type='ReLU'), - dict(type='HSigmoid', bias=3.0, divisor=6.0))) + act_cfg=( + dict(type="ReLU"), + dict(type="HSigmoid", bias=3.0, divisor=6.0), + ), + ) else: se_cfg = None @@ -173,9 +185,10 @@ def _make_layer(self): conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=dict(type=act), - with_cp=self.with_cp) + with_cp=self.with_cp, + ) in_channels = out_channels - layer_name = 'layer{}'.format(i + 1) + layer_name = f"layer{i + 1}" self.add_module(layer_name, layer) layers.append(layer_name) @@ -184,20 +197,21 @@ def _make_layer(self): # block6 layer16 os=32 for large model layer = ConvModule( in_channels=in_channels, - out_channels=576 if self.arch == 'small' else 960, + out_channels=576 if self.arch == "small" else 960, kernel_size=1, stride=1, dilation=4, padding=0, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=dict(type='HSwish')) - layer_name = 'layer{}'.format(len(layer_setting) + 1) + act_cfg=dict(type="HSwish"), + ) + layer_name = f"layer{len(layer_setting) + 1}" self.add_module(layer_name, layer) layers.append(layer_name) # next, convert backbone MobileNetV3 to a semantic segmentation version - if self.arch == 'small': + if self.arch == "small": self.layer4.depthwise_conv.conv.stride = (1, 1) self.layer9.depthwise_conv.conv.stride = (1, 1) for i in range(4, len(layers)): @@ -253,13 +267,13 @@ def forward(self, x): def _freeze_stages(self): for i in range(self.frozen_stages + 1): - layer = getattr(self, f'layer{i}') + layer = getattr(self, f"layer{i}") layer.eval() for param in layer.parameters(): param.requires_grad = False def train(self, mode=True): - super(MobileNetV3, self).train(mode) + super().train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): diff --git a/mmsegmentation/mmseg/models/backbones/resnest.py b/mmsegmentation/mmseg/models/backbones/resnest.py index 91952c2..1f5ab43 100644 --- a/mmsegmentation/mmseg/models/backbones/resnest.py +++ b/mmsegmentation/mmseg/models/backbones/resnest.py @@ -56,20 +56,22 @@ class SplitAttentionConv2d(nn.Module): dcn (dict): Config dict for DCN. Default: None. """ - def __init__(self, - in_channels, - channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - radix=2, - reduction_factor=4, - conv_cfg=None, - norm_cfg=dict(type='BN'), - dcn=None): - super(SplitAttentionConv2d, self).__init__() + def __init__( + self, + in_channels, + channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + radix=2, + reduction_factor=4, + conv_cfg=None, + norm_cfg=dict(type="BN"), + dcn=None, + ): + super().__init__() inter_channels = max(in_channels * radix // reduction_factor, 32) self.radix = radix self.groups = groups @@ -78,9 +80,9 @@ def __init__(self, self.dcn = dcn fallback_on_stride = False if self.with_dcn: - fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + fallback_on_stride = self.dcn.pop("fallback_on_stride", False) if self.with_dcn and not fallback_on_stride: - assert conv_cfg is None, 'conv_cfg must be None for DCN' + assert conv_cfg is None, "conv_cfg must be None for DCN" conv_cfg = dcn self.conv = build_conv_layer( conv_cfg, @@ -91,18 +93,19 @@ def __init__(self, padding=padding, dilation=dilation, groups=groups * radix, - bias=False) - self.norm0_name, norm0 = build_norm_layer( - norm_cfg, channels * radix, postfix=0) + bias=False, + ) + self.norm0_name, norm0 = build_norm_layer(norm_cfg, channels * radix, postfix=0) self.add_module(self.norm0_name, norm0) self.relu = nn.ReLU(inplace=True) self.fc1 = build_conv_layer( - None, channels, inter_channels, 1, groups=self.groups) - self.norm1_name, norm1 = build_norm_layer( - norm_cfg, inter_channels, postfix=1) + None, channels, inter_channels, 1, groups=self.groups + ) + self.norm1_name, norm1 = build_norm_layer(norm_cfg, inter_channels, postfix=1) self.add_module(self.norm1_name, norm1) self.fc2 = build_conv_layer( - None, inter_channels, channels * radix, 1, groups=self.groups) + None, inter_channels, channels * radix, 1, groups=self.groups + ) self.rsoftmax = RSoftmax(radix, groups) @property @@ -161,33 +164,35 @@ class Bottleneck(_Bottleneck): Bottleneck. Default: True. kwargs (dict): Key word arguments for base class. """ + expansion = 4 - def __init__(self, - inplanes, - planes, - groups=1, - base_width=4, - base_channels=64, - radix=2, - reduction_factor=4, - avg_down_stride=True, - **kwargs): + def __init__( + self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs, + ): """Bottleneck block for ResNeSt.""" - super(Bottleneck, self).__init__(inplanes, planes, **kwargs) + super().__init__(inplanes, planes, **kwargs) if groups == 1: width = self.planes else: - width = math.floor(self.planes * - (base_width / base_channels)) * groups + width = math.floor(self.planes * (base_width / base_channels)) * groups self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 - self.norm1_name, norm1 = build_norm_layer( - self.norm_cfg, width, postfix=1) + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, width, postfix=1) self.norm3_name, norm3 = build_norm_layer( - self.norm_cfg, self.planes * self.expansion, postfix=3) + self.norm_cfg, self.planes * self.expansion, postfix=3 + ) self.conv1 = build_conv_layer( self.conv_cfg, @@ -195,7 +200,8 @@ def __init__(self, width, kernel_size=1, stride=self.conv1_stride, - bias=False) + bias=False, + ) self.add_module(self.norm1_name, norm1) self.with_modulated_dcn = False self.conv2 = SplitAttentionConv2d( @@ -210,7 +216,8 @@ def __init__(self, reduction_factor=reduction_factor, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - dcn=self.dcn) + dcn=self.dcn, + ) delattr(self, self.norm2_name) if self.avg_down_stride: @@ -221,11 +228,11 @@ def __init__(self, width, self.planes * self.expansion, kernel_size=1, - bias=False) + bias=False, + ) self.add_module(self.norm3_name, norm3) def forward(self, x): - def _inner_forward(x): identity = x @@ -289,22 +296,24 @@ class ResNeSt(ResNetV1d): 50: (Bottleneck, (3, 4, 6, 3)), 101: (Bottleneck, (3, 4, 23, 3)), 152: (Bottleneck, (3, 8, 36, 3)), - 200: (Bottleneck, (3, 24, 36, 3)) + 200: (Bottleneck, (3, 24, 36, 3)), } - def __init__(self, - groups=1, - base_width=4, - radix=2, - reduction_factor=4, - avg_down_stride=True, - **kwargs): + def __init__( + self, + groups=1, + base_width=4, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs, + ): self.groups = groups self.base_width = base_width self.radix = radix self.reduction_factor = reduction_factor self.avg_down_stride = avg_down_stride - super(ResNeSt, self).__init__(**kwargs) + super().__init__(**kwargs) def make_res_layer(self, **kwargs): """Pack all blocks in a stage into a ``ResLayer``.""" @@ -315,4 +324,5 @@ def make_res_layer(self, **kwargs): radix=self.radix, reduction_factor=self.reduction_factor, avg_down_stride=self.avg_down_stride, - **kwargs) + **kwargs, + ) diff --git a/mmsegmentation/mmseg/models/backbones/resnet.py b/mmsegmentation/mmseg/models/backbones/resnet.py index e8b961d..357b0c3 100644 --- a/mmsegmentation/mmseg/models/backbones/resnet.py +++ b/mmsegmentation/mmseg/models/backbones/resnet.py @@ -16,22 +16,24 @@ class BasicBlock(BaseModule): expansion = 1 - def __init__(self, - inplanes, - planes, - stride=1, - dilation=1, - downsample=None, - style='pytorch', - with_cp=False, - conv_cfg=None, - norm_cfg=dict(type='BN'), - dcn=None, - plugins=None, - init_cfg=None): - super(BasicBlock, self).__init__(init_cfg) - assert dcn is None, 'Not implemented yet.' - assert plugins is None, 'Not implemented yet.' + def __init__( + self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style="pytorch", + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type="BN"), + dcn=None, + plugins=None, + init_cfg=None, + ): + super().__init__(init_cfg) + assert dcn is None, "Not implemented yet." + assert plugins is None, "Not implemented yet." self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) @@ -44,10 +46,12 @@ def __init__(self, stride=stride, padding=dilation, dilation=dilation, - bias=False) + bias=False, + ) self.add_module(self.norm1_name, norm1) self.conv2 = build_conv_layer( - conv_cfg, planes, planes, 3, padding=1, bias=False) + conv_cfg, planes, planes, 3, padding=1, bias=False + ) self.add_module(self.norm2_name, norm2) self.relu = nn.ReLU(inplace=True) @@ -105,26 +109,28 @@ class Bottleneck(BaseModule): expansion = 4 - def __init__(self, - inplanes, - planes, - stride=1, - dilation=1, - downsample=None, - style='pytorch', - with_cp=False, - conv_cfg=None, - norm_cfg=dict(type='BN'), - dcn=None, - plugins=None, - init_cfg=None): - super(Bottleneck, self).__init__(init_cfg) - assert style in ['pytorch', 'caffe'] + def __init__( + self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style="pytorch", + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type="BN"), + dcn=None, + plugins=None, + init_cfg=None, + ): + super().__init__(init_cfg) + assert style in ["pytorch", "caffe"] assert dcn is None or isinstance(dcn, dict) assert plugins is None or isinstance(plugins, list) if plugins is not None: - allowed_position = ['after_conv1', 'after_conv2', 'after_conv3'] - assert all(p['position'] in allowed_position for p in plugins) + allowed_position = ["after_conv1", "after_conv2", "after_conv3"] + assert all(p["position"] in allowed_position for p in plugins) self.inplanes = inplanes self.planes = planes @@ -142,19 +148,22 @@ def __init__(self, if self.with_plugins: # collect plugins for conv1/conv2/conv3 self.after_conv1_plugins = [ - plugin['cfg'] for plugin in plugins - if plugin['position'] == 'after_conv1' + plugin["cfg"] + for plugin in plugins + if plugin["position"] == "after_conv1" ] self.after_conv2_plugins = [ - plugin['cfg'] for plugin in plugins - if plugin['position'] == 'after_conv2' + plugin["cfg"] + for plugin in plugins + if plugin["position"] == "after_conv2" ] self.after_conv3_plugins = [ - plugin['cfg'] for plugin in plugins - if plugin['position'] == 'after_conv3' + plugin["cfg"] + for plugin in plugins + if plugin["position"] == "after_conv3" ] - if self.style == 'pytorch': + if self.style == "pytorch": self.conv1_stride = 1 self.conv2_stride = stride else: @@ -164,7 +173,8 @@ def __init__(self, self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) self.norm3_name, norm3 = build_norm_layer( - norm_cfg, planes * self.expansion, postfix=3) + norm_cfg, planes * self.expansion, postfix=3 + ) self.conv1 = build_conv_layer( conv_cfg, @@ -172,11 +182,12 @@ def __init__(self, planes, kernel_size=1, stride=self.conv1_stride, - bias=False) + bias=False, + ) self.add_module(self.norm1_name, norm1) fallback_on_stride = False if self.with_dcn: - fallback_on_stride = dcn.pop('fallback_on_stride', False) + fallback_on_stride = dcn.pop("fallback_on_stride", False) if not self.with_dcn or fallback_on_stride: self.conv2 = build_conv_layer( conv_cfg, @@ -186,9 +197,10 @@ def __init__(self, stride=self.conv2_stride, padding=dilation, dilation=dilation, - bias=False) + bias=False, + ) else: - assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + assert self.conv_cfg is None, "conv_cfg must be None for DCN" self.conv2 = build_conv_layer( dcn, planes, @@ -197,15 +209,13 @@ def __init__(self, stride=self.conv2_stride, padding=dilation, dilation=dilation, - bias=False) + bias=False, + ) self.add_module(self.norm2_name, norm2) self.conv3 = build_conv_layer( - conv_cfg, - planes, - planes * self.expansion, - kernel_size=1, - bias=False) + conv_cfg, planes, planes * self.expansion, kernel_size=1, bias=False + ) self.add_module(self.norm3_name, norm3) self.relu = nn.ReLU(inplace=True) @@ -213,11 +223,14 @@ def __init__(self, if self.with_plugins: self.after_conv1_plugin_names = self.make_block_plugins( - planes, self.after_conv1_plugins) + planes, self.after_conv1_plugins + ) self.after_conv2_plugin_names = self.make_block_plugins( - planes, self.after_conv2_plugins) + planes, self.after_conv2_plugins + ) self.after_conv3_plugin_names = self.make_block_plugins( - planes * self.expansion, self.after_conv3_plugins) + planes * self.expansion, self.after_conv3_plugins + ) def make_block_plugins(self, in_channels, plugins): """make plugins for block. @@ -234,10 +247,9 @@ def make_block_plugins(self, in_channels, plugins): for plugin in plugins: plugin = plugin.copy() name, layer = build_plugin_layer( - plugin, - in_channels=in_channels, - postfix=plugin.pop('postfix', '')) - assert not hasattr(self, name), f'duplicate plugin {name}' + plugin, in_channels=in_channels, postfix=plugin.pop("postfix", "") + ) + assert not hasattr(self, name), f"duplicate plugin {name}" self.add_module(name, layer) plugin_names.append(name) return plugin_names @@ -390,70 +402,70 @@ class ResNet(BaseModule): 34: (BasicBlock, (3, 4, 6, 3)), 50: (Bottleneck, (3, 4, 6, 3)), 101: (Bottleneck, (3, 4, 23, 3)), - 152: (Bottleneck, (3, 8, 36, 3)) + 152: (Bottleneck, (3, 8, 36, 3)), } - def __init__(self, - depth, - in_channels=3, - stem_channels=64, - base_channels=64, - num_stages=4, - strides=(1, 2, 2, 2), - dilations=(1, 1, 1, 1), - out_indices=(0, 1, 2, 3), - style='pytorch', - deep_stem=False, - avg_down=False, - frozen_stages=-1, - conv_cfg=None, - norm_cfg=dict(type='BN', requires_grad=True), - norm_eval=False, - dcn=None, - stage_with_dcn=(False, False, False, False), - plugins=None, - multi_grid=None, - contract_dilation=False, - with_cp=False, - zero_init_residual=True, - pretrained=None, - init_cfg=None): - super(ResNet, self).__init__(init_cfg) + def __init__( + self, + depth, + in_channels=3, + stem_channels=64, + base_channels=64, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style="pytorch", + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type="BN", requires_grad=True), + norm_eval=False, + dcn=None, + stage_with_dcn=(False, False, False, False), + plugins=None, + multi_grid=None, + contract_dilation=False, + with_cp=False, + zero_init_residual=True, + pretrained=None, + init_cfg=None, + ): + super().__init__(init_cfg) if depth not in self.arch_settings: - raise KeyError(f'invalid depth {depth} for resnet') + raise KeyError(f"invalid depth {depth} for resnet") self.pretrained = pretrained self.zero_init_residual = zero_init_residual block_init_cfg = None - assert not (init_cfg and pretrained), \ - 'init_cfg and pretrained cannot be setting at the same time' + assert not ( + init_cfg and pretrained + ), "init_cfg and pretrained cannot be setting at the same time" if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is a deprecated, ' - 'please use "init_cfg" instead') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + warnings.warn( + "DeprecationWarning: pretrained is a deprecated, " + 'please use "init_cfg" instead' + ) + self.init_cfg = dict(type="Pretrained", checkpoint=pretrained) elif pretrained is None: if init_cfg is None: self.init_cfg = [ - dict(type='Kaiming', layer='Conv2d'), - dict( - type='Constant', - val=1, - layer=['_BatchNorm', 'GroupNorm']) + dict(type="Kaiming", layer="Conv2d"), + dict(type="Constant", val=1, layer=["_BatchNorm", "GroupNorm"]), ] block = self.arch_settings[depth][0] if self.zero_init_residual: if block is BasicBlock: block_init_cfg = dict( - type='Constant', - val=0, - override=dict(name='norm2')) + type="Constant", val=0, override=dict(name="norm2") + ) elif block is Bottleneck: block_init_cfg = dict( - type='Constant', - val=0, - override=dict(name='norm3')) + type="Constant", val=0, override=dict(name="norm3") + ) else: - raise TypeError('pretrained must be a str or None') + raise TypeError("pretrained must be a str or None") self.depth = depth self.stem_channels = stem_channels @@ -496,8 +508,7 @@ def __init__(self, else: stage_plugins = None # multi grid is applied to last layer only - stage_multi_grid = multi_grid if i == len( - self.stage_blocks) - 1 else None + stage_multi_grid = multi_grid if i == len(self.stage_blocks) - 1 else None planes = base_channels * 2**i res_layer = self.make_res_layer( block=self.block, @@ -515,16 +526,18 @@ def __init__(self, plugins=stage_plugins, multi_grid=stage_multi_grid, contract_dilation=contract_dilation, - init_cfg=block_init_cfg) + init_cfg=block_init_cfg, + ) self.inplanes = planes * self.block.expansion - layer_name = f'layer{i+1}' + layer_name = f"layer{i+1}" self.add_module(layer_name, res_layer) self.res_layers.append(layer_name) self._freeze_stages() - self.feat_dim = self.block.expansion * base_channels * 2**( - len(self.stage_blocks) - 1) + self.feat_dim = ( + self.block.expansion * base_channels * 2 ** (len(self.stage_blocks) - 1) + ) def make_stage_plugins(self, plugins, stage_idx): """make plugins for ResNet 'stage_idx'th stage . @@ -571,7 +584,7 @@ def make_stage_plugins(self, plugins, stage_idx): stage_plugins = [] for plugin in plugins: plugin = plugin.copy() - stages = plugin.pop('stages', None) + stages = plugin.pop("stages", None) assert stages is None or len(stages) == self.num_stages # whether to insert plugin into current stage if stages is None or stages[stage_idx]: @@ -599,7 +612,8 @@ def _make_stem_layer(self, in_channels, stem_channels): kernel_size=3, stride=2, padding=1, - bias=False), + bias=False, + ), build_norm_layer(self.norm_cfg, stem_channels // 2)[1], nn.ReLU(inplace=True), build_conv_layer( @@ -609,7 +623,8 @@ def _make_stem_layer(self, in_channels, stem_channels): kernel_size=3, stride=1, padding=1, - bias=False), + bias=False, + ), build_norm_layer(self.norm_cfg, stem_channels // 2)[1], nn.ReLU(inplace=True), build_conv_layer( @@ -619,9 +634,11 @@ def _make_stem_layer(self, in_channels, stem_channels): kernel_size=3, stride=1, padding=1, - bias=False), + bias=False, + ), build_norm_layer(self.norm_cfg, stem_channels)[1], - nn.ReLU(inplace=True)) + nn.ReLU(inplace=True), + ) else: self.conv1 = build_conv_layer( self.conv_cfg, @@ -630,9 +647,11 @@ def _make_stem_layer(self, in_channels, stem_channels): kernel_size=7, stride=2, padding=3, - bias=False) + bias=False, + ) self.norm1_name, norm1 = build_norm_layer( - self.norm_cfg, stem_channels, postfix=1) + self.norm_cfg, stem_channels, postfix=1 + ) self.add_module(self.norm1_name, norm1) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -651,7 +670,7 @@ def _freeze_stages(self): param.requires_grad = False for i in range(1, self.frozen_stages + 1): - m = getattr(self, f'layer{i}') + m = getattr(self, f"layer{i}") m.eval() for param in m.parameters(): param.requires_grad = False @@ -676,7 +695,7 @@ def forward(self, x): def train(self, mode=True): """Convert the model into training mode while keep normalization layer freezed.""" - super(ResNet, self).train(mode) + super().train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): @@ -696,8 +715,7 @@ class ResNetV1c(ResNet): """ def __init__(self, **kwargs): - super(ResNetV1c, self).__init__( - deep_stem=True, avg_down=False, **kwargs) + super().__init__(deep_stem=True, avg_down=False, **kwargs) @BACKBONES.register_module() @@ -710,5 +728,4 @@ class ResNetV1d(ResNet): """ def __init__(self, **kwargs): - super(ResNetV1d, self).__init__( - deep_stem=True, avg_down=True, **kwargs) + super().__init__(deep_stem=True, avg_down=True, **kwargs) diff --git a/mmsegmentation/mmseg/models/backbones/resnext.py b/mmsegmentation/mmseg/models/backbones/resnext.py index 805c27b..a612b93 100644 --- a/mmsegmentation/mmseg/models/backbones/resnext.py +++ b/mmsegmentation/mmseg/models/backbones/resnext.py @@ -16,27 +16,21 @@ class Bottleneck(_Bottleneck): "caffe", the stride-two layer is the first 1x1 conv layer. """ - def __init__(self, - inplanes, - planes, - groups=1, - base_width=4, - base_channels=64, - **kwargs): - super(Bottleneck, self).__init__(inplanes, planes, **kwargs) + def __init__( + self, inplanes, planes, groups=1, base_width=4, base_channels=64, **kwargs + ): + super().__init__(inplanes, planes, **kwargs) if groups == 1: width = self.planes else: - width = math.floor(self.planes * - (base_width / base_channels)) * groups + width = math.floor(self.planes * (base_width / base_channels)) * groups - self.norm1_name, norm1 = build_norm_layer( - self.norm_cfg, width, postfix=1) - self.norm2_name, norm2 = build_norm_layer( - self.norm_cfg, width, postfix=2) + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, width, postfix=1) + self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, width, postfix=2) self.norm3_name, norm3 = build_norm_layer( - self.norm_cfg, self.planes * self.expansion, postfix=3) + self.norm_cfg, self.planes * self.expansion, postfix=3 + ) self.conv1 = build_conv_layer( self.conv_cfg, @@ -44,12 +38,13 @@ def __init__(self, width, kernel_size=1, stride=self.conv1_stride, - bias=False) + bias=False, + ) self.add_module(self.norm1_name, norm1) fallback_on_stride = False self.with_modulated_dcn = False if self.with_dcn: - fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + fallback_on_stride = self.dcn.pop("fallback_on_stride", False) if not self.with_dcn or fallback_on_stride: self.conv2 = build_conv_layer( self.conv_cfg, @@ -60,9 +55,10 @@ def __init__(self, padding=self.dilation, dilation=self.dilation, groups=groups, - bias=False) + bias=False, + ) else: - assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + assert self.conv_cfg is None, "conv_cfg must be None for DCN" self.conv2 = build_conv_layer( self.dcn, width, @@ -72,7 +68,8 @@ def __init__(self, padding=self.dilation, dilation=self.dilation, groups=groups, - bias=False) + bias=False, + ) self.add_module(self.norm2_name, norm2) self.conv3 = build_conv_layer( @@ -80,7 +77,8 @@ def __init__(self, width, self.planes * self.expansion, kernel_size=1, - bias=False) + bias=False, + ) self.add_module(self.norm3_name, norm3) @@ -133,13 +131,13 @@ class ResNeXt(ResNet): arch_settings = { 50: (Bottleneck, (3, 4, 6, 3)), 101: (Bottleneck, (3, 4, 23, 3)), - 152: (Bottleneck, (3, 8, 36, 3)) + 152: (Bottleneck, (3, 8, 36, 3)), } def __init__(self, groups=1, base_width=4, **kwargs): self.groups = groups self.base_width = base_width - super(ResNeXt, self).__init__(**kwargs) + super().__init__(**kwargs) def make_res_layer(self, **kwargs): """Pack all blocks in a stage into a ``ResLayer``""" @@ -147,4 +145,5 @@ def make_res_layer(self, **kwargs): groups=self.groups, base_width=self.base_width, base_channels=self.base_channels, - **kwargs) + **kwargs, + ) diff --git a/mmsegmentation/mmseg/models/backbones/stdc.py b/mmsegmentation/mmseg/models/backbones/stdc.py index 04f2f7a..81e5b7a 100644 --- a/mmsegmentation/mmseg/models/backbones/stdc.py +++ b/mmsegmentation/mmseg/models/backbones/stdc.py @@ -26,25 +26,28 @@ class STDCModule(BaseModule): Default: None. """ - def __init__(self, - in_channels, - out_channels, - stride, - norm_cfg=None, - act_cfg=None, - num_convs=4, - fusion_type='add', - init_cfg=None): - super(STDCModule, self).__init__(init_cfg=init_cfg) + def __init__( + self, + in_channels, + out_channels, + stride, + norm_cfg=None, + act_cfg=None, + num_convs=4, + fusion_type="add", + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) assert num_convs > 1 - assert fusion_type in ['add', 'cat'] + assert fusion_type in ["add", "cat"] self.stride = stride self.with_downsample = True if self.stride == 2 else False self.fusion_type = fusion_type self.layers = ModuleList() conv_0 = ConvModule( - in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg) + in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg + ) if self.with_downsample: self.downsample = ConvModule( @@ -55,9 +58,10 @@ def __init__(self, padding=1, groups=out_channels // 2, norm_cfg=norm_cfg, - act_cfg=None) + act_cfg=None, + ) - if self.fusion_type == 'add': + if self.fusion_type == "add": self.layers.append(nn.Sequential(conv_0, self.downsample)) self.skip = Sequential( ConvModule( @@ -68,13 +72,12 @@ def __init__(self, padding=1, groups=in_channels, norm_cfg=norm_cfg, - act_cfg=None), + act_cfg=None, + ), ConvModule( - in_channels, - out_channels, - 1, - norm_cfg=norm_cfg, - act_cfg=None)) + in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None + ), + ) else: self.layers.append(conv_0) self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) @@ -82,7 +85,7 @@ def __init__(self, self.layers.append(conv_0) for i in range(1, num_convs): - out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i + out_factor = 2 ** (i + 1) if i != num_convs - 1 else 2**i self.layers.append( ConvModule( out_channels // 2**i, @@ -91,10 +94,12 @@ def __init__(self, stride=1, padding=1, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ) + ) def forward(self, inputs): - if self.fusion_type == 'add': + if self.fusion_type == "add": out = self.forward_add(inputs) else: out = self.forward_cat(inputs) @@ -148,33 +153,30 @@ class FeatureFusionModule(BaseModule): Default: None. """ - def __init__(self, - in_channels, - out_channels, - scale_factor=4, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - init_cfg=None): - super(FeatureFusionModule, self).__init__(init_cfg=init_cfg) + def __init__( + self, + in_channels, + out_channels, + scale_factor=4, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) channels = out_channels // scale_factor self.conv0 = ConvModule( - in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) + in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg + ) self.attention = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), ConvModule( - out_channels, - channels, - 1, - norm_cfg=None, - bias=False, - act_cfg=act_cfg), + out_channels, channels, 1, norm_cfg=None, bias=False, act_cfg=act_cfg + ), ConvModule( - channels, - out_channels, - 1, - norm_cfg=None, - bias=False, - act_cfg=None), nn.Sigmoid()) + channels, out_channels, 1, norm_cfg=None, bias=False, act_cfg=None + ), + nn.Sigmoid(), + ) def forward(self, spatial_inputs, context_inputs): inputs = torch.cat([spatial_inputs, context_inputs], dim=1) @@ -225,29 +227,35 @@ class STDCNet(BaseModule): """ arch_settings = { - 'STDCNet1': [(2, 1), (2, 1), (2, 1)], - 'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)] + "STDCNet1": [(2, 1), (2, 1), (2, 1)], + "STDCNet2": [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)], } - def __init__(self, - stdc_type, - in_channels, - channels, - bottleneck_type, - norm_cfg, - act_cfg, - num_convs=4, - with_final_conv=False, - pretrained=None, - init_cfg=None): - super(STDCNet, self).__init__(init_cfg=init_cfg) - assert stdc_type in self.arch_settings, \ - f'invalid structure {stdc_type} for STDCNet.' - assert bottleneck_type in ['add', 'cat'],\ - f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}' - - assert len(channels) == 5,\ - f'invalid channels length {len(channels)} for STDCNet.' + def __init__( + self, + stdc_type, + in_channels, + channels, + bottleneck_type, + norm_cfg, + act_cfg, + num_convs=4, + with_final_conv=False, + pretrained=None, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + assert ( + stdc_type in self.arch_settings + ), f"invalid structure {stdc_type} for STDCNet." + assert bottleneck_type in [ + "add", + "cat", + ], f"bottleneck_type must be `add` or `cat`, got {bottleneck_type}" + + assert ( + len(channels) == 5 + ), f"invalid channels length {len(channels)} for STDCNet." self.in_channels = in_channels self.channels = channels @@ -256,24 +264,28 @@ def __init__(self, self.num_convs = num_convs self.with_final_conv = with_final_conv - self.stages = ModuleList([ - ConvModule( - self.in_channels, - self.channels[0], - kernel_size=3, - stride=2, - padding=1, - norm_cfg=norm_cfg, - act_cfg=act_cfg), - ConvModule( - self.channels[0], - self.channels[1], - kernel_size=3, - stride=2, - padding=1, - norm_cfg=norm_cfg, - act_cfg=act_cfg) - ]) + self.stages = ModuleList( + [ + ConvModule( + self.in_channels, + self.channels[0], + kernel_size=3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ), + ConvModule( + self.channels[0], + self.channels[1], + kernel_size=3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ), + ] + ) # `self.num_shallow_features` is the number of shallow modules in # `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper. # They are both not used for following modules like Attention @@ -285,8 +297,15 @@ def __init__(self, for strides in self.stage_strides: idx = len(self.stages) - 1 self.stages.append( - self._make_stage(self.channels[idx], self.channels[idx + 1], - strides, norm_cfg, act_cfg, bottleneck_type)) + self._make_stage( + self.channels[idx], + self.channels[idx + 1], + strides, + norm_cfg, + act_cfg, + bottleneck_type, + ) + ) # After appending, `self.stages` is a ModuleList including several # shallow modules and STDCModules. # (len(self.stages) == @@ -297,10 +316,12 @@ def __init__(self, max(1024, self.channels[-1]), 1, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) - def _make_stage(self, in_channels, out_channels, strides, norm_cfg, - act_cfg, bottleneck_type): + def _make_stage( + self, in_channels, out_channels, strides, norm_cfg, act_cfg, bottleneck_type + ): layers = [] for i, stride in enumerate(strides): layers.append( @@ -311,7 +332,9 @@ def _make_stage(self, in_channels, out_channels, strides, norm_cfg, norm_cfg, act_cfg, num_convs=self.num_convs, - fusion_type=bottleneck_type)) + fusion_type=bottleneck_type, + ) + ) return Sequential(*layers) def forward(self, x): @@ -321,7 +344,7 @@ def forward(self, x): outs.append(x) if self.with_final_conv: outs[-1] = self.final_conv(outs[-1]) - outs = outs[self.num_shallow_features:] + outs = outs[self.num_shallow_features :] return tuple(outs) @@ -360,31 +383,29 @@ class STDCContextPathNet(BaseModule): auxiliary heads and decoder head. """ - def __init__(self, - backbone_cfg, - last_in_channels=(1024, 512), - out_channels=128, - ffm_cfg=dict( - in_channels=512, out_channels=256, scale_factor=4), - upsample_mode='nearest', - align_corners=None, - norm_cfg=dict(type='BN'), - init_cfg=None): - super(STDCContextPathNet, self).__init__(init_cfg=init_cfg) + def __init__( + self, + backbone_cfg, + last_in_channels=(1024, 512), + out_channels=128, + ffm_cfg=dict(in_channels=512, out_channels=256, scale_factor=4), + upsample_mode="nearest", + align_corners=None, + norm_cfg=dict(type="BN"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.backbone = build_backbone(backbone_cfg) self.arms = ModuleList() self.convs = ModuleList() for channels in last_in_channels: self.arms.append(AttentionRefinementModule(channels, out_channels)) self.convs.append( - ConvModule( - out_channels, - out_channels, - 3, - padding=1, - norm_cfg=norm_cfg)) + ConvModule(out_channels, out_channels, 3, padding=1, norm_cfg=norm_cfg) + ) self.conv_avg = ConvModule( - last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg) + last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg + ) self.ffm = FeatureFusionModule(**ffm_cfg) @@ -400,7 +421,8 @@ def forward(self, x): avg_feat, size=outs[-1].shape[2:], mode=self.upsample_mode, - align_corners=self.align_corners) + align_corners=self.align_corners, + ) arms_out = [] for i in range(len(self.arms)): x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up @@ -408,7 +430,8 @@ def forward(self, x): x_arm, size=outs[len(outs) - 1 - i - 1].shape[2:], mode=self.upsample_mode, - align_corners=self.align_corners) + align_corners=self.align_corners, + ) feature_up = self.convs[i](feature_up) arms_out.append(feature_up) diff --git a/mmsegmentation/mmseg/models/backbones/swin.py b/mmsegmentation/mmseg/models/backbones/swin.py index cbf1328..8ca53fb 100644 --- a/mmsegmentation/mmseg/models/backbones/swin.py +++ b/mmsegmentation/mmseg/models/backbones/swin.py @@ -9,10 +9,8 @@ import torch.utils.checkpoint as cp from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN, build_dropout -from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_, - trunc_normal_init) -from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList, - load_state_dict) +from mmcv.cnn.utils.weight_init import constant_init, trunc_normal_, trunc_normal_init +from mmcv.runner import BaseModule, CheckpointLoader, ModuleList, load_state_dict from mmcv.utils import to_2tuple from ...utils import get_root_logger @@ -39,16 +37,17 @@ class WindowMSA(BaseModule): Default: None. """ - def __init__(self, - embed_dims, - num_heads, - window_size, - qkv_bias=True, - qk_scale=None, - attn_drop_rate=0., - proj_drop_rate=0., - init_cfg=None): - + def __init__( + self, + embed_dims, + num_heads, + window_size, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0.0, + proj_drop_rate=0.0, + init_cfg=None, + ): super().__init__(init_cfg=init_cfg) self.embed_dims = embed_dims self.window_size = window_size # Wh, Ww @@ -58,15 +57,15 @@ def __init__(self, # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), - num_heads)) # 2*Wh-1 * 2*Ww-1, nH + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH # About 2x faster than original impl Wh, Ww = self.window_size rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) rel_position_index = rel_index_coords + rel_index_coords.T rel_position_index = rel_position_index.flip(1).contiguous() - self.register_buffer('relative_position_index', rel_position_index) + self.register_buffer("relative_position_index", rel_position_index) self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop_rate) @@ -87,27 +86,34 @@ def forward(self, x, mask=None): Wh*Ww, Wh*Ww), value should be between (-inf, 0]. """ B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, - C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale - attn = (q @ k.transpose(-2, -1)) + attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], - -1) # Wh*Ww,Wh*Ww,nH + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute( - 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] - attn = attn.view(B // nW, nW, self.num_heads, N, - N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) @@ -148,17 +154,19 @@ class ShiftWindowMSA(BaseModule): Default: None. """ - def __init__(self, - embed_dims, - num_heads, - window_size, - shift_size=0, - qkv_bias=True, - qk_scale=None, - attn_drop_rate=0, - proj_drop_rate=0, - dropout_layer=dict(type='DropPath', drop_prob=0.), - init_cfg=None): + def __init__( + self, + embed_dims, + num_heads, + window_size, + shift_size=0, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0, + proj_drop_rate=0, + dropout_layer=dict(type="DropPath", drop_prob=0.0), + init_cfg=None, + ): super().__init__(init_cfg=init_cfg) self.window_size = window_size @@ -173,14 +181,15 @@ def __init__(self, qk_scale=qk_scale, attn_drop_rate=attn_drop_rate, proj_drop_rate=proj_drop_rate, - init_cfg=None) + init_cfg=None, + ) self.drop = build_dropout(dropout_layer) def forward(self, query, hw_shape): B, L, C = query.shape H, W = hw_shape - assert L == H * W, 'input feature has wrong size' + assert L == H * W, "input feature has wrong size" query = query.view(B, H, W, C) # pad feature maps to multiples of window size @@ -192,18 +201,21 @@ def forward(self, query, hw_shape): # cyclic shift if self.shift_size > 0: shifted_query = torch.roll( - query, - shifts=(-self.shift_size, -self.shift_size), - dims=(1, 2)) + query, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) # calculate attention mask for SW-MSA img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device) - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, - -self.shift_size), slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, - -self.shift_size), slice(-self.shift_size, None)) + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) cnt = 0 for h in h_slices: for w in w_slices: @@ -212,12 +224,11 @@ def forward(self, query, hw_shape): # nW, window_size, window_size, 1 mask_windows = self.window_partition(img_mask) - mask_windows = mask_windows.view( - -1, self.window_size * self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, - float(-100.0)).masked_fill( - attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0) + ).masked_fill(attn_mask == 0, float(0.0)) else: shifted_query = query attn_mask = None @@ -231,17 +242,15 @@ def forward(self, query, hw_shape): attn_windows = self.w_msa(query_windows, mask=attn_mask) # merge windows - attn_windows = attn_windows.view(-1, self.window_size, - self.window_size, C) + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # B H' W' C shifted_x = self.window_reverse(attn_windows, H_pad, W_pad) # reverse cyclic shift if self.shift_size > 0: x = torch.roll( - shifted_x, - shifts=(self.shift_size, self.shift_size), - dims=(1, 2)) + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) else: x = shifted_x @@ -264,8 +273,9 @@ def window_reverse(self, windows, H, W): """ window_size = self.window_size B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, - window_size, -1) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x @@ -278,15 +288,14 @@ def window_partition(self, x): """ B, H, W, C = x.shape window_size = self.window_size - x = x.view(B, H // window_size, window_size, W // window_size, - window_size, C) + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() windows = windows.view(-1, window_size, window_size, C) return windows class SwinBlock(BaseModule): - """" + """ " Args: embed_dims (int): The feature dimension. num_heads (int): Parallel attention heads. @@ -310,23 +319,24 @@ class SwinBlock(BaseModule): Default: None. """ - def __init__(self, - embed_dims, - num_heads, - feedforward_channels, - window_size=7, - shift=False, - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - act_cfg=dict(type='GELU'), - norm_cfg=dict(type='LN'), - with_cp=False, - init_cfg=None): - - super(SwinBlock, self).__init__(init_cfg=init_cfg) + def __init__( + self, + embed_dims, + num_heads, + feedforward_channels, + window_size=7, + shift=False, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + act_cfg=dict(type="GELU"), + norm_cfg=dict(type="LN"), + with_cp=False, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.with_cp = with_cp @@ -340,8 +350,9 @@ def __init__(self, qk_scale=qk_scale, attn_drop_rate=attn_drop_rate, proj_drop_rate=drop_rate, - dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), - init_cfg=None) + dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate), + init_cfg=None, + ) self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] self.ffn = FFN( @@ -349,13 +360,13 @@ def __init__(self, feedforward_channels=feedforward_channels, num_fcs=2, ffn_drop=drop_rate, - dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate), act_cfg=act_cfg, add_identity=True, - init_cfg=None) + init_cfg=None, + ) def forward(self, x, hw_shape): - def _inner_forward(x): identity = x x = self.norm1(x) @@ -406,22 +417,24 @@ class SwinBlockSequence(BaseModule): Default: None. """ - def __init__(self, - embed_dims, - num_heads, - feedforward_channels, - depth, - window_size=7, - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - downsample=None, - act_cfg=dict(type='GELU'), - norm_cfg=dict(type='LN'), - with_cp=False, - init_cfg=None): + def __init__( + self, + embed_dims, + num_heads, + feedforward_channels, + depth, + window_size=7, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + downsample=None, + act_cfg=dict(type="GELU"), + norm_cfg=dict(type="LN"), + with_cp=False, + init_cfg=None, + ): super().__init__(init_cfg=init_cfg) if isinstance(drop_path_rate, list): @@ -446,7 +459,8 @@ def __init__(self, act_cfg=act_cfg, norm_cfg=norm_cfg, with_cp=with_cp, - init_cfg=None) + init_cfg=None, + ) self.blocks.append(block) self.downsample = downsample @@ -515,30 +529,32 @@ class SwinTransformer(BaseModule): Defaults to None. """ - def __init__(self, - pretrain_img_size=224, - in_channels=3, - embed_dims=96, - patch_size=4, - window_size=7, - mlp_ratio=4, - depths=(2, 2, 6, 2), - num_heads=(3, 6, 12, 24), - strides=(4, 2, 2, 2), - out_indices=(0, 1, 2, 3), - qkv_bias=True, - qk_scale=None, - patch_norm=True, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.1, - use_abs_pos_embed=False, - act_cfg=dict(type='GELU'), - norm_cfg=dict(type='LN'), - with_cp=False, - pretrained=None, - frozen_stages=-1, - init_cfg=None): + def __init__( + self, + pretrain_img_size=224, + in_channels=3, + embed_dims=96, + patch_size=4, + window_size=7, + mlp_ratio=4, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + strides=(4, 2, 2, 2), + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + patch_norm=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + use_abs_pos_embed=False, + act_cfg=dict(type="GELU"), + norm_cfg=dict(type="LN"), + with_cp=False, + pretrained=None, + frozen_stages=-1, + init_cfg=None, + ): self.frozen_stages = frozen_stages if isinstance(pretrain_img_size, int): @@ -546,53 +562,57 @@ def __init__(self, elif isinstance(pretrain_img_size, tuple): if len(pretrain_img_size) == 1: pretrain_img_size = to_2tuple(pretrain_img_size[0]) - assert len(pretrain_img_size) == 2, \ - f'The size of image should have length 1 or 2, ' \ - f'but got {len(pretrain_img_size)}' - - assert not (init_cfg and pretrained), \ - 'init_cfg and pretrained cannot be specified at the same time' + assert len(pretrain_img_size) == 2, ( + f"The size of image should have length 1 or 2, " + f"but got {len(pretrain_img_size)}" + ) + + assert not ( + init_cfg and pretrained + ), "init_cfg and pretrained cannot be specified at the same time" if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is deprecated, ' - 'please use "init_cfg" instead') - init_cfg = dict(type='Pretrained', checkpoint=pretrained) + warnings.warn( + "DeprecationWarning: pretrained is deprecated, " + 'please use "init_cfg" instead' + ) + init_cfg = dict(type="Pretrained", checkpoint=pretrained) elif pretrained is None: init_cfg = init_cfg else: - raise TypeError('pretrained must be a str or None') + raise TypeError("pretrained must be a str or None") - super(SwinTransformer, self).__init__(init_cfg=init_cfg) + super().__init__(init_cfg=init_cfg) num_layers = len(depths) self.out_indices = out_indices self.use_abs_pos_embed = use_abs_pos_embed - assert strides[0] == patch_size, 'Use non-overlapping patch embed.' + assert strides[0] == patch_size, "Use non-overlapping patch embed." self.patch_embed = PatchEmbed( in_channels=in_channels, embed_dims=embed_dims, - conv_type='Conv2d', + conv_type="Conv2d", kernel_size=patch_size, stride=strides[0], - padding='corner', + padding="corner", norm_cfg=norm_cfg if patch_norm else None, - init_cfg=None) + init_cfg=None, + ) if self.use_abs_pos_embed: patch_row = pretrain_img_size[0] // patch_size patch_col = pretrain_img_size[1] // patch_size num_patches = patch_row * patch_col self.absolute_pos_embed = nn.Parameter( - torch.zeros((1, num_patches, embed_dims))) + torch.zeros((1, num_patches, embed_dims)) + ) self.drop_after_pos = nn.Dropout(p=drop_rate) # set stochastic depth decay rule total_depth = sum(depths) - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, total_depth) - ] + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] self.stages = ModuleList() in_channels = embed_dims @@ -603,7 +623,8 @@ def __init__(self, out_channels=2 * in_channels, stride=strides[i + 1], norm_cfg=norm_cfg if patch_norm else None, - init_cfg=None) + init_cfg=None, + ) else: downsample = None @@ -617,12 +638,13 @@ def __init__(self, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, - drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])], + drop_path_rate=dpr[sum(depths[:i]) : sum(depths[: i + 1])], downsample=downsample, act_cfg=act_cfg, norm_cfg=norm_cfg, with_cp=with_cp, - init_cfg=None) + init_cfg=None, + ) self.stages.append(stage) if downsample: in_channels = downsample.out_channels @@ -631,12 +653,12 @@ def __init__(self, # Add a norm layer for each output for i in out_indices: layer = build_norm_layer(norm_cfg, self.num_features[i])[1] - layer_name = f'norm{i}' + layer_name = f"norm{i}" self.add_module(layer_name, layer) def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" - super(SwinTransformer, self).train(mode) + super().train(mode) self._freeze_stages() def _freeze_stages(self): @@ -649,9 +671,8 @@ def _freeze_stages(self): self.drop_after_pos.eval() for i in range(1, self.frozen_stages + 1): - if (i - 1) in self.out_indices: - norm_layer = getattr(self, f'norm{i-1}') + norm_layer = getattr(self, f"norm{i-1}") norm_layer.eval() for param in norm_layer.parameters(): param.requires_grad = False @@ -664,56 +685,63 @@ def _freeze_stages(self): def init_weights(self): logger = get_root_logger() if self.init_cfg is None: - logger.warn(f'No pre-trained weights for ' - f'{self.__class__.__name__}, ' - f'training start from scratch') + logger.warn( + f"No pre-trained weights for " + f"{self.__class__.__name__}, " + f"training start from scratch" + ) if self.use_abs_pos_embed: trunc_normal_(self.absolute_pos_embed, std=0.02) for m in self.modules(): if isinstance(m, nn.Linear): - trunc_normal_init(m, std=.02, bias=0.) + trunc_normal_init(m, std=0.02, bias=0.0) elif isinstance(m, nn.LayerNorm): - constant_init(m, val=1.0, bias=0.) + constant_init(m, val=1.0, bias=0.0) else: - assert 'checkpoint' in self.init_cfg, f'Only support ' \ - f'specify `Pretrained` in ' \ - f'`init_cfg` in ' \ - f'{self.__class__.__name__} ' + assert "checkpoint" in self.init_cfg, ( + f"Only support " + f"specify `Pretrained` in " + f"`init_cfg` in " + f"{self.__class__.__name__} " + ) ckpt = CheckpointLoader.load_checkpoint( - self.init_cfg['checkpoint'], logger=logger, map_location='cpu') - if 'state_dict' in ckpt: - _state_dict = ckpt['state_dict'] - elif 'model' in ckpt: - _state_dict = ckpt['model'] + self.init_cfg["checkpoint"], logger=logger, map_location="cpu" + ) + if "state_dict" in ckpt: + _state_dict = ckpt["state_dict"] + elif "model" in ckpt: + _state_dict = ckpt["model"] else: _state_dict = ckpt state_dict = OrderedDict() for k, v in _state_dict.items(): - if k.startswith('backbone.'): + if k.startswith("backbone."): state_dict[k[9:]] = v else: state_dict[k] = v # strip prefix of state_dict - if list(state_dict.keys())[0].startswith('module.'): + if list(state_dict.keys())[0].startswith("module."): state_dict = {k[7:]: v for k, v in state_dict.items()} # reshape absolute position embedding - if state_dict.get('absolute_pos_embed') is not None: - absolute_pos_embed = state_dict['absolute_pos_embed'] + if state_dict.get("absolute_pos_embed") is not None: + absolute_pos_embed = state_dict["absolute_pos_embed"] N1, L, C1 = absolute_pos_embed.size() N2, C2, H, W = self.absolute_pos_embed.size() if N1 != N2 or C1 != C2 or L != H * W: - logger.warning('Error in loading absolute_pos_embed, pass') + logger.warning("Error in loading absolute_pos_embed, pass") else: - state_dict['absolute_pos_embed'] = absolute_pos_embed.view( - N2, H, W, C2).permute(0, 3, 1, 2).contiguous() + state_dict["absolute_pos_embed"] = ( + absolute_pos_embed.view(N2, H, W, C2) + .permute(0, 3, 1, 2) + .contiguous() + ) # interpolate position bias table if needed relative_position_bias_table_keys = [ - k for k in state_dict.keys() - if 'relative_position_bias_table' in k + k for k in state_dict.keys() if "relative_position_bias_table" in k ] for table_key in relative_position_bias_table_keys: table_pretrained = state_dict[table_key] @@ -721,16 +749,20 @@ def init_weights(self): L1, nH1 = table_pretrained.size() L2, nH2 = table_current.size() if nH1 != nH2: - logger.warning(f'Error in loading {table_key}, pass') + logger.warning(f"Error in loading {table_key}, pass") elif L1 != L2: S1 = int(L1**0.5) S2 = int(L2**0.5) table_pretrained_resized = F.interpolate( table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1), size=(S2, S2), - mode='bicubic') - state_dict[table_key] = table_pretrained_resized.view( - nH2, L2).permute(1, 0).contiguous() + mode="bicubic", + ) + state_dict[table_key] = ( + table_pretrained_resized.view(nH2, L2) + .permute(1, 0) + .contiguous() + ) # load state_dict load_state_dict(self, state_dict, strict=False, logger=logger) @@ -746,11 +778,13 @@ def forward(self, x): for i, stage in enumerate(self.stages): x, hw_shape, out, out_hw_shape = stage(x, hw_shape) if i in self.out_indices: - norm_layer = getattr(self, f'norm{i}') + norm_layer = getattr(self, f"norm{i}") out = norm_layer(out) - out = out.view(-1, *out_hw_shape, - self.num_features[i]).permute(0, 3, 1, - 2).contiguous() + out = ( + out.view(-1, *out_hw_shape, self.num_features[i]) + .permute(0, 3, 1, 2) + .contiguous() + ) outs.append(out) return outs diff --git a/mmsegmentation/mmseg/models/backbones/timm_backbone.py b/mmsegmentation/mmseg/models/backbones/timm_backbone.py index 01b29fc..1cd7638 100644 --- a/mmsegmentation/mmseg/models/backbones/timm_backbone.py +++ b/mmsegmentation/mmseg/models/backbones/timm_backbone.py @@ -1,14 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -try: - import timm -except ImportError: - timm = None - from mmcv.cnn.bricks.registry import NORM_LAYERS from mmcv.runner import BaseModule from ..builder import BACKBONES +try: + import timm +except ImportError: + timm = None + @BACKBONES.register_module() class TIMMBackbone(BaseModule): @@ -30,16 +30,16 @@ def __init__( model_name, features_only=True, pretrained=True, - checkpoint_path='', + checkpoint_path="", in_channels=3, init_cfg=None, **kwargs, ): if timm is None: - raise RuntimeError('timm is not installed') - super(TIMMBackbone, self).__init__(init_cfg) - if 'norm_layer' in kwargs: - kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer']) + raise RuntimeError("timm is not installed") + super().__init__(init_cfg) + if "norm_layer" in kwargs: + kwargs["norm_layer"] = NORM_LAYERS.get(kwargs["norm_layer"]) self.timm_model = timm.create_model( model_name=model_name, features_only=features_only, diff --git a/mmsegmentation/mmseg/models/backbones/twins.py b/mmsegmentation/mmseg/models/backbones/twins.py index 6bd9469..1fba9a5 100644 --- a/mmsegmentation/mmseg/models/backbones/twins.py +++ b/mmsegmentation/mmseg/models/backbones/twins.py @@ -8,8 +8,7 @@ from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.transformer import FFN -from mmcv.cnn.utils.weight_init import (constant_init, normal_init, - trunc_normal_init) +from mmcv.cnn.utils.weight_init import constant_init, normal_init, trunc_normal_init from mmcv.runner import BaseModule, ModuleList from torch.nn.modules.batchnorm import _BatchNorm @@ -51,18 +50,20 @@ class GlobalSubsampledAttention(EfficientMultiheadAttention): Defaults to None. """ - def __init__(self, - embed_dims, - num_heads, - attn_drop=0., - proj_drop=0., - dropout_layer=None, - batch_first=True, - qkv_bias=True, - norm_cfg=dict(type='LN'), - sr_ratio=1, - init_cfg=None): - super(GlobalSubsampledAttention, self).__init__( + def __init__( + self, + embed_dims, + num_heads, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=True, + qkv_bias=True, + norm_cfg=dict(type="LN"), + sr_ratio=1, + init_cfg=None, + ): + super().__init__( embed_dims, num_heads, attn_drop=attn_drop, @@ -72,7 +73,8 @@ def __init__(self, qkv_bias=qkv_bias, norm_cfg=norm_cfg, sr_ratio=sr_ratio, - init_cfg=init_cfg) + init_cfg=init_cfg, + ) class GSAEncoderLayer(BaseModule): @@ -99,20 +101,22 @@ class GSAEncoderLayer(BaseModule): Defaults to None. """ - def __init__(self, - embed_dims, - num_heads, - feedforward_channels, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - num_fcs=2, - qkv_bias=True, - act_cfg=dict(type='GELU'), - norm_cfg=dict(type='LN'), - sr_ratio=1., - init_cfg=None): - super(GSAEncoderLayer, self).__init__(init_cfg=init_cfg) + def __init__( + self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type="GELU"), + norm_cfg=dict(type="LN"), + sr_ratio=1.0, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] self.attn = GlobalSubsampledAttention( @@ -120,10 +124,11 @@ def __init__(self, num_heads=num_heads, attn_drop=attn_drop_rate, proj_drop=drop_rate, - dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate), qkv_bias=qkv_bias, norm_cfg=norm_cfg, - sr_ratio=sr_ratio) + sr_ratio=sr_ratio, + ) self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] self.ffn = FFN( @@ -131,16 +136,19 @@ def __init__(self, feedforward_channels=feedforward_channels, num_fcs=num_fcs, ffn_drop=drop_rate, - dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate), act_cfg=act_cfg, - add_identity=False) + add_identity=False, + ) - self.drop_path = build_dropout( - dict(type='DropPath', drop_prob=drop_path_rate) - ) if drop_path_rate > 0. else nn.Identity() + self.drop_path = ( + build_dropout(dict(type="DropPath", drop_prob=drop_path_rate)) + if drop_path_rate > 0.0 + else nn.Identity() + ) def forward(self, x, hw_shape): - x = x + self.drop_path(self.attn(self.norm1(x), hw_shape, identity=0.)) + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape, identity=0.0)) x = x + self.drop_path(self.ffn(self.norm2(x))) return x @@ -163,20 +171,22 @@ class LocallyGroupedSelfAttention(BaseModule): Defaults to None. """ - def __init__(self, - embed_dims, - num_heads=8, - qkv_bias=False, - qk_scale=None, - attn_drop_rate=0., - proj_drop_rate=0., - window_size=1, - init_cfg=None): - super(LocallyGroupedSelfAttention, self).__init__(init_cfg=init_cfg) - - assert embed_dims % num_heads == 0, f'dim {embed_dims} should be ' \ - f'divided by num_heads ' \ - f'{num_heads}.' + def __init__( + self, + embed_dims, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop_rate=0.0, + proj_drop_rate=0.0, + window_size=1, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + + assert embed_dims % num_heads == 0, ( + f"dim {embed_dims} should be " f"divided by num_heads " f"{num_heads}." + ) self.embed_dims = embed_dims self.num_heads = num_heads head_dim = embed_dims // num_heads @@ -207,33 +217,45 @@ def forward(self, x, hw_shape): mask[:, :, -pad_r:].fill_(1) # [B, _h, _w, window_size, window_size, C] - x = x.reshape(b, _h, self.window_size, _w, self.window_size, - c).transpose(2, 3) - mask = mask.reshape(1, _h, self.window_size, _w, - self.window_size).transpose(2, 3).reshape( - 1, _h * _w, - self.window_size * self.window_size) + x = x.reshape(b, _h, self.window_size, _w, self.window_size, c).transpose(2, 3) + mask = ( + mask.reshape(1, _h, self.window_size, _w, self.window_size) + .transpose(2, 3) + .reshape(1, _h * _w, self.window_size * self.window_size) + ) # [1, _h*_w, window_size*window_size, window_size*window_size] attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) - attn_mask = attn_mask.masked_fill(attn_mask != 0, - float(-1000.0)).masked_fill( - attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-1000.0)).masked_fill( + attn_mask == 0, float(0.0) + ) # [3, B, _w*_h, nhead, window_size*window_size, dim] - qkv = self.qkv(x).reshape(b, _h * _w, - self.window_size * self.window_size, 3, - self.num_heads, c // self.num_heads).permute( - 3, 0, 1, 4, 2, 5) + qkv = ( + self.qkv(x) + .reshape( + b, + _h * _w, + self.window_size * self.window_size, + 3, + self.num_heads, + c // self.num_heads, + ) + .permute(3, 0, 1, 4, 2, 5) + ) q, k, v = qkv[0], qkv[1], qkv[2] # [B, _h*_w, n_head, window_size*window_size, window_size*window_size] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn + attn_mask.unsqueeze(2) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - attn = (attn @ v).transpose(2, 3).reshape(b, _h, _w, self.window_size, - self.window_size, c) - x = attn.transpose(2, 3).reshape(b, _h * self.window_size, - _w * self.window_size, c) + attn = ( + (attn @ v) + .transpose(2, 3) + .reshape(b, _h, _w, self.window_size, self.window_size, c) + ) + x = attn.transpose(2, 3).reshape( + b, _h * self.window_size, _w * self.window_size, c + ) if pad_r > 0 or pad_b > 0: x = x[:, :h, :w, :].contiguous() @@ -269,28 +291,34 @@ class LSAEncoderLayer(BaseModule): Defaults to None. """ - def __init__(self, - embed_dims, - num_heads, - feedforward_channels, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - num_fcs=2, - qkv_bias=True, - qk_scale=None, - act_cfg=dict(type='GELU'), - norm_cfg=dict(type='LN'), - window_size=1, - init_cfg=None): - - super(LSAEncoderLayer, self).__init__(init_cfg=init_cfg) + def __init__( + self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + num_fcs=2, + qkv_bias=True, + qk_scale=None, + act_cfg=dict(type="GELU"), + norm_cfg=dict(type="LN"), + window_size=1, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] - self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads, - qkv_bias, qk_scale, - attn_drop_rate, drop_rate, - window_size) + self.attn = LocallyGroupedSelfAttention( + embed_dims, + num_heads, + qkv_bias, + qk_scale, + attn_drop_rate, + drop_rate, + window_size, + ) self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] self.ffn = FFN( @@ -298,13 +326,16 @@ def __init__(self, feedforward_channels=feedforward_channels, num_fcs=num_fcs, ffn_drop=drop_rate, - dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate), act_cfg=act_cfg, - add_identity=False) + add_identity=False, + ) - self.drop_path = build_dropout( - dict(type='DropPath', drop_prob=drop_path_rate) - ) if drop_path_rate > 0. else nn.Identity() + self.drop_path = ( + build_dropout(dict(type="DropPath", drop_prob=drop_path_rate)) + if drop_path_rate > 0.0 + else nn.Identity() + ) def forward(self, x, hw_shape): x = x + self.drop_path(self.attn(self.norm1(x), hw_shape)) @@ -325,7 +356,7 @@ class ConditionalPositionEncoding(BaseModule): """ def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None): - super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg) + super().__init__(init_cfg=init_cfg) self.proj = nn.Conv2d( in_channels, embed_dims, @@ -333,7 +364,8 @@ def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None): stride=stride, padding=1, bias=True, - groups=embed_dims) + groups=embed_dims, + ) self.stride = stride def forward(self, x, hw_shape): @@ -383,33 +415,38 @@ class PCPVT(BaseModule): Defaults to None. """ - def __init__(self, - in_channels=3, - embed_dims=[64, 128, 256, 512], - patch_sizes=[4, 2, 2, 2], - strides=[4, 2, 2, 2], - num_heads=[1, 2, 4, 8], - mlp_ratios=[4, 4, 4, 4], - out_indices=(0, 1, 2, 3), - qkv_bias=False, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - norm_cfg=dict(type='LN'), - depths=[3, 4, 6, 3], - sr_ratios=[8, 4, 2, 1], - norm_after_stage=False, - pretrained=None, - init_cfg=None): - super(PCPVT, self).__init__(init_cfg=init_cfg) - assert not (init_cfg and pretrained), \ - 'init_cfg and pretrained cannot be set at the same time' + def __init__( + self, + in_channels=3, + embed_dims=[64, 128, 256, 512], + patch_sizes=[4, 2, 2, 2], + strides=[4, 2, 2, 2], + num_heads=[1, 2, 4, 8], + mlp_ratios=[4, 4, 4, 4], + out_indices=(0, 1, 2, 3), + qkv_bias=False, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_cfg=dict(type="LN"), + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + norm_after_stage=False, + pretrained=None, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + assert not ( + init_cfg and pretrained + ), "init_cfg and pretrained cannot be set at the same time" if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is deprecated, ' - 'please use "init_cfg" instead') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + warnings.warn( + "DeprecationWarning: pretrained is deprecated, " + 'please use "init_cfg" instead' + ) + self.init_cfg = dict(type="Pretrained", checkpoint=pretrained) elif pretrained is not None: - raise TypeError('pretrained must be a str or None') + raise TypeError("pretrained must be a str or None") self.depths = depths # patch_embed @@ -422,18 +459,22 @@ def __init__(self, PatchEmbed( in_channels=in_channels if i == 0 else embed_dims[i - 1], embed_dims=embed_dims[i], - conv_type='Conv2d', + conv_type="Conv2d", kernel_size=patch_sizes[i], stride=strides[i], - padding='corner', - norm_cfg=norm_cfg)) + padding="corner", + norm_cfg=norm_cfg, + ) + ) self.position_encoding_drops.append(nn.Dropout(p=drop_rate)) - self.position_encodings = ModuleList([ - ConditionalPositionEncoding(embed_dim, embed_dim) - for embed_dim in embed_dims - ]) + self.position_encodings = ModuleList( + [ + ConditionalPositionEncoding(embed_dim, embed_dim) + for embed_dim in embed_dims + ] + ) # transformer encoder dpr = [ @@ -442,25 +483,28 @@ def __init__(self, cur = 0 for k in range(len(depths)): - _block = ModuleList([ - GSAEncoderLayer( - embed_dims=embed_dims[k], - num_heads=num_heads[k], - feedforward_channels=mlp_ratios[k] * embed_dims[k], - attn_drop_rate=attn_drop_rate, - drop_rate=drop_rate, - drop_path_rate=dpr[cur + i], - num_fcs=2, - qkv_bias=qkv_bias, - act_cfg=dict(type='GELU'), - norm_cfg=dict(type='LN'), - sr_ratio=sr_ratios[k]) for i in range(depths[k]) - ]) + _block = ModuleList( + [ + GSAEncoderLayer( + embed_dims=embed_dims[k], + num_heads=num_heads[k], + feedforward_channels=mlp_ratios[k] * embed_dims[k], + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[cur + i], + num_fcs=2, + qkv_bias=qkv_bias, + act_cfg=dict(type="GELU"), + norm_cfg=dict(type="LN"), + sr_ratio=sr_ratios[k], + ) + for i in range(depths[k]) + ] + ) self.layers.append(_block) cur += depths[k] - self.norm_name, norm = build_norm_layer( - norm_cfg, embed_dims[-1], postfix=1) + self.norm_name, norm = build_norm_layer(norm_cfg, embed_dims[-1], postfix=1) self.out_indices = out_indices self.norm_after_stage = norm_after_stage @@ -471,19 +515,17 @@ def __init__(self, def init_weights(self): if self.init_cfg is not None: - super(PCPVT, self).init_weights() + super().init_weights() else: for m in self.modules(): if isinstance(m, nn.Linear): - trunc_normal_init(m, std=.02, bias=0.) + trunc_normal_init(m, std=0.02, bias=0.0) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): - constant_init(m, val=1.0, bias=0.) + constant_init(m, val=1.0, bias=0.0) elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[ - 1] * m.out_channels + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups - normal_init( - m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + normal_init(m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) def forward(self, x): outputs = list() @@ -544,30 +586,46 @@ class SVT(PCPVT): Defaults to None. """ - def __init__(self, - in_channels=3, - embed_dims=[64, 128, 256], - patch_sizes=[4, 2, 2, 2], - strides=[4, 2, 2, 2], - num_heads=[1, 2, 4], - mlp_ratios=[4, 4, 4], - out_indices=(0, 1, 2, 3), - qkv_bias=False, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_cfg=dict(type='LN'), - depths=[4, 4, 4], - sr_ratios=[4, 2, 1], - windiow_sizes=[7, 7, 7], - norm_after_stage=True, - pretrained=None, - init_cfg=None): - super(SVT, self).__init__(in_channels, embed_dims, patch_sizes, - strides, num_heads, mlp_ratios, out_indices, - qkv_bias, drop_rate, attn_drop_rate, - drop_path_rate, norm_cfg, depths, sr_ratios, - norm_after_stage, pretrained, init_cfg) + def __init__( + self, + in_channels=3, + embed_dims=[64, 128, 256], + patch_sizes=[4, 2, 2, 2], + strides=[4, 2, 2, 2], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + out_indices=(0, 1, 2, 3), + qkv_bias=False, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.2, + norm_cfg=dict(type="LN"), + depths=[4, 4, 4], + sr_ratios=[4, 2, 1], + windiow_sizes=[7, 7, 7], + norm_after_stage=True, + pretrained=None, + init_cfg=None, + ): + super().__init__( + in_channels, + embed_dims, + patch_sizes, + strides, + num_heads, + mlp_ratios, + out_indices, + qkv_bias, + drop_rate, + attn_drop_rate, + drop_path_rate, + norm_cfg, + depths, + sr_ratios, + norm_after_stage, + pretrained, + init_cfg, + ) # transformer encoder dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) @@ -576,13 +634,13 @@ def __init__(self, for k in range(len(depths)): for i in range(depths[k]): if i % 2 == 0: - self.layers[k][i] = \ - LSAEncoderLayer( - embed_dims=embed_dims[k], - num_heads=num_heads[k], - feedforward_channels=mlp_ratios[k] * embed_dims[k], - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=dpr[sum(depths[:k])+i], - qkv_bias=qkv_bias, - window_size=windiow_sizes[k]) + self.layers[k][i] = LSAEncoderLayer( + embed_dims=embed_dims[k], + num_heads=num_heads[k], + feedforward_channels=mlp_ratios[k] * embed_dims[k], + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[sum(depths[:k]) + i], + qkv_bias=qkv_bias, + window_size=windiow_sizes[k], + ) diff --git a/mmsegmentation/mmseg/models/backbones/unet.py b/mmsegmentation/mmseg/models/backbones/unet.py index c2d3366..51597b3 100644 --- a/mmsegmentation/mmseg/models/backbones/unet.py +++ b/mmsegmentation/mmseg/models/backbones/unet.py @@ -3,8 +3,12 @@ import torch.nn as nn import torch.utils.checkpoint as cp -from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer, - build_norm_layer) +from mmcv.cnn import ( + UPSAMPLE_LAYERS, + ConvModule, + build_activation_layer, + build_norm_layer, +) from mmcv.runner import BaseModule from mmcv.utils.parrots_wrapper import _BatchNorm @@ -43,21 +47,23 @@ class BasicConvBlock(nn.Module): plugins (dict): plugins for convolutional layers. Default: None. """ - def __init__(self, - in_channels, - out_channels, - num_convs=2, - stride=1, - dilation=1, - with_cp=False, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - dcn=None, - plugins=None): - super(BasicConvBlock, self).__init__() - assert dcn is None, 'Not implemented yet.' - assert plugins is None, 'Not implemented yet.' + def __init__( + self, + in_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + dcn=None, + plugins=None, + ): + super().__init__() + assert dcn is None, "Not implemented yet." + assert plugins is None, "Not implemented yet." self.with_cp = with_cp convs = [] @@ -72,7 +78,9 @@ def __init__(self, padding=1 if i == 0 else dilation, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ) + ) self.convs = nn.Sequential(*convs) @@ -105,23 +113,27 @@ class DeconvModule(nn.Module): kernel_size (int): Kernel size of the convolutional layer. Default: 4. """ - def __init__(self, - in_channels, - out_channels, - with_cp=False, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - *, - kernel_size=4, - scale_factor=2): - super(DeconvModule, self).__init__() - - assert (kernel_size - scale_factor >= 0) and\ - (kernel_size - scale_factor) % 2 == 0,\ - f'kernel_size should be greater than or equal to scale_factor '\ - f'and (kernel_size - scale_factor) should be even numbers, '\ - f'while the kernel size is {kernel_size} and scale_factor is '\ - f'{scale_factor}.' + def __init__( + self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + *, + kernel_size=4, + scale_factor=2, + ): + super().__init__() + + assert (kernel_size - scale_factor >= 0) and ( + kernel_size - scale_factor + ) % 2 == 0, ( + f"kernel_size should be greater than or equal to scale_factor " + f"and (kernel_size - scale_factor) should be even numbers, " + f"while the kernel size is {kernel_size} and scale_factor is " + f"{scale_factor}." + ) stride = scale_factor padding = (kernel_size - scale_factor) // 2 @@ -131,7 +143,8 @@ def __init__(self, out_channels, kernel_size=kernel_size, stride=stride, - padding=padding) + padding=padding, + ) norm_name, norm = build_norm_layer(norm_cfg, out_channels) activate = build_activation_layer(act_cfg) @@ -179,21 +192,22 @@ class InterpConv(nn.Module): scale_factor=2, mode='bilinear', align_corners=False). """ - def __init__(self, - in_channels, - out_channels, - with_cp=False, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - *, - conv_cfg=None, - conv_first=False, - kernel_size=1, - stride=1, - padding=0, - upsample_cfg=dict( - scale_factor=2, mode='bilinear', align_corners=False)): - super(InterpConv, self).__init__() + def __init__( + self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + *, + conv_cfg=None, + conv_first=False, + kernel_size=1, + stride=1, + padding=0, + upsample_cfg=dict(scale_factor=2, mode="bilinear", align_corners=False), + ): + super().__init__() self.with_cp = with_cp conv = ConvModule( @@ -204,7 +218,8 @@ def __init__(self, padding=padding, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) upsample = Upsample(**upsample_cfg) if conv_first: self.interp_upsample = nn.Sequential(conv, upsample) @@ -280,79 +295,87 @@ class UNet(BaseModule): in UNet._check_input_divisible. """ - def __init__(self, - in_channels=3, - base_channels=64, - num_stages=5, - strides=(1, 1, 1, 1, 1), - enc_num_convs=(2, 2, 2, 2, 2), - dec_num_convs=(2, 2, 2, 2), - downsamples=(True, True, True, True), - enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1), - with_cp=False, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - upsample_cfg=dict(type='InterpConv'), - norm_eval=False, - dcn=None, - plugins=None, - pretrained=None, - init_cfg=None): - super(UNet, self).__init__(init_cfg) + def __init__( + self, + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + upsample_cfg=dict(type="InterpConv"), + norm_eval=False, + dcn=None, + plugins=None, + pretrained=None, + init_cfg=None, + ): + super().__init__(init_cfg) self.pretrained = pretrained - assert not (init_cfg and pretrained), \ - 'init_cfg and pretrained cannot be setting at the same time' + assert not ( + init_cfg and pretrained + ), "init_cfg and pretrained cannot be setting at the same time" if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is a deprecated, ' - 'please use "init_cfg" instead') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + warnings.warn( + "DeprecationWarning: pretrained is a deprecated, " + 'please use "init_cfg" instead' + ) + self.init_cfg = dict(type="Pretrained", checkpoint=pretrained) elif pretrained is None: if init_cfg is None: self.init_cfg = [ - dict(type='Kaiming', layer='Conv2d'), - dict( - type='Constant', - val=1, - layer=['_BatchNorm', 'GroupNorm']) + dict(type="Kaiming", layer="Conv2d"), + dict(type="Constant", val=1, layer=["_BatchNorm", "GroupNorm"]), ] else: - raise TypeError('pretrained must be a str or None') - - assert dcn is None, 'Not implemented yet.' - assert plugins is None, 'Not implemented yet.' - assert len(strides) == num_stages, \ - 'The length of strides should be equal to num_stages, '\ - f'while the strides is {strides}, the length of '\ - f'strides is {len(strides)}, and the num_stages is '\ - f'{num_stages}.' - assert len(enc_num_convs) == num_stages, \ - 'The length of enc_num_convs should be equal to num_stages, '\ - f'while the enc_num_convs is {enc_num_convs}, the length of '\ - f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\ - f'{num_stages}.' - assert len(dec_num_convs) == (num_stages-1), \ - 'The length of dec_num_convs should be equal to (num_stages-1), '\ - f'while the dec_num_convs is {dec_num_convs}, the length of '\ - f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\ - f'{num_stages}.' - assert len(downsamples) == (num_stages-1), \ - 'The length of downsamples should be equal to (num_stages-1), '\ - f'while the downsamples is {downsamples}, the length of '\ - f'downsamples is {len(downsamples)}, and the num_stages is '\ - f'{num_stages}.' - assert len(enc_dilations) == num_stages, \ - 'The length of enc_dilations should be equal to num_stages, '\ - f'while the enc_dilations is {enc_dilations}, the length of '\ - f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\ - f'{num_stages}.' - assert len(dec_dilations) == (num_stages-1), \ - 'The length of dec_dilations should be equal to (num_stages-1), '\ - f'while the dec_dilations is {dec_dilations}, the length of '\ - f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\ - f'{num_stages}.' + raise TypeError("pretrained must be a str or None") + + assert dcn is None, "Not implemented yet." + assert plugins is None, "Not implemented yet." + assert len(strides) == num_stages, ( + "The length of strides should be equal to num_stages, " + f"while the strides is {strides}, the length of " + f"strides is {len(strides)}, and the num_stages is " + f"{num_stages}." + ) + assert len(enc_num_convs) == num_stages, ( + "The length of enc_num_convs should be equal to num_stages, " + f"while the enc_num_convs is {enc_num_convs}, the length of " + f"enc_num_convs is {len(enc_num_convs)}, and the num_stages is " + f"{num_stages}." + ) + assert len(dec_num_convs) == (num_stages - 1), ( + "The length of dec_num_convs should be equal to (num_stages-1), " + f"while the dec_num_convs is {dec_num_convs}, the length of " + f"dec_num_convs is {len(dec_num_convs)}, and the num_stages is " + f"{num_stages}." + ) + assert len(downsamples) == (num_stages - 1), ( + "The length of downsamples should be equal to (num_stages-1), " + f"while the downsamples is {downsamples}, the length of " + f"downsamples is {len(downsamples)}, and the num_stages is " + f"{num_stages}." + ) + assert len(enc_dilations) == num_stages, ( + "The length of enc_dilations should be equal to num_stages, " + f"while the enc_dilations is {enc_dilations}, the length of " + f"enc_dilations is {len(enc_dilations)}, and the num_stages is " + f"{num_stages}." + ) + assert len(dec_dilations) == (num_stages - 1), ( + "The length of dec_dilations should be equal to (num_stages-1), " + f"while the dec_dilations is {dec_dilations}, the length of " + f"dec_dilations is {len(dec_dilations)}, and the num_stages is " + f"{num_stages}." + ) self.num_stages = num_stages self.strides = strides self.downsamples = downsamples @@ -367,13 +390,13 @@ def __init__(self, if i != 0: if strides[i] == 1 and downsamples[i - 1]: enc_conv_block.append(nn.MaxPool2d(kernel_size=2)) - upsample = (strides[i] != 1 or downsamples[i - 1]) + upsample = strides[i] != 1 or downsamples[i - 1] self.decoder.append( UpConvBlock( conv_block=BasicConvBlock, in_channels=base_channels * 2**i, - skip_channels=base_channels * 2**(i - 1), - out_channels=base_channels * 2**(i - 1), + skip_channels=base_channels * 2 ** (i - 1), + out_channels=base_channels * 2 ** (i - 1), num_convs=dec_num_convs[i - 1], stride=1, dilation=dec_dilations[i - 1], @@ -383,7 +406,9 @@ def __init__(self, act_cfg=act_cfg, upsample_cfg=upsample_cfg if upsample else None, dcn=None, - plugins=None)) + plugins=None, + ) + ) enc_conv_block.append( BasicConvBlock( @@ -397,8 +422,10 @@ def __init__(self, norm_cfg=norm_cfg, act_cfg=act_cfg, dcn=None, - plugins=None)) - self.encoder.append((nn.Sequential(*enc_conv_block))) + plugins=None, + ) + ) + self.encoder.append(nn.Sequential(*enc_conv_block)) in_channels = base_channels * 2**i def forward(self, x): @@ -417,7 +444,7 @@ def forward(self, x): def train(self, mode=True): """Convert the model into training mode while keep normalization layer freezed.""" - super(UNet, self).train(mode) + super().train(mode) if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only @@ -430,9 +457,9 @@ def _check_input_divisible(self, x): for i in range(1, self.num_stages): if self.strides[i] == 2 or self.downsamples[i - 1]: whole_downsample_rate *= 2 - assert (h % whole_downsample_rate == 0) \ - and (w % whole_downsample_rate == 0),\ - f'The input image size {(h, w)} should be divisible by the whole '\ - f'downsample rate {whole_downsample_rate}, when num_stages is '\ - f'{self.num_stages}, strides is {self.strides}, and downsamples '\ - f'is {self.downsamples}.' + assert (h % whole_downsample_rate == 0) and (w % whole_downsample_rate == 0), ( + f"The input image size {(h, w)} should be divisible by the whole " + f"downsample rate {whole_downsample_rate}, when num_stages is " + f"{self.num_stages}, strides is {self.strides}, and downsamples " + f"is {self.downsamples}." + ) diff --git a/mmsegmentation/mmseg/models/backbones/vit.py b/mmsegmentation/mmseg/models/backbones/vit.py index 37b9a4f..e214f4f 100644 --- a/mmsegmentation/mmseg/models/backbones/vit.py +++ b/mmsegmentation/mmseg/models/backbones/vit.py @@ -7,10 +7,8 @@ import torch.utils.checkpoint as cp from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention -from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, - trunc_normal_) -from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList, - load_state_dict) +from mmcv.cnn.utils.weight_init import constant_init, kaiming_init, trunc_normal_ +from mmcv.runner import BaseModule, CheckpointLoader, ModuleList, load_state_dict from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.utils import _pair as to_2tuple @@ -46,25 +44,26 @@ class TransformerEncoderLayer(BaseModule): some memory while slowing down the training speed. Default: False. """ - def __init__(self, - embed_dims, - num_heads, - feedforward_channels, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - num_fcs=2, - qkv_bias=True, - act_cfg=dict(type='GELU'), - norm_cfg=dict(type='LN'), - batch_first=True, - attn_cfg=dict(), - ffn_cfg=dict(), - with_cp=False): - super(TransformerEncoderLayer, self).__init__() - - self.norm1_name, norm1 = build_norm_layer( - norm_cfg, embed_dims, postfix=1) + def __init__( + self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type="GELU"), + norm_cfg=dict(type="LN"), + batch_first=True, + attn_cfg=dict(), + ffn_cfg=dict(), + with_cp=False, + ): + super().__init__() + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) attn_cfg.update( @@ -74,12 +73,13 @@ def __init__(self, attn_drop=attn_drop_rate, proj_drop=drop_rate, batch_first=batch_first, - bias=qkv_bias)) + bias=qkv_bias, + ) + ) self.build_attn(attn_cfg) - self.norm2_name, norm2 = build_norm_layer( - norm_cfg, embed_dims, postfix=2) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2) self.add_module(self.norm2_name, norm2) ffn_cfg.update( @@ -88,9 +88,12 @@ def __init__(self, feedforward_channels=feedforward_channels, num_fcs=num_fcs, ffn_drop=drop_rate, - dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate) - if drop_path_rate > 0 else None, - act_cfg=act_cfg)) + dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate) + if drop_path_rate > 0 + else None, + act_cfg=act_cfg, + ) + ) self.build_ffn(ffn_cfg) self.with_cp = with_cp @@ -109,7 +112,6 @@ def norm2(self): return getattr(self, self.norm2_name) def forward(self, x): - def _inner_forward(x): x = self.attn(self.norm1(x), identity=x) x = self.ffn(self.norm2(x), identity=x) @@ -173,54 +175,62 @@ class VisionTransformer(BaseModule): Default: None. """ - def __init__(self, - img_size=224, - patch_size=16, - in_channels=3, - embed_dims=768, - num_layers=12, - num_heads=12, - mlp_ratio=4, - out_indices=-1, - qkv_bias=True, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - with_cls_token=True, - output_cls_token=False, - norm_cfg=dict(type='LN'), - act_cfg=dict(type='GELU'), - patch_norm=False, - final_norm=False, - interpolate_mode='bicubic', - num_fcs=2, - norm_eval=False, - with_cp=False, - pretrained=None, - init_cfg=None): - super(VisionTransformer, self).__init__(init_cfg=init_cfg) + def __init__( + self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + with_cls_token=True, + output_cls_token=False, + norm_cfg=dict(type="LN"), + act_cfg=dict(type="GELU"), + patch_norm=False, + final_norm=False, + interpolate_mode="bicubic", + num_fcs=2, + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) if isinstance(img_size, int): img_size = to_2tuple(img_size) elif isinstance(img_size, tuple): if len(img_size) == 1: img_size = to_2tuple(img_size[0]) - assert len(img_size) == 2, \ - f'The size of image should have length 1 or 2, ' \ - f'but got {len(img_size)}' + assert len(img_size) == 2, ( + f"The size of image should have length 1 or 2, " + f"but got {len(img_size)}" + ) if output_cls_token: - assert with_cls_token is True, f'with_cls_token must be True if' \ - f'set output_cls_token to True, but got {with_cls_token}' - - assert not (init_cfg and pretrained), \ - 'init_cfg and pretrained cannot be set at the same time' + assert with_cls_token is True, ( + f"with_cls_token must be True if" + f"set output_cls_token to True, but got {with_cls_token}" + ) + + assert not ( + init_cfg and pretrained + ), "init_cfg and pretrained cannot be set at the same time" if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is deprecated, ' - 'please use "init_cfg" instead') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + warnings.warn( + "DeprecationWarning: pretrained is deprecated, " + 'please use "init_cfg" instead' + ) + self.init_cfg = dict(type="Pretrained", checkpoint=pretrained) elif pretrained is not None: - raise TypeError('pretrained must be a str or None') + raise TypeError("pretrained must be a str or None") self.img_size = img_size self.patch_size = patch_size @@ -232,22 +242,20 @@ def __init__(self, self.patch_embed = PatchEmbed( in_channels=in_channels, embed_dims=embed_dims, - conv_type='Conv2d', + conv_type="Conv2d", kernel_size=patch_size, stride=patch_size, - padding='corner', + padding="corner", norm_cfg=norm_cfg if patch_norm else None, init_cfg=None, ) - num_patches = (img_size[0] // patch_size) * \ - (img_size[1] // patch_size) + num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) self.with_cls_token = with_cls_token self.output_cls_token = output_cls_token self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) - self.pos_embed = nn.Parameter( - torch.zeros(1, num_patches + 1, embed_dims)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dims)) self.drop_after_pos = nn.Dropout(p=drop_rate) if isinstance(out_indices, int): @@ -257,7 +265,7 @@ def __init__(self, elif isinstance(out_indices, list) or isinstance(out_indices, tuple): self.out_indices = out_indices else: - raise TypeError('out_indices must be type of int, list or tuple') + raise TypeError("out_indices must be type of int, list or tuple") dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, num_layers) @@ -278,12 +286,13 @@ def __init__(self, act_cfg=act_cfg, norm_cfg=norm_cfg, with_cp=with_cp, - batch_first=True)) + batch_first=True, + ) + ) self.final_norm = final_norm if final_norm: - self.norm1_name, norm1 = build_norm_layer( - norm_cfg, embed_dims, postfix=1) + self.norm1_name, norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) @property @@ -291,50 +300,56 @@ def norm1(self): return getattr(self, self.norm1_name) def init_weights(self): - if (isinstance(self.init_cfg, dict) - and self.init_cfg.get('type') == 'Pretrained'): + if ( + isinstance(self.init_cfg, dict) + and self.init_cfg.get("type") == "Pretrained" + ): logger = get_root_logger() checkpoint = CheckpointLoader.load_checkpoint( - self.init_cfg['checkpoint'], logger=logger, map_location='cpu') + self.init_cfg["checkpoint"], logger=logger, map_location="cpu" + ) - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] else: state_dict = checkpoint - if 'pos_embed' in state_dict.keys(): - if self.pos_embed.shape != state_dict['pos_embed'].shape: - logger.info(msg=f'Resize the pos_embed shape from ' - f'{state_dict["pos_embed"].shape} to ' - f'{self.pos_embed.shape}') + if "pos_embed" in state_dict.keys(): + if self.pos_embed.shape != state_dict["pos_embed"].shape: + logger.info( + msg=f"Resize the pos_embed shape from " + f'{state_dict["pos_embed"].shape} to ' + f"{self.pos_embed.shape}" + ) h, w = self.img_size - pos_size = int( - math.sqrt(state_dict['pos_embed'].shape[1] - 1)) - state_dict['pos_embed'] = self.resize_pos_embed( - state_dict['pos_embed'], + pos_size = int(math.sqrt(state_dict["pos_embed"].shape[1] - 1)) + state_dict["pos_embed"] = self.resize_pos_embed( + state_dict["pos_embed"], (h // self.patch_size, w // self.patch_size), - (pos_size, pos_size), self.interpolate_mode) + (pos_size, pos_size), + self.interpolate_mode, + ) load_state_dict(self, state_dict, strict=False, logger=logger) elif self.init_cfg is not None: - super(VisionTransformer, self).init_weights() + super().init_weights() else: # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 - trunc_normal_(self.pos_embed, std=.02) - trunc_normal_(self.cls_token, std=.02) + trunc_normal_(self.pos_embed, std=0.02) + trunc_normal_(self.cls_token, std=0.02) for n, m in self.named_modules(): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if m.bias is not None: - if 'ffn' in n: - nn.init.normal_(m.bias, mean=0., std=1e-6) + if "ffn" in n: + nn.init.normal_(m.bias, mean=0.0, std=1e-6) else: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): - kaiming_init(m, mode='fan_in', bias=0.) + kaiming_init(m, mode="fan_in", bias=0.0) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): - constant_init(m, val=1.0, bias=0.) + constant_init(m, val=1.0, bias=0.0) def _pos_embeding(self, patched_img, hw_shape, pos_embed): """Positioning embeding method. @@ -350,21 +365,26 @@ def _pos_embeding(self, patched_img, hw_shape, pos_embed): Return: torch.Tensor: The pos encoded image feature. """ - assert patched_img.ndim == 3 and pos_embed.ndim == 3, \ - 'the shapes of patched_img and pos_embed must be [B, L, C]' + assert ( + patched_img.ndim == 3 and pos_embed.ndim == 3 + ), "the shapes of patched_img and pos_embed must be [B, L, C]" x_len, pos_len = patched_img.shape[1], pos_embed.shape[1] if x_len != pos_len: - if pos_len == (self.img_size[0] // self.patch_size) * ( - self.img_size[1] // self.patch_size) + 1: + if ( + pos_len + == (self.img_size[0] // self.patch_size) + * (self.img_size[1] // self.patch_size) + + 1 + ): pos_h = self.img_size[0] // self.patch_size pos_w = self.img_size[1] // self.patch_size else: raise ValueError( - 'Unexpected shape of pos_embed, got {}.'.format( - pos_embed.shape)) - pos_embed = self.resize_pos_embed(pos_embed, hw_shape, - (pos_h, pos_w), - self.interpolate_mode) + "Unexpected shape of pos_embed, got {}.".format(pos_embed.shape) + ) + pos_embed = self.resize_pos_embed( + pos_embed, hw_shape, (pos_h, pos_w), self.interpolate_mode + ) return self.drop_after_pos(patched_img + pos_embed) @staticmethod @@ -384,15 +404,17 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): Return: torch.Tensor: The resized pos_embed of shape [B, L_new, C] """ - assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' + assert pos_embed.ndim == 3, "shape of pos_embed must be [B, L, C]" pos_h, pos_w = pos_shape # keep dim for easy deployment cls_token_weight = pos_embed[:, 0:1] - pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w) :] pos_embed_weight = pos_embed_weight.reshape( - 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) + 1, pos_h, pos_w, pos_embed.shape[2] + ).permute(0, 3, 1, 2) pos_embed_weight = resize( - pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) + pos_embed_weight, size=input_shpae, align_corners=False, mode=mode + ) pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) return pos_embed @@ -424,8 +446,11 @@ def forward(self, inputs): else: out = x B, _, C = out.shape - out = out.reshape(B, hw_shape[0], hw_shape[1], - C).permute(0, 3, 1, 2).contiguous() + out = ( + out.reshape(B, hw_shape[0], hw_shape[1], C) + .permute(0, 3, 1, 2) + .contiguous() + ) if self.output_cls_token: out = [out, x[:, 0]] outs.append(out) @@ -433,7 +458,7 @@ def forward(self, inputs): return tuple(outs) def train(self, mode=True): - super(VisionTransformer, self).train(mode) + super().train(mode) if mode and self.norm_eval: for m in self.modules(): if isinstance(m, nn.LayerNorm): diff --git a/mmsegmentation/mmseg/models/builder.py b/mmsegmentation/mmseg/models/builder.py index 5e18e4e..0a8a648 100644 --- a/mmsegmentation/mmseg/models/builder.py +++ b/mmsegmentation/mmseg/models/builder.py @@ -5,8 +5,8 @@ from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION from mmcv.utils import Registry -MODELS = Registry('models', parent=MMCV_MODELS) -ATTENTION = Registry('attention', parent=MMCV_ATTENTION) +MODELS = Registry("models", parent=MMCV_MODELS) +ATTENTION = Registry("attention", parent=MMCV_ATTENTION) BACKBONES = MODELS NECKS = MODELS @@ -39,11 +39,15 @@ def build_segmentor(cfg, train_cfg=None, test_cfg=None): """Build segmentor.""" if train_cfg is not None or test_cfg is not None: warnings.warn( - 'train_cfg and test_cfg is deprecated, ' - 'please specify them in model', UserWarning) - assert cfg.get('train_cfg') is None or train_cfg is None, \ - 'train_cfg specified in both outer field and model field ' - assert cfg.get('test_cfg') is None or test_cfg is None, \ - 'test_cfg specified in both outer field and model field ' + "train_cfg and test_cfg is deprecated, " "please specify them in model", + UserWarning, + ) + assert ( + cfg.get("train_cfg") is None or train_cfg is None + ), "train_cfg specified in both outer field and model field " + assert ( + cfg.get("test_cfg") is None or test_cfg is None + ), "test_cfg specified in both outer field and model field " return SEGMENTORS.build( - cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) + cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg) + ) diff --git a/mmsegmentation/mmseg/models/decode_heads/__init__.py b/mmsegmentation/mmseg/models/decode_heads/__init__.py index 8add761..7878051 100644 --- a/mmsegmentation/mmseg/models/decode_heads/__init__.py +++ b/mmsegmentation/mmseg/models/decode_heads/__init__.py @@ -30,11 +30,36 @@ from .uper_head import UPerHead __all__ = [ - 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', - 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', - 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', - 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', - 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', - 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', - 'KernelUpdateHead', 'KernelUpdator' + "FCNHead", + "PSPHead", + "ASPPHead", + "PSAHead", + "NLHead", + "GCHead", + "CCHead", + "UPerHead", + "DepthwiseSeparableASPPHead", + "ANNHead", + "DAHead", + "OCRHead", + "EncHead", + "DepthwiseSeparableFCNHead", + "FPNHead", + "EMAHead", + "DNLHead", + "PointHead", + "APCHead", + "DMHead", + "LRASPPHead", + "SETRUPHead", + "SETRMLAHead", + "DPTHead", + "SETRMLAHead", + "SegmenterMaskTransformerHead", + "SegformerHead", + "ISAHead", + "STDCHead", + "IterativeDecodeHead", + "KernelUpdateHead", + "KernelUpdator", ] diff --git a/mmsegmentation/mmseg/models/decode_heads/ann_head.py b/mmsegmentation/mmseg/models/decode_heads/ann_head.py index c8d882e..48b3caf 100644 --- a/mmsegmentation/mmseg/models/decode_heads/ann_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/ann_head.py @@ -17,8 +17,9 @@ class PPMConcat(nn.ModuleList): """ def __init__(self, pool_scales=(1, 3, 6, 8)): - super(PPMConcat, self).__init__( - [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales]) + super().__init__( + [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales] + ) def forward(self, feats): """Forward function.""" @@ -50,15 +51,25 @@ class SelfAttentionBlock(_SelfAttentionBlock): act_cfg (dict|None): Config of activation layers. """ - def __init__(self, low_in_channels, high_in_channels, channels, - out_channels, share_key_query, query_scale, key_pool_scales, - conv_cfg, norm_cfg, act_cfg): + def __init__( + self, + low_in_channels, + high_in_channels, + channels, + out_channels, + share_key_query, + query_scale, + key_pool_scales, + conv_cfg, + norm_cfg, + act_cfg, + ): key_psp = PPMConcat(key_pool_scales) if query_scale > 1: query_downsample = nn.MaxPool2d(kernel_size=query_scale) else: query_downsample = None - super(SelfAttentionBlock, self).__init__( + super().__init__( key_in_channels=low_in_channels, query_in_channels=high_in_channels, channels=channels, @@ -74,7 +85,8 @@ def __init__(self, low_in_channels, high_in_channels, channels, with_out=True, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) class AFNB(nn.Module): @@ -97,10 +109,19 @@ class AFNB(nn.Module): act_cfg (dict|None): Config of activation layers. """ - def __init__(self, low_in_channels, high_in_channels, channels, - out_channels, query_scales, key_pool_scales, conv_cfg, - norm_cfg, act_cfg): - super(AFNB, self).__init__() + def __init__( + self, + low_in_channels, + high_in_channels, + channels, + out_channels, + query_scales, + key_pool_scales, + conv_cfg, + norm_cfg, + act_cfg, + ): + super().__init__() self.stages = nn.ModuleList() for query_scale in query_scales: self.stages.append( @@ -114,14 +135,17 @@ def __init__(self, low_in_channels, high_in_channels, channels, key_pool_scales=key_pool_scales, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ) + ) self.bottleneck = ConvModule( out_channels + high_in_channels, out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=None) + act_cfg=None, + ) def forward(self, low_feats, high_feats): """Forward function.""" @@ -148,9 +172,18 @@ class APNB(nn.Module): act_cfg (dict|None): Config of activation layers. """ - def __init__(self, in_channels, channels, out_channels, query_scales, - key_pool_scales, conv_cfg, norm_cfg, act_cfg): - super(APNB, self).__init__() + def __init__( + self, + in_channels, + channels, + out_channels, + query_scales, + key_pool_scales, + conv_cfg, + norm_cfg, + act_cfg, + ): + super().__init__() self.stages = nn.ModuleList() for query_scale in query_scales: self.stages.append( @@ -164,14 +197,17 @@ def __init__(self, in_channels, channels, out_channels, query_scales, key_pool_scales=key_pool_scales, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ) + ) self.bottleneck = ConvModule( 2 * in_channels, out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) def forward(self, feats): """Forward function.""" @@ -196,13 +232,14 @@ class ANNHead(BaseDecodeHead): Default: (1, 3, 6, 8). """ - def __init__(self, - project_channels, - query_scales=(1, ), - key_pool_scales=(1, 3, 6, 8), - **kwargs): - super(ANNHead, self).__init__( - input_transform='multiple_select', **kwargs) + def __init__( + self, + project_channels, + query_scales=(1,), + key_pool_scales=(1, 3, 6, 8), + **kwargs, + ): + super().__init__(input_transform="multiple_select", **kwargs) assert len(self.in_channels) == 2 low_in_channels, high_in_channels = self.in_channels self.project_channels = project_channels @@ -215,7 +252,8 @@ def __init__(self, key_pool_scales=key_pool_scales, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.bottleneck = ConvModule( high_in_channels, self.channels, @@ -223,7 +261,8 @@ def __init__(self, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.context = APNB( in_channels=self.channels, out_channels=self.channels, @@ -232,7 +271,8 @@ def __init__(self, key_pool_scales=key_pool_scales, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) def forward(self, inputs): """Forward function.""" diff --git a/mmsegmentation/mmseg/models/decode_heads/apc_head.py b/mmsegmentation/mmseg/models/decode_heads/apc_head.py index 3198fd1..ad1cc76 100644 --- a/mmsegmentation/mmseg/models/decode_heads/apc_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/apc_head.py @@ -23,9 +23,10 @@ class ACM(nn.Module): act_cfg (dict): Config of activation layers. """ - def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, - norm_cfg, act_cfg): - super(ACM, self).__init__() + def __init__( + self, pool_scale, fusion, in_channels, channels, conv_cfg, norm_cfg, act_cfg + ): + super().__init__() self.pool_scale = pool_scale self.fusion = fusion self.in_channels = in_channels @@ -39,7 +40,8 @@ def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.input_redu_conv = ConvModule( self.in_channels, @@ -47,7 +49,8 @@ def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.global_info = ConvModule( self.channels, @@ -55,7 +58,8 @@ def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0) @@ -65,7 +69,8 @@ def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) if self.fusion: self.fusion_conv = ConvModule( @@ -74,7 +79,8 @@ def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) def forward(self, x): """Forward function.""" @@ -85,13 +91,20 @@ def forward(self, x): pooled_x = self.pooled_redu_conv(pooled_x) batch_size = x.size(0) # [batch_size, pool_scale * pool_scale, channels] - pooled_x = pooled_x.view(batch_size, self.channels, - -1).permute(0, 2, 1).contiguous() + pooled_x = ( + pooled_x.view(batch_size, self.channels, -1).permute(0, 2, 1).contiguous() + ) # [batch_size, h * w, pool_scale * pool_scale] - affinity_matrix = self.gla(x + resize( - self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:]) - ).permute(0, 2, 3, 1).reshape( - batch_size, -1, self.pool_scale**2) + affinity_matrix = ( + self.gla( + x + + resize( + self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:] + ) + ) + .permute(0, 2, 3, 1) + .reshape(batch_size, -1, self.pool_scale**2) + ) affinity_matrix = F.sigmoid(affinity_matrix) # [batch_size, h * w, channels] z_out = torch.matmul(affinity_matrix, pooled_x) @@ -123,20 +136,23 @@ class APCHead(BaseDecodeHead): """ def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs): - super(APCHead, self).__init__(**kwargs) + super().__init__(**kwargs) assert isinstance(pool_scales, (list, tuple)) self.pool_scales = pool_scales self.fusion = fusion acm_modules = [] for pool_scale in self.pool_scales: acm_modules.append( - ACM(pool_scale, + ACM( + pool_scale, self.fusion, self.in_channels, self.channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg)) + act_cfg=self.act_cfg, + ) + ) self.acm_modules = nn.ModuleList(acm_modules) self.bottleneck = ConvModule( self.in_channels + len(pool_scales) * self.channels, @@ -145,7 +161,8 @@ def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs): padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) def forward(self, inputs): """Forward function.""" diff --git a/mmsegmentation/mmseg/models/decode_heads/aspp_head.py b/mmsegmentation/mmseg/models/decode_heads/aspp_head.py index 7059aee..088d1f0 100644 --- a/mmsegmentation/mmseg/models/decode_heads/aspp_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/aspp_head.py @@ -20,9 +20,8 @@ class ASPPModule(nn.ModuleList): act_cfg (dict): Config of activation layers. """ - def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, - act_cfg): - super(ASPPModule, self).__init__() + def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, act_cfg): + super().__init__() self.dilations = dilations self.in_channels = in_channels self.channels = channels @@ -39,7 +38,9 @@ def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, padding=0 if dilation == 1 else dilation, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg)) + act_cfg=self.act_cfg, + ) + ) def forward(self, x): """Forward function.""" @@ -63,7 +64,7 @@ class ASPPHead(BaseDecodeHead): """ def __init__(self, dilations=(1, 6, 12, 18), **kwargs): - super(ASPPHead, self).__init__(**kwargs) + super().__init__(**kwargs) assert isinstance(dilations, (list, tuple)) self.dilations = dilations self.image_pool = nn.Sequential( @@ -74,14 +75,17 @@ def __init__(self, dilations=(1, 6, 12, 18), **kwargs): 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg)) + act_cfg=self.act_cfg, + ), + ) self.aspp_modules = ASPPModule( dilations, self.in_channels, self.channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.bottleneck = ConvModule( (len(dilations) + 1) * self.channels, self.channels, @@ -89,7 +93,8 @@ def __init__(self, dilations=(1, 6, 12, 18), **kwargs): padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) def _forward_feature(self, inputs): """Forward function for feature maps before classifying each pixel with @@ -107,8 +112,9 @@ def _forward_feature(self, inputs): resize( self.image_pool(x), size=x.size()[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) ] aspp_outs.extend(self.aspp_modules(x)) aspp_outs = torch.cat(aspp_outs, dim=1) diff --git a/mmsegmentation/mmseg/models/decode_heads/cascade_decode_head.py b/mmsegmentation/mmseg/models/decode_heads/cascade_decode_head.py index f7c3da0..08c434f 100644 --- a/mmsegmentation/mmseg/models/decode_heads/cascade_decode_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/cascade_decode_head.py @@ -9,15 +9,13 @@ class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): :class:`CascadeEncoderDecoder.""" def __init__(self, *args, **kwargs): - super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) @abstractmethod def forward(self, inputs, prev_output): """Placeholder of forward function.""" - pass - def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, - train_cfg): + def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, train_cfg): """Forward function for training. Args: inputs (list[Tensor]): List of multi-level img features. diff --git a/mmsegmentation/mmseg/models/decode_heads/cc_head.py b/mmsegmentation/mmseg/models/decode_heads/cc_head.py index ed19eb4..6de4f97 100644 --- a/mmsegmentation/mmseg/models/decode_heads/cc_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/cc_head.py @@ -24,9 +24,10 @@ class CCHead(FCNHead): def __init__(self, recurrence=2, **kwargs): if CrissCrossAttention is None: - raise RuntimeError('Please install mmcv-full for ' - 'CrissCrossAttention ops') - super(CCHead, self).__init__(num_convs=2, **kwargs) + raise RuntimeError( + "Please install mmcv-full for " "CrissCrossAttention ops" + ) + super().__init__(num_convs=2, **kwargs) self.recurrence = recurrence self.cca = CrissCrossAttention(self.channels) diff --git a/mmsegmentation/mmseg/models/decode_heads/da_head.py b/mmsegmentation/mmseg/models/decode_heads/da_head.py index 77fd663..8c0ca7c 100644 --- a/mmsegmentation/mmseg/models/decode_heads/da_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/da_head.py @@ -19,7 +19,7 @@ class PAM(_SelfAttentionBlock): """ def __init__(self, in_channels, channels): - super(PAM, self).__init__( + super().__init__( key_in_channels=in_channels, query_in_channels=in_channels, channels=channels, @@ -35,13 +35,14 @@ def __init__(self, in_channels, channels): with_out=False, conv_cfg=None, norm_cfg=None, - act_cfg=None) + act_cfg=None, + ) self.gamma = Scale(0) def forward(self, x): """Forward function.""" - out = super(PAM, self).forward(x, x) + out = super().forward(x, x) out = self.gamma(out) + x return out @@ -51,7 +52,7 @@ class CAM(nn.Module): """Channel Attention Module (CAM)""" def __init__(self): - super(CAM, self).__init__() + super().__init__() self.gamma = Scale(0) def forward(self, x): @@ -60,8 +61,7 @@ def forward(self, x): proj_query = x.view(batch_size, channels, -1) proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1) energy = torch.bmm(proj_query, proj_key) - energy_new = torch.max( - energy, -1, keepdim=True)[0].expand_as(energy) - energy + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy attention = F.softmax(energy_new, dim=-1) proj_value = x.view(batch_size, channels, -1) @@ -84,7 +84,7 @@ class DAHead(BaseDecodeHead): """ def __init__(self, pam_channels, **kwargs): - super(DAHead, self).__init__(**kwargs) + super().__init__(**kwargs) self.pam_channels = pam_channels self.pam_in_conv = ConvModule( self.in_channels, @@ -93,7 +93,8 @@ def __init__(self, pam_channels, **kwargs): padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.pam = PAM(self.channels, pam_channels) self.pam_out_conv = ConvModule( self.channels, @@ -102,9 +103,9 @@ def __init__(self, pam_channels, **kwargs): padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) - self.pam_conv_seg = nn.Conv2d( - self.channels, self.num_classes, kernel_size=1) + act_cfg=self.act_cfg, + ) + self.pam_conv_seg = nn.Conv2d(self.channels, self.num_classes, kernel_size=1) self.cam_in_conv = ConvModule( self.in_channels, @@ -113,7 +114,8 @@ def __init__(self, pam_channels, **kwargs): padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.cam = CAM() self.cam_out_conv = ConvModule( self.channels, @@ -122,9 +124,9 @@ def __init__(self, pam_channels, **kwargs): padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) - self.cam_conv_seg = nn.Conv2d( - self.channels, self.num_classes, kernel_size=1) + act_cfg=self.act_cfg, + ) + self.cam_conv_seg = nn.Conv2d(self.channels, self.num_classes, kernel_size=1) def pam_cls_seg(self, feat): """PAM feature classification.""" @@ -166,14 +168,7 @@ def losses(self, seg_logit, seg_label): """Compute ``pam_cam``, ``pam``, ``cam`` loss.""" pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit loss = dict() - loss.update( - add_prefix( - super(DAHead, self).losses(pam_cam_seg_logit, seg_label), - 'pam_cam')) - loss.update( - add_prefix( - super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam')) - loss.update( - add_prefix( - super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam')) + loss.update(add_prefix(super().losses(pam_cam_seg_logit, seg_label), "pam_cam")) + loss.update(add_prefix(super().losses(pam_seg_logit, seg_label), "pam")) + loss.update(add_prefix(super().losses(cam_seg_logit, seg_label), "cam")) return loss diff --git a/mmsegmentation/mmseg/models/decode_heads/decode_head.py b/mmsegmentation/mmseg/models/decode_heads/decode_head.py index c893f76..87da02d 100644 --- a/mmsegmentation/mmseg/models/decode_heads/decode_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/decode_head.py @@ -55,29 +55,27 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta): init_cfg (dict or list[dict], optional): Initialization config dict. """ - def __init__(self, - in_channels, - channels, - *, - num_classes, - out_channels=None, - threshold=None, - dropout_ratio=0.1, - conv_cfg=None, - norm_cfg=None, - act_cfg=dict(type='ReLU'), - in_index=-1, - input_transform=None, - loss_decode=dict( - type='CrossEntropyLoss', - use_sigmoid=False, - loss_weight=1.0), - ignore_index=255, - sampler=None, - align_corners=False, - init_cfg=dict( - type='Normal', std=0.01, override=dict(name='conv_seg'))): - super(BaseDecodeHead, self).__init__(init_cfg) + def __init__( + self, + in_channels, + channels, + *, + num_classes, + out_channels=None, + threshold=None, + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type="ReLU"), + in_index=-1, + input_transform=None, + loss_decode=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), + ignore_index=255, + sampler=None, + align_corners=False, + init_cfg=dict(type="Normal", std=0.01, override=dict(name="conv_seg")), + ): + super().__init__(init_cfg) self._init_inputs(in_channels, in_index, input_transform) self.channels = channels self.dropout_ratio = dropout_ratio @@ -91,24 +89,26 @@ def __init__(self, if out_channels is None: if num_classes == 2: - warnings.warn('For binary segmentation, we suggest using' - '`out_channels = 1` to define the output' - 'channels of segmentor, and use `threshold`' - 'to convert seg_logist into a prediction' - 'applying a threshold') + warnings.warn( + "For binary segmentation, we suggest using" + "`out_channels = 1` to define the output" + "channels of segmentor, and use `threshold`" + "to convert seg_logist into a prediction" + "applying a threshold" + ) out_channels = num_classes if out_channels != num_classes and out_channels != 1: raise ValueError( - 'out_channels should be equal to num_classes,' - 'except binary segmentation set out_channels == 1 and' - f'num_classes == 2, but got out_channels={out_channels}' - f'and num_classes={num_classes}') + "out_channels should be equal to num_classes," + "except binary segmentation set out_channels == 1 and" + f"num_classes == 2, but got out_channels={out_channels}" + f"and num_classes={num_classes}" + ) if out_channels == 1 and threshold is None: threshold = 0.3 - warnings.warn('threshold is not defined for binary, and defaults' - 'to 0.3') + warnings.warn("threshold is not defined for binary, and defaults" "to 0.3") self.num_classes = num_classes self.out_channels = out_channels self.threshold = threshold @@ -120,8 +120,10 @@ def __init__(self, for loss in loss_decode: self.loss_decode.append(build_loss(loss)) else: - raise TypeError(f'loss_decode must be a dict or sequence of dict,\ - but got {type(loss_decode)}') + raise TypeError( + f"loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}" + ) if sampler is not None: self.sampler = build_pixel_sampler(sampler, context=self) @@ -137,9 +139,11 @@ def __init__(self, def extra_repr(self): """Extra repr.""" - s = f'input_transform={self.input_transform}, ' \ - f'ignore_index={self.ignore_index}, ' \ - f'align_corners={self.align_corners}' + s = ( + f"input_transform={self.input_transform}, " + f"ignore_index={self.ignore_index}, " + f"align_corners={self.align_corners}" + ) return s def _init_inputs(self, in_channels, in_index, input_transform): @@ -164,14 +168,14 @@ def _init_inputs(self, in_channels, in_index, input_transform): """ if input_transform is not None: - assert input_transform in ['resize_concat', 'multiple_select'] + assert input_transform in ["resize_concat", "multiple_select"] self.input_transform = input_transform self.in_index = in_index if input_transform is not None: assert isinstance(in_channels, (list, tuple)) assert isinstance(in_index, (list, tuple)) assert len(in_channels) == len(in_index) - if input_transform == 'resize_concat': + if input_transform == "resize_concat": self.in_channels = sum(in_channels) else: self.in_channels = in_channels @@ -190,17 +194,19 @@ def _transform_inputs(self, inputs): Tensor: The transformed inputs """ - if self.input_transform == 'resize_concat': + if self.input_transform == "resize_concat": inputs = [inputs[i] for i in self.in_index] upsampled_inputs = [ resize( input=x, size=inputs[0].shape[2:], - mode='bilinear', - align_corners=self.align_corners) for x in inputs + mode="bilinear", + align_corners=self.align_corners, + ) + for x in inputs ] inputs = torch.cat(upsampled_inputs, dim=1) - elif self.input_transform == 'multiple_select': + elif self.input_transform == "multiple_select": inputs = [inputs[i] for i in self.in_index] else: inputs = inputs[self.in_index] @@ -211,7 +217,6 @@ def _transform_inputs(self, inputs): @abstractmethod def forward(self, inputs): """Placeholder of forward function.""" - pass def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): """Forward function for training. @@ -257,15 +262,16 @@ def cls_seg(self, feat): output = self.conv_seg(feat) return output - @force_fp32(apply_to=('seg_logit', )) + @force_fp32(apply_to=("seg_logit",)) def losses(self, seg_logit, seg_label): """Compute segmentation loss.""" loss = dict() seg_logit = resize( input=seg_logit, size=seg_label.shape[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) if self.sampler is not None: seg_weight = self.sampler.sample(seg_logit, seg_label) else: @@ -282,14 +288,15 @@ def losses(self, seg_logit, seg_label): seg_logit, seg_label, weight=seg_weight, - ignore_index=self.ignore_index) + ignore_index=self.ignore_index, + ) else: loss[loss_decode.loss_name] += loss_decode( seg_logit, seg_label, weight=seg_weight, - ignore_index=self.ignore_index) + ignore_index=self.ignore_index, + ) - loss['acc_seg'] = accuracy( - seg_logit, seg_label, ignore_index=self.ignore_index) + loss["acc_seg"] = accuracy(seg_logit, seg_label, ignore_index=self.ignore_index) return loss diff --git a/mmsegmentation/mmseg/models/decode_heads/dm_head.py b/mmsegmentation/mmseg/models/decode_heads/dm_head.py index ffaa870..742af3c 100644 --- a/mmsegmentation/mmseg/models/decode_heads/dm_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/dm_head.py @@ -22,9 +22,10 @@ class DCM(nn.Module): act_cfg (dict): Config of activation layers. """ - def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg, - norm_cfg, act_cfg): - super(DCM, self).__init__() + def __init__( + self, filter_size, fusion, in_channels, channels, conv_cfg, norm_cfg, act_cfg + ): + super().__init__() self.filter_size = filter_size self.fusion = fusion self.in_channels = in_channels @@ -32,8 +33,7 @@ def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg, self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg - self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1, - 0) + self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1, 0) self.input_redu_conv = ConvModule( self.in_channels, @@ -41,7 +41,8 @@ def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) if self.norm_cfg is not None: self.norm = build_norm_layer(self.norm_cfg, self.channels)[1] @@ -56,25 +57,28 @@ def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) def forward(self, x): """Forward function.""" generated_filter = self.filter_gen_conv( - F.adaptive_avg_pool2d(x, self.filter_size)) + F.adaptive_avg_pool2d(x, self.filter_size) + ) x = self.input_redu_conv(x) b, c, h, w = x.shape # [1, b * c, h, w], c = self.channels x = x.view(1, b * c, h, w) # [b * c, 1, filter_size, filter_size] - generated_filter = generated_filter.view(b * c, 1, self.filter_size, - self.filter_size) + generated_filter = generated_filter.view( + b * c, 1, self.filter_size, self.filter_size + ) pad = (self.filter_size - 1) // 2 if (self.filter_size - 1) % 2 == 0: p2d = (pad, pad, pad, pad) else: p2d = (pad + 1, pad, pad + 1, pad) - x = F.pad(input=x, pad=p2d, mode='constant', value=0) + x = F.pad(input=x, pad=p2d, mode="constant", value=0) # [1, b * c, h, w] output = F.conv2d(input=x, weight=generated_filter, groups=b * c) # [b, c, h, w] @@ -105,20 +109,23 @@ class DMHead(BaseDecodeHead): """ def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs): - super(DMHead, self).__init__(**kwargs) + super().__init__(**kwargs) assert isinstance(filter_sizes, (list, tuple)) self.filter_sizes = filter_sizes self.fusion = fusion dcm_modules = [] for filter_size in self.filter_sizes: dcm_modules.append( - DCM(filter_size, + DCM( + filter_size, self.fusion, self.in_channels, self.channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg)) + act_cfg=self.act_cfg, + ) + ) self.dcm_modules = nn.ModuleList(dcm_modules) self.bottleneck = ConvModule( self.in_channels + len(filter_sizes) * self.channels, @@ -127,7 +134,8 @@ def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs): padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) def forward(self, inputs): """Forward function.""" diff --git a/mmsegmentation/mmseg/models/decode_heads/dnl_head.py b/mmsegmentation/mmseg/models/decode_heads/dnl_head.py index dabf154..da01954 100644 --- a/mmsegmentation/mmseg/models/decode_heads/dnl_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/dnl_head.py @@ -27,12 +27,9 @@ def embedded_gaussian(self, theta_x, phi_x): if self.use_scale: # theta_x.shape[-1] is `self.inter_channels` pairwise_weight /= torch.tensor( - theta_x.shape[-1], - dtype=torch.float, - device=pairwise_weight.device)**torch.tensor( - 0.5, device=pairwise_weight.device) - pairwise_weight /= torch.tensor( - self.temperature, device=pairwise_weight.device) + theta_x.shape[-1], dtype=torch.float, device=pairwise_weight.device + ) ** torch.tensor(0.5, device=pairwise_weight.device) + pairwise_weight /= torch.tensor(self.temperature, device=pairwise_weight.device) pairwise_weight = pairwise_weight.softmax(dim=-1) return pairwise_weight @@ -45,14 +42,14 @@ def forward(self, x): g_x = g_x.permute(0, 2, 1) # theta_x: [N, HxW, C], phi_x: [N, C, HxW] - if self.mode == 'gaussian': + if self.mode == "gaussian": theta_x = x.view(n, self.in_channels, -1) theta_x = theta_x.permute(0, 2, 1) if self.sub_sample: phi_x = self.phi(x).view(n, self.in_channels, -1) else: phi_x = x.view(n, self.in_channels, -1) - elif self.mode == 'concatenation': + elif self.mode == "concatenation": theta_x = self.theta(x).view(n, self.inter_channels, -1, 1) phi_x = self.phi(x).view(n, self.inter_channels, 1, -1) else: @@ -71,8 +68,11 @@ def forward(self, x): # y: [N, HxW, C] y = torch.matmul(pairwise_weight, g_x) # y: [N, C, H, W] - y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, - *x.size()[2:]) + y = ( + y.permute(0, 2, 1) + .contiguous() + .reshape(n, self.inter_channels, *x.size()[2:]) + ) # unary_mask: [N, 1, HxW] unary_mask = self.conv_mask(x) @@ -81,8 +81,9 @@ def forward(self, x): # unary_x: [N, 1, C] unary_x = torch.matmul(unary_mask, g_x) # unary_x: [N, C, 1, 1] - unary_x = unary_x.permute(0, 2, 1).contiguous().reshape( - n, self.inter_channels, 1, 1) + unary_x = ( + unary_x.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, 1, 1) + ) output = x + self.conv_out(y + unary_x) @@ -105,13 +106,15 @@ class DNLHead(FCNHead): temperature (float): Temperature to adjust attention. Default: 0.05 """ - def __init__(self, - reduction=2, - use_scale=True, - mode='embedded_gaussian', - temperature=0.05, - **kwargs): - super(DNLHead, self).__init__(num_convs=2, **kwargs) + def __init__( + self, + reduction=2, + use_scale=True, + mode="embedded_gaussian", + temperature=0.05, + **kwargs, + ): + super().__init__(num_convs=2, **kwargs) self.reduction = reduction self.use_scale = use_scale self.mode = mode @@ -123,7 +126,8 @@ def __init__(self, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, mode=self.mode, - temperature=self.temperature) + temperature=self.temperature, + ) def forward(self, inputs): """Forward function.""" diff --git a/mmsegmentation/mmseg/models/decode_heads/dpt_head.py b/mmsegmentation/mmseg/models/decode_heads/dpt_head.py index 6c895d0..cee39f8 100644 --- a/mmsegmentation/mmseg/models/decode_heads/dpt_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/dpt_head.py @@ -24,55 +24,67 @@ class ReassembleBlocks(BaseModule): init_cfg (dict, optional): Initialization config dict. Default: None. """ - def __init__(self, - in_channels=768, - out_channels=[96, 192, 384, 768], - readout_type='ignore', - patch_size=16, - init_cfg=None): - super(ReassembleBlocks, self).__init__(init_cfg) - - assert readout_type in ['ignore', 'add', 'project'] + def __init__( + self, + in_channels=768, + out_channels=[96, 192, 384, 768], + readout_type="ignore", + patch_size=16, + init_cfg=None, + ): + super().__init__(init_cfg) + + assert readout_type in ["ignore", "add", "project"] self.readout_type = readout_type self.patch_size = patch_size - self.projects = nn.ModuleList([ - ConvModule( - in_channels=in_channels, - out_channels=out_channel, - kernel_size=1, - act_cfg=None, - ) for out_channel in out_channels - ]) - - self.resize_layers = nn.ModuleList([ - nn.ConvTranspose2d( - in_channels=out_channels[0], - out_channels=out_channels[0], - kernel_size=4, - stride=4, - padding=0), - nn.ConvTranspose2d( - in_channels=out_channels[1], - out_channels=out_channels[1], - kernel_size=2, - stride=2, - padding=0), - nn.Identity(), - nn.Conv2d( - in_channels=out_channels[3], - out_channels=out_channels[3], - kernel_size=3, - stride=2, - padding=1) - ]) - if self.readout_type == 'project': + self.projects = nn.ModuleList( + [ + ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + act_cfg=None, + ) + for out_channel in out_channels + ] + ) + + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0, + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1, + ), + ] + ) + if self.readout_type == "project": self.readout_projects = nn.ModuleList() for _ in range(len(self.projects)): self.readout_projects.append( nn.Sequential( Linear(2 * in_channels, in_channels), - build_activation_layer(dict(type='GELU')))) + build_activation_layer(dict(type="GELU")), + ) + ) def forward(self, inputs): assert isinstance(inputs, list) @@ -81,12 +93,12 @@ def forward(self, inputs): assert len(x) == 2 x, cls_token = x[0], x[1] feature_shape = x.shape - if self.readout_type == 'project': + if self.readout_type == "project": x = x.flatten(2).permute((0, 2, 1)) readout = cls_token.unsqueeze(1).expand_as(x) x = self.readout_projects[i](torch.cat((x, readout), -1)) x = x.permute(0, 2, 1).reshape(feature_shape) - elif self.readout_type == 'add': + elif self.readout_type == "add": x = x.flatten(2) + cls_token.unsqueeze(-1) x = x.reshape(feature_shape) else: @@ -109,14 +121,10 @@ class PreActResidualConvUnit(BaseModule): init_cfg (dict, optional): Initialization config dict. Default: None. """ - def __init__(self, - in_channels, - act_cfg, - norm_cfg, - stride=1, - dilation=1, - init_cfg=None): - super(PreActResidualConvUnit, self).__init__(init_cfg) + def __init__( + self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None + ): + super().__init__(init_cfg) self.conv1 = ConvModule( in_channels, @@ -128,7 +136,8 @@ def __init__(self, norm_cfg=norm_cfg, act_cfg=act_cfg, bias=False, - order=('act', 'conv', 'norm')) + order=("act", "conv", "norm"), + ) self.conv2 = ConvModule( in_channels, @@ -138,7 +147,8 @@ def __init__(self, norm_cfg=norm_cfg, act_cfg=act_cfg, bias=False, - order=('act', 'conv', 'norm')) + order=("act", "conv", "norm"), + ) def forward(self, inputs): inputs_ = inputs.clone() @@ -161,14 +171,16 @@ class FeatureFusionBlock(BaseModule): init_cfg (dict, optional): Initialization config dict. Default: None. """ - def __init__(self, - in_channels, - act_cfg, - norm_cfg, - expand=False, - align_corners=True, - init_cfg=None): - super(FeatureFusionBlock, self).__init__(init_cfg) + def __init__( + self, + in_channels, + act_cfg, + norm_cfg, + expand=False, + align_corners=True, + init_cfg=None, + ): + super().__init__(init_cfg) self.in_channels = in_channels self.expand = expand @@ -179,16 +191,15 @@ def __init__(self, self.out_channels = in_channels // 2 self.project = ConvModule( - self.in_channels, - self.out_channels, - kernel_size=1, - act_cfg=None, - bias=True) + self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True + ) self.res_conv_unit1 = PreActResidualConvUnit( - in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg + ) self.res_conv_unit2 = PreActResidualConvUnit( - in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg + ) def forward(self, *inputs): x = inputs[0] @@ -197,17 +208,14 @@ def forward(self, *inputs): res = resize( inputs[1], size=(x.shape[2], x.shape[3]), - mode='bilinear', - align_corners=False) + mode="bilinear", + align_corners=False, + ) else: res = inputs[1] x = x + self.res_conv_unit1(res) x = self.res_conv_unit2(x) - x = resize( - x, - scale_factor=2, - mode='bilinear', - align_corners=self.align_corners) + x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) x = self.project(x) return x @@ -233,22 +241,24 @@ class DPTHead(BaseDecodeHead): Default: dict(type='BN'). """ - def __init__(self, - embed_dims=768, - post_process_channels=[96, 192, 384, 768], - readout_type='ignore', - patch_size=16, - expand_channels=False, - act_cfg=dict(type='ReLU'), - norm_cfg=dict(type='BN'), - **kwargs): - super(DPTHead, self).__init__(**kwargs) + def __init__( + self, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + readout_type="ignore", + patch_size=16, + expand_channels=False, + act_cfg=dict(type="ReLU"), + norm_cfg=dict(type="BN"), + **kwargs, + ): + super().__init__(**kwargs) self.in_channels = self.in_channels self.expand_channels = expand_channels - self.reassemble_blocks = ReassembleBlocks(embed_dims, - post_process_channels, - readout_type, patch_size) + self.reassemble_blocks = ReassembleBlocks( + embed_dims, post_process_channels, readout_type, patch_size + ) self.post_process_channels = [ channel * math.pow(2, i) if expand_channels else channel @@ -263,18 +273,18 @@ def __init__(self, kernel_size=3, padding=1, act_cfg=None, - bias=False)) + bias=False, + ) + ) self.fusion_blocks = nn.ModuleList() for _ in range(len(self.convs)): self.fusion_blocks.append( - FeatureFusionBlock(self.channels, act_cfg, norm_cfg)) + FeatureFusionBlock(self.channels, act_cfg, norm_cfg) + ) self.fusion_blocks[0].res_conv_unit1 = None self.project = ConvModule( - self.channels, - self.channels, - kernel_size=3, - padding=1, - norm_cfg=norm_cfg) + self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=norm_cfg + ) self.num_fusion_blocks = len(self.fusion_blocks) self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) self.num_post_process_channels = len(self.post_process_channels) diff --git a/mmsegmentation/mmseg/models/decode_heads/ema_head.py b/mmsegmentation/mmseg/models/decode_heads/ema_head.py index f6de167..c588b7e 100644 --- a/mmsegmentation/mmseg/models/decode_heads/ema_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/ema_head.py @@ -30,17 +30,17 @@ class EMAModule(nn.Module): """ def __init__(self, channels, num_bases, num_stages, momentum): - super(EMAModule, self).__init__() - assert num_stages >= 1, 'num_stages must be at least 1!' + super().__init__() + assert num_stages >= 1, "num_stages must be at least 1!" self.num_bases = num_bases self.num_stages = num_stages self.momentum = momentum bases = torch.zeros(1, channels, self.num_bases) - bases.normal_(0, math.sqrt(2. / self.num_bases)) + bases.normal_(0, math.sqrt(2.0 / self.num_bases)) # [1, channels, num_bases] bases = F.normalize(bases, dim=1, p=2) - self.register_buffer('bases', bases) + self.register_buffer("bases", bases) def forward(self, feats): """Forward function.""" @@ -53,16 +53,16 @@ def forward(self, feats): with torch.no_grad(): for i in range(self.num_stages): # [batch_size, height*width, num_bases] - attention = torch.einsum('bcn,bck->bnk', feats, bases) + attention = torch.einsum("bcn,bck->bnk", feats, bases) attention = F.softmax(attention, dim=2) # l1 norm attention_normed = F.normalize(attention, dim=1, p=1) # [batch_size, channels, num_bases] - bases = torch.einsum('bcn,bnk->bck', feats, attention_normed) + bases = torch.einsum("bcn,bnk->bck", feats, attention_normed) # l2 norm bases = F.normalize(bases, dim=1, p=2) - feats_recon = torch.einsum('bck,bnk->bcn', bases, attention) + feats_recon = torch.einsum("bck,bnk->bcn", bases, attention) feats_recon = feats_recon.view(batch_size, channels, height, width) if self.training: @@ -70,8 +70,7 @@ def forward(self, feats): bases = reduce_mean(bases) # l2 norm bases = F.normalize(bases, dim=1, p=2) - self.bases = (1 - - self.momentum) * self.bases + self.momentum * bases + self.bases = (1 - self.momentum) * self.bases + self.momentum * bases return feats_recon @@ -92,21 +91,24 @@ class EMAHead(BaseDecodeHead): momentum (float): Momentum to update the base. Default: 0.1. """ - def __init__(self, - ema_channels, - num_bases, - num_stages, - concat_input=True, - momentum=0.1, - **kwargs): - super(EMAHead, self).__init__(**kwargs) + def __init__( + self, + ema_channels, + num_bases, + num_stages, + concat_input=True, + momentum=0.1, + **kwargs, + ): + super().__init__(**kwargs) self.ema_channels = ema_channels self.num_bases = num_bases self.num_stages = num_stages self.concat_input = concat_input self.momentum = momentum - self.ema_module = EMAModule(self.ema_channels, self.num_bases, - self.num_stages, self.momentum) + self.ema_module = EMAModule( + self.ema_channels, self.num_bases, self.num_stages, self.momentum + ) self.ema_in_conv = ConvModule( self.in_channels, @@ -115,7 +117,8 @@ def __init__(self, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) # project (0, inf) -> (-inf, inf) self.ema_mid_conv = ConvModule( self.ema_channels, @@ -123,7 +126,8 @@ def __init__(self, 1, conv_cfg=self.conv_cfg, norm_cfg=None, - act_cfg=None) + act_cfg=None, + ) for param in self.ema_mid_conv.parameters(): param.requires_grad = False @@ -133,7 +137,8 @@ def __init__(self, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=None) + act_cfg=None, + ) self.bottleneck = ConvModule( self.ema_channels, self.channels, @@ -141,7 +146,8 @@ def __init__(self, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) if self.concat_input: self.conv_cat = ConvModule( self.in_channels + self.channels, @@ -150,7 +156,8 @@ def __init__(self, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) def forward(self, inputs): """Forward function.""" diff --git a/mmsegmentation/mmseg/models/decode_heads/enc_head.py b/mmsegmentation/mmseg/models/decode_heads/enc_head.py index 648c890..5ddd1b9 100644 --- a/mmsegmentation/mmseg/models/decode_heads/enc_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/enc_head.py @@ -21,32 +21,34 @@ class EncModule(nn.Module): """ def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg): - super(EncModule, self).__init__() + super().__init__() self.encoding_project = ConvModule( in_channels, in_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) # TODO: resolve this hack # change to 1d if norm_cfg is not None: encoding_norm_cfg = norm_cfg.copy() - if encoding_norm_cfg['type'] in ['BN', 'IN']: - encoding_norm_cfg['type'] += '1d' + if encoding_norm_cfg["type"] in ["BN", "IN"]: + encoding_norm_cfg["type"] += "1d" else: - encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace( - '2d', '1d') + encoding_norm_cfg["type"] = encoding_norm_cfg["type"].replace( + "2d", "1d" + ) else: # fallback to BN1d - encoding_norm_cfg = dict(type='BN1d') + encoding_norm_cfg = dict(type="BN1d") self.encoding = nn.Sequential( Encoding(channels=in_channels, num_codes=num_codes), build_norm_layer(encoding_norm_cfg, num_codes)[1], - nn.ReLU(inplace=True)) - self.fc = nn.Sequential( - nn.Linear(in_channels, in_channels), nn.Sigmoid()) + nn.ReLU(inplace=True), + ) + self.fc = nn.Sequential(nn.Linear(in_channels, in_channels), nn.Sigmoid()) def forward(self, x): """Forward function.""" @@ -76,17 +78,15 @@ class EncHead(BaseDecodeHead): Default: dict(type='CrossEntropyLoss', use_sigmoid=True). """ - def __init__(self, - num_codes=32, - use_se_loss=True, - add_lateral=False, - loss_se_decode=dict( - type='CrossEntropyLoss', - use_sigmoid=True, - loss_weight=0.2), - **kwargs): - super(EncHead, self).__init__( - input_transform='multiple_select', **kwargs) + def __init__( + self, + num_codes=32, + use_se_loss=True, + add_lateral=False, + loss_se_decode=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=0.2), + **kwargs, + ): + super().__init__(input_transform="multiple_select", **kwargs) self.use_se_loss = use_se_loss self.add_lateral = add_lateral self.num_codes = num_codes @@ -97,7 +97,8 @@ def __init__(self, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) if add_lateral: self.lateral_convs = nn.ModuleList() for in_channels in self.in_channels[:-1]: # skip the last one @@ -108,7 +109,9 @@ def __init__(self, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg)) + act_cfg=self.act_cfg, + ) + ) self.fusion = ConvModule( len(self.in_channels) * self.channels, self.channels, @@ -116,13 +119,15 @@ def __init__(self, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.enc_module = EncModule( self.channels, num_codes=num_codes, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) if self.use_se_loss: self.loss_se_decode = build_loss(loss_se_decode) self.se_layer = nn.Linear(self.channels, self.num_classes) @@ -136,8 +141,9 @@ def forward(self, inputs): resize( lateral_conv(inputs[i]), size=feat.shape[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) for i, lateral_conv in enumerate(self.lateral_convs) ] feat = self.fusion(torch.cat([feat, *laterals], 1)) @@ -171,8 +177,9 @@ def _convert_to_onehot_labels(seg_label, num_classes): batch_size = seg_label.size(0) onehot_labels = seg_label.new_zeros((batch_size, num_classes)) for i in range(batch_size): - hist = seg_label[i].float().histc( - bins=num_classes, min=0, max=num_classes - 1) + hist = ( + seg_label[i].float().histc(bins=num_classes, min=0, max=num_classes - 1) + ) onehot_labels[i] = hist > 0 return onehot_labels @@ -180,9 +187,9 @@ def losses(self, seg_logit, seg_label): """Compute segmentation and semantic encoding loss.""" seg_logit, se_seg_logit = seg_logit loss = dict() - loss.update(super(EncHead, self).losses(seg_logit, seg_label)) + loss.update(super().losses(seg_logit, seg_label)) se_loss = self.loss_se_decode( - se_seg_logit, - self._convert_to_onehot_labels(seg_label, self.num_classes)) - loss['loss_se'] = se_loss + se_seg_logit, self._convert_to_onehot_labels(seg_label, self.num_classes) + ) + loss["loss_se"] = se_loss return loss diff --git a/mmsegmentation/mmseg/models/decode_heads/fcn_head.py b/mmsegmentation/mmseg/models/decode_heads/fcn_head.py index e27be69..6dd9d1b 100644 --- a/mmsegmentation/mmseg/models/decode_heads/fcn_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/fcn_head.py @@ -21,17 +21,14 @@ class FCNHead(BaseDecodeHead): dilation (int): The dilation rate for convs in the head. Default: 1. """ - def __init__(self, - num_convs=2, - kernel_size=3, - concat_input=True, - dilation=1, - **kwargs): + def __init__( + self, num_convs=2, kernel_size=3, concat_input=True, dilation=1, **kwargs + ): assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) self.num_convs = num_convs self.concat_input = concat_input self.kernel_size = kernel_size - super(FCNHead, self).__init__(**kwargs) + super().__init__(**kwargs) if num_convs == 0: assert self.in_channels == self.channels @@ -48,7 +45,9 @@ def __init__(self, dilation=dilation, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg)) + act_cfg=self.act_cfg, + ) + ) if len(convs) == 0: self.convs = nn.Identity() @@ -62,7 +61,8 @@ def __init__(self, padding=kernel_size // 2, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) def _forward_feature(self, inputs): """Forward function for feature maps before classifying each pixel with diff --git a/mmsegmentation/mmseg/models/decode_heads/fpn_head.py b/mmsegmentation/mmseg/models/decode_heads/fpn_head.py index e41f324..2dfda8a 100644 --- a/mmsegmentation/mmseg/models/decode_heads/fpn_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/fpn_head.py @@ -22,8 +22,7 @@ class FPNHead(BaseDecodeHead): """ def __init__(self, feature_strides, **kwargs): - super(FPNHead, self).__init__( - input_transform='multiple_select', **kwargs) + super().__init__(input_transform="multiple_select", **kwargs) assert len(feature_strides) == len(self.in_channels) assert min(feature_strides) == feature_strides[0] self.feature_strides = feature_strides @@ -31,8 +30,8 @@ def __init__(self, feature_strides, **kwargs): self.scale_heads = nn.ModuleList() for i in range(len(feature_strides)): head_length = max( - 1, - int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) + 1, int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])) + ) scale_head = [] for k in range(head_length): scale_head.append( @@ -43,17 +42,20 @@ def __init__(self, feature_strides, **kwargs): padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg)) + act_cfg=self.act_cfg, + ) + ) if feature_strides[i] != feature_strides[0]: scale_head.append( Upsample( scale_factor=2, - mode='bilinear', - align_corners=self.align_corners)) + mode="bilinear", + align_corners=self.align_corners, + ) + ) self.scale_heads.append(nn.Sequential(*scale_head)) def forward(self, inputs): - x = self._transform_inputs(inputs) output = self.scale_heads[0](x[0]) @@ -62,8 +64,9 @@ def forward(self, inputs): output = output + resize( self.scale_heads[i](x[i]), size=output.shape[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) output = self.cls_seg(output) return output diff --git a/mmsegmentation/mmseg/models/decode_heads/gc_head.py b/mmsegmentation/mmseg/models/decode_heads/gc_head.py index eed5074..695019f 100644 --- a/mmsegmentation/mmseg/models/decode_heads/gc_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/gc_head.py @@ -21,12 +21,10 @@ class GCHead(FCNHead): Options are 'channel_add', 'channel_mul'. Default: ('channel_add',) """ - def __init__(self, - ratio=1 / 4., - pooling_type='att', - fusion_types=('channel_add', ), - **kwargs): - super(GCHead, self).__init__(num_convs=2, **kwargs) + def __init__( + self, ratio=1 / 4.0, pooling_type="att", fusion_types=("channel_add",), **kwargs + ): + super().__init__(num_convs=2, **kwargs) self.ratio = ratio self.pooling_type = pooling_type self.fusion_types = fusion_types @@ -34,7 +32,8 @@ def __init__(self, in_channels=self.channels, ratio=self.ratio, pooling_type=self.pooling_type, - fusion_types=self.fusion_types) + fusion_types=self.fusion_types, + ) def forward(self, inputs): """Forward function.""" diff --git a/mmsegmentation/mmseg/models/decode_heads/isa_head.py b/mmsegmentation/mmseg/models/decode_heads/isa_head.py index 0bf3455..bcd86a7 100644 --- a/mmsegmentation/mmseg/models/decode_heads/isa_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/isa_head.py @@ -22,7 +22,7 @@ class SelfAttentionBlock(_SelfAttentionBlock): """ def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg): - super(SelfAttentionBlock, self).__init__( + super().__init__( key_in_channels=in_channels, query_in_channels=in_channels, channels=channels, @@ -38,7 +38,8 @@ def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg): with_out=False, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) self.output_project = self.build_project( in_channels, @@ -47,11 +48,12 @@ def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg): use_conv_module=True, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) def forward(self, x): """Forward function.""" - context = super(SelfAttentionBlock, self).forward(x, x) + context = super().forward(x, x) return self.output_project(context) @@ -68,7 +70,7 @@ class ISAHead(BaseDecodeHead): """ def __init__(self, isa_channels, down_factor=(8, 8), **kwargs): - super(ISAHead, self).__init__(**kwargs) + super().__init__(**kwargs) self.down_factor = down_factor self.in_conv = ConvModule( @@ -78,26 +80,30 @@ def __init__(self, isa_channels, down_factor=(8, 8), **kwargs): padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.global_relation = SelfAttentionBlock( self.channels, isa_channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.local_relation = SelfAttentionBlock( self.channels, isa_channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.out_conv = ConvModule( self.channels * 2, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) def forward(self, inputs): """Forward function.""" @@ -110,8 +116,7 @@ def forward(self, inputs): glb_h, glb_w = math.ceil(h / loc_h), math.ceil(w / loc_w) pad_h, pad_w = glb_h * loc_h - h, glb_w * loc_w - w if pad_h > 0 or pad_w > 0: # pad if the size is not divisible - padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, - pad_h - pad_h // 2) + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) x = F.pad(x, padding) # global relation @@ -135,7 +140,7 @@ def forward(self, inputs): x = x.permute(0, 3, 1, 4, 2, 5) # (n, c, glb_h, loc_h, glb_w, loc_w) x = x.reshape(n, c, glb_h * loc_h, glb_w * loc_w) if pad_h > 0 or pad_w > 0: # remove padding - x = x[:, :, pad_h // 2:pad_h // 2 + h, pad_w // 2:pad_w // 2 + w] + x = x[:, :, pad_h // 2 : pad_h // 2 + h, pad_w // 2 : pad_w // 2 + w] x = self.out_conv(torch.cat([x, residual], dim=1)) out = self.cls_seg(x) diff --git a/mmsegmentation/mmseg/models/decode_heads/knet_head.py b/mmsegmentation/mmseg/models/decode_heads/knet_head.py index 78a2702..1e082c1 100644 --- a/mmsegmentation/mmseg/models/decode_heads/knet_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/knet_head.py @@ -3,9 +3,12 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer -from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER, - MultiheadAttention, - build_transformer_layer) +from mmcv.cnn.bricks.transformer import ( + FFN, + TRANSFORMER_LAYER, + MultiheadAttention, + build_transformer_layer, +) from mmseg.models.builder import HEADS, build_head from mmseg.models.decode_heads.decode_head import BaseDecodeHead @@ -35,17 +38,17 @@ class KernelUpdator(nn.Module): """ def __init__( - self, - in_channels=256, - feat_channels=64, - out_channels=None, - gate_sigmoid=True, - gate_norm_act=False, - activate_out=False, - norm_cfg=dict(type='LN'), - act_cfg=dict(type='ReLU', inplace=True), + self, + in_channels=256, + feat_channels=64, + out_channels=None, + gate_sigmoid=True, + gate_norm_act=False, + activate_out=False, + norm_cfg=dict(type="LN"), + act_cfg=dict(type="ReLU", inplace=True), ): - super(KernelUpdator, self).__init__() + super().__init__() self.in_channels = in_channels self.feat_channels = feat_channels self.out_channels_raw = out_channels @@ -59,10 +62,11 @@ def __init__( self.num_params_in = self.feat_channels self.num_params_out = self.feat_channels self.dynamic_layer = nn.Linear( - self.in_channels, self.num_params_in + self.num_params_out) - self.input_layer = nn.Linear(self.in_channels, - self.num_params_in + self.num_params_out, - 1) + self.in_channels, self.num_params_in + self.num_params_out + ) + self.input_layer = nn.Linear( + self.in_channels, self.num_params_in + self.num_params_out, 1 + ) self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1) self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1) if self.gate_norm_act: @@ -98,17 +102,16 @@ def forward(self, update_feature, input_feature): # dynamic_layer works for # phi_1 and psi_3 in Eq.(4) and (5) of K-Net paper parameters = self.dynamic_layer(update_feature) - param_in = parameters[:, :self.num_params_in].view( - -1, self.feat_channels) - param_out = parameters[:, -self.num_params_out:].view( - -1, self.feat_channels) + param_in = parameters[:, : self.num_params_in].view(-1, self.feat_channels) + param_out = parameters[:, -self.num_params_out :].view(-1, self.feat_channels) # input_layer works for # phi_2 and psi_4 in Eq.(4) and (5) of K-Net paper input_feats = self.input_layer( - input_feature.reshape(num_proposals, -1, self.feat_channels)) - input_in = input_feats[..., :self.num_params_in] - input_out = input_feats[..., -self.num_params_out:] + input_feature.reshape(num_proposals, -1, self.feat_channels) + ) + input_in = input_feats[..., : self.num_params_in] + input_out = input_feats[..., -self.num_params_out :] # `gate_feats` is F^G in K-Net paper gate_feats = input_in * param_in.unsqueeze(-2) @@ -129,8 +132,7 @@ def forward(self, update_feature, input_feature): # Gate mechanism. Eq.(5) in original paper. # param_out has shape (batch_size, feat_channels, out_channels) - features = update_gate * param_out.unsqueeze( - -2) + input_gate * input_out + features = update_gate * param_out.unsqueeze(-2) + input_gate * input_out features = self.fc_layer(features) features = self.fc_norm(features) @@ -186,31 +188,34 @@ class KernelUpdateHead(nn.Module): norm_cfg=dict(type='LN')). """ - def __init__(self, - num_classes=150, - num_ffn_fcs=2, - num_heads=8, - num_mask_fcs=3, - feedforward_channels=2048, - in_channels=256, - out_channels=256, - dropout=0.0, - act_cfg=dict(type='ReLU', inplace=True), - ffn_act_cfg=dict(type='ReLU', inplace=True), - conv_kernel_size=1, - feat_transform_cfg=None, - kernel_init=False, - with_ffn=True, - feat_gather_stride=1, - mask_transform_stride=1, - kernel_updator_cfg=dict( - type='DynamicConv', - in_channels=256, - feat_channels=64, - out_channels=256, - act_cfg=dict(type='ReLU', inplace=True), - norm_cfg=dict(type='LN'))): - super(KernelUpdateHead, self).__init__() + def __init__( + self, + num_classes=150, + num_ffn_fcs=2, + num_heads=8, + num_mask_fcs=3, + feedforward_channels=2048, + in_channels=256, + out_channels=256, + dropout=0.0, + act_cfg=dict(type="ReLU", inplace=True), + ffn_act_cfg=dict(type="ReLU", inplace=True), + conv_kernel_size=1, + feat_transform_cfg=None, + kernel_init=False, + with_ffn=True, + feat_gather_stride=1, + mask_transform_stride=1, + kernel_updator_cfg=dict( + type="DynamicConv", + in_channels=256, + feat_channels=64, + out_channels=256, + act_cfg=dict(type="ReLU", inplace=True), + norm_cfg=dict(type="LN"), + ), + ): + super().__init__() self.num_classes = num_classes self.in_channels = in_channels self.out_channels = out_channels @@ -223,14 +228,16 @@ def __init__(self, self.feat_gather_stride = feat_gather_stride self.mask_transform_stride = mask_transform_stride - self.attention = MultiheadAttention(in_channels * conv_kernel_size**2, - num_heads, dropout) + self.attention = MultiheadAttention( + in_channels * conv_kernel_size**2, num_heads, dropout + ) self.attention_norm = build_norm_layer( - dict(type='LN'), in_channels * conv_kernel_size**2)[1] + dict(type="LN"), in_channels * conv_kernel_size**2 + )[1] self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg) if feat_transform_cfg is not None: - kernel_size = feat_transform_cfg.pop('kernel_size', 1) + kernel_size = feat_transform_cfg.pop("kernel_size", 1) transform_channels = in_channels self.feat_transform = ConvModule( transform_channels, @@ -238,7 +245,8 @@ def __init__(self, kernel_size, stride=feat_gather_stride, padding=int(feat_gather_stride // 2), - **feat_transform_cfg) + **feat_transform_cfg, + ) else: self.feat_transform = None @@ -248,15 +256,14 @@ def __init__(self, feedforward_channels, num_ffn_fcs, act_cfg=ffn_act_cfg, - dropout=dropout) - self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1] + dropout=dropout, + ) + self.ffn_norm = build_norm_layer(dict(type="LN"), in_channels)[1] self.mask_fcs = nn.ModuleList() for _ in range(num_mask_fcs): - self.mask_fcs.append( - nn.Linear(in_channels, in_channels, bias=False)) - self.mask_fcs.append( - build_norm_layer(dict(type='LN'), in_channels)[1]) + self.mask_fcs.append(nn.Linear(in_channels, in_channels, bias=False)) + self.mask_fcs.append(build_norm_layer(dict(type="LN"), in_channels)[1]) self.mask_fcs.append(build_activation_layer(act_cfg)) self.fc_mask = nn.Linear(in_channels, out_channels) @@ -273,8 +280,7 @@ def init_weights(self): pass if self.kernel_init: logger = get_root_logger() - logger.info( - 'mask kernel in mask head is normal initialized by std 0.01') + logger.info("mask kernel in mask head is normal initialized by std 0.01") nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01) def forward(self, x, proposal_feat, mask_preds, mask_shape=None): @@ -303,7 +309,8 @@ def forward(self, x, proposal_feat, mask_preds, mask_shape=None): mask_h, mask_w = mask_preds.shape[-2:] if mask_h != H or mask_w != W: gather_mask = F.interpolate( - mask_preds, (H, W), align_corners=False, mode='bilinear') + mask_preds, (H, W), align_corners=False, mode="bilinear" + ) else: gather_mask = mask_preds @@ -311,12 +318,12 @@ def forward(self, x, proposal_feat, mask_preds, mask_shape=None): # Group Feature Assembling. Eq.(3) in original paper. # einsum is faster than bmm by 30% - x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x) + x_feat = torch.einsum("bnhw,bchw->bnc", sigmoid_masks, x) # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C] - proposal_feat = proposal_feat.reshape(N, num_proposals, - self.in_channels, - -1).permute(0, 1, 3, 2) + proposal_feat = proposal_feat.reshape( + N, num_proposals, self.in_channels, -1 + ).permute(0, 1, 3, 2) obj_feat = self.kernel_update_conv(x_feat, proposal_feat) # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C] @@ -340,9 +347,10 @@ def forward(self, x, proposal_feat, mask_preds, mask_shape=None): # [B, N, K*K, C] -> [B, N, C, K*K] mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2) - if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1): + if self.mask_transform_stride == 2 and self.feat_gather_stride == 1: mask_x = F.interpolate( - x, scale_factor=0.5, mode='bilinear', align_corners=False) + x, scale_factor=0.5, mode="bilinear", align_corners=False + ) H, W = mask_x.shape[-2:] else: mask_x = x @@ -358,37 +366,39 @@ def forward(self, x, proposal_feat, mask_preds, mask_shape=None): # mask_feat = mask_feat.reshape(N, num_proposals, -1) # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x) # [B, N, C, K*K] -> [B*N, C, K, K] - mask_feat = mask_feat.reshape(N, num_proposals, C, - self.conv_kernel_size, - self.conv_kernel_size) + mask_feat = mask_feat.reshape( + N, num_proposals, C, self.conv_kernel_size, self.conv_kernel_size + ) # [B, C, H, W] -> [1, B*C, H, W] new_mask_preds = [] for i in range(N): new_mask_preds.append( F.conv2d( - mask_x[i:i + 1], + mask_x[i : i + 1], mask_feat[i], - padding=int(self.conv_kernel_size // 2))) + padding=int(self.conv_kernel_size // 2), + ) + ) new_mask_preds = torch.cat(new_mask_preds, dim=0) new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W) if self.mask_transform_stride == 2: new_mask_preds = F.interpolate( - new_mask_preds, - scale_factor=2, - mode='bilinear', - align_corners=False) + new_mask_preds, scale_factor=2, mode="bilinear", align_corners=False + ) if mask_shape is not None and mask_shape[0] != H: new_mask_preds = F.interpolate( - new_mask_preds, - mask_shape, - align_corners=False, - mode='bilinear') + new_mask_preds, mask_shape, align_corners=False, mode="bilinear" + ) return new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape( - N, num_proposals, self.in_channels, self.conv_kernel_size, - self.conv_kernel_size) + N, + num_proposals, + self.in_channels, + self.conv_kernel_size, + self.conv_kernel_size, + ) @HEADS.register_module() @@ -409,8 +419,7 @@ class IterativeDecodeHead(BaseDecodeHead): """ - def __init__(self, num_stages, kernel_generate_head, kernel_update_head, - **kwargs): + def __init__(self, num_stages, kernel_generate_head, kernel_update_head, **kwargs): # ``IterativeDecodeHead`` would skip initialization of # ``BaseDecodeHead`` which would be called when building # ``self.kernel_generate_head``. @@ -433,14 +442,13 @@ def forward(self, inputs): feats = self.kernel_generate_head._forward_feature(inputs) sem_seg = self.kernel_generate_head.cls_seg(feats) seg_kernels = self.kernel_generate_head.conv_seg.weight.clone() - seg_kernels = seg_kernels[None].expand( - feats.size(0), *seg_kernels.size()) + seg_kernels = seg_kernels[None].expand(feats.size(0), *seg_kernels.size()) stage_segs = [sem_seg] for i in range(self.num_stages): - sem_seg, seg_kernels = self.kernel_update_head[i](feats, - seg_kernels, - sem_seg) + sem_seg, seg_kernels = self.kernel_update_head[i]( + feats, seg_kernels, sem_seg + ) stage_segs.append(sem_seg) if self.training: return stage_segs @@ -452,6 +460,6 @@ def losses(self, seg_logit, seg_label): for i, logit in enumerate(seg_logit): loss = self.kernel_generate_head.losses(logit, seg_label) for k, v in loss.items(): - losses[f'{k}.s{i}'] = v + losses[f"{k}.s{i}"] = v return losses diff --git a/mmsegmentation/mmseg/models/decode_heads/lraspp_head.py b/mmsegmentation/mmseg/models/decode_heads/lraspp_head.py index c10ff0d..5101c28 100644 --- a/mmsegmentation/mmseg/models/decode_heads/lraspp_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/lraspp_head.py @@ -22,11 +22,13 @@ class LRASPPHead(BaseDecodeHead): """ def __init__(self, branch_channels=(32, 64), **kwargs): - super(LRASPPHead, self).__init__(**kwargs) - if self.input_transform != 'multiple_select': - raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' - f'must be \'multiple_select\'. But received ' - f'\'{self.input_transform}\'') + super().__init__(**kwargs) + if self.input_transform != "multiple_select": + raise ValueError( + "in Lite R-ASPP (LRASPP) head, input_transform " + f"must be 'multiple_select'. But received " + f"'{self.input_transform}'" + ) assert is_tuple_of(branch_channels, int) assert len(branch_channels) == len(self.in_channels) - 1 self.branch_channels = branch_channels @@ -35,18 +37,20 @@ def __init__(self, branch_channels=(32, 64), **kwargs): self.conv_ups = nn.Sequential() for i in range(len(branch_channels)): self.convs.add_module( - f'conv{i}', - nn.Conv2d( - self.in_channels[i], branch_channels[i], 1, bias=False)) + f"conv{i}", + nn.Conv2d(self.in_channels[i], branch_channels[i], 1, bias=False), + ) self.conv_ups.add_module( - f'conv_up{i}', + f"conv_up{i}", ConvModule( self.channels + branch_channels[i], self.channels, 1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, - bias=False)) + bias=False, + ), + ) self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) @@ -56,15 +60,18 @@ def __init__(self, branch_channels=(32, 64), **kwargs): 1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, - bias=False) + bias=False, + ) self.image_pool = nn.Sequential( nn.AvgPool2d(kernel_size=49, stride=(16, 20)), ConvModule( self.in_channels[2], self.channels, 1, - act_cfg=dict(type='Sigmoid'), - bias=False)) + act_cfg=dict(type="Sigmoid"), + bias=False, + ), + ) def forward(self, inputs): """Forward function.""" @@ -75,16 +82,18 @@ def forward(self, inputs): x = self.aspp_conv(x) * resize( self.image_pool(x), size=x.size()[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) x = self.conv_up_input(x) for i in range(len(self.branch_channels) - 1, -1, -1): x = resize( x, size=inputs[i].size()[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) x = torch.cat([x, self.convs[i](inputs[i])], 1) x = self.conv_ups[i](x) diff --git a/mmsegmentation/mmseg/models/decode_heads/nl_head.py b/mmsegmentation/mmseg/models/decode_heads/nl_head.py index 637517e..50df90a 100644 --- a/mmsegmentation/mmseg/models/decode_heads/nl_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/nl_head.py @@ -21,12 +21,8 @@ class NLHead(FCNHead): 'dot_product'. Default: 'embedded_gaussian.'. """ - def __init__(self, - reduction=2, - use_scale=True, - mode='embedded_gaussian', - **kwargs): - super(NLHead, self).__init__(num_convs=2, **kwargs) + def __init__(self, reduction=2, use_scale=True, mode="embedded_gaussian", **kwargs): + super().__init__(num_convs=2, **kwargs) self.reduction = reduction self.use_scale = use_scale self.mode = mode @@ -36,7 +32,8 @@ def __init__(self, use_scale=self.use_scale, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - mode=self.mode) + mode=self.mode, + ) def forward(self, inputs): """Forward function.""" diff --git a/mmsegmentation/mmseg/models/decode_heads/ocr_head.py b/mmsegmentation/mmseg/models/decode_heads/ocr_head.py index 09eadfb..e289c88 100644 --- a/mmsegmentation/mmseg/models/decode_heads/ocr_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/ocr_head.py @@ -18,7 +18,7 @@ class SpatialGatherModule(nn.Module): """ def __init__(self, scale): - super(SpatialGatherModule, self).__init__() + super().__init__() self.scale = scale def forward(self, feats, probs): @@ -40,13 +40,12 @@ def forward(self, feats, probs): class ObjectAttentionBlock(_SelfAttentionBlock): """Make a OCR used SelfAttentionBlock.""" - def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg, - act_cfg): + def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg, act_cfg): if scale > 1: query_downsample = nn.MaxPool2d(kernel_size=scale) else: query_downsample = None - super(ObjectAttentionBlock, self).__init__( + super().__init__( key_in_channels=in_channels, query_in_channels=in_channels, channels=channels, @@ -62,19 +61,20 @@ def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg, with_out=True, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) self.bottleneck = ConvModule( in_channels * 2, in_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) def forward(self, query_feats, key_feats): """Forward function.""" - context = super(ObjectAttentionBlock, - self).forward(query_feats, key_feats) + context = super().forward(query_feats, key_feats) output = self.bottleneck(torch.cat([context, query_feats], dim=1)) if self.query_downsample is not None: output = resize(query_feats) @@ -96,7 +96,7 @@ class OCRHead(BaseCascadeDecodeHead): """ def __init__(self, ocr_channels, scale=1, **kwargs): - super(OCRHead, self).__init__(**kwargs) + super().__init__(**kwargs) self.ocr_channels = ocr_channels self.scale = scale self.object_context_block = ObjectAttentionBlock( @@ -105,7 +105,8 @@ def __init__(self, ocr_channels, scale=1, **kwargs): self.scale, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.spatial_gather_module = SpatialGatherModule(self.scale) self.bottleneck = ConvModule( @@ -115,7 +116,8 @@ def __init__(self, ocr_channels, scale=1, **kwargs): padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) def forward(self, inputs, prev_output): """Forward function.""" diff --git a/mmsegmentation/mmseg/models/decode_heads/point_head.py b/mmsegmentation/mmseg/models/decode_heads/point_head.py index 5e60527..0e14b62 100644 --- a/mmsegmentation/mmseg/models/decode_heads/point_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/point_head.py @@ -5,16 +5,16 @@ import torch.nn as nn from mmcv.cnn import ConvModule -try: - from mmcv.ops import point_sample -except ModuleNotFoundError: - point_sample = None - from mmseg.models.builder import HEADS from mmseg.ops import resize from ..losses import accuracy from .cascade_decode_head import BaseCascadeDecodeHead +try: + from mmcv.ops import point_sample +except ModuleNotFoundError: + point_sample = None + def calculate_uncertainty(seg_logits): """Estimate uncertainty based on seg logits. @@ -64,24 +64,25 @@ class PointHead(BaseCascadeDecodeHead): loss_weight=1.0). """ - def __init__(self, - num_fcs=3, - coarse_pred_each_layer=True, - conv_cfg=dict(type='Conv1d'), - norm_cfg=None, - act_cfg=dict(type='ReLU', inplace=False), - **kwargs): - super(PointHead, self).__init__( - input_transform='multiple_select', + def __init__( + self, + num_fcs=3, + coarse_pred_each_layer=True, + conv_cfg=dict(type="Conv1d"), + norm_cfg=None, + act_cfg=dict(type="ReLU", inplace=False), + **kwargs, + ): + super().__init__( + input_transform="multiple_select", conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, - init_cfg=dict( - type='Normal', std=0.01, override=dict(name='fc_seg')), - **kwargs) + init_cfg=dict(type="Normal", std=0.01, override=dict(name="fc_seg")), + **kwargs, + ) if point_sample is None: - raise RuntimeError('Please install mmcv-full for ' - 'point_sample ops') + raise RuntimeError("Please install mmcv-full for " "point_sample ops") self.num_fcs = num_fcs self.coarse_pred_each_layer = coarse_pred_each_layer @@ -98,20 +99,17 @@ def __init__(self, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) self.fcs.append(fc) fc_in_channels = fc_channels - fc_in_channels += self.num_classes if self.coarse_pred_each_layer \ - else 0 + fc_in_channels += self.num_classes if self.coarse_pred_each_layer else 0 self.fc_seg = nn.Conv1d( - fc_in_channels, - self.num_classes, - kernel_size=1, - stride=1, - padding=0) + fc_in_channels, self.num_classes, kernel_size=1, stride=1, padding=0 + ) if self.dropout_ratio > 0: self.dropout = nn.Dropout(self.dropout_ratio) - delattr(self, 'conv_seg') + delattr(self, "conv_seg") def cls_seg(self, feat): """Classify each pixel with fc.""" @@ -142,8 +140,7 @@ def _get_fine_grained_point_feats(self, x, points): """ fine_grained_feats_list = [ - point_sample(_, points, align_corners=self.align_corners) - for _ in x + point_sample(_, points, align_corners=self.align_corners) for _ in x ] if len(fine_grained_feats_list) > 1: fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1) @@ -166,12 +163,12 @@ def _get_coarse_point_feats(self, prev_output, points): """ coarse_feats = point_sample( - prev_output, points, align_corners=self.align_corners) + prev_output, points, align_corners=self.align_corners + ) return coarse_feats - def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, - train_cfg): + def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, train_cfg): """Forward function for training. Args: inputs (list[Tensor]): List of multi-level img features. @@ -191,17 +188,17 @@ def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, x = self._transform_inputs(inputs) with torch.no_grad(): points = self.get_points_train( - prev_output, calculate_uncertainty, cfg=train_cfg) - fine_grained_point_feats = self._get_fine_grained_point_feats( - x, points) + prev_output, calculate_uncertainty, cfg=train_cfg + ) + fine_grained_point_feats = self._get_fine_grained_point_feats(x, points) coarse_point_feats = self._get_coarse_point_feats(prev_output, points) - point_logits = self.forward(fine_grained_point_feats, - coarse_point_feats) + point_logits = self.forward(fine_grained_point_feats, coarse_point_feats) point_label = point_sample( gt_semantic_seg.float(), points, - mode='nearest', - align_corners=self.align_corners) + mode="nearest", + align_corners=self.align_corners, + ) point_label = point_label.squeeze(1).long() losses = self.losses(point_logits, point_label) @@ -231,25 +228,27 @@ def forward_test(self, inputs, prev_output, img_metas, test_cfg): refined_seg_logits = resize( refined_seg_logits, scale_factor=test_cfg.scale_factor, - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) batch_size, channels, height, width = refined_seg_logits.shape point_indices, points = self.get_points_test( - refined_seg_logits, calculate_uncertainty, cfg=test_cfg) - fine_grained_point_feats = self._get_fine_grained_point_feats( - x, points) - coarse_point_feats = self._get_coarse_point_feats( - prev_output, points) - point_logits = self.forward(fine_grained_point_feats, - coarse_point_feats) + refined_seg_logits, calculate_uncertainty, cfg=test_cfg + ) + fine_grained_point_feats = self._get_fine_grained_point_feats(x, points) + coarse_point_feats = self._get_coarse_point_feats(prev_output, points) + point_logits = self.forward(fine_grained_point_feats, coarse_point_feats) point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) refined_seg_logits = refined_seg_logits.reshape( - batch_size, channels, height * width) + batch_size, channels, height * width + ) refined_seg_logits = refined_seg_logits.scatter_( - 2, point_indices, point_logits) + 2, point_indices, point_logits + ) refined_seg_logits = refined_seg_logits.view( - batch_size, channels, height, width) + batch_size, channels, height, width + ) return refined_seg_logits @@ -261,11 +260,13 @@ def losses(self, point_logits, point_label): else: losses_decode = self.loss_decode for loss_module in losses_decode: - loss['point' + loss_module.loss_name] = loss_module( - point_logits, point_label, ignore_index=self.ignore_index) + loss["point" + loss_module.loss_name] = loss_module( + point_logits, point_label, ignore_index=self.ignore_index + ) - loss['acc_point'] = accuracy( - point_logits, point_label, ignore_index=self.ignore_index) + loss["acc_point"] = accuracy( + point_logits, point_label, ignore_index=self.ignore_index + ) return loss def get_points_train(self, seg_logits, uncertainty_func, cfg): @@ -294,8 +295,7 @@ def get_points_train(self, seg_logits, uncertainty_func, cfg): assert 0 <= importance_sample_ratio <= 1 batch_size = seg_logits.shape[0] num_sampled = int(num_points * oversample_ratio) - point_coords = torch.rand( - batch_size, num_sampled, 2, device=seg_logits.device) + point_coords = torch.rand(batch_size, num_sampled, 2, device=seg_logits.device) point_logits = point_sample(seg_logits, point_coords) # It is crucial to calculate uncertainty based on the sampled # prediction value for the points. Calculating uncertainties of the @@ -309,16 +309,18 @@ def get_points_train(self, seg_logits, uncertainty_func, cfg): point_uncertainties = uncertainty_func(point_logits) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points - idx = torch.topk( - point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_sampled * torch.arange( - batch_size, dtype=torch.long, device=seg_logits.device) + batch_size, dtype=torch.long, device=seg_logits.device + ) idx += shift[:, None] point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( - batch_size, num_uncertain_points, 2) + batch_size, num_uncertain_points, 2 + ) if num_random_points > 0: rand_point_coords = torch.rand( - batch_size, num_random_points, 2, device=seg_logits.device) + batch_size, num_random_points, 2, device=seg_logits.device + ) point_coords = torch.cat((point_coords, rand_point_coords), dim=1) return point_coords @@ -352,13 +354,8 @@ def get_points_test(self, seg_logits, uncertainty_func, cfg): num_points = min(height * width, num_points) point_indices = uncertainty_map.topk(num_points, dim=1)[1] point_coords = torch.zeros( - batch_size, - num_points, - 2, - dtype=torch.float, - device=seg_logits.device) - point_coords[:, :, 0] = w_step / 2.0 + (point_indices % - width).float() * w_step - point_coords[:, :, 1] = h_step / 2.0 + (point_indices // - width).float() * h_step + batch_size, num_points, 2, dtype=torch.float, device=seg_logits.device + ) + point_coords[:, :, 0] = w_step / 2.0 + (point_indices % width).float() * w_step + point_coords[:, :, 1] = h_step / 2.0 + (point_indices // width).float() * h_step return point_indices, point_coords diff --git a/mmsegmentation/mmseg/models/decode_heads/psa_head.py b/mmsegmentation/mmseg/models/decode_heads/psa_head.py index df7593c..f4390eb 100644 --- a/mmsegmentation/mmseg/models/decode_heads/psa_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/psa_head.py @@ -33,18 +33,20 @@ class PSAHead(BaseDecodeHead): psa_softmax (bool): Whether use softmax for attention. """ - def __init__(self, - mask_size, - psa_type='bi-direction', - compact=False, - shrink_factor=2, - normalization_factor=1.0, - psa_softmax=True, - **kwargs): + def __init__( + self, + mask_size, + psa_type="bi-direction", + compact=False, + shrink_factor=2, + normalization_factor=1.0, + psa_softmax=True, + **kwargs, + ): if PSAMask is None: - raise RuntimeError('Please install mmcv-full for PSAMask ops') - super(PSAHead, self).__init__(**kwargs) - assert psa_type in ['collect', 'distribute', 'bi-direction'] + raise RuntimeError("Please install mmcv-full for PSAMask ops") + super().__init__(**kwargs) + assert psa_type in ["collect", "distribute", "bi-direction"] self.psa_type = psa_type self.compact = compact self.shrink_factor = shrink_factor @@ -61,7 +63,8 @@ def __init__(self, kernel_size=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.attention = nn.Sequential( ConvModule( self.channels, @@ -69,17 +72,19 @@ def __init__(self, kernel_size=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg), - nn.Conv2d( - self.channels, mask_h * mask_w, kernel_size=1, bias=False)) - if psa_type == 'bi-direction': + act_cfg=self.act_cfg, + ), + nn.Conv2d(self.channels, mask_h * mask_w, kernel_size=1, bias=False), + ) + if psa_type == "bi-direction": self.reduce_p = ConvModule( self.in_channels, self.channels, kernel_size=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.attention_p = nn.Sequential( ConvModule( self.channels, @@ -87,21 +92,23 @@ def __init__(self, kernel_size=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg), - nn.Conv2d( - self.channels, mask_h * mask_w, kernel_size=1, bias=False)) - self.psamask_collect = PSAMask('collect', mask_size) - self.psamask_distribute = PSAMask('distribute', mask_size) + act_cfg=self.act_cfg, + ), + nn.Conv2d(self.channels, mask_h * mask_w, kernel_size=1, bias=False), + ) + self.psamask_collect = PSAMask("collect", mask_size) + self.psamask_distribute = PSAMask("distribute", mask_size) else: self.psamask = PSAMask(psa_type, mask_size) self.proj = ConvModule( - self.channels * (2 if psa_type == 'bi-direction' else 1), + self.channels * (2 if psa_type == "bi-direction" else 1), self.in_channels, kernel_size=1, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) self.bottleneck = ConvModule( self.in_channels * 2, self.channels, @@ -109,14 +116,15 @@ def __init__(self, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) identity = x align_corners = self.align_corners - if self.psa_type in ['collect', 'distribute']: + if self.psa_type in ["collect", "distribute"]: out = self.reduce(x) n, c, h, w = out.size() if self.shrink_factor != 1: @@ -129,22 +137,19 @@ def forward(self, inputs): w = w // self.shrink_factor align_corners = False out = resize( - out, - size=(h, w), - mode='bilinear', - align_corners=align_corners) + out, size=(h, w), mode="bilinear", align_corners=align_corners + ) y = self.attention(out) if self.compact: - if self.psa_type == 'collect': - y = y.view(n, h * w, - h * w).transpose(1, 2).view(n, h * w, h, w) + if self.psa_type == "collect": + y = y.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w) else: y = self.psamask(y) if self.psa_softmax: y = F.softmax(y, dim=1) - out = torch.bmm( - out.view(n, c, h * w), y.view(n, h * w, h * w)).view( - n, c, h, w) * (1.0 / self.normalization_factor) + out = torch.bmm(out.view(n, c, h * w), y.view(n, h * w, h * w)).view( + n, c, h, w + ) * (1.0 / self.normalization_factor) else: x_col = self.reduce(x) x_dis = self.reduce_p(x) @@ -159,20 +164,15 @@ def forward(self, inputs): w = w // self.shrink_factor align_corners = False x_col = resize( - x_col, - size=(h, w), - mode='bilinear', - align_corners=align_corners) + x_col, size=(h, w), mode="bilinear", align_corners=align_corners + ) x_dis = resize( - x_dis, - size=(h, w), - mode='bilinear', - align_corners=align_corners) + x_dis, size=(h, w), mode="bilinear", align_corners=align_corners + ) y_col = self.attention(x_col) y_dis = self.attention_p(x_dis) if self.compact: - y_dis = y_dis.view(n, h * w, - h * w).transpose(1, 2).view(n, h * w, h, w) + y_dis = y_dis.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w) else: y_col = self.psamask_collect(y_col) y_dis = self.psamask_distribute(y_dis) @@ -180,18 +180,16 @@ def forward(self, inputs): y_col = F.softmax(y_col, dim=1) y_dis = F.softmax(y_dis, dim=1) x_col = torch.bmm( - x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view( - n, c, h, w) * (1.0 / self.normalization_factor) + x_col.view(n, c, h * w), y_col.view(n, h * w, h * w) + ).view(n, c, h, w) * (1.0 / self.normalization_factor) x_dis = torch.bmm( - x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view( - n, c, h, w) * (1.0 / self.normalization_factor) + x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w) + ).view(n, c, h, w) * (1.0 / self.normalization_factor) out = torch.cat([x_col, x_dis], 1) out = self.proj(out) out = resize( - out, - size=identity.shape[2:], - mode='bilinear', - align_corners=align_corners) + out, size=identity.shape[2:], mode="bilinear", align_corners=align_corners + ) out = self.bottleneck(torch.cat((identity, out), dim=1)) out = self.cls_seg(out) return out diff --git a/mmsegmentation/mmseg/models/decode_heads/psp_head.py b/mmsegmentation/mmseg/models/decode_heads/psp_head.py index 6990676..20898b5 100644 --- a/mmsegmentation/mmseg/models/decode_heads/psp_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/psp_head.py @@ -22,9 +22,18 @@ class PPM(nn.ModuleList): align_corners (bool): align_corners argument of F.interpolate. """ - def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, - act_cfg, align_corners, **kwargs): - super(PPM, self).__init__() + def __init__( + self, + pool_scales, + in_channels, + channels, + conv_cfg, + norm_cfg, + act_cfg, + align_corners, + **kwargs, + ): + super().__init__() self.pool_scales = pool_scales self.align_corners = align_corners self.in_channels = in_channels @@ -43,7 +52,10 @@ def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, - **kwargs))) + **kwargs, + ), + ) + ) def forward(self, x): """Forward function.""" @@ -53,8 +65,9 @@ def forward(self, x): upsampled_ppm_out = resize( ppm_out, size=x.size()[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) ppm_outs.append(upsampled_ppm_out) return ppm_outs @@ -72,7 +85,7 @@ class PSPHead(BaseDecodeHead): """ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): - super(PSPHead, self).__init__(**kwargs) + super().__init__(**kwargs) assert isinstance(pool_scales, (list, tuple)) self.pool_scales = pool_scales self.psp_modules = PPM( @@ -82,7 +95,8 @@ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, - align_corners=self.align_corners) + align_corners=self.align_corners, + ) self.bottleneck = ConvModule( self.in_channels + len(pool_scales) * self.channels, self.channels, @@ -90,7 +104,8 @@ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) def _forward_feature(self, inputs): """Forward function for feature maps before classifying each pixel with diff --git a/mmsegmentation/mmseg/models/decode_heads/segformer_head.py b/mmsegmentation/mmseg/models/decode_heads/segformer_head.py index d6e172e..1c2b53d 100644 --- a/mmsegmentation/mmseg/models/decode_heads/segformer_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/segformer_head.py @@ -98,8 +98,8 @@ class SegformerHead(BaseDecodeHead): Default: 'bilinear'. """ - def __init__(self, interpolate_mode='bilinear', **kwargs): - super().__init__(input_transform='multiple_select', **kwargs) + def __init__(self, interpolate_mode="bilinear", **kwargs): + super().__init__(input_transform="multiple_select", **kwargs) self.interpolate_mode = interpolate_mode num_inputs = len(self.in_channels) @@ -115,13 +115,16 @@ def __init__(self, interpolate_mode='bilinear', **kwargs): kernel_size=1, stride=1, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg)) + act_cfg=self.act_cfg, + ) + ) self.fusion_conv = ConvModule( in_channels=self.channels * num_inputs, out_channels=self.channels, kernel_size=1, - norm_cfg=self.norm_cfg) + norm_cfg=self.norm_cfg, + ) def forward(self, inputs): # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 @@ -135,7 +138,9 @@ def forward(self, inputs): input=conv(x), size=inputs[0].shape[2:], mode=self.interpolate_mode, - align_corners=self.align_corners)) + align_corners=self.align_corners, + ) + ) out = self.fusion_conv(torch.cat(outs, dim=1)) diff --git a/mmsegmentation/mmseg/models/decode_heads/segmenter_mask_head.py b/mmsegmentation/mmseg/models/decode_heads/segmenter_mask_head.py index 6a9b3d4..d330a32 100644 --- a/mmsegmentation/mmseg/models/decode_heads/segmenter_mask_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/segmenter_mask_head.py @@ -3,8 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import build_norm_layer -from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_, - trunc_normal_init) +from mmcv.cnn.utils.weight_init import constant_init, trunc_normal_, trunc_normal_init from mmcv.runner import ModuleList from mmseg.models.backbones.vit import TransformerEncoderLayer @@ -45,24 +44,23 @@ class SegmenterMaskTransformerHead(BaseDecodeHead): """ def __init__( - self, - in_channels, - num_layers, - num_heads, - embed_dims, - mlp_ratio=4, - drop_path_rate=0.1, - drop_rate=0.0, - attn_drop_rate=0.0, - num_fcs=2, - qkv_bias=True, - act_cfg=dict(type='GELU'), - norm_cfg=dict(type='LN'), - init_std=0.02, - **kwargs, + self, + in_channels, + num_layers, + num_heads, + embed_dims, + mlp_ratio=4, + drop_path_rate=0.1, + drop_rate=0.0, + attn_drop_rate=0.0, + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type="GELU"), + norm_cfg=dict(type="LN"), + init_std=0.02, + **kwargs, ): - super(SegmenterMaskTransformerHead, self).__init__( - in_channels=in_channels, **kwargs) + super().__init__(in_channels=in_channels, **kwargs) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] self.layers = ModuleList() @@ -80,23 +78,21 @@ def __init__( act_cfg=act_cfg, norm_cfg=norm_cfg, batch_first=True, - )) + ) + ) self.dec_proj = nn.Linear(in_channels, embed_dims) - self.cls_emb = nn.Parameter( - torch.randn(1, self.num_classes, embed_dims)) + self.cls_emb = nn.Parameter(torch.randn(1, self.num_classes, embed_dims)) self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False) self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False) - self.decoder_norm = build_norm_layer( - norm_cfg, embed_dims, postfix=1)[1] - self.mask_norm = build_norm_layer( - norm_cfg, self.num_classes, postfix=2)[1] + self.decoder_norm = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.mask_norm = build_norm_layer(norm_cfg, self.num_classes, postfix=2)[1] self.init_std = init_std - delattr(self, 'conv_seg') + delattr(self, "conv_seg") def init_weights(self): trunc_normal_(self.cls_emb, std=self.init_std) @@ -120,8 +116,8 @@ def forward(self, inputs): x = layer(x) x = self.decoder_norm(x) - patches = self.patch_proj(x[:, :-self.num_classes]) - cls_seg_feat = self.classes_proj(x[:, -self.num_classes:]) + patches = self.patch_proj(x[:, : -self.num_classes]) + cls_seg_feat = self.classes_proj(x[:, -self.num_classes :]) patches = F.normalize(patches, dim=2, p=2) cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2) diff --git a/mmsegmentation/mmseg/models/decode_heads/sep_aspp_head.py b/mmsegmentation/mmseg/models/decode_heads/sep_aspp_head.py index 4e894e2..c13a66c 100644 --- a/mmsegmentation/mmseg/models/decode_heads/sep_aspp_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/sep_aspp_head.py @@ -13,7 +13,7 @@ class DepthwiseSeparableASPPModule(ASPPModule): conv.""" def __init__(self, **kwargs): - super(DepthwiseSeparableASPPModule, self).__init__(**kwargs) + super().__init__(**kwargs) for i, dilation in enumerate(self.dilations): if dilation > 1: self[i] = DepthwiseSeparableConvModule( @@ -23,7 +23,8 @@ def __init__(self, **kwargs): dilation=dilation, padding=dilation, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) @HEADS.register_module() @@ -41,7 +42,7 @@ class DepthwiseSeparableASPPHead(ASPPHead): """ def __init__(self, c1_in_channels, c1_channels, **kwargs): - super(DepthwiseSeparableASPPHead, self).__init__(**kwargs) + super().__init__(**kwargs) assert c1_in_channels >= 0 self.aspp_modules = DepthwiseSeparableASPPModule( dilations=self.dilations, @@ -49,7 +50,8 @@ def __init__(self, c1_in_channels, c1_channels, **kwargs): channels=self.channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) if c1_in_channels > 0: self.c1_bottleneck = ConvModule( c1_in_channels, @@ -57,7 +59,8 @@ def __init__(self, c1_in_channels, c1_channels, **kwargs): 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) else: self.c1_bottleneck = None self.sep_bottleneck = nn.Sequential( @@ -67,14 +70,17 @@ def __init__(self, c1_in_channels, c1_channels, **kwargs): 3, padding=1, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg), + act_cfg=self.act_cfg, + ), DepthwiseSeparableConvModule( self.channels, self.channels, 3, padding=1, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg)) + act_cfg=self.act_cfg, + ), + ) def forward(self, inputs): """Forward function.""" @@ -83,8 +89,9 @@ def forward(self, inputs): resize( self.image_pool(x), size=x.size()[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) ] aspp_outs.extend(self.aspp_modules(x)) aspp_outs = torch.cat(aspp_outs, dim=1) @@ -94,8 +101,9 @@ def forward(self, inputs): output = resize( input=output, size=c1_output.shape[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) output = torch.cat([output, c1_output], dim=1) output = self.sep_bottleneck(output) output = self.cls_seg(output) diff --git a/mmsegmentation/mmseg/models/decode_heads/sep_fcn_head.py b/mmsegmentation/mmseg/models/decode_heads/sep_fcn_head.py index 7f9658e..ffa5718 100644 --- a/mmsegmentation/mmseg/models/decode_heads/sep_fcn_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/sep_fcn_head.py @@ -32,14 +32,15 @@ class DepthwiseSeparableFCNHead(FCNHead): """ def __init__(self, dw_act_cfg=None, **kwargs): - super(DepthwiseSeparableFCNHead, self).__init__(**kwargs) + super().__init__(**kwargs) self.convs[0] = DepthwiseSeparableConvModule( self.in_channels, self.channels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, norm_cfg=self.norm_cfg, - dw_act_cfg=dw_act_cfg) + dw_act_cfg=dw_act_cfg, + ) for i in range(1, self.num_convs): self.convs[i] = DepthwiseSeparableConvModule( @@ -48,7 +49,8 @@ def __init__(self, dw_act_cfg=None, **kwargs): kernel_size=self.kernel_size, padding=self.kernel_size // 2, norm_cfg=self.norm_cfg, - dw_act_cfg=dw_act_cfg) + dw_act_cfg=dw_act_cfg, + ) if self.concat_input: self.conv_cat = DepthwiseSeparableConvModule( @@ -57,4 +59,5 @@ def __init__(self, dw_act_cfg=None, **kwargs): kernel_size=self.kernel_size, padding=self.kernel_size // 2, norm_cfg=self.norm_cfg, - dw_act_cfg=dw_act_cfg) + dw_act_cfg=dw_act_cfg, + ) diff --git a/mmsegmentation/mmseg/models/decode_heads/setr_mla_head.py b/mmsegmentation/mmseg/models/decode_heads/setr_mla_head.py index 6bb94ae..c7fed60 100644 --- a/mmsegmentation/mmseg/models/decode_heads/setr_mla_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/setr_mla_head.py @@ -21,8 +21,7 @@ class SETRMLAHead(BaseDecodeHead): """ def __init__(self, mla_channels=128, up_scale=4, **kwargs): - super(SETRMLAHead, self).__init__( - input_transform='multiple_select', **kwargs) + super().__init__(input_transform="multiple_select", **kwargs) self.mla_channels = mla_channels num_inputs = len(self.in_channels) @@ -40,18 +39,23 @@ def __init__(self, mla_channels=128, up_scale=4, **kwargs): kernel_size=3, padding=1, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg), + act_cfg=self.act_cfg, + ), ConvModule( in_channels=mla_channels, out_channels=mla_channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg), + act_cfg=self.act_cfg, + ), Upsample( scale_factor=up_scale, - mode='bilinear', - align_corners=self.align_corners))) + mode="bilinear", + align_corners=self.align_corners, + ), + ) + ) def forward(self, inputs): inputs = self._transform_inputs(inputs) diff --git a/mmsegmentation/mmseg/models/decode_heads/setr_up_head.py b/mmsegmentation/mmseg/models/decode_heads/setr_up_head.py index 87e7ea7..8cac5c2 100644 --- a/mmsegmentation/mmseg/models/decode_heads/setr_up_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/setr_up_head.py @@ -25,23 +25,21 @@ class SETRUPHead(BaseDecodeHead): type='Constant', val=1.0, bias=0, layer='LayerNorm'). """ - def __init__(self, - norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), - num_convs=1, - up_scale=4, - kernel_size=3, - init_cfg=[ - dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'), - dict( - type='Normal', - std=0.01, - override=dict(name='conv_seg')) - ], - **kwargs): + def __init__( + self, + norm_layer=dict(type="LN", eps=1e-6, requires_grad=True), + num_convs=1, + up_scale=4, + kernel_size=3, + init_cfg=[ + dict(type="Constant", val=1.0, bias=0, layer="LayerNorm"), + dict(type="Normal", std=0.01, override=dict(name="conv_seg")), + ], + **kwargs, + ): + assert kernel_size in [1, 3], "kernel_size must be 1 or 3." - assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.' - - super(SETRUPHead, self).__init__(init_cfg=init_cfg, **kwargs) + super().__init__(init_cfg=init_cfg, **kwargs) assert isinstance(self.in_channels, int) @@ -60,11 +58,15 @@ def __init__(self, stride=1, padding=int(kernel_size - 1) // 2, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg), + act_cfg=self.act_cfg, + ), Upsample( scale_factor=up_scale, - mode='bilinear', - align_corners=self.align_corners))) + mode="bilinear", + align_corners=self.align_corners, + ), + ) + ) in_channels = out_channels def forward(self, x): diff --git a/mmsegmentation/mmseg/models/decode_heads/stdc_head.py b/mmsegmentation/mmseg/models/decode_heads/stdc_head.py index bddf1eb..995e072 100644 --- a/mmsegmentation/mmseg/models/decode_heads/stdc_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/stdc_head.py @@ -17,19 +17,24 @@ class STDCHead(FCNHead): """ def __init__(self, boundary_threshold=0.1, **kwargs): - super(STDCHead, self).__init__(**kwargs) + super().__init__(**kwargs) self.boundary_threshold = boundary_threshold # Using register buffer to make laplacian kernel on the same # device of `seg_label`. self.register_buffer( - 'laplacian_kernel', - torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1], - dtype=torch.float32, - requires_grad=False).reshape((1, 1, 3, 3))) + "laplacian_kernel", + torch.tensor( + [-1, -1, -1, -1, 8, -1, -1, -1, -1], + dtype=torch.float32, + requires_grad=False, + ).reshape((1, 1, 3, 3)), + ) self.fusion_kernel = torch.nn.Parameter( - torch.tensor([[6. / 10], [3. / 10], [1. / 10]], - dtype=torch.float32).reshape(1, 3, 1, 1), - requires_grad=False) + torch.tensor( + [[6.0 / 10], [3.0 / 10], [1.0 / 10]], dtype=torch.float32 + ).reshape(1, 3, 1, 1), + requires_grad=False, + ) def losses(self, seg_logit, seg_label): """Compute Detail Aggregation Loss.""" @@ -38,48 +43,47 @@ def losses(self, seg_logit, seg_label): # codebase because it would not be added into computation graph # after threshold operation. seg_label = seg_label.to(self.laplacian_kernel) - boundary_targets = F.conv2d( - seg_label, self.laplacian_kernel, padding=1) + boundary_targets = F.conv2d(seg_label, self.laplacian_kernel, padding=1) boundary_targets = boundary_targets.clamp(min=0) boundary_targets[boundary_targets > self.boundary_threshold] = 1 boundary_targets[boundary_targets <= self.boundary_threshold] = 0 boundary_targets_x2 = F.conv2d( - seg_label, self.laplacian_kernel, stride=2, padding=1) + seg_label, self.laplacian_kernel, stride=2, padding=1 + ) boundary_targets_x2 = boundary_targets_x2.clamp(min=0) boundary_targets_x4 = F.conv2d( - seg_label, self.laplacian_kernel, stride=4, padding=1) + seg_label, self.laplacian_kernel, stride=4, padding=1 + ) boundary_targets_x4 = boundary_targets_x4.clamp(min=0) boundary_targets_x4_up = F.interpolate( - boundary_targets_x4, boundary_targets.shape[2:], mode='nearest') + boundary_targets_x4, boundary_targets.shape[2:], mode="nearest" + ) boundary_targets_x2_up = F.interpolate( - boundary_targets_x2, boundary_targets.shape[2:], mode='nearest') + boundary_targets_x2, boundary_targets.shape[2:], mode="nearest" + ) - boundary_targets_x2_up[ - boundary_targets_x2_up > self.boundary_threshold] = 1 - boundary_targets_x2_up[ - boundary_targets_x2_up <= self.boundary_threshold] = 0 + boundary_targets_x2_up[boundary_targets_x2_up > self.boundary_threshold] = 1 + boundary_targets_x2_up[boundary_targets_x2_up <= self.boundary_threshold] = 0 - boundary_targets_x4_up[ - boundary_targets_x4_up > self.boundary_threshold] = 1 - boundary_targets_x4_up[ - boundary_targets_x4_up <= self.boundary_threshold] = 0 + boundary_targets_x4_up[boundary_targets_x4_up > self.boundary_threshold] = 1 + boundary_targets_x4_up[boundary_targets_x4_up <= self.boundary_threshold] = 0 boundary_targets_pyramids = torch.stack( - (boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), - dim=1) + (boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), dim=1 + ) boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2) - boundary_targets_pyramid = F.conv2d(boundary_targets_pyramids, - self.fusion_kernel) + boundary_targets_pyramid = F.conv2d( + boundary_targets_pyramids, self.fusion_kernel + ) + boundary_targets_pyramid[boundary_targets_pyramid > self.boundary_threshold] = 1 boundary_targets_pyramid[ - boundary_targets_pyramid > self.boundary_threshold] = 1 - boundary_targets_pyramid[ - boundary_targets_pyramid <= self.boundary_threshold] = 0 + boundary_targets_pyramid <= self.boundary_threshold + ] = 0 - loss = super(STDCHead, self).losses(seg_logit, - boundary_targets_pyramid.long()) + loss = super().losses(seg_logit, boundary_targets_pyramid.long()) return loss diff --git a/mmsegmentation/mmseg/models/decode_heads/uper_head.py b/mmsegmentation/mmseg/models/decode_heads/uper_head.py index 06b152a..ee21a2c 100644 --- a/mmsegmentation/mmseg/models/decode_heads/uper_head.py +++ b/mmsegmentation/mmseg/models/decode_heads/uper_head.py @@ -22,8 +22,7 @@ class UPerHead(BaseDecodeHead): """ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): - super(UPerHead, self).__init__( - input_transform='multiple_select', **kwargs) + super().__init__(input_transform="multiple_select", **kwargs) # PSP Module self.psp_modules = PPM( pool_scales, @@ -32,7 +31,8 @@ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, - align_corners=self.align_corners) + align_corners=self.align_corners, + ) self.bottleneck = ConvModule( self.in_channels[-1] + len(pool_scales) * self.channels, self.channels, @@ -40,7 +40,8 @@ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) # FPN Module self.lateral_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList() @@ -52,7 +53,8 @@ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, - inplace=False) + inplace=False, + ) fpn_conv = ConvModule( self.channels, self.channels, @@ -61,7 +63,8 @@ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, - inplace=False) + inplace=False, + ) self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) @@ -72,7 +75,8 @@ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + ) def psp_forward(self, inputs): """Forward function of PSP module.""" @@ -99,8 +103,7 @@ def _forward_feature(self, inputs): # build laterals laterals = [ - lateral_conv(inputs[i]) - for i, lateral_conv in enumerate(self.lateral_convs) + lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs) ] laterals.append(self.psp_forward(inputs)) @@ -112,13 +115,13 @@ def _forward_feature(self, inputs): laterals[i - 1] = laterals[i - 1] + resize( laterals[i], size=prev_shape, - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) # build outputs fpn_outs = [ - self.fpn_convs[i](laterals[i]) - for i in range(used_backbone_levels - 1) + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1) ] # append psp feature fpn_outs.append(laterals[-1]) @@ -127,8 +130,9 @@ def _forward_feature(self, inputs): fpn_outs[i] = resize( fpn_outs[i], size=fpn_outs[0].shape[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) fpn_outs = torch.cat(fpn_outs, dim=1) feats = self.fpn_bottleneck(fpn_outs) return feats diff --git a/mmsegmentation/mmseg/models/losses/__init__.py b/mmsegmentation/mmseg/models/losses/__init__.py index d7e0197..84bfbff 100644 --- a/mmsegmentation/mmseg/models/losses/__init__.py +++ b/mmsegmentation/mmseg/models/losses/__init__.py @@ -1,7 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .accuracy import Accuracy, accuracy -from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, - cross_entropy, mask_cross_entropy) +from .cross_entropy_loss import ( + CrossEntropyLoss, + binary_cross_entropy, + cross_entropy, + mask_cross_entropy, +) from .dice_loss import DiceLoss from .focal_loss import FocalLoss from .lovasz_loss import LovaszLoss @@ -9,8 +13,17 @@ from .utils import reduce_loss, weight_reduce_loss, weighted_loss __all__ = [ - 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', - 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', - 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', - 'FocalLoss', 'TverskyLoss' + "accuracy", + "Accuracy", + "cross_entropy", + "binary_cross_entropy", + "mask_cross_entropy", + "CrossEntropyLoss", + "reduce_loss", + "weight_reduce_loss", + "weighted_loss", + "LovaszLoss", + "DiceLoss", + "FocalLoss", + "TverskyLoss", ] diff --git a/mmsegmentation/mmseg/models/losses/accuracy.py b/mmsegmentation/mmseg/models/losses/accuracy.py index 1d9e2d7..9900ce9 100644 --- a/mmsegmentation/mmseg/models/losses/accuracy.py +++ b/mmsegmentation/mmseg/models/losses/accuracy.py @@ -25,19 +25,18 @@ def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): """ assert isinstance(topk, (int, tuple)) if isinstance(topk, int): - topk = (topk, ) + topk = (topk,) return_single = True else: return_single = False maxk = max(topk) if pred.size(0) == 0: - accu = [pred.new_tensor(0.) for i in range(len(topk))] + accu = [pred.new_tensor(0.0) for i in range(len(topk))] return accu[0] if return_single else accu assert pred.ndim == target.ndim + 1 assert pred.size(0) == target.size(0) - assert maxk <= pred.size(1), \ - f'maxk {maxk} exceeds pred dimension {pred.size(1)}' + assert maxk <= pred.size(1), f"maxk {maxk} exceeds pred dimension {pred.size(1)}" pred_value, pred_label = pred.topk(maxk, dim=1) # transpose to shape (maxk, N, ...) pred_label = pred_label.transpose(0, 1) @@ -64,7 +63,7 @@ def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): class Accuracy(nn.Module): """Accuracy calculation module.""" - def __init__(self, topk=(1, ), thresh=None, ignore_index=None): + def __init__(self, topk=(1,), thresh=None, ignore_index=None): """Module to calculate the accuracy. Args: @@ -88,5 +87,4 @@ def forward(self, pred, target): Returns: tuple[float]: The accuracies under different topk criterions. """ - return accuracy(pred, target, self.topk, self.thresh, - self.ignore_index) + return accuracy(pred, target, self.topk, self.thresh, self.ignore_index) diff --git a/mmsegmentation/mmseg/models/losses/cross_entropy_loss.py b/mmsegmentation/mmseg/models/losses/cross_entropy_loss.py index fe7b4a2..37441b2 100644 --- a/mmsegmentation/mmseg/models/losses/cross_entropy_loss.py +++ b/mmsegmentation/mmseg/models/losses/cross_entropy_loss.py @@ -9,14 +9,16 @@ from .utils import get_class_weight, weight_reduce_loss -def cross_entropy(pred, - label, - weight=None, - class_weight=None, - reduction='mean', - avg_factor=None, - ignore_index=-100, - avg_non_ignore=False): +def cross_entropy( + pred, + label, + weight=None, + class_weight=None, + reduction="mean", + avg_factor=None, + ignore_index=-100, + avg_non_ignore=False, +): """cross_entropy. The wrapper function for :func:`F.cross_entropy` Args: @@ -43,22 +45,20 @@ def cross_entropy(pred, # class_weight is a manual rescaling weight given to each class. # If given, has to be a Tensor of size C element-wise losses loss = F.cross_entropy( - pred, - label, - weight=class_weight, - reduction='none', - ignore_index=ignore_index) + pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index + ) # apply weights and do the reduction # average loss over non-ignored elements # pytorch's official cross_entropy average loss over non-ignored elements # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa - if (avg_factor is None) and avg_non_ignore and reduction == 'mean': + if (avg_factor is None) and avg_non_ignore and reduction == "mean": avg_factor = label.numel() - (label == ignore_index).sum().item() if weight is not None: weight = weight.float() loss = weight_reduce_loss( - loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + loss, weight=weight, reduction=reduction, avg_factor=avg_factor + ) return loss @@ -86,15 +86,17 @@ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): return bin_labels, bin_label_weights, valid_mask -def binary_cross_entropy(pred, - label, - weight=None, - reduction='mean', - avg_factor=None, - class_weight=None, - ignore_index=-100, - avg_non_ignore=False, - **kwargs): +def binary_cross_entropy( + pred, + label, + weight=None, + reduction="mean", + avg_factor=None, + class_weight=None, + ignore_index=-100, + avg_non_ignore=False, + **kwargs, +): """Calculate the binary CrossEntropy loss. Args: @@ -121,19 +123,22 @@ def binary_cross_entropy(pred, # As the ignore_index often set as 255, so the # binary class label check should mask out # ignore_index - assert label[label != ignore_index].max() <= 1, \ - 'For pred with shape [N, 1, H, W], its label must have at ' \ - 'most 2 classes' + assert label[label != ignore_index].max() <= 1, ( + "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes" + ) pred = pred.squeeze(1) if pred.dim() != label.dim(): assert (pred.dim() == 2 and label.dim() == 1) or ( - pred.dim() == 4 and label.dim() == 3), \ - 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ - 'H, W], label shape [N, H, W] are supported' + pred.dim() == 4 and label.dim() == 3 + ), ( + "Only pred shape [N, C], label shape [N] or pred shape [N, C, " + "H, W], label shape [N, H, W] are supported" + ) # `weight` returned from `_expand_onehot_labels` # has been treated for valid (non-ignore) pixels label, weight, valid_mask = _expand_onehot_labels( - label, weight, pred.shape, ignore_index) + label, weight, pred.shape, ignore_index + ) else: # should mask out the ignored elements valid_mask = ((label >= 0) & (label != ignore_index)).float() @@ -142,26 +147,28 @@ def binary_cross_entropy(pred, else: weight = valid_mask # average loss over non-ignored and valid elements - if reduction == 'mean' and avg_factor is None and avg_non_ignore: + if reduction == "mean" and avg_factor is None and avg_non_ignore: avg_factor = valid_mask.sum().item() loss = F.binary_cross_entropy_with_logits( - pred, label.float(), pos_weight=class_weight, reduction='none') + pred, label.float(), pos_weight=class_weight, reduction="none" + ) # do the reduction for the weighted loss - loss = weight_reduce_loss( - loss, weight, reduction=reduction, avg_factor=avg_factor) + loss = weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor) return loss -def mask_cross_entropy(pred, - target, - label, - reduction='mean', - avg_factor=None, - class_weight=None, - ignore_index=None, - **kwargs): +def mask_cross_entropy( + pred, + target, + label, + reduction="mean", + avg_factor=None, + class_weight=None, + ignore_index=None, + **kwargs, +): """Calculate the CrossEntropy loss for masks. Args: @@ -183,14 +190,15 @@ def mask_cross_entropy(pred, Returns: torch.Tensor: The calculated loss """ - assert ignore_index is None, 'BCE loss does not support ignore_index' + assert ignore_index is None, "BCE loss does not support ignore_index" # TODO: handle these two reserved arguments - assert reduction == 'mean' and avg_factor is None + assert reduction == "mean" and avg_factor is None num_rois = pred.size()[0] inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) pred_slice = pred[inds, label].squeeze(1) return F.binary_cross_entropy_with_logits( - pred_slice, target, weight=class_weight, reduction='mean')[None] + pred_slice, target, weight=class_weight, reduction="mean" + )[None] @LOSSES.register_module() @@ -215,15 +223,17 @@ class CrossEntropyLoss(nn.Module): `New in version 0.23.0.` """ - def __init__(self, - use_sigmoid=False, - use_mask=False, - reduction='mean', - class_weight=None, - loss_weight=1.0, - loss_name='loss_ce', - avg_non_ignore=False): - super(CrossEntropyLoss, self).__init__() + def __init__( + self, + use_sigmoid=False, + use_mask=False, + reduction="mean", + class_weight=None, + loss_weight=1.0, + loss_name="loss_ce", + avg_non_ignore=False, + ): + super().__init__() assert (use_sigmoid is False) or (use_mask is False) self.use_sigmoid = use_sigmoid self.use_mask = use_mask @@ -231,12 +241,13 @@ def __init__(self, self.loss_weight = loss_weight self.class_weight = get_class_weight(class_weight) self.avg_non_ignore = avg_non_ignore - if not self.avg_non_ignore and self.reduction == 'mean': + if not self.avg_non_ignore and self.reduction == "mean": warnings.warn( - 'Default ``avg_non_ignore`` is False, if you would like to ' - 'ignore the certain label and average loss over non-ignore ' - 'labels, which is the same with PyTorch official ' - 'cross_entropy, set ``avg_non_ignore=True``.') + "Default ``avg_non_ignore`` is False, if you would like to " + "ignore the certain label and average loss over non-ignore " + "labels, which is the same with PyTorch official " + "cross_entropy, set ``avg_non_ignore=True``." + ) if self.use_sigmoid: self.cls_criterion = binary_cross_entropy @@ -248,21 +259,22 @@ def __init__(self, def extra_repr(self): """Extra repr.""" - s = f'avg_non_ignore={self.avg_non_ignore}' + s = f"avg_non_ignore={self.avg_non_ignore}" return s - def forward(self, - cls_score, - label, - weight=None, - avg_factor=None, - reduction_override=None, - ignore_index=-100, - **kwargs): + def forward( + self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=-100, + **kwargs, + ): """Forward function.""" - assert reduction_override in (None, 'none', 'mean', 'sum') - reduction = ( - reduction_override if reduction_override else self.reduction) + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction if self.class_weight is not None: class_weight = cls_score.new_tensor(self.class_weight) else: @@ -277,7 +289,8 @@ def forward(self, avg_factor=avg_factor, avg_non_ignore=self.avg_non_ignore, ignore_index=ignore_index, - **kwargs) + **kwargs, + ) return loss_cls @property diff --git a/mmsegmentation/mmseg/models/losses/dice_loss.py b/mmsegmentation/mmseg/models/losses/dice_loss.py index a294bc2..5975157 100644 --- a/mmsegmentation/mmseg/models/losses/dice_loss.py +++ b/mmsegmentation/mmseg/models/losses/dice_loss.py @@ -10,13 +10,9 @@ @weighted_loss -def dice_loss(pred, - target, - valid_mask, - smooth=1, - exponent=2, - class_weight=None, - ignore_index=255): +def dice_loss( + pred, target, valid_mask, smooth=1, exponent=2, class_weight=None, ignore_index=255 +): assert pred.shape[0] == target.shape[0] total_loss = 0 num_classes = pred.shape[1] @@ -27,7 +23,8 @@ def dice_loss(pred, target[..., i], valid_mask=valid_mask, smooth=smooth, - exponent=exponent) + exponent=exponent, + ) if class_weight is not None: dice_loss *= class_weight[i] total_loss += dice_loss @@ -71,16 +68,18 @@ class DiceLoss(nn.Module): prefix of the name. Defaults to 'loss_dice'. """ - def __init__(self, - smooth=1, - exponent=2, - reduction='mean', - class_weight=None, - loss_weight=1.0, - ignore_index=255, - loss_name='loss_dice', - **kwargs): - super(DiceLoss, self).__init__() + def __init__( + self, + smooth=1, + exponent=2, + reduction="mean", + class_weight=None, + loss_weight=1.0, + ignore_index=255, + loss_name="loss_dice", + **kwargs, + ): + super().__init__() self.smooth = smooth self.exponent = exponent self.reduction = reduction @@ -89,15 +88,9 @@ def __init__(self, self.ignore_index = ignore_index self._loss_name = loss_name - def forward(self, - pred, - target, - avg_factor=None, - reduction_override=None, - **kwargs): - assert reduction_override in (None, 'none', 'mean', 'sum') - reduction = ( - reduction_override if reduction_override else self.reduction) + def forward(self, pred, target, avg_factor=None, reduction_override=None, **kwargs): + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction if self.class_weight is not None: class_weight = pred.new_tensor(self.class_weight) else: @@ -106,8 +99,8 @@ def forward(self, pred = F.softmax(pred, dim=1) num_classes = pred.shape[1] one_hot_target = F.one_hot( - torch.clamp(target.long(), 0, num_classes - 1), - num_classes=num_classes) + torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes + ) valid_mask = (target != self.ignore_index).long() loss = self.loss_weight * dice_loss( @@ -119,7 +112,8 @@ def forward(self, smooth=self.smooth, exponent=self.exponent, class_weight=class_weight, - ignore_index=self.ignore_index) + ignore_index=self.ignore_index, + ) return loss @property diff --git a/mmsegmentation/mmseg/models/losses/focal_loss.py b/mmsegmentation/mmseg/models/losses/focal_loss.py index cd43ce5..33dd81c 100644 --- a/mmsegmentation/mmseg/models/losses/focal_loss.py +++ b/mmsegmentation/mmseg/models/losses/focal_loss.py @@ -10,16 +10,18 @@ # This method is used when cuda is not available -def py_sigmoid_focal_loss(pred, - target, - one_hot_target=None, - weight=None, - gamma=2.0, - alpha=0.5, - class_weight=None, - valid_mask=None, - reduction='mean', - avg_factor=None): +def py_sigmoid_focal_loss( + pred, + target, + one_hot_target=None, + weight=None, + gamma=2.0, + alpha=0.5, + class_weight=None, + valid_mask=None, + reduction="mean", + avg_factor=None, +): """PyTorch version of `Focal Loss `_. Args: @@ -47,11 +49,14 @@ def py_sigmoid_focal_loss(pred, pred_sigmoid = pred.sigmoid() target = target.type_as(pred) one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) - focal_weight = (alpha * target + (1 - alpha) * - (1 - target)) * one_minus_pt.pow(gamma) + focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * one_minus_pt.pow( + gamma + ) - loss = F.binary_cross_entropy_with_logits( - pred, target, reduction='none') * focal_weight + loss = ( + F.binary_cross_entropy_with_logits(pred, target, reduction="none") + * focal_weight + ) final_weight = torch.ones(1, pred.size(1)).type_as(loss) if weight is not None: if weight.shape != loss.shape and weight.size(0) == loss.size(0): @@ -68,16 +73,18 @@ def py_sigmoid_focal_loss(pred, return loss -def sigmoid_focal_loss(pred, - target, - one_hot_target, - weight=None, - gamma=2.0, - alpha=0.5, - class_weight=None, - valid_mask=None, - reduction='mean', - avg_factor=None): +def sigmoid_focal_loss( + pred, + target, + one_hot_target, + weight=None, + gamma=2.0, + alpha=0.5, + class_weight=None, + valid_mask=None, + reduction="mean", + avg_factor=None, +): r"""A wrapper of cuda version `Focal Loss `_. Args: @@ -110,14 +117,20 @@ def sigmoid_focal_loss(pred, # multiplying the loss by 2, the effect of setting alpha as 0.5 is # undone. The alpha of type list is used to regulate the loss in the # post-processing process. - loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), - gamma, 0.5, None, 'none') * 2 + loss = ( + _sigmoid_focal_loss( + pred.contiguous(), target.contiguous(), gamma, 0.5, None, "none" + ) + * 2 + ) alpha = pred.new_tensor(alpha) final_weight = final_weight * ( - alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target)) + alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target) + ) else: - loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), - gamma, alpha, None, 'none') + loss = _sigmoid_focal_loss( + pred.contiguous(), target.contiguous(), gamma, alpha, None, "none" + ) if weight is not None: if weight.shape != loss.shape and weight.size(0) == loss.size(0): # For most cases, weight is of shape (N, ), @@ -135,15 +148,16 @@ def sigmoid_focal_loss(pred, @LOSSES.register_module() class FocalLoss(nn.Module): - - def __init__(self, - use_sigmoid=True, - gamma=2.0, - alpha=0.5, - reduction='mean', - class_weight=None, - loss_weight=1.0, - loss_name='loss_focal'): + def __init__( + self, + use_sigmoid=True, + gamma=2.0, + alpha=0.5, + reduction="mean", + class_weight=None, + loss_weight=1.0, + loss_name="loss_focal", + ): """`Focal Loss `_ Args: use_sigmoid (bool, optional): Whether to the prediction is @@ -172,22 +186,26 @@ def __init__(self, loss item to be included into the backward graph, `loss_` must be the prefix of the name. Defaults to 'loss_focal'. """ - super(FocalLoss, self).__init__() - assert use_sigmoid is True, \ - 'AssertionError: Only sigmoid focal loss supported now.' - assert reduction in ('none', 'mean', 'sum'), \ - "AssertionError: reduction should be 'none', 'mean' or " \ - "'sum'" - assert isinstance(alpha, (float, list)), \ - 'AssertionError: alpha should be of type float' - assert isinstance(gamma, float), \ - 'AssertionError: gamma should be of type float' - assert isinstance(loss_weight, float), \ - 'AssertionError: loss_weight should be of type float' - assert isinstance(loss_name, str), \ - 'AssertionError: loss_name should be of type str' - assert isinstance(class_weight, list) or class_weight is None, \ - 'AssertionError: class_weight must be None or of type list' + super().__init__() + assert ( + use_sigmoid is True + ), "AssertionError: Only sigmoid focal loss supported now." + assert reduction in ("none", "mean", "sum"), ( + "AssertionError: reduction should be 'none', 'mean' or " "'sum'" + ) + assert isinstance( + alpha, (float, list) + ), "AssertionError: alpha should be of type float" + assert isinstance(gamma, float), "AssertionError: gamma should be of type float" + assert isinstance( + loss_weight, float + ), "AssertionError: loss_weight should be of type float" + assert isinstance( + loss_name, str + ), "AssertionError: loss_name should be of type str" + assert ( + isinstance(class_weight, list) or class_weight is None + ), "AssertionError: class_weight must be None or of type list" self.use_sigmoid = use_sigmoid self.gamma = gamma self.alpha = alpha @@ -196,14 +214,16 @@ def __init__(self, self.loss_weight = loss_weight self._loss_name = loss_name - def forward(self, - pred, - target, - weight=None, - avg_factor=None, - reduction_override=None, - ignore_index=255, - **kwargs): + def forward( + self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=255, + **kwargs, + ): """Forward function. Args: @@ -228,15 +248,13 @@ def forward(self, Returns: torch.Tensor: The calculated loss """ - assert isinstance(ignore_index, int), \ - 'ignore_index must be of type int' - assert reduction_override in (None, 'none', 'mean', 'sum'), \ - "AssertionError: reduction should be 'none', 'mean' or " \ - "'sum'" - assert pred.shape == target.shape or \ - (pred.size(0) == target.size(0) and - pred.shape[2:] == target.shape[1:]), \ - "The shape of pred doesn't match the shape of target" + assert isinstance(ignore_index, int), "ignore_index must be of type int" + assert reduction_override in (None, "none", "mean", "sum"), ( + "AssertionError: reduction should be 'none', 'mean' or " "'sum'" + ) + assert pred.shape == target.shape or ( + pred.size(0) == target.size(0) and pred.shape[2:] == target.shape[1:] + ), "The shape of pred doesn't match the shape of target" original_shape = pred.shape @@ -262,11 +280,9 @@ def forward(self, target = target.view(-1).contiguous() valid_mask = (target != ignore_index).view(-1, 1) # avoid raising error when using F.one_hot() - target = torch.where(target == ignore_index, target.new_tensor(0), - target) + target = torch.where(target == ignore_index, target.new_tensor(0), target) - reduction = ( - reduction_override if reduction_override else self.reduction) + reduction = reduction_override if reduction_override else self.reduction if self.use_sigmoid: num_classes = pred.size(1) if torch.cuda.is_available() and pred.is_cuda: @@ -282,8 +298,7 @@ def forward(self, if target.dim() == 1: target = F.one_hot(target, num_classes=num_classes) else: - valid_mask = (target.argmax(dim=1) != ignore_index).view( - -1, 1) + valid_mask = (target.argmax(dim=1) != ignore_index).view(-1, 1) calculate_loss_func = py_sigmoid_focal_loss loss_cls = self.loss_weight * calculate_loss_func( @@ -296,16 +311,17 @@ def forward(self, class_weight=self.class_weight, valid_mask=valid_mask, reduction=reduction, - avg_factor=avg_factor) + avg_factor=avg_factor, + ) - if reduction == 'none': + if reduction == "none": # [N, C] -> [C, N] loss_cls = loss_cls.transpose(0, 1) # [C, N] -> [C, B, d1, d2, ...] # original_shape: [B, C, d1, d2, ...] - loss_cls = loss_cls.reshape(original_shape[1], - original_shape[0], - *original_shape[2:]) + loss_cls = loss_cls.reshape( + original_shape[1], original_shape[0], *original_shape[2:] + ) # [C, B, d1, d2, ...] -> [B, C, d1, d2, ...] loss_cls = loss_cls.transpose(0, 1).contiguous() else: diff --git a/mmsegmentation/mmseg/models/losses/lovasz_loss.py b/mmsegmentation/mmseg/models/losses/lovasz_loss.py index 2bb0fad..2c8e930 100644 --- a/mmsegmentation/mmseg/models/losses/lovasz_loss.py +++ b/mmsegmentation/mmseg/models/losses/lovasz_loss.py @@ -21,7 +21,7 @@ def lovasz_grad(gt_sorted): gts = gt_sorted.sum() intersection = gts - gt_sorted.float().cumsum(0) union = gts + (1 - gt_sorted).float().cumsum(0) - jaccard = 1. - intersection / union + jaccard = 1.0 - intersection / union if p > 1: # cover 1-pixel case jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] return jaccard @@ -34,7 +34,7 @@ def flatten_binary_logits(logits, labels, ignore_index=None): labels = labels.view(-1) if ignore_index is None: return logits, labels - valid = (labels != ignore_index) + valid = labels != ignore_index vlogits = logits[valid] vlabels = labels[valid] return vlogits, vlabels @@ -51,7 +51,7 @@ def flatten_probs(probs, labels, ignore_index=None): labels = labels.view(-1) if ignore_index is None: return probs, labels - valid = (labels != ignore_index) + valid = labels != ignore_index vprobs = probs[valid.nonzero().squeeze()] vlabels = labels[valid] return vprobs, vlabels @@ -70,9 +70,9 @@ def lovasz_hinge_flat(logits, labels): """ if len(labels) == 0: # only void pixels, the gradients should be 0 - return logits.sum() * 0. - signs = 2. * labels.float() - 1. - errors = (1. - logits * signs) + return logits.sum() * 0.0 + signs = 2.0 * labels.float() - 1.0 + errors = 1.0 - logits * signs errors_sorted, perm = torch.sort(errors, dim=0, descending=True) perm = perm.data gt_sorted = labels[perm] @@ -81,14 +81,16 @@ def lovasz_hinge_flat(logits, labels): return loss -def lovasz_hinge(logits, - labels, - classes='present', - per_image=False, - class_weight=None, - reduction='mean', - avg_factor=None, - ignore_index=255): +def lovasz_hinge( + logits, + labels, + classes="present", + per_image=False, + class_weight=None, + reduction="mean", + avg_factor=None, + ignore_index=255, +): """Binary Lovasz hinge loss. Args: @@ -114,19 +116,20 @@ def lovasz_hinge(logits, """ if per_image: loss = [ - lovasz_hinge_flat(*flatten_binary_logits( - logit.unsqueeze(0), label.unsqueeze(0), ignore_index)) + lovasz_hinge_flat( + *flatten_binary_logits( + logit.unsqueeze(0), label.unsqueeze(0), ignore_index + ) + ) for logit, label in zip(logits, labels) ] - loss = weight_reduce_loss( - torch.stack(loss), None, reduction, avg_factor) + loss = weight_reduce_loss(torch.stack(loss), None, reduction, avg_factor) else: - loss = lovasz_hinge_flat( - *flatten_binary_logits(logits, labels, ignore_index)) + loss = lovasz_hinge_flat(*flatten_binary_logits(logits, labels, ignore_index)) return loss -def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None): +def lovasz_softmax_flat(probs, labels, classes="present", class_weight=None): """Multi-class Lovasz-Softmax loss. Args: @@ -144,17 +147,17 @@ def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None): """ if probs.numel() == 0: # only void pixels, the gradients should be 0 - return probs * 0. + return probs * 0.0 C = probs.size(1) losses = [] - class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + class_to_sum = list(range(C)) if classes in ["all", "present"] else classes for c in class_to_sum: fg = (labels == c).float() # foreground for class c - if (classes == 'present' and fg.sum() == 0): + if classes == "present" and fg.sum() == 0: continue if C == 1: if len(classes) > 1: - raise ValueError('Sigmoid output possible only with 1 class') + raise ValueError("Sigmoid output possible only with 1 class") class_pred = probs[:, 0] else: class_pred = probs[:, c] @@ -169,14 +172,16 @@ def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None): return torch.stack(losses).mean() -def lovasz_softmax(probs, - labels, - classes='present', - per_image=False, - class_weight=None, - reduction='mean', - avg_factor=None, - ignore_index=255): +def lovasz_softmax( + probs, + labels, + classes="present", + per_image=False, + class_weight=None, + reduction="mean", + avg_factor=None, + ignore_index=255, +): """Multi-class Lovasz-Softmax loss. Args: @@ -206,19 +211,19 @@ def lovasz_softmax(probs, if per_image: loss = [ lovasz_softmax_flat( - *flatten_probs( - prob.unsqueeze(0), label.unsqueeze(0), ignore_index), + *flatten_probs(prob.unsqueeze(0), label.unsqueeze(0), ignore_index), classes=classes, - class_weight=class_weight) + class_weight=class_weight, + ) for prob, label in zip(probs, labels) ] - loss = weight_reduce_loss( - torch.stack(loss), None, reduction, avg_factor) + loss = weight_reduce_loss(torch.stack(loss), None, reduction, avg_factor) else: loss = lovasz_softmax_flat( *flatten_probs(probs, labels, ignore_index), classes=classes, - class_weight=class_weight) + class_weight=class_weight, + ) return loss @@ -249,25 +254,32 @@ class LovaszLoss(nn.Module): prefix of the name. Defaults to 'loss_lovasz'. """ - def __init__(self, - loss_type='multi_class', - classes='present', - per_image=False, - reduction='mean', - class_weight=None, - loss_weight=1.0, - loss_name='loss_lovasz'): - super(LovaszLoss, self).__init__() - assert loss_type in ('binary', 'multi_class'), "loss_type should be \ + def __init__( + self, + loss_type="multi_class", + classes="present", + per_image=False, + reduction="mean", + class_weight=None, + loss_weight=1.0, + loss_name="loss_lovasz", + ): + super().__init__() + assert loss_type in ( + "binary", + "multi_class", + ), "loss_type should be \ 'binary' or 'multi_class'." - if loss_type == 'binary': + if loss_type == "binary": self.cls_criterion = lovasz_hinge else: self.cls_criterion = lovasz_softmax - assert classes in ('all', 'present') or mmcv.is_list_of(classes, int) + assert classes in ("all", "present") or mmcv.is_list_of(classes, int) if not per_image: - assert reduction == 'none', "reduction should be 'none' when \ + assert ( + reduction == "none" + ), "reduction should be 'none' when \ per_image is False." self.classes = classes @@ -277,17 +289,18 @@ def __init__(self, self.class_weight = get_class_weight(class_weight) self._loss_name = loss_name - def forward(self, - cls_score, - label, - weight=None, - avg_factor=None, - reduction_override=None, - **kwargs): + def forward( + self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs, + ): """Forward function.""" - assert reduction_override in (None, 'none', 'mean', 'sum') - reduction = ( - reduction_override if reduction_override else self.reduction) + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction if self.class_weight is not None: class_weight = cls_score.new_tensor(self.class_weight) else: @@ -305,7 +318,8 @@ def forward(self, class_weight=class_weight, reduction=reduction, avg_factor=avg_factor, - **kwargs) + **kwargs, + ) return loss_cls @property diff --git a/mmsegmentation/mmseg/models/losses/tversky_loss.py b/mmsegmentation/mmseg/models/losses/tversky_loss.py index 4ad14f7..6f01d38 100644 --- a/mmsegmentation/mmseg/models/losses/tversky_loss.py +++ b/mmsegmentation/mmseg/models/losses/tversky_loss.py @@ -11,14 +11,16 @@ @weighted_loss -def tversky_loss(pred, - target, - valid_mask, - alpha=0.3, - beta=0.7, - smooth=1, - class_weight=None, - ignore_index=255): +def tversky_loss( + pred, + target, + valid_mask, + alpha=0.3, + beta=0.7, + smooth=1, + class_weight=None, + ignore_index=255, +): assert pred.shape[0] == target.shape[0] total_loss = 0 num_classes = pred.shape[1] @@ -30,7 +32,8 @@ def tversky_loss(pred, valid_mask=valid_mask, alpha=alpha, beta=beta, - smooth=smooth) + smooth=smooth, + ) if class_weight is not None: tversky_loss *= class_weight[i] total_loss += tversky_loss @@ -38,12 +41,7 @@ def tversky_loss(pred, @weighted_loss -def binary_tversky_loss(pred, - target, - valid_mask, - alpha=0.3, - beta=0.7, - smooth=1): +def binary_tversky_loss(pred, target, valid_mask, alpha=0.3, beta=0.7, smooth=1): assert pred.shape[0] == target.shape[0] pred = pred.reshape(pred.shape[0], -1) target = target.reshape(target.shape[0], -1) @@ -80,20 +78,22 @@ class TverskyLoss(nn.Module): prefix of the name. Defaults to 'loss_tversky'. """ - def __init__(self, - smooth=1, - class_weight=None, - loss_weight=1.0, - ignore_index=255, - alpha=0.3, - beta=0.7, - loss_name='loss_tversky'): - super(TverskyLoss, self).__init__() + def __init__( + self, + smooth=1, + class_weight=None, + loss_weight=1.0, + ignore_index=255, + alpha=0.3, + beta=0.7, + loss_name="loss_tversky", + ): + super().__init__() self.smooth = smooth self.class_weight = get_class_weight(class_weight) self.loss_weight = loss_weight self.ignore_index = ignore_index - assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!' + assert alpha + beta == 1.0, "Sum of alpha and beta but be 1.0!" self.alpha = alpha self.beta = beta self._loss_name = loss_name @@ -107,8 +107,8 @@ def forward(self, pred, target, **kwargs): pred = F.softmax(pred, dim=1) num_classes = pred.shape[1] one_hot_target = F.one_hot( - torch.clamp(target.long(), 0, num_classes - 1), - num_classes=num_classes) + torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes + ) valid_mask = (target != self.ignore_index).long() loss = self.loss_weight * tversky_loss( @@ -119,7 +119,8 @@ def forward(self, pred, target, **kwargs): beta=self.beta, smooth=self.smooth, class_weight=class_weight, - ignore_index=self.ignore_index) + ignore_index=self.ignore_index, + ) return loss @property diff --git a/mmsegmentation/mmseg/models/losses/utils.py b/mmsegmentation/mmseg/models/losses/utils.py index 621f57c..3c136fb 100644 --- a/mmsegmentation/mmseg/models/losses/utils.py +++ b/mmsegmentation/mmseg/models/losses/utils.py @@ -16,7 +16,7 @@ def get_class_weight(class_weight): """ if isinstance(class_weight, str): # take it as a file path - if class_weight.endswith('.npy'): + if class_weight.endswith(".npy"): class_weight = np.load(class_weight) else: # pkl, json or yaml @@ -45,7 +45,7 @@ def reduce_loss(loss, reduction): return loss.sum() -def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): +def weight_reduce_loss(loss, weight=None, reduction="mean", avg_factor=None): """Apply element-wise weight and reduce loss. Args: @@ -69,13 +69,13 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): loss = reduce_loss(loss, reduction) else: # if reduction is mean, then average the loss by avg_factor - if reduction == 'mean': + if reduction == "mean": # Avoid causing ZeroDivisionError when avg_factor is 0.0, # i.e., all labels of an image belong to ignore index. eps = torch.finfo(torch.float32).eps loss = loss.sum() / (avg_factor + eps) # if reduction is 'none', then do nothing, otherwise raise an error - elif reduction != 'none': + elif reduction != "none": raise ValueError('avg_factor can not be used with reduction="sum"') return loss @@ -112,12 +112,7 @@ def weighted_loss(loss_func): """ @functools.wraps(loss_func) - def wrapper(pred, - target, - weight=None, - reduction='mean', - avg_factor=None, - **kwargs): + def wrapper(pred, target, weight=None, reduction="mean", avg_factor=None, **kwargs): # get element-wise loss loss = loss_func(pred, target, **kwargs) loss = weight_reduce_loss(loss, weight, reduction, avg_factor) diff --git a/mmsegmentation/mmseg/models/necks/__init__.py b/mmsegmentation/mmseg/models/necks/__init__.py index ff03186..58ee82b 100644 --- a/mmsegmentation/mmseg/models/necks/__init__.py +++ b/mmsegmentation/mmseg/models/necks/__init__.py @@ -6,6 +6,4 @@ from .mla_neck import MLANeck from .multilevel_neck import MultiLevelNeck -__all__ = [ - 'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU', 'Feature2Pyramid' -] +__all__ = ["FPN", "MultiLevelNeck", "MLANeck", "ICNeck", "JPU", "Feature2Pyramid"] diff --git a/mmsegmentation/mmseg/models/necks/featurepyramid.py b/mmsegmentation/mmseg/models/necks/featurepyramid.py index 82a00ce..b82a27c 100644 --- a/mmsegmentation/mmseg/models/necks/featurepyramid.py +++ b/mmsegmentation/mmseg/models/necks/featurepyramid.py @@ -19,27 +19,27 @@ class Feature2Pyramid(nn.Module): Default: dict(type='SyncBN', requires_grad=True). """ - def __init__(self, - embed_dim, - rescales=[4, 2, 1, 0.5], - norm_cfg=dict(type='SyncBN', requires_grad=True)): - super(Feature2Pyramid, self).__init__() + def __init__( + self, + embed_dim, + rescales=[4, 2, 1, 0.5], + norm_cfg=dict(type="SyncBN", requires_grad=True), + ): + super().__init__() self.rescales = rescales self.upsample_4x = None for k in self.rescales: if k == 4: self.upsample_4x = nn.Sequential( - nn.ConvTranspose2d( - embed_dim, embed_dim, kernel_size=2, stride=2), + nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), build_norm_layer(norm_cfg, embed_dim)[1], nn.GELU(), - nn.ConvTranspose2d( - embed_dim, embed_dim, kernel_size=2, stride=2), + nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), ) elif k == 2: self.upsample_2x = nn.Sequential( - nn.ConvTranspose2d( - embed_dim, embed_dim, kernel_size=2, stride=2)) + nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2) + ) elif k == 1: self.identity = nn.Identity() elif k == 0.5: @@ -47,20 +47,24 @@ def __init__(self, elif k == 0.25: self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4) else: - raise KeyError(f'invalid {k} for feature2pyramid') + raise KeyError(f"invalid {k} for feature2pyramid") def forward(self, inputs): assert len(inputs) == len(self.rescales) outputs = [] if self.upsample_4x is not None: ops = [ - self.upsample_4x, self.upsample_2x, self.identity, - self.downsample_2x + self.upsample_4x, + self.upsample_2x, + self.identity, + self.downsample_2x, ] else: ops = [ - self.upsample_2x, self.identity, self.downsample_2x, - self.downsample_4x + self.upsample_2x, + self.identity, + self.downsample_2x, + self.downsample_4x, ] for i in range(len(inputs)): outputs.append(ops[i](inputs[i])) diff --git a/mmsegmentation/mmseg/models/necks/fpn.py b/mmsegmentation/mmseg/models/necks/fpn.py index 6997de9..ca73ab0 100644 --- a/mmsegmentation/mmseg/models/necks/fpn.py +++ b/mmsegmentation/mmseg/models/necks/fpn.py @@ -64,23 +64,24 @@ class FPN(BaseModule): outputs[3].shape = torch.Size([1, 11, 43, 43]) """ - def __init__(self, - in_channels, - out_channels, - num_outs, - start_level=0, - end_level=-1, - add_extra_convs=False, - extra_convs_on_inputs=False, - relu_before_extra_convs=False, - no_norm_on_lateral=False, - conv_cfg=None, - norm_cfg=None, - act_cfg=None, - upsample_cfg=dict(mode='nearest'), - init_cfg=dict( - type='Xavier', layer='Conv2d', distribution='uniform')): - super(FPN, self).__init__(init_cfg) + def __init__( + self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + extra_convs_on_inputs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + upsample_cfg=dict(mode="nearest"), + init_cfg=dict(type="Xavier", layer="Conv2d", distribution="uniform"), + ): + super().__init__(init_cfg) assert isinstance(in_channels, list) self.in_channels = in_channels self.out_channels = out_channels @@ -105,14 +106,14 @@ def __init__(self, assert isinstance(add_extra_convs, (str, bool)) if isinstance(add_extra_convs, str): # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' - assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + assert add_extra_convs in ("on_input", "on_lateral", "on_output") elif add_extra_convs: # True if extra_convs_on_inputs: # For compatibility with previous release # TODO: deprecate `extra_convs_on_inputs` - self.add_extra_convs = 'on_input' + self.add_extra_convs = "on_input" else: - self.add_extra_convs = 'on_output' + self.add_extra_convs = "on_output" self.lateral_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList() @@ -125,7 +126,8 @@ def __init__(self, conv_cfg=conv_cfg, norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, act_cfg=act_cfg, - inplace=False) + inplace=False, + ) fpn_conv = ConvModule( out_channels, out_channels, @@ -134,7 +136,8 @@ def __init__(self, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, - inplace=False) + inplace=False, + ) self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) @@ -143,7 +146,7 @@ def __init__(self, extra_levels = num_outs - self.backbone_end_level + self.start_level if self.add_extra_convs and extra_levels >= 1: for i in range(extra_levels): - if i == 0 and self.add_extra_convs == 'on_input': + if i == 0 and self.add_extra_convs == "on_input": in_channels = self.in_channels[self.backbone_end_level - 1] else: in_channels = out_channels @@ -156,7 +159,8 @@ def __init__(self, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, - inplace=False) + inplace=False, + ) self.fpn_convs.append(extra_fpn_conv) @auto_fp16() @@ -174,19 +178,19 @@ def forward(self, inputs): for i in range(used_backbone_levels - 1, 0, -1): # In some cases, fixing `scale factor` (e.g. 2) is preferred, but # it cannot co-exist with `size` in `F.interpolate`. - if 'scale_factor' in self.upsample_cfg: + if "scale_factor" in self.upsample_cfg: laterals[i - 1] = laterals[i - 1] + resize( - laterals[i], **self.upsample_cfg) + laterals[i], **self.upsample_cfg + ) else: prev_shape = laterals[i - 1].shape[2:] laterals[i - 1] = laterals[i - 1] + resize( - laterals[i], size=prev_shape, **self.upsample_cfg) + laterals[i], size=prev_shape, **self.upsample_cfg + ) # build outputs # part 1: from original levels - outs = [ - self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) - ] + outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)] # part 2: add extra levels if self.num_outs > len(outs): # use max pool to get more levels on top of outputs @@ -196,11 +200,11 @@ def forward(self, inputs): outs.append(F.max_pool2d(outs[-1], 1, stride=2)) # add conv layers on top of original feature maps (RetinaNet) else: - if self.add_extra_convs == 'on_input': + if self.add_extra_convs == "on_input": extra_source = inputs[self.backbone_end_level - 1] - elif self.add_extra_convs == 'on_lateral': + elif self.add_extra_convs == "on_lateral": extra_source = laterals[-1] - elif self.add_extra_convs == 'on_output': + elif self.add_extra_convs == "on_output": extra_source = outs[-1] else: raise NotImplementedError diff --git a/mmsegmentation/mmseg/models/necks/ic_neck.py b/mmsegmentation/mmseg/models/necks/ic_neck.py index a5d81ce..8ed1c55 100644 --- a/mmsegmentation/mmseg/models/necks/ic_neck.py +++ b/mmsegmentation/mmseg/models/necks/ic_neck.py @@ -33,16 +33,18 @@ class CascadeFeatureFusion(BaseModule): for Cascade Label Guidance in auxiliary heads. """ - def __init__(self, - low_channels, - high_channels, - out_channels, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - align_corners=False, - init_cfg=None): - super(CascadeFeatureFusion, self).__init__(init_cfg=init_cfg) + def __init__( + self, + low_channels, + high_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + align_corners=False, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.align_corners = align_corners self.conv_low = ConvModule( low_channels, @@ -52,21 +54,24 @@ def __init__(self, dilation=2, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) self.conv_high = ConvModule( high_channels, out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) def forward(self, x_low, x_high): x_low = resize( x_low, size=x_high.size()[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) # Note: Different from original paper, `x_low` is underwent # `self.conv_low` rather than another 1x1 conv classifier # before being used for auxiliary head. @@ -100,17 +105,21 @@ class ICNeck(BaseModule): Default: None. """ - def __init__(self, - in_channels=(64, 256, 256), - out_channels=128, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - align_corners=False, - init_cfg=None): - super(ICNeck, self).__init__(init_cfg=init_cfg) - assert len(in_channels) == 3, 'Length of input channels \ - must be 3!' + def __init__( + self, + in_channels=(64, 256, 256), + out_channels=128, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + align_corners=False, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + assert ( + len(in_channels) == 3 + ), "Length of input channels \ + must be 3!" self.in_channels = in_channels self.out_channels = out_channels @@ -125,7 +134,8 @@ def __init__(self, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, - align_corners=self.align_corners) + align_corners=self.align_corners, + ) self.cff_12 = CascadeFeatureFusion( self.out_channels, @@ -134,11 +144,14 @@ def __init__(self, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, - align_corners=self.align_corners) + align_corners=self.align_corners, + ) def forward(self, inputs): - assert len(inputs) == 3, 'Length of input feature \ - maps must be 3!' + assert ( + len(inputs) == 3 + ), "Length of input feature \ + maps must be 3!" x_sub1, x_sub2, x_sub4 = inputs x_cff_24, x_24 = self.cff_24(x_sub4, x_sub2) diff --git a/mmsegmentation/mmseg/models/necks/jpu.py b/mmsegmentation/mmseg/models/necks/jpu.py index 3cc6b9f..350cde7 100644 --- a/mmsegmentation/mmseg/models/necks/jpu.py +++ b/mmsegmentation/mmseg/models/necks/jpu.py @@ -40,18 +40,20 @@ class JPU(BaseModule): Default: None. """ - def __init__(self, - in_channels=(512, 1024, 2048), - mid_channels=512, - start_level=0, - end_level=-1, - dilations=(1, 2, 4, 8), - align_corners=False, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - init_cfg=None): - super(JPU, self).__init__(init_cfg=init_cfg) + def __init__( + self, + in_channels=(512, 1024, 2048), + mid_channels=512, + start_level=0, + end_level=-1, + dilations=(1, 2, 4, 8), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) assert isinstance(in_channels, tuple) assert isinstance(dilations, tuple) self.in_channels = in_channels @@ -78,13 +80,15 @@ def __init__(self, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ) + ) self.conv_layers.append(conv_layer) for i in range(len(dilations)): dilation_layer = nn.Sequential( DepthwiseSeparableConvModule( - in_channels=(self.backbone_end_level - self.start_level) * - self.mid_channels, + in_channels=(self.backbone_end_level - self.start_level) + * self.mid_channels, out_channels=self.mid_channels, kernel_size=3, stride=1, @@ -93,13 +97,17 @@ def __init__(self, dw_norm_cfg=norm_cfg, dw_act_cfg=None, pw_norm_cfg=norm_cfg, - pw_act_cfg=act_cfg)) + pw_act_cfg=act_cfg, + ) + ) self.dilation_layers.append(dilation_layer) def forward(self, inputs): """Forward function.""" - assert len(inputs) == len(self.in_channels), 'Length of inputs must \ - be the same with self.in_channels!' + assert len(inputs) == len( + self.in_channels + ), "Length of inputs must \ + be the same with self.in_channels!" feats = [ self.conv_layers[i - self.start_level](inputs[i]) @@ -109,16 +117,13 @@ def forward(self, inputs): h, w = feats[0].shape[2:] for i in range(1, len(feats)): feats[i] = resize( - feats[i], - size=(h, w), - mode='bilinear', - align_corners=self.align_corners) + feats[i], size=(h, w), mode="bilinear", align_corners=self.align_corners + ) feat = torch.cat(feats, dim=1) - concat_feat = torch.cat([ - self.dilation_layers[i](feat) for i in range(len(self.dilations)) - ], - dim=1) + concat_feat = torch.cat( + [self.dilation_layers[i](feat) for i in range(len(self.dilations))], dim=1 + ) outs = [] diff --git a/mmsegmentation/mmseg/models/necks/mla_neck.py b/mmsegmentation/mmseg/models/necks/mla_neck.py index 1513e29..b233648 100644 --- a/mmsegmentation/mmseg/models/necks/mla_neck.py +++ b/mmsegmentation/mmseg/models/necks/mla_neck.py @@ -6,13 +6,14 @@ class MLAModule(nn.Module): - - def __init__(self, - in_channels=[1024, 1024, 1024, 1024], - out_channels=256, - norm_cfg=None, - act_cfg=None): - super(MLAModule, self).__init__() + def __init__( + self, + in_channels=[1024, 1024, 1024, 1024], + out_channels=256, + norm_cfg=None, + act_cfg=None, + ): + super().__init__() self.channel_proj = nn.ModuleList() for i in range(len(in_channels)): self.channel_proj.append( @@ -21,7 +22,9 @@ def __init__(self, out_channels=out_channels, kernel_size=1, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ) + ) self.feat_extract = nn.ModuleList() for i in range(len(in_channels)): self.feat_extract.append( @@ -31,10 +34,11 @@ def __init__(self, kernel_size=3, padding=1, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ) + ) def forward(self, inputs): - # feat_list -> [p2, p3, p4, p5] feat_list = [] for x, conv in zip(inputs, self.channel_proj): @@ -77,29 +81,34 @@ class MLANeck(nn.Module): Default: None. """ - def __init__(self, - in_channels, - out_channels, - norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), - norm_cfg=None, - act_cfg=None): - super(MLANeck, self).__init__() + def __init__( + self, + in_channels, + out_channels, + norm_layer=dict(type="LN", eps=1e-6, requires_grad=True), + norm_cfg=None, + act_cfg=None, + ): + super().__init__() assert isinstance(in_channels, list) self.in_channels = in_channels self.out_channels = out_channels # In order to build general vision transformer backbone, we have to # move MLA to neck. - self.norm = nn.ModuleList([ - build_norm_layer(norm_layer, in_channels[i])[1] - for i in range(len(in_channels)) - ]) + self.norm = nn.ModuleList( + [ + build_norm_layer(norm_layer, in_channels[i])[1] + for i in range(len(in_channels)) + ] + ) self.mla = MLAModule( in_channels=in_channels, out_channels=out_channels, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) def forward(self, inputs): assert len(inputs) == len(self.in_channels) diff --git a/mmsegmentation/mmseg/models/necks/multilevel_neck.py b/mmsegmentation/mmseg/models/necks/multilevel_neck.py index 5151f87..26345aa 100644 --- a/mmsegmentation/mmseg/models/necks/multilevel_neck.py +++ b/mmsegmentation/mmseg/models/necks/multilevel_neck.py @@ -22,13 +22,15 @@ class MultiLevelNeck(nn.Module): Default: None. """ - def __init__(self, - in_channels, - out_channels, - scales=[0.5, 1, 2, 4], - norm_cfg=None, - act_cfg=None): - super(MultiLevelNeck, self).__init__() + def __init__( + self, + in_channels, + out_channels, + scales=[0.5, 1, 2, 4], + norm_cfg=None, + act_cfg=None, + ): + super().__init__() assert isinstance(in_channels, list) self.in_channels = in_channels self.out_channels = out_channels @@ -43,7 +45,9 @@ def __init__(self, out_channels, kernel_size=1, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ) + ) for _ in range(self.num_outs): self.convs.append( ConvModule( @@ -53,26 +57,26 @@ def __init__(self, padding=1, stride=1, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ) + ) # default init_weights for conv(msra) and norm in ConvModule def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): - xavier_init(m, distribution='uniform') + xavier_init(m, distribution="uniform") def forward(self, inputs): assert len(inputs) == len(self.in_channels) inputs = [ - lateral_conv(inputs[i]) - for i, lateral_conv in enumerate(self.lateral_convs) + lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs) ] # for len(inputs) not equal to self.num_outs if len(inputs) == 1: inputs = [inputs[0] for _ in range(self.num_outs)] outs = [] for i in range(self.num_outs): - x_resize = resize( - inputs[i], scale_factor=self.scales[i], mode='bilinear') + x_resize = resize(inputs[i], scale_factor=self.scales[i], mode="bilinear") outs.append(self.convs[i](x_resize)) return tuple(outs) diff --git a/mmsegmentation/mmseg/models/segmentors/__init__.py b/mmsegmentation/mmseg/models/segmentors/__init__.py index 387c858..659bb16 100644 --- a/mmsegmentation/mmseg/models/segmentors/__init__.py +++ b/mmsegmentation/mmseg/models/segmentors/__init__.py @@ -3,4 +3,4 @@ from .cascade_encoder_decoder import CascadeEncoderDecoder from .encoder_decoder import EncoderDecoder -__all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder'] +__all__ = ["BaseSegmentor", "EncoderDecoder", "CascadeEncoderDecoder"] diff --git a/mmsegmentation/mmseg/models/segmentors/base.py b/mmsegmentation/mmseg/models/segmentors/base.py index 76dc8f0..5033889 100644 --- a/mmsegmentation/mmseg/models/segmentors/base.py +++ b/mmsegmentation/mmseg/models/segmentors/base.py @@ -14,50 +14,44 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta): """Base class for segmentors.""" def __init__(self, init_cfg=None): - super(BaseSegmentor, self).__init__(init_cfg) + super().__init__(init_cfg) self.fp16_enabled = False @property def with_neck(self): """bool: whether the segmentor has neck""" - return hasattr(self, 'neck') and self.neck is not None + return hasattr(self, "neck") and self.neck is not None @property def with_auxiliary_head(self): """bool: whether the segmentor has auxiliary head""" - return hasattr(self, - 'auxiliary_head') and self.auxiliary_head is not None + return hasattr(self, "auxiliary_head") and self.auxiliary_head is not None @property def with_decode_head(self): """bool: whether the segmentor has decode head""" - return hasattr(self, 'decode_head') and self.decode_head is not None + return hasattr(self, "decode_head") and self.decode_head is not None @abstractmethod def extract_feat(self, imgs): """Placeholder for extract features from images.""" - pass @abstractmethod def encode_decode(self, img, img_metas): """Placeholder for encode images with backbone and decode into a semantic segmentation map of the same size as input.""" - pass @abstractmethod def forward_train(self, imgs, img_metas, **kwargs): """Placeholder for Forward function for training.""" - pass @abstractmethod def simple_test(self, img, img_meta, **kwargs): """Placeholder for single image test.""" - pass @abstractmethod def aug_test(self, imgs, img_metas, **kwargs): """Placeholder for augmentation test.""" - pass def forward_test(self, imgs, img_metas, **kwargs): """ @@ -69,23 +63,24 @@ def forward_test(self, imgs, img_metas, **kwargs): augs (multiscale, flip, etc.) and the inner list indicates images in a batch. """ - for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: + for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]: if not isinstance(var, list): - raise TypeError(f'{name} must be a list, but got ' - f'{type(var)}') + raise TypeError(f"{name} must be a list, but got " f"{type(var)}") num_augs = len(imgs) if num_augs != len(img_metas): - raise ValueError(f'num of augmentations ({len(imgs)}) != ' - f'num of image meta ({len(img_metas)})') + raise ValueError( + f"num of augmentations ({len(imgs)}) != " + f"num of image meta ({len(img_metas)})" + ) # all images in the same aug batch all of the same ori_shape and pad # shape for img_meta in img_metas: - ori_shapes = [_['ori_shape'] for _ in img_meta] + ori_shapes = [_["ori_shape"] for _ in img_meta] assert all(shape == ori_shapes[0] for shape in ori_shapes) - img_shapes = [_['img_shape'] for _ in img_meta] + img_shapes = [_["img_shape"] for _ in img_meta] assert all(shape == img_shapes[0] for shape in img_shapes) - pad_shapes = [_['pad_shape'] for _ in img_meta] + pad_shapes = [_["pad_shape"] for _ in img_meta] assert all(shape == pad_shapes[0] for shape in pad_shapes) if num_augs == 1: @@ -93,7 +88,7 @@ def forward_test(self, imgs, img_metas, **kwargs): else: return self.aug_test(imgs, img_metas, **kwargs) - @auto_fp16(apply_to=('img', )) + @auto_fp16(apply_to=("img",)) def forward(self, img, img_metas, return_loss=True, **kwargs): """Calls either :func:`forward_train` or :func:`forward_test` depending on whether ``return_loss`` is ``True``. @@ -139,9 +134,8 @@ def train_step(self, data_batch, optimizer, **kwargs): loss, log_vars = self._parse_losses(losses) outputs = dict( - loss=loss, - log_vars=log_vars, - num_samples=len(data_batch['img_metas'])) + loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]) + ) return outputs @@ -157,13 +151,12 @@ def val_step(self, data_batch, optimizer=None, **kwargs): log_vars_ = dict() for loss_name, loss_value in log_vars.items(): - k = loss_name + '_val' + k = loss_name + "_val" log_vars_[k] = loss_value outputs = dict( - loss=loss, - log_vars=log_vars_, - num_samples=len(data_batch['img_metas'])) + loss=loss, log_vars=log_vars_, num_samples=len(data_batch["img_metas"]) + ) return outputs @@ -187,24 +180,27 @@ def _parse_losses(losses): elif isinstance(loss_value, list): log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) else: - raise TypeError( - f'{loss_name} is not a tensor or list of tensors') + raise TypeError(f"{loss_name} is not a tensor or list of tensors") - loss = sum(_value for _key, _value in log_vars.items() - if 'loss' in _key) + loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key) # If the loss_vars has different length, raise assertion error # to prevent GPUs from infinite waiting. if dist.is_available() and dist.is_initialized(): log_var_length = torch.tensor(len(log_vars), device=loss.device) dist.all_reduce(log_var_length) - message = (f'rank {dist.get_rank()}' + - f' len(log_vars): {len(log_vars)}' + ' keys: ' + - ','.join(log_vars.keys()) + '\n') - assert log_var_length == len(log_vars) * dist.get_world_size(), \ - 'loss log variables are different across GPUs!\n' + message - - log_vars['loss'] = loss + message = ( + f"rank {dist.get_rank()}" + + f" len(log_vars): {len(log_vars)}" + + " keys: " + + ",".join(log_vars.keys()) + + "\n" + ) + assert log_var_length == len(log_vars) * dist.get_world_size(), ( + "loss log variables are different across GPUs!\n" + message + ) + + log_vars["loss"] = loss for loss_name, loss_value in log_vars.items(): # reduce loss when distributed training if dist.is_available() and dist.is_initialized(): @@ -214,15 +210,17 @@ def _parse_losses(losses): return loss, log_vars - def show_result(self, - img, - result, - palette=None, - win_name='', - show=False, - wait_time=0, - out_file=None, - opacity=0.5): + def show_result( + self, + img, + result, + palette=None, + win_name="", + show=False, + wait_time=0, + out_file=None, + opacity=0.5, + ): """Draw `result` over `img`. Args: @@ -258,8 +256,7 @@ def show_result(self, state = np.random.get_state() np.random.seed(42) # random palette - palette = np.random.randint( - 0, 255, size=(len(self.CLASSES), 3)) + palette = np.random.randint(0, 255, size=(len(self.CLASSES), 3)) np.random.set_state(state) else: palette = self.PALETTE @@ -286,6 +283,8 @@ def show_result(self, mmcv.imwrite(img, out_file) if not (show or out_file): - warnings.warn('show==False and out_file is not specified, only ' - 'result image will be returned') + warnings.warn( + "show==False and out_file is not specified, only " + "result image will be returned" + ) return img diff --git a/mmsegmentation/mmseg/models/segmentors/cascade_encoder_decoder.py b/mmsegmentation/mmseg/models/segmentors/cascade_encoder_decoder.py index e9a9127..ba688f4 100644 --- a/mmsegmentation/mmseg/models/segmentors/cascade_encoder_decoder.py +++ b/mmsegmentation/mmseg/models/segmentors/cascade_encoder_decoder.py @@ -17,18 +17,20 @@ class CascadeEncoderDecoder(EncoderDecoder): will be the input of next decoder_head. """ - def __init__(self, - num_stages, - backbone, - decode_head, - neck=None, - auxiliary_head=None, - train_cfg=None, - test_cfg=None, - pretrained=None, - init_cfg=None): + def __init__( + self, + num_stages, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None, + ): self.num_stages = num_stages - super(CascadeEncoderDecoder, self).__init__( + super().__init__( backbone=backbone, decode_head=decode_head, neck=neck, @@ -36,7 +38,8 @@ def __init__(self, train_cfg=train_cfg, test_cfg=test_cfg, pretrained=pretrained, - init_cfg=init_cfg) + init_cfg=init_cfg, + ) def _init_decode_head(self, decode_head): """Initialize ``decode_head``""" @@ -55,13 +58,13 @@ def encode_decode(self, img, img_metas): x = self.extract_feat(img) out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg) for i in range(1, self.num_stages): - out = self.decode_head[i].forward_test(x, out, img_metas, - self.test_cfg) + out = self.decode_head[i].forward_test(x, out, img_metas, self.test_cfg) out = resize( input=out, size=img.shape[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) return out def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): @@ -70,20 +73,24 @@ def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): losses = dict() loss_decode = self.decode_head[0].forward_train( - x, img_metas, gt_semantic_seg, self.train_cfg) + x, img_metas, gt_semantic_seg, self.train_cfg + ) - losses.update(add_prefix(loss_decode, 'decode_0')) + losses.update(add_prefix(loss_decode, "decode_0")) for i in range(1, self.num_stages): # forward test again, maybe unnecessary for most methods. if i == 1: prev_outputs = self.decode_head[0].forward_test( - x, img_metas, self.test_cfg) + x, img_metas, self.test_cfg + ) else: prev_outputs = self.decode_head[i - 1].forward_test( - x, prev_outputs, img_metas, self.test_cfg) + x, prev_outputs, img_metas, self.test_cfg + ) loss_decode = self.decode_head[i].forward_train( - x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg) - losses.update(add_prefix(loss_decode, f'decode_{i}')) + x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg + ) + losses.update(add_prefix(loss_decode, f"decode_{i}")) return losses diff --git a/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py b/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py index e0ce8df..237501d 100644 --- a/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py +++ b/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py @@ -19,19 +19,22 @@ class EncoderDecoder(BaseSegmentor): which could be dumped during inference. """ - def __init__(self, - backbone, - decode_head, - neck=None, - auxiliary_head=None, - train_cfg=None, - test_cfg=None, - pretrained=None, - init_cfg=None): - super(EncoderDecoder, self).__init__(init_cfg) + def __init__( + self, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None, + ): + super().__init__(init_cfg) if pretrained is not None: - assert backbone.get('pretrained') is None, \ - 'both backbone and segmentor set pretrained weight' + assert ( + backbone.get("pretrained") is None + ), "both backbone and segmentor set pretrained weight" backbone.pretrained = pretrained self.backbone = builder.build_backbone(backbone) if neck is not None: @@ -76,19 +79,20 @@ def encode_decode(self, img, img_metas): out = resize( input=out, size=img.shape[2:], - mode='bilinear', - align_corners=self.align_corners) + mode="bilinear", + align_corners=self.align_corners, + ) return out def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): """Run forward function and calculate loss for decode head in training.""" losses = dict() - loss_decode = self.decode_head.forward_train(x, img_metas, - gt_semantic_seg, - self.train_cfg) + loss_decode = self.decode_head.forward_train( + x, img_metas, gt_semantic_seg, self.train_cfg + ) - losses.update(add_prefix(loss_decode, 'decode')) + losses.update(add_prefix(loss_decode, "decode")) return losses def _decode_head_forward_test(self, x, img_metas): @@ -103,14 +107,15 @@ def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): losses = dict() if isinstance(self.auxiliary_head, nn.ModuleList): for idx, aux_head in enumerate(self.auxiliary_head): - loss_aux = aux_head.forward_train(x, img_metas, - gt_semantic_seg, - self.train_cfg) - losses.update(add_prefix(loss_aux, f'aux_{idx}')) + loss_aux = aux_head.forward_train( + x, img_metas, gt_semantic_seg, self.train_cfg + ) + losses.update(add_prefix(loss_aux, f"aux_{idx}")) else: loss_aux = self.auxiliary_head.forward_train( - x, img_metas, gt_semantic_seg, self.train_cfg) - losses.update(add_prefix(loss_aux, 'aux')) + x, img_metas, gt_semantic_seg, self.train_cfg + ) + losses.update(add_prefix(loss_aux, "aux")) return losses @@ -141,13 +146,11 @@ def forward_train(self, img, img_metas, gt_semantic_seg): losses = dict() - loss_decode = self._decode_head_forward_train(x, img_metas, - gt_semantic_seg) + loss_decode = self._decode_head_forward_train(x, img_metas, gt_semantic_seg) losses.update(loss_decode) if self.with_auxiliary_head: - loss_aux = self._auxiliary_head_forward_train( - x, img_metas, gt_semantic_seg) + loss_aux = self._auxiliary_head_forward_train(x, img_metas, gt_semantic_seg) losses.update(loss_aux) return losses @@ -178,27 +181,35 @@ def slide_inference(self, img, img_meta, rescale): x1 = max(x2 - w_crop, 0) crop_img = img[:, :, y1:y2, x1:x2] crop_seg_logit = self.encode_decode(crop_img, img_meta) - preds += F.pad(crop_seg_logit, - (int(x1), int(preds.shape[3] - x2), int(y1), - int(preds.shape[2] - y2))) + preds += F.pad( + crop_seg_logit, + ( + int(x1), + int(preds.shape[3] - x2), + int(y1), + int(preds.shape[2] - y2), + ), + ) count_mat[:, :, y1:y2, x1:x2] += 1 assert (count_mat == 0).sum() == 0 if torch.onnx.is_in_onnx_export(): # cast count_mat to constant while exporting to ONNX - count_mat = torch.from_numpy( - count_mat.cpu().detach().numpy()).to(device=img.device) + count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to( + device=img.device + ) preds = preds / count_mat if rescale: # remove padding area - resize_shape = img_meta[0]['img_shape'][:2] - preds = preds[:, :, :resize_shape[0], :resize_shape[1]] + resize_shape = img_meta[0]["img_shape"][:2] + preds = preds[:, :, : resize_shape[0], : resize_shape[1]] preds = resize( preds, - size=img_meta[0]['ori_shape'][:2], - mode='bilinear', + size=img_meta[0]["ori_shape"][:2], + mode="bilinear", align_corners=self.align_corners, - warning=False) + warning=False, + ) return preds def whole_inference(self, img, img_meta, rescale): @@ -211,15 +222,16 @@ def whole_inference(self, img, img_meta, rescale): size = img.shape[2:] else: # remove padding area - resize_shape = img_meta[0]['img_shape'][:2] - seg_logit = seg_logit[:, :, :resize_shape[0], :resize_shape[1]] - size = img_meta[0]['ori_shape'][:2] + resize_shape = img_meta[0]["img_shape"][:2] + seg_logit = seg_logit[:, :, : resize_shape[0], : resize_shape[1]] + size = img_meta[0]["ori_shape"][:2] seg_logit = resize( seg_logit, size=size, - mode='bilinear', + mode="bilinear", align_corners=self.align_corners, - warning=False) + warning=False, + ) return seg_logit @@ -239,10 +251,10 @@ def inference(self, img, img_meta, rescale): Tensor: The output segmentation map. """ - assert self.test_cfg.mode in ['slide', 'whole'] - ori_shape = img_meta[0]['ori_shape'] - assert all(_['ori_shape'] == ori_shape for _ in img_meta) - if self.test_cfg.mode == 'slide': + assert self.test_cfg.mode in ["slide", "whole"] + ori_shape = img_meta[0]["ori_shape"] + assert all(_["ori_shape"] == ori_shape for _ in img_meta) + if self.test_cfg.mode == "slide": seg_logit = self.slide_inference(img, img_meta, rescale) else: seg_logit = self.whole_inference(img, img_meta, rescale) @@ -250,14 +262,14 @@ def inference(self, img, img_meta, rescale): output = F.sigmoid(seg_logit) else: output = F.softmax(seg_logit, dim=1) - flip = img_meta[0]['flip'] + flip = img_meta[0]["flip"] if flip: - flip_direction = img_meta[0]['flip_direction'] - assert flip_direction in ['horizontal', 'vertical'] - if flip_direction == 'horizontal': - output = output.flip(dims=(3, )) - elif flip_direction == 'vertical': - output = output.flip(dims=(2, )) + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) return output @@ -265,8 +277,7 @@ def simple_test(self, img, img_meta, rescale=True): """Simple test with single image.""" seg_logit = self.inference(img, img_meta, rescale) if self.out_channels == 1: - seg_pred = (seg_logit > - self.decode_head.threshold).to(seg_logit).squeeze(1) + seg_pred = (seg_logit > self.decode_head.threshold).to(seg_logit).squeeze(1) else: seg_pred = seg_logit.argmax(dim=1) if torch.onnx.is_in_onnx_export(): @@ -301,8 +312,7 @@ def aug_test(self, imgs, img_metas, rescale=True): seg_logit += cur_seg_logit seg_logit /= len(imgs) if self.out_channels == 1: - seg_pred = (seg_logit > - self.decode_head.threshold).to(seg_logit).squeeze(1) + seg_pred = (seg_logit > self.decode_head.threshold).to(seg_logit).squeeze(1) else: seg_pred = seg_logit.argmax(dim=1) seg_pred = seg_pred.cpu().numpy() diff --git a/mmsegmentation/mmseg/models/utils/__init__.py b/mmsegmentation/mmseg/models/utils/__init__.py index 6d83290..feb5e30 100644 --- a/mmsegmentation/mmseg/models/utils/__init__.py +++ b/mmsegmentation/mmseg/models/utils/__init__.py @@ -5,12 +5,20 @@ from .res_layer import ResLayer from .se_layer import SELayer from .self_attention_block import SelfAttentionBlock -from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, - nlc_to_nchw) +from .shape_convert import nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, nlc_to_nchw from .up_conv_block import UpConvBlock __all__ = [ - 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', - 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', - 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc' + "ResLayer", + "SelfAttentionBlock", + "make_divisible", + "InvertedResidual", + "UpConvBlock", + "InvertedResidualV3", + "SELayer", + "PatchEmbed", + "nchw_to_nlc", + "nlc_to_nchw", + "nchw2nlc2nchw", + "nlc2nchw2nlc", ] diff --git a/mmsegmentation/mmseg/models/utils/embed.py b/mmsegmentation/mmseg/models/utils/embed.py index 1515675..10ed901 100644 --- a/mmsegmentation/mmseg/models/utils/embed.py +++ b/mmsegmentation/mmseg/models/utils/embed.py @@ -40,11 +40,10 @@ class AdaptivePadding(nn.Module): >>> assert (out.shape[2], out.shape[3]) == (16, 32) """ - def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): + def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"): + super().__init__() - super(AdaptivePadding, self).__init__() - - assert padding in ('same', 'corner') + assert padding in ("same", "corner") kernel_size = to_2tuple(kernel_size) stride = to_2tuple(stride) @@ -61,22 +60,25 @@ def get_pad_shape(self, input_shape): stride_h, stride_w = self.stride output_h = math.ceil(input_h / stride_h) output_w = math.ceil(input_w / stride_w) - pad_h = max((output_h - 1) * stride_h + - (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) - pad_w = max((output_w - 1) * stride_w + - (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + pad_h = max( + (output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h, + 0, + ) + pad_w = max( + (output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w, + 0, + ) return pad_h, pad_w def forward(self, x): pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) if pad_h > 0 or pad_w > 0: - if self.padding == 'corner': + if self.padding == "corner": x = F.pad(x, [0, pad_w, 0, pad_h]) - elif self.padding == 'same': - x = F.pad(x, [ - pad_w // 2, pad_w - pad_w // 2, pad_h // 2, - pad_h - pad_h // 2 - ]) + elif self.padding == "same": + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) return x @@ -108,19 +110,21 @@ class PatchEmbed(BaseModule): Default: None. """ - def __init__(self, - in_channels=3, - embed_dims=768, - conv_type='Conv2d', - kernel_size=16, - stride=None, - padding='corner', - dilation=1, - bias=True, - norm_cfg=None, - input_size=None, - init_cfg=None): - super(PatchEmbed, self).__init__(init_cfg=init_cfg) + def __init__( + self, + in_channels=3, + embed_dims=768, + conv_type="Conv2d", + kernel_size=16, + stride=None, + padding="corner", + dilation=1, + bias=True, + norm_cfg=None, + input_size=None, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) self.embed_dims = embed_dims if stride is None: @@ -135,7 +139,8 @@ def __init__(self, kernel_size=kernel_size, stride=stride, dilation=dilation, - padding=padding) + padding=padding, + ) # disable the padding of conv padding = 0 else: @@ -150,7 +155,8 @@ def __init__(self, stride=stride, padding=padding, dilation=dilation, - bias=bias) + bias=bias, + ) if norm_cfg is not None: self.norm = build_norm_layer(norm_cfg, embed_dims)[1] @@ -171,10 +177,12 @@ def __init__(self, input_size = (input_h, input_w) # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html - h_out = (input_size[0] + 2 * padding[0] - dilation[0] * - (kernel_size[0] - 1) - 1) // stride[0] + 1 - w_out = (input_size[1] + 2 * padding[1] - dilation[1] * - (kernel_size[1] - 1) - 1) // stride[1] + 1 + h_out = ( + input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1 + ) // stride[0] + 1 + w_out = ( + input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1 + ) // stride[1] + 1 self.init_out_size = (h_out, w_out) else: self.init_input_size = None @@ -233,16 +241,18 @@ class PatchMerging(BaseModule): Default: None. """ - def __init__(self, - in_channels, - out_channels, - kernel_size=2, - stride=None, - padding='corner', - dilation=1, - bias=False, - norm_cfg=dict(type='LN'), - init_cfg=None): + def __init__( + self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding="corner", + dilation=1, + bias=False, + norm_cfg=dict(type="LN"), + init_cfg=None, + ): super().__init__(init_cfg=init_cfg) self.in_channels = in_channels self.out_channels = out_channels @@ -260,7 +270,8 @@ def __init__(self, kernel_size=kernel_size, stride=stride, dilation=dilation, - padding=padding) + padding=padding, + ) # disable the padding of unfold padding = 0 else: @@ -268,10 +279,8 @@ def __init__(self, padding = to_2tuple(padding) self.sampler = nn.Unfold( - kernel_size=kernel_size, - dilation=dilation, - padding=padding, - stride=stride) + kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride + ) sample_dim = kernel_size[0] * kernel_size[1] * in_channels @@ -297,13 +306,12 @@ def forward(self, x, input_size): (Merged_H, Merged_W). """ B, L, C = x.shape - assert isinstance(input_size, Sequence), f'Expect ' \ - f'input_size is ' \ - f'`Sequence` ' \ - f'but get {input_size}' + assert isinstance(input_size, Sequence), ( + f"Expect " f"input_size is " f"`Sequence` " f"but get {input_size}" + ) H, W = input_size - assert L == H * W, 'input feature has wrong size' + assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W # Use nn.Unfold to merge patch. About 25% faster than original method, @@ -316,12 +324,18 @@ def forward(self, x, input_size): x = self.sampler(x) # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) - out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * - (self.sampler.kernel_size[0] - 1) - - 1) // self.sampler.stride[0] + 1 - out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * - (self.sampler.kernel_size[1] - 1) - - 1) // self.sampler.stride[1] + 1 + out_h = ( + H + + 2 * self.sampler.padding[0] + - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1) + - 1 + ) // self.sampler.stride[0] + 1 + out_w = ( + W + + 2 * self.sampler.padding[1] + - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1) + - 1 + ) // self.sampler.stride[1] + 1 output_size = (out_h, out_w) x = x.transpose(1, 2) # B, H/2*W/2, 4*C diff --git a/mmsegmentation/mmseg/models/utils/inverted_residual.py b/mmsegmentation/mmseg/models/utils/inverted_residual.py index c9cda76..28b8ff2 100644 --- a/mmsegmentation/mmseg/models/utils/inverted_residual.py +++ b/mmsegmentation/mmseg/models/utils/inverted_residual.py @@ -29,21 +29,22 @@ class InvertedResidual(nn.Module): Tensor: The output tensor. """ - def __init__(self, - in_channels, - out_channels, - stride, - expand_ratio, - dilation=1, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU6'), - with_cp=False, - **kwargs): - super(InvertedResidual, self).__init__() + def __init__( + self, + in_channels, + out_channels, + stride, + expand_ratio, + dilation=1, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU6"), + with_cp=False, + **kwargs, + ): + super().__init__() self.stride = stride - assert stride in [1, 2], f'stride must in [1, 2]. ' \ - f'But received {stride}.' + assert stride in [1, 2], f"stride must in [1, 2]. " f"But received {stride}." self.with_cp = with_cp self.use_res_connect = self.stride == 1 and in_channels == out_channels hidden_dim = int(round(in_channels * expand_ratio)) @@ -58,33 +59,38 @@ def __init__(self, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, - **kwargs)) - layers.extend([ - ConvModule( - in_channels=hidden_dim, - out_channels=hidden_dim, - kernel_size=3, - stride=stride, - padding=dilation, - dilation=dilation, - groups=hidden_dim, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=act_cfg, - **kwargs), - ConvModule( - in_channels=hidden_dim, - out_channels=out_channels, - kernel_size=1, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=None, - **kwargs) - ]) + **kwargs, + ) + ) + layers.extend( + [ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs, + ), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + **kwargs, + ), + ] + ) self.conv = nn.Sequential(*layers) def forward(self, x): - def _inner_forward(x): if self.use_res_connect: return x + self.conv(x) @@ -126,20 +132,22 @@ class InvertedResidualV3(nn.Module): Tensor: The output tensor. """ - def __init__(self, - in_channels, - out_channels, - mid_channels, - kernel_size=3, - stride=1, - se_cfg=None, - with_expand_conv=True, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - with_cp=False): - super(InvertedResidualV3, self).__init__() - self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + def __init__( + self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + with_expand_conv=True, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + with_cp=False, + ): + super().__init__() + self.with_res_shortcut = stride == 1 and in_channels == out_channels assert stride in [1, 2] self.with_cp = with_cp self.with_se = se_cfg is not None @@ -159,7 +167,8 @@ def __init__(self, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) self.depthwise_conv = ConvModule( in_channels=mid_channels, out_channels=mid_channels, @@ -167,10 +176,10 @@ def __init__(self, stride=stride, padding=kernel_size // 2, groups=mid_channels, - conv_cfg=dict( - type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg, + conv_cfg=dict(type="Conv2dAdaptivePadding") if stride == 2 else conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) if self.with_se: self.se = SELayer(**se_cfg) @@ -183,10 +192,10 @@ def __init__(self, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=None) + act_cfg=None, + ) def forward(self, x): - def _inner_forward(x): out = x diff --git a/mmsegmentation/mmseg/models/utils/res_layer.py b/mmsegmentation/mmseg/models/utils/res_layer.py index 190a0c5..f53f063 100644 --- a/mmsegmentation/mmseg/models/utils/res_layer.py +++ b/mmsegmentation/mmseg/models/utils/res_layer.py @@ -25,19 +25,21 @@ class ResLayer(Sequential): Default: False """ - def __init__(self, - block, - inplanes, - planes, - num_blocks, - stride=1, - dilation=1, - avg_down=False, - conv_cfg=None, - norm_cfg=dict(type='BN'), - multi_grid=None, - contract_dilation=False, - **kwargs): + def __init__( + self, + block, + inplanes, + planes, + num_blocks, + stride=1, + dilation=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type="BN"), + multi_grid=None, + contract_dilation=False, + **kwargs, + ): self.block = block downsample = None @@ -51,17 +53,22 @@ def __init__(self, kernel_size=stride, stride=stride, ceil_mode=True, - count_include_pad=False)) - downsample.extend([ - build_conv_layer( - conv_cfg, - inplanes, - planes * block.expansion, - kernel_size=1, - stride=conv_stride, - bias=False), - build_norm_layer(norm_cfg, planes * block.expansion)[1] - ]) + count_include_pad=False, + ) + ) + downsample.extend( + [ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False, + ), + build_norm_layer(norm_cfg, planes * block.expansion)[1], + ] + ) downsample = nn.Sequential(*downsample) layers = [] @@ -81,7 +88,9 @@ def __init__(self, downsample=downsample, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - **kwargs)) + **kwargs, + ) + ) inplanes = planes * block.expansion for i in range(1, num_blocks): layers.append( @@ -92,5 +101,7 @@ def __init__(self, dilation=dilation if multi_grid is None else multi_grid[i], conv_cfg=conv_cfg, norm_cfg=norm_cfg, - **kwargs)) - super(ResLayer, self).__init__(*layers) + **kwargs, + ) + ) + super().__init__(*layers) diff --git a/mmsegmentation/mmseg/models/utils/se_layer.py b/mmsegmentation/mmseg/models/utils/se_layer.py index 16f52aa..ce86abe 100644 --- a/mmsegmentation/mmseg/models/utils/se_layer.py +++ b/mmsegmentation/mmseg/models/utils/se_layer.py @@ -24,13 +24,14 @@ class SELayer(nn.Module): divisor=6.0)). """ - def __init__(self, - channels, - ratio=16, - conv_cfg=None, - act_cfg=(dict(type='ReLU'), - dict(type='HSigmoid', bias=3.0, divisor=6.0))): - super(SELayer, self).__init__() + def __init__( + self, + channels, + ratio=16, + conv_cfg=None, + act_cfg=(dict(type="ReLU"), dict(type="HSigmoid", bias=3.0, divisor=6.0)), + ): + super().__init__() if isinstance(act_cfg, dict): act_cfg = (act_cfg, act_cfg) assert len(act_cfg) == 2 @@ -42,14 +43,16 @@ def __init__(self, kernel_size=1, stride=1, conv_cfg=conv_cfg, - act_cfg=act_cfg[0]) + act_cfg=act_cfg[0], + ) self.conv2 = ConvModule( in_channels=make_divisible(channels // ratio, 8), out_channels=channels, kernel_size=1, stride=1, conv_cfg=conv_cfg, - act_cfg=act_cfg[1]) + act_cfg=act_cfg[1], + ) def forward(self, x): out = self.global_avgpool(x) diff --git a/mmsegmentation/mmseg/models/utils/self_attention_block.py b/mmsegmentation/mmseg/models/utils/self_attention_block.py index c945fa7..3679470 100644 --- a/mmsegmentation/mmseg/models/utils/self_attention_block.py +++ b/mmsegmentation/mmseg/models/utils/self_attention_block.py @@ -30,12 +30,26 @@ class SelfAttentionBlock(nn.Module): act_cfg (dict|None): Config of activation layers. """ - def __init__(self, key_in_channels, query_in_channels, channels, - out_channels, share_key_query, query_downsample, - key_downsample, key_query_num_convs, value_out_num_convs, - key_query_norm, value_out_norm, matmul_norm, with_out, - conv_cfg, norm_cfg, act_cfg): - super(SelfAttentionBlock, self).__init__() + def __init__( + self, + key_in_channels, + query_in_channels, + channels, + out_channels, + share_key_query, + query_downsample, + key_downsample, + key_query_num_convs, + value_out_num_convs, + key_query_norm, + value_out_norm, + matmul_norm, + with_out, + conv_cfg, + norm_cfg, + act_cfg, + ): + super().__init__() if share_key_query: assert key_in_channels == query_in_channels self.key_in_channels = key_in_channels @@ -53,7 +67,8 @@ def __init__(self, key_in_channels, query_in_channels, channels, use_conv_module=key_query_norm, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) if share_key_query: self.query_project = self.key_project else: @@ -64,7 +79,8 @@ def __init__(self, key_in_channels, query_in_channels, channels, use_conv_module=key_query_norm, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) self.value_project = self.build_project( key_in_channels, channels if with_out else out_channels, @@ -72,7 +88,8 @@ def __init__(self, key_in_channels, query_in_channels, channels, use_conv_module=value_out_norm, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) if with_out: self.out_project = self.build_project( channels, @@ -81,7 +98,8 @@ def __init__(self, key_in_channels, query_in_channels, channels, use_conv_module=value_out_norm, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) else: self.out_project = None @@ -97,8 +115,16 @@ def init_weights(self): if not isinstance(self.out_project, ConvModule): constant_init(self.out_project, 0) - def build_project(self, in_channels, channels, num_convs, use_conv_module, - conv_cfg, norm_cfg, act_cfg): + def build_project( + self, + in_channels, + channels, + num_convs, + use_conv_module, + conv_cfg, + norm_cfg, + act_cfg, + ): """Build projection layer for key/query/value/out.""" if use_conv_module: convs = [ @@ -108,7 +134,8 @@ def build_project(self, in_channels, channels, num_convs, use_conv_module, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) ] for _ in range(num_convs - 1): convs.append( @@ -118,7 +145,9 @@ def build_project(self, in_channels, channels, num_convs, use_conv_module, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + ) + ) else: convs = [nn.Conv2d(in_channels, channels, 1)] for _ in range(num_convs - 1): @@ -149,7 +178,7 @@ def forward(self, query_feats, key_feats): sim_map = torch.matmul(query, key) if self.matmul_norm: - sim_map = (self.channels**-.5) * sim_map + sim_map = (self.channels**-0.5) * sim_map sim_map = F.softmax(sim_map, dim=-1) context = torch.matmul(sim_map, value) diff --git a/mmsegmentation/mmseg/models/utils/shape_convert.py b/mmsegmentation/mmseg/models/utils/shape_convert.py index cce1e22..8aa8fa6 100644 --- a/mmsegmentation/mmseg/models/utils/shape_convert.py +++ b/mmsegmentation/mmseg/models/utils/shape_convert.py @@ -12,7 +12,7 @@ def nlc_to_nchw(x, hw_shape): H, W = hw_shape assert len(x.shape) == 3 B, L, C = x.shape - assert L == H * W, 'The seq_len doesn\'t match H, W' + assert L == H * W, "The seq_len doesn't match H, W" return x.transpose(1, 2).reshape(B, C, H, W) @@ -95,7 +95,7 @@ def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs): H, W = hw_shape assert len(x.shape) == 3 B, L, C = x.shape - assert L == H * W, 'The seq_len doesn\'t match H, W' + assert L == H * W, "The seq_len doesn't match H, W" if not contiguous: x = x.transpose(1, 2).reshape(B, C, H, W) x = module(x, **kwargs) diff --git a/mmsegmentation/mmseg/models/utils/up_conv_block.py b/mmsegmentation/mmseg/models/utils/up_conv_block.py index d8396d9..6564b79 100644 --- a/mmsegmentation/mmseg/models/utils/up_conv_block.py +++ b/mmsegmentation/mmseg/models/utils/up_conv_block.py @@ -42,24 +42,26 @@ class UpConvBlock(nn.Module): plugins (dict): plugins for convolutional layers. Default: None. """ - def __init__(self, - conv_block, - in_channels, - skip_channels, - out_channels, - num_convs=2, - stride=1, - dilation=1, - with_cp=False, - conv_cfg=None, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), - upsample_cfg=dict(type='InterpConv'), - dcn=None, - plugins=None): - super(UpConvBlock, self).__init__() - assert dcn is None, 'Not implemented yet.' - assert plugins is None, 'Not implemented yet.' + def __init__( + self, + conv_block, + in_channels, + skip_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), + upsample_cfg=dict(type="InterpConv"), + dcn=None, + plugins=None, + ): + super().__init__() + assert dcn is None, "Not implemented yet." + assert plugins is None, "Not implemented yet." self.conv_block = conv_block( in_channels=2 * skip_channels, @@ -72,7 +74,8 @@ def __init__(self, norm_cfg=norm_cfg, act_cfg=act_cfg, dcn=None, - plugins=None) + plugins=None, + ) if upsample_cfg is not None: self.upsample = build_upsample_layer( cfg=upsample_cfg, @@ -80,7 +83,8 @@ def __init__(self, out_channels=skip_channels, with_cp=with_cp, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) else: self.upsample = ConvModule( in_channels, @@ -90,7 +94,8 @@ def __init__(self, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg) + act_cfg=act_cfg, + ) def forward(self, skip, x): """Forward function.""" diff --git a/mmsegmentation/mmseg/ops/__init__.py b/mmsegmentation/mmseg/ops/__init__.py index bc075cd..a6a4761 100644 --- a/mmsegmentation/mmseg/ops/__init__.py +++ b/mmsegmentation/mmseg/ops/__init__.py @@ -2,4 +2,4 @@ from .encoding import Encoding from .wrappers import Upsample, resize -__all__ = ['Upsample', 'resize', 'Encoding'] +__all__ = ["Upsample", "resize", "Encoding"] diff --git a/mmsegmentation/mmseg/ops/encoding.py b/mmsegmentation/mmseg/ops/encoding.py index f397cc5..c8d944c 100644 --- a/mmsegmentation/mmseg/ops/encoding.py +++ b/mmsegmentation/mmseg/ops/encoding.py @@ -16,31 +16,32 @@ class Encoding(nn.Module): """ def __init__(self, channels, num_codes): - super(Encoding, self).__init__() + super().__init__() # init codewords and smoothing factor self.channels, self.num_codes = channels, num_codes - std = 1. / ((num_codes * channels)**0.5) + std = 1.0 / ((num_codes * channels) ** 0.5) # [num_codes, channels] self.codewords = nn.Parameter( - torch.empty(num_codes, channels, - dtype=torch.float).uniform_(-std, std), - requires_grad=True) + torch.empty(num_codes, channels, dtype=torch.float).uniform_(-std, std), + requires_grad=True, + ) # [num_codes] self.scale = nn.Parameter( torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), - requires_grad=True) + requires_grad=True, + ) @staticmethod def scaled_l2(x, codewords, scale): num_codes, channels = codewords.size() batch_size = x.size(0) reshaped_scale = scale.view((1, 1, num_codes)) - expanded_x = x.unsqueeze(2).expand( - (batch_size, x.size(1), num_codes, channels)) + expanded_x = x.unsqueeze(2).expand((batch_size, x.size(1), num_codes, channels)) reshaped_codewords = codewords.view((1, 1, num_codes, channels)) - scaled_l2_norm = reshaped_scale * ( - expanded_x - reshaped_codewords).pow(2).sum(dim=3) + scaled_l2_norm = reshaped_scale * (expanded_x - reshaped_codewords).pow(2).sum( + dim=3 + ) return scaled_l2_norm @staticmethod @@ -49,10 +50,10 @@ def aggregate(assignment_weights, x, codewords): reshaped_codewords = codewords.view((1, 1, num_codes, channels)) batch_size = x.size(0) - expanded_x = x.unsqueeze(2).expand( - (batch_size, x.size(1), num_codes, channels)) - encoded_feat = (assignment_weights.unsqueeze(3) * - (expanded_x - reshaped_codewords)).sum(dim=1) + expanded_x = x.unsqueeze(2).expand((batch_size, x.size(1), num_codes, channels)) + encoded_feat = ( + assignment_weights.unsqueeze(3) * (expanded_x - reshaped_codewords) + ).sum(dim=1) return encoded_feat def forward(self, x): @@ -63,13 +64,13 @@ def forward(self, x): x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() # assignment_weights: [batch_size, channels, num_codes] assignment_weights = F.softmax( - self.scaled_l2(x, self.codewords, self.scale), dim=2) + self.scaled_l2(x, self.codewords, self.scale), dim=2 + ) # aggregate encoded_feat = self.aggregate(assignment_weights, x, self.codewords) return encoded_feat def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ - f'x{self.channels})' + repr_str += f"(Nx{self.channels}xHxW =>Nx{self.num_codes}" f"x{self.channels})" return repr_str diff --git a/mmsegmentation/mmseg/ops/wrappers.py b/mmsegmentation/mmseg/ops/wrappers.py index bcababd..bf4da00 100644 --- a/mmsegmentation/mmseg/ops/wrappers.py +++ b/mmsegmentation/mmseg/ops/wrappers.py @@ -5,36 +5,38 @@ import torch.nn.functional as F -def resize(input, - size=None, - scale_factor=None, - mode='nearest', - align_corners=None, - warning=True): +def resize( + input, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, + warning=True, +): if warning: if size is not None and align_corners: input_h, input_w = tuple(int(x) for x in input.shape[2:]) output_h, output_w = tuple(int(x) for x in size) if output_h > input_h or output_w > input_w: - if ((output_h > 1 and output_w > 1 and input_h > 1 - and input_w > 1) and (output_h - 1) % (input_h - 1) - and (output_w - 1) % (input_w - 1)): + if ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): warnings.warn( - f'When align_corners={align_corners}, ' - 'the output would more aligned if ' - f'input size {(input_h, input_w)} is `x+1` and ' - f'out size {(output_h, output_w)} is `nx+1`') + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`" + ) return F.interpolate(input, size, scale_factor, mode, align_corners) class Upsample(nn.Module): - - def __init__(self, - size=None, - scale_factor=None, - mode='nearest', - align_corners=None): - super(Upsample, self).__init__() + def __init__( + self, size=None, scale_factor=None, mode="nearest", align_corners=None + ): + super().__init__() self.size = size if isinstance(scale_factor, tuple): self.scale_factor = tuple(float(factor) for factor in scale_factor) diff --git a/mmsegmentation/mmseg/utils/__init__.py b/mmsegmentation/mmseg/utils/__init__.py index e3ef4b3..6589b41 100644 --- a/mmsegmentation/mmseg/utils/__init__.py +++ b/mmsegmentation/mmseg/utils/__init__.py @@ -6,6 +6,11 @@ from .util_distribution import build_ddp, build_dp, get_device __all__ = [ - 'get_root_logger', 'collect_env', 'find_latest_checkpoint', - 'setup_multi_processes', 'build_ddp', 'build_dp', 'get_device' + "get_root_logger", + "collect_env", + "find_latest_checkpoint", + "setup_multi_processes", + "build_ddp", + "build_dp", + "get_device", ] diff --git a/mmsegmentation/mmseg/utils/collect_env.py b/mmsegmentation/mmseg/utils/collect_env.py index 3379ecb..55a8e00 100644 --- a/mmsegmentation/mmseg/utils/collect_env.py +++ b/mmsegmentation/mmseg/utils/collect_env.py @@ -8,11 +8,11 @@ def collect_env(): """Collect the information of the running environments.""" env_info = collect_base_env() - env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' + env_info["MMSegmentation"] = f"{mmseg.__version__}+{get_git_hash()[:7]}" return env_info -if __name__ == '__main__': +if __name__ == "__main__": for name, val in collect_env().items(): - print('{}: {}'.format(name, val)) + print(f"{name}: {val}") diff --git a/mmsegmentation/mmseg/utils/logger.py b/mmsegmentation/mmseg/utils/logger.py index 0cb3c78..00019e4 100644 --- a/mmsegmentation/mmseg/utils/logger.py +++ b/mmsegmentation/mmseg/utils/logger.py @@ -23,6 +23,6 @@ def get_root_logger(log_file=None, log_level=logging.INFO): logging.Logger: The root logger. """ - logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) + logger = get_logger(name="mmseg", log_file=log_file, log_level=log_level) return logger diff --git a/mmsegmentation/mmseg/utils/misc.py b/mmsegmentation/mmseg/utils/misc.py index bd1b6b1..e669ce2 100644 --- a/mmsegmentation/mmseg/utils/misc.py +++ b/mmsegmentation/mmseg/utils/misc.py @@ -4,7 +4,7 @@ import warnings -def find_latest_checkpoint(path, suffix='pth'): +def find_latest_checkpoint(path, suffix="pth"): """This function is for finding the latest checkpoint. It will be used when automatically resume, modified from @@ -20,21 +20,21 @@ def find_latest_checkpoint(path, suffix='pth'): if not osp.exists(path): warnings.warn("The path of the checkpoints doesn't exist.") return None - if osp.exists(osp.join(path, f'latest.{suffix}')): - return osp.join(path, f'latest.{suffix}') + if osp.exists(osp.join(path, f"latest.{suffix}")): + return osp.join(path, f"latest.{suffix}") - checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) + checkpoints = glob.glob(osp.join(path, f"*.{suffix}")) if len(checkpoints) == 0: - warnings.warn('The are no checkpoints in the path') + warnings.warn("The are no checkpoints in the path") return None latest = -1 - latest_path = '' + latest_path = "" for checkpoint in checkpoints: if len(checkpoint) < len(latest_path): continue # `count` is iteration number, as checkpoints are saved as # 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number. - count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) + count = int(osp.basename(checkpoint).split("_")[-1].split(".")[0]) if count > latest: latest = count latest_path = checkpoint diff --git a/mmsegmentation/mmseg/utils/set_env.py b/mmsegmentation/mmseg/utils/set_env.py index bf18453..855275f 100644 --- a/mmsegmentation/mmseg/utils/set_env.py +++ b/mmsegmentation/mmseg/utils/set_env.py @@ -13,43 +13,43 @@ def setup_multi_processes(cfg): logger = get_root_logger() # set multi-process start method - if platform.system() != 'Windows': - mp_start_method = cfg.get('mp_start_method', None) + if platform.system() != "Windows": + mp_start_method = cfg.get("mp_start_method", None) current_method = mp.get_start_method(allow_none=True) - if mp_start_method in ('fork', 'spawn', 'forkserver'): + if mp_start_method in ("fork", "spawn", "forkserver"): logger.info( - f'Multi-processing start method `{mp_start_method}` is ' - f'different from the previous setting `{current_method}`.' - f'It will be force set to `{mp_start_method}`.') + f"Multi-processing start method `{mp_start_method}` is " + f"different from the previous setting `{current_method}`." + f"It will be force set to `{mp_start_method}`." + ) mp.set_start_method(mp_start_method, force=True) else: - logger.info( - f'Multi-processing start method is `{mp_start_method}`') + logger.info(f"Multi-processing start method is `{mp_start_method}`") # disable opencv multithreading to avoid system being overloaded - opencv_num_threads = cfg.get('opencv_num_threads', None) + opencv_num_threads = cfg.get("opencv_num_threads", None) if isinstance(opencv_num_threads, int): - logger.info(f'OpenCV num_threads is `{opencv_num_threads}`') + logger.info(f"OpenCV num_threads is `{opencv_num_threads}`") cv2.setNumThreads(opencv_num_threads) else: - logger.info(f'OpenCV num_threads is `{cv2.getNumThreads()}') + logger.info(f"OpenCV num_threads is `{cv2.getNumThreads()}") if cfg.data.workers_per_gpu > 1: # setup OMP threads # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa - omp_num_threads = cfg.get('omp_num_threads', None) - if 'OMP_NUM_THREADS' not in os.environ: + omp_num_threads = cfg.get("omp_num_threads", None) + if "OMP_NUM_THREADS" not in os.environ: if isinstance(omp_num_threads, int): - logger.info(f'OMP num threads is {omp_num_threads}') - os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) + logger.info(f"OMP num threads is {omp_num_threads}") + os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) else: logger.info(f'OMP num threads is {os.environ["OMP_NUM_THREADS"] }') # setup MKL threads - if 'MKL_NUM_THREADS' not in os.environ: - mkl_num_threads = cfg.get('mkl_num_threads', None) + if "MKL_NUM_THREADS" not in os.environ: + mkl_num_threads = cfg.get("mkl_num_threads", None) if isinstance(mkl_num_threads, int): - logger.info(f'MKL num threads is {mkl_num_threads}') - os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) + logger.info(f"MKL num threads is {mkl_num_threads}") + os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads) else: logger.info(f'MKL num threads is {os.environ["MKL_NUM_THREADS"]}') diff --git a/mmsegmentation/mmseg/utils/util_distribution.py b/mmsegmentation/mmseg/utils/util_distribution.py index 16651c2..051435e 100644 --- a/mmsegmentation/mmseg/utils/util_distribution.py +++ b/mmsegmentation/mmseg/utils/util_distribution.py @@ -5,12 +5,12 @@ from mmseg import digit_version -dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel} +dp_factory = {"cuda": MMDataParallel, "cpu": MMDataParallel} -ddp_factory = {'cuda': MMDistributedDataParallel} +ddp_factory = {"cuda": MMDistributedDataParallel} -def build_dp(model, device='cuda', dim=0, *args, **kwargs): +def build_dp(model, device="cuda", dim=0, *args, **kwargs): """build DataParallel module by device type. if device is cuda, return a MMDataParallel module; if device is mlu, @@ -24,19 +24,21 @@ def build_dp(model, device='cuda', dim=0, *args, **kwargs): Returns: :class:`nn.Module`: parallelized module. """ - if device == 'cuda': + if device == "cuda": model = model.cuda() - elif device == 'mlu': - assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \ - 'Please use MMCV >= 1.5.0 for MLU training!' + elif device == "mlu": + assert digit_version(mmcv.__version__) >= digit_version( + "1.5.0" + ), "Please use MMCV >= 1.5.0 for MLU training!" from mmcv.device.mlu import MLUDataParallel - dp_factory['mlu'] = MLUDataParallel + + dp_factory["mlu"] = MLUDataParallel model = model.mlu() return dp_factory[device](model, dim=dim, *args, **kwargs) -def build_ddp(model, device='cuda', *args, **kwargs): +def build_ddp(model, device="cuda", *args, **kwargs): """Build DistributedDataParallel module by device type. If device is cuda, return a MMDistributedDataParallel module; @@ -53,14 +55,16 @@ def build_ddp(model, device='cuda', *args, **kwargs): .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel. DistributedDataParallel.html """ - assert device in ['cuda', 'mlu'], 'Only available for cuda or mlu devices.' - if device == 'cuda': + assert device in ["cuda", "mlu"], "Only available for cuda or mlu devices." + if device == "cuda": model = model.cuda() - elif device == 'mlu': - assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \ - 'Please use MMCV >= 1.5.0 for MLU training!' + elif device == "mlu": + assert digit_version(mmcv.__version__) >= digit_version( + "1.5.0" + ), "Please use MMCV >= 1.5.0 for MLU training!" from mmcv.device.mlu import MLUDistributedDataParallel - ddp_factory['mlu'] = MLUDistributedDataParallel + + ddp_factory["mlu"] = MLUDistributedDataParallel model = model.mlu() return ddp_factory[device](model, *args, **kwargs) @@ -68,14 +72,11 @@ def build_ddp(model, device='cuda', *args, **kwargs): def is_mlu_available(): """Returns a bool indicating if MLU is currently available.""" - return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() + return hasattr(torch, "is_mlu_available") and torch.is_mlu_available() def get_device(): """Returns an available device, cpu, cuda or mlu.""" - is_device_available = { - 'cuda': torch.cuda.is_available(), - 'mlu': is_mlu_available() - } + is_device_available = {"cuda": torch.cuda.is_available(), "mlu": is_mlu_available()} device_list = [k for k, v in is_device_available.items() if v] - return device_list[0] if len(device_list) == 1 else 'cpu' + return device_list[0] if len(device_list) == 1 else "cpu" diff --git a/mmsegmentation/mmseg/version.py b/mmsegmentation/mmseg/version.py index 9f27ecb..4cd1f71 100644 --- a/mmsegmentation/mmseg/version.py +++ b/mmsegmentation/mmseg/version.py @@ -1,17 +1,17 @@ # Copyright (c) Open-MMLab. All rights reserved. -__version__ = '0.29.1' +__version__ = "0.29.1" def parse_version_info(version_str): version_info = [] - for x in version_str.split('.'): + for x in version_str.split("."): if x.isdigit(): version_info.append(int(x)) - elif x.find('rc') != -1: - patch_version = x.split('rc') + elif x.find("rc") != -1: + patch_version = x.split("rc") version_info.append(int(patch_version[0])) - version_info.append(f'rc{patch_version[1]}') + version_info.append(f"rc{patch_version[1]}") return tuple(version_info) diff --git a/mmsegmentation/setup.py b/mmsegmentation/setup.py index 7461e76..a0e3069 100755 --- a/mmsegmentation/setup.py +++ b/mmsegmentation/setup.py @@ -9,21 +9,21 @@ def readme(): - with open('README.md', encoding='utf-8') as f: + with open("README.md", encoding="utf-8") as f: content = f.read() return content -version_file = 'mmseg/version.py' +version_file = "mmseg/version.py" def get_version(): - with open(version_file, 'r') as f: - exec(compile(f.read(), version_file, 'exec')) - return locals()['__version__'] + with open(version_file) as f: + exec(compile(f.read(), version_file, "exec")) + return locals()["__version__"] -def parse_requirements(fname='requirements.txt', with_version=True): +def parse_requirements(fname="requirements.txt", with_version=True): """Parse the package dependencies listed in a requirements file but strips specific versioning information. @@ -40,56 +40,56 @@ def parse_requirements(fname='requirements.txt', with_version=True): import re import sys from os.path import exists + require_fpath = fname def parse_line(line): """Parse information from a line in a requirements text file.""" - if line.startswith('-r '): + if line.startswith("-r "): # Allow specifying requirements in other files - target = line.split(' ')[1] + target = line.split(" ")[1] yield from parse_require_file(target) else: - info = {'line': line} - if line.startswith('-e '): - info['package'] = line.split('#egg=')[1] + info = {"line": line} + if line.startswith("-e "): + info["package"] = line.split("#egg=")[1] else: # Remove versioning from the package - pat = '(' + '|'.join(['>=', '==', '>']) + ')' + pat = "(" + "|".join([">=", "==", ">"]) + ")" parts = re.split(pat, line, maxsplit=1) parts = [p.strip() for p in parts] - info['package'] = parts[0] + info["package"] = parts[0] if len(parts) > 1: op, rest = parts[1:] - if ';' in rest: + if ";" in rest: # Handle platform specific dependencies # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies - version, platform_deps = map(str.strip, - rest.split(';')) - info['platform_deps'] = platform_deps + version, platform_deps = map(str.strip, rest.split(";")) + info["platform_deps"] = platform_deps else: version = rest - info['version'] = op, version + info["version"] = op, version yield info def parse_require_file(fpath): - with open(fpath, 'r') as f: + with open(fpath) as f: for line in f.readlines(): line = line.strip() - if line and not line.startswith('#'): + if line and not line.startswith("#"): yield from parse_line(line) def gen_packages_items(): if not exists(require_fpath): return for info in parse_require_file(require_fpath): - parts = [info['package']] - if with_version and 'version' in info: - parts.extend(info['version']) - if not sys.version.startswith('3.4'): - platform_deps = info.get('platform_deps') + parts = [info["package"]] + if with_version and "version" in info: + parts.extend(info["version"]) + if not sys.version.startswith("3.4"): + platform_deps = info.get("platform_deps") if platform_deps is not None: - parts.append(f';{platform_deps}') - item = ''.join(parts) + parts.append(f";{platform_deps}") + item = "".join(parts) yield item packages = list(gen_packages_items()) @@ -105,21 +105,24 @@ def add_mim_extension(): """ # parse installment mode - if 'develop' in sys.argv: + if "develop" in sys.argv: # installed by `pip install -e .` # set `copy` mode here since symlink fails on Windows. - mode = 'copy' if platform.system() == 'Windows' else 'symlink' - elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv or platform.system( - ) == 'Windows': + mode = "copy" if platform.system() == "Windows" else "symlink" + elif ( + "sdist" in sys.argv + or "bdist_wheel" in sys.argv + or platform.system() == "Windows" + ): # installed by `pip install .` # or create source distribution by `python setup.py sdist` # set `copy` mode here since symlink fails with WinError on Windows. - mode = 'copy' + mode = "copy" else: return - filenames = ['tools', 'configs', 'model-index.yml'] + filenames = ["tools", "configs", "model-index.yml"] repo_path = osp.dirname(__file__) - mim_path = osp.join(repo_path, 'mmseg', '.mim') + mim_path = osp.join(repo_path, "mmseg", ".mim") os.makedirs(mim_path, exist_ok=True) for filename in filenames: if osp.exists(filename): @@ -129,7 +132,7 @@ def add_mim_extension(): os.remove(tar_path) elif osp.isdir(tar_path): shutil.rmtree(tar_path) - if mode == 'symlink': + if mode == "symlink": src_relpath = osp.relpath(src_path, osp.dirname(tar_path)) try: os.symlink(src_relpath, tar_path) @@ -137,54 +140,56 @@ def add_mim_extension(): # Creating a symbolic link on windows may raise an # `OSError: [WinError 1314]` due to privilege. If # the error happens, the src file will be copied - mode = 'copy' + mode = "copy" warnings.warn( - f'Failed to create a symbolic link for {src_relpath},' - f' and it will be copied to {tar_path}') + f"Failed to create a symbolic link for {src_relpath}," + f" and it will be copied to {tar_path}" + ) else: continue - if mode != 'copy': - raise ValueError(f'Invalid mode {mode}') + if mode != "copy": + raise ValueError(f"Invalid mode {mode}") if osp.isfile(src_path): shutil.copyfile(src_path, tar_path) elif osp.isdir(src_path): shutil.copytree(src_path, tar_path) else: - warnings.warn(f'Cannot copy file {src_path}.') + warnings.warn(f"Cannot copy file {src_path}.") -if __name__ == '__main__': +if __name__ == "__main__": add_mim_extension() setup( - name='mmsegmentation', + name="mmsegmentation", version=get_version(), - description='Open MMLab Semantic Segmentation Toolbox and Benchmark', + description="Open MMLab Semantic Segmentation Toolbox and Benchmark", long_description=readme(), - long_description_content_type='text/markdown', - author='MMSegmentation Contributors', - author_email='openmmlab@gmail.com', - keywords='computer vision, semantic segmentation', - url='http://github.com/open-mmlab/mmsegmentation', - packages=find_packages(exclude=('configs', 'tools', 'demo')), + long_description_content_type="text/markdown", + author="MMSegmentation Contributors", + author_email="openmmlab@gmail.com", + keywords="computer vision, semantic segmentation", + url="http://github.com/open-mmlab/mmsegmentation", + packages=find_packages(exclude=("configs", "tools", "demo")), include_package_data=True, classifiers=[ - 'Development Status :: 4 - Beta', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', + "Development Status :: 4 - Beta", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", ], - license='Apache License 2.0', - install_requires=parse_requirements('requirements/runtime.txt'), + license="Apache License 2.0", + install_requires=parse_requirements("requirements/runtime.txt"), extras_require={ - 'all': parse_requirements('requirements.txt'), - 'tests': parse_requirements('requirements/tests.txt'), - 'build': parse_requirements('requirements/build.txt'), - 'optional': parse_requirements('requirements/optional.txt'), - 'mim': parse_requirements('requirements/mminstall.txt'), + "all": parse_requirements("requirements.txt"), + "tests": parse_requirements("requirements/tests.txt"), + "build": parse_requirements("requirements/build.txt"), + "optional": parse_requirements("requirements/optional.txt"), + "mim": parse_requirements("requirements/mminstall.txt"), }, ext_modules=[], - zip_safe=False) + zip_safe=False, + ) diff --git a/mmsegmentation/tests/test_apis/test_single_gpu.py b/mmsegmentation/tests/test_apis/test_single_gpu.py index 0b484f2..23e3e75 100644 --- a/mmsegmentation/tests/test_apis/test_single_gpu.py +++ b/mmsegmentation/tests/test_apis/test_single_gpu.py @@ -12,7 +12,6 @@ class ExampleDataset(Dataset): - def __getitem__(self, idx): results = dict(img=torch.tensor([1]), img_metas=dict()) return results @@ -22,9 +21,8 @@ def __len__(self): class ExampleModel(nn.Module): - def __init__(self): - super(ExampleModel, self).__init__() + super().__init__() self.test_cfg = None self.conv = nn.Conv2d(3, 3, 3) @@ -48,26 +46,23 @@ def test_single_gpu(): assert len(results) == 1 pred = np.load(results[0]) assert isinstance(pred, np.ndarray) - assert pred.shape == (1, ) + assert pred.shape == (1,) assert pred[0] == 1 - shutil.rmtree('.efficient_test') + shutil.rmtree(".efficient_test") # Test pre_eval - test_dataset.pre_eval = MagicMock(return_value=['success']) + test_dataset.pre_eval = MagicMock(return_value=["success"]) results = single_gpu_test(model, data_loader, pre_eval=True) - assert results == ['success'] + assert results == ["success"] # Test format_only - test_dataset.format_results = MagicMock(return_value=['success']) + test_dataset.format_results = MagicMock(return_value=["success"]) results = single_gpu_test(model, data_loader, format_only=True) - assert results == ['success'] + assert results == ["success"] # efficient_test, pre_eval and format_only are mutually exclusive with pytest.raises(AssertionError): single_gpu_test( - model, - dataloader, - efficient_test=True, - format_only=True, - pre_eval=True) + model, dataloader, efficient_test=True, format_only=True, pre_eval=True + ) diff --git a/mmsegmentation/tests/test_config.py b/mmsegmentation/tests/test_config.py index 2482144..43dee1f 100644 --- a/mmsegmentation/tests/test_config.py +++ b/mmsegmentation/tests/test_config.py @@ -17,10 +17,11 @@ def _get_config_directory(): except NameError: # For IPython development when this __file__ is not defined import mmseg + repo_dpath = dirname(dirname(mmseg.__file__)) - config_dpath = join(repo_dpath, 'configs') + config_dpath = join(repo_dpath, "configs") if not exists(config_dpath): - raise Exception('Cannot find config path') + raise Exception("Cannot find config path") return config_dpath @@ -28,35 +29,36 @@ def test_config_build_segmentor(): """Test that all segmentation models defined in the configs can be initialized.""" config_dpath = _get_config_directory() - print('Found config_dpath = {!r}'.format(config_dpath)) + print(f"Found config_dpath = {config_dpath!r}") config_fpaths = [] # one config each sub folder for sub_folder in os.listdir(config_dpath): if isdir(sub_folder): config_fpaths.append( - list(glob.glob(join(config_dpath, sub_folder, '*.py')))[0]) - config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1] + list(glob.glob(join(config_dpath, sub_folder, "*.py")))[0] + ) + config_fpaths = [p for p in config_fpaths if p.find("_base_") == -1] config_names = [relpath(p, config_dpath) for p in config_fpaths] - print('Using {} config files'.format(len(config_names))) + print(f"Using {len(config_names)} config files") for config_fname in config_names: config_fpath = join(config_dpath, config_fname) config_mod = Config.fromfile(config_fpath) config_mod.model - print('Building segmentor, config_fpath = {!r}'.format(config_fpath)) + print(f"Building segmentor, config_fpath = {config_fpath!r}") # Remove pretrained keys to allow for testing in an offline environment - if 'pretrained' in config_mod.model: - config_mod.model['pretrained'] = None + if "pretrained" in config_mod.model: + config_mod.model["pretrained"] = None - print('building {}'.format(config_fname)) + print(f"building {config_fname}") segmentor = build_segmentor(config_mod.model) assert segmentor is not None - head_config = config_mod.model['decode_head'] + head_config = config_mod.model["decode_head"] _check_decode_head(head_config, segmentor.decode_head) @@ -72,24 +74,24 @@ def test_config_data_pipeline(): from mmseg.datasets.pipelines import Compose config_dpath = _get_config_directory() - print('Found config_dpath = {!r}'.format(config_dpath)) + print(f"Found config_dpath = {config_dpath!r}") import glob - config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py'))) - config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1] + + config_fpaths = list(glob.glob(join(config_dpath, "**", "*.py"))) + config_fpaths = [p for p in config_fpaths if p.find("_base_") == -1] config_names = [relpath(p, config_dpath) for p in config_fpaths] - print('Using {} config files'.format(len(config_names))) + print(f"Using {len(config_names)} config files") for config_fname in config_names: config_fpath = join(config_dpath, config_fname) - print( - 'Building data pipeline, config_fpath = {!r}'.format(config_fpath)) + print(f"Building data pipeline, config_fpath = {config_fpath!r}") config_mod = Config.fromfile(config_fpath) # remove loading pipeline load_img_pipeline = config_mod.train_pipeline.pop(0) - to_float32 = load_img_pipeline.get('to_float32', False) + to_float32 = load_img_pipeline.get("to_float32", False) config_mod.train_pipeline.pop(0) config_mod.test_pipeline.pop(0) @@ -102,26 +104,27 @@ def test_config_data_pipeline(): seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8) results = dict( - filename='test_img.png', - ori_filename='test_img.png', + filename="test_img.png", + ori_filename="test_img.png", img=img, img_shape=img.shape, ori_shape=img.shape, - gt_semantic_seg=seg) - results['seg_fields'] = ['gt_semantic_seg'] + gt_semantic_seg=seg, + ) + results["seg_fields"] = ["gt_semantic_seg"] - print('Test training data pipeline: \n{!r}'.format(train_pipeline)) + print(f"Test training data pipeline: \n{train_pipeline!r}") output_results = train_pipeline(results) assert output_results is not None results = dict( - filename='test_img.png', - ori_filename='test_img.png', + filename="test_img.png", + ori_filename="test_img.png", img=img, img_shape=img.shape, ori_shape=img.shape, ) - print('Test testing data pipeline: \n{!r}'.format(test_pipeline)) + print(f"Test testing data pipeline: \n{test_pipeline!r}") output_results = test_pipeline(results) assert output_results is not None @@ -135,27 +138,29 @@ def _check_decode_head(decode_head_cfg, decode_head): _check_decode_head(decode_head_cfg[i], decode_head[i]) return # check consistency between head_config and roi_head - assert decode_head_cfg['type'] == decode_head.__class__.__name__ + assert decode_head_cfg["type"] == decode_head.__class__.__name__ - assert decode_head_cfg['type'] == decode_head.__class__.__name__ + assert decode_head_cfg["type"] == decode_head.__class__.__name__ in_channels = decode_head_cfg.in_channels input_transform = decode_head.input_transform - assert input_transform in ['resize_concat', 'multiple_select', None] + assert input_transform in ["resize_concat", "multiple_select", None] if input_transform is not None: assert isinstance(in_channels, (list, tuple)) assert isinstance(decode_head.in_index, (list, tuple)) assert len(in_channels) == len(decode_head.in_index) - elif input_transform == 'resize_concat': + elif input_transform == "resize_concat": assert sum(in_channels) == decode_head.in_channels else: assert isinstance(in_channels, int) assert in_channels == decode_head.in_channels assert isinstance(decode_head.in_index, int) - if decode_head_cfg['type'] == 'PointHead': - assert decode_head_cfg.channels+decode_head_cfg.num_classes == \ - decode_head.fc_seg.in_channels + if decode_head_cfg["type"] == "PointHead": + assert ( + decode_head_cfg.channels + decode_head_cfg.num_classes + == decode_head.fc_seg.in_channels + ) assert decode_head.fc_seg.out_channels == decode_head_cfg.num_classes else: assert decode_head_cfg.channels == decode_head.conv_seg.in_channels diff --git a/mmsegmentation/tests/test_core/test_layer_decay_optimizer_constructor.py b/mmsegmentation/tests/test_core/test_layer_decay_optimizer_constructor.py index 4911f3b..9c5121b 100644 --- a/mmsegmentation/tests/test_core/test_layer_decay_optimizer_constructor.py +++ b/mmsegmentation/tests/test_core/test_layer_decay_optimizer_constructor.py @@ -5,121 +5,59 @@ from mmcv.cnn import ConvModule from mmseg.core.optimizers.layer_decay_optimizer_constructor import ( - LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) + LayerDecayOptimizerConstructor, + LearningRateDecayOptimizerConstructor, +) base_lr = 1 decay_rate = 2 base_wd = 0.05 weight_decay = 0.05 -expected_stage_wise_lr_wd_convnext = [{ - 'weight_decay': 0.0, - 'lr_scale': 128 -}, { - 'weight_decay': 0.0, - 'lr_scale': 1 -}, { - 'weight_decay': 0.05, - 'lr_scale': 64 -}, { - 'weight_decay': 0.0, - 'lr_scale': 64 -}, { - 'weight_decay': 0.05, - 'lr_scale': 32 -}, { - 'weight_decay': 0.0, - 'lr_scale': 32 -}, { - 'weight_decay': 0.05, - 'lr_scale': 16 -}, { - 'weight_decay': 0.0, - 'lr_scale': 16 -}, { - 'weight_decay': 0.05, - 'lr_scale': 8 -}, { - 'weight_decay': 0.0, - 'lr_scale': 8 -}, { - 'weight_decay': 0.05, - 'lr_scale': 128 -}, { - 'weight_decay': 0.05, - 'lr_scale': 1 -}] - -expected_layer_wise_lr_wd_convnext = [{ - 'weight_decay': 0.0, - 'lr_scale': 128 -}, { - 'weight_decay': 0.0, - 'lr_scale': 1 -}, { - 'weight_decay': 0.05, - 'lr_scale': 64 -}, { - 'weight_decay': 0.0, - 'lr_scale': 64 -}, { - 'weight_decay': 0.05, - 'lr_scale': 32 -}, { - 'weight_decay': 0.0, - 'lr_scale': 32 -}, { - 'weight_decay': 0.05, - 'lr_scale': 16 -}, { - 'weight_decay': 0.0, - 'lr_scale': 16 -}, { - 'weight_decay': 0.05, - 'lr_scale': 2 -}, { - 'weight_decay': 0.0, - 'lr_scale': 2 -}, { - 'weight_decay': 0.05, - 'lr_scale': 128 -}, { - 'weight_decay': 0.05, - 'lr_scale': 1 -}] - -expected_layer_wise_wd_lr_beit = [{ - 'weight_decay': 0.0, - 'lr_scale': 16 -}, { - 'weight_decay': 0.05, - 'lr_scale': 8 -}, { - 'weight_decay': 0.0, - 'lr_scale': 8 -}, { - 'weight_decay': 0.05, - 'lr_scale': 4 -}, { - 'weight_decay': 0.0, - 'lr_scale': 4 -}, { - 'weight_decay': 0.05, - 'lr_scale': 2 -}, { - 'weight_decay': 0.0, - 'lr_scale': 2 -}, { - 'weight_decay': 0.05, - 'lr_scale': 1 -}, { - 'weight_decay': 0.0, - 'lr_scale': 1 -}] +expected_stage_wise_lr_wd_convnext = [ + {"weight_decay": 0.0, "lr_scale": 128}, + {"weight_decay": 0.0, "lr_scale": 1}, + {"weight_decay": 0.05, "lr_scale": 64}, + {"weight_decay": 0.0, "lr_scale": 64}, + {"weight_decay": 0.05, "lr_scale": 32}, + {"weight_decay": 0.0, "lr_scale": 32}, + {"weight_decay": 0.05, "lr_scale": 16}, + {"weight_decay": 0.0, "lr_scale": 16}, + {"weight_decay": 0.05, "lr_scale": 8}, + {"weight_decay": 0.0, "lr_scale": 8}, + {"weight_decay": 0.05, "lr_scale": 128}, + {"weight_decay": 0.05, "lr_scale": 1}, +] + +expected_layer_wise_lr_wd_convnext = [ + {"weight_decay": 0.0, "lr_scale": 128}, + {"weight_decay": 0.0, "lr_scale": 1}, + {"weight_decay": 0.05, "lr_scale": 64}, + {"weight_decay": 0.0, "lr_scale": 64}, + {"weight_decay": 0.05, "lr_scale": 32}, + {"weight_decay": 0.0, "lr_scale": 32}, + {"weight_decay": 0.05, "lr_scale": 16}, + {"weight_decay": 0.0, "lr_scale": 16}, + {"weight_decay": 0.05, "lr_scale": 2}, + {"weight_decay": 0.0, "lr_scale": 2}, + {"weight_decay": 0.05, "lr_scale": 128}, + {"weight_decay": 0.05, "lr_scale": 1}, +] + +expected_layer_wise_wd_lr_beit = [ + {"weight_decay": 0.0, "lr_scale": 16}, + {"weight_decay": 0.05, "lr_scale": 8}, + {"weight_decay": 0.0, "lr_scale": 8}, + {"weight_decay": 0.05, "lr_scale": 4}, + {"weight_decay": 0.0, "lr_scale": 4}, + {"weight_decay": 0.05, "lr_scale": 2}, + {"weight_decay": 0.0, "lr_scale": 2}, + {"weight_decay": 0.05, "lr_scale": 1}, + {"weight_decay": 0.0, "lr_scale": 1}, +] class ToyConvNeXt(nn.Module): - def __init__(self): super().__init__() self.stages = nn.ModuleList() @@ -145,7 +83,6 @@ def __init__(self): class ToyBEiT(nn.Module): - def __init__(self): super().__init__() # add some variables to meet unit test coverate rate @@ -158,7 +95,6 @@ def __init__(self): class ToyMAE(nn.Module): - def __init__(self): super().__init__() # add some variables to meet unit test coverate rate @@ -171,7 +107,6 @@ def __init__(self): class ToySegmentor(nn.Module): - def __init__(self, backbone): super().__init__() self.backbone = backbone @@ -179,50 +114,52 @@ def __init__(self, backbone): class PseudoDataParallel(nn.Module): - def __init__(self, model): super().__init__() self.module = model class ToyViT(nn.Module): - def __init__(self): super().__init__() def check_optimizer_lr_wd(optimizer, gt_lr_wd): assert isinstance(optimizer, torch.optim.AdamW) - assert optimizer.defaults['lr'] == base_lr - assert optimizer.defaults['weight_decay'] == base_wd + assert optimizer.defaults["lr"] == base_lr + assert optimizer.defaults["weight_decay"] == base_wd param_groups = optimizer.param_groups print(param_groups) assert len(param_groups) == len(gt_lr_wd) for i, param_dict in enumerate(param_groups): - assert param_dict['weight_decay'] == gt_lr_wd[i]['weight_decay'] - assert param_dict['lr_scale'] == gt_lr_wd[i]['lr_scale'] - assert param_dict['lr_scale'] == param_dict['lr'] + assert param_dict["weight_decay"] == gt_lr_wd[i]["weight_decay"] + assert param_dict["lr_scale"] == gt_lr_wd[i]["lr_scale"] + assert param_dict["lr_scale"] == param_dict["lr"] def test_learning_rate_decay_optimizer_constructor(): - # Test lr wd for ConvNeXT backbone = ToyConvNeXt() model = PseudoDataParallel(ToySegmentor(backbone)) optimizer_cfg = dict( - type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05) + type="AdamW", lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05 + ) # stagewise decay stagewise_paramwise_cfg = dict( - decay_rate=decay_rate, decay_type='stage_wise', num_layers=6) + decay_rate=decay_rate, decay_type="stage_wise", num_layers=6 + ) optim_constructor = LearningRateDecayOptimizerConstructor( - optimizer_cfg, stagewise_paramwise_cfg) + optimizer_cfg, stagewise_paramwise_cfg + ) optimizer = optim_constructor(model) check_optimizer_lr_wd(optimizer, expected_stage_wise_lr_wd_convnext) # layerwise decay layerwise_paramwise_cfg = dict( - decay_rate=decay_rate, decay_type='layer_wise', num_layers=6) + decay_rate=decay_rate, decay_type="layer_wise", num_layers=6 + ) optim_constructor = LearningRateDecayOptimizerConstructor( - optimizer_cfg, layerwise_paramwise_cfg) + optimizer_cfg, layerwise_paramwise_cfg + ) optimizer = optim_constructor(model) check_optimizer_lr_wd(optimizer, expected_layer_wise_lr_wd_convnext) @@ -231,9 +168,11 @@ def test_learning_rate_decay_optimizer_constructor(): model = PseudoDataParallel(ToySegmentor(backbone)) layerwise_paramwise_cfg = dict( - decay_rate=decay_rate, decay_type='layer_wise', num_layers=3) + decay_rate=decay_rate, decay_type="layer_wise", num_layers=3 + ) optim_constructor = LearningRateDecayOptimizerConstructor( - optimizer_cfg, layerwise_paramwise_cfg) + optimizer_cfg, layerwise_paramwise_cfg + ) optimizer = optim_constructor(model) check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_beit) @@ -242,11 +181,13 @@ def test_learning_rate_decay_optimizer_constructor(): model = PseudoDataParallel(ToySegmentor(backbone)) with pytest.raises(NotImplementedError): optim_constructor = LearningRateDecayOptimizerConstructor( - optimizer_cfg, layerwise_paramwise_cfg) + optimizer_cfg, layerwise_paramwise_cfg + ) optimizer = optim_constructor(model) with pytest.raises(NotImplementedError): optim_constructor = LearningRateDecayOptimizerConstructor( - optimizer_cfg, stagewise_paramwise_cfg) + optimizer_cfg, stagewise_paramwise_cfg + ) optimizer = optim_constructor(model) # Test lr wd for MAE @@ -254,22 +195,21 @@ def test_learning_rate_decay_optimizer_constructor(): model = PseudoDataParallel(ToySegmentor(backbone)) layerwise_paramwise_cfg = dict( - decay_rate=decay_rate, decay_type='layer_wise', num_layers=3) + decay_rate=decay_rate, decay_type="layer_wise", num_layers=3 + ) optim_constructor = LearningRateDecayOptimizerConstructor( - optimizer_cfg, layerwise_paramwise_cfg) + optimizer_cfg, layerwise_paramwise_cfg + ) optimizer = optim_constructor(model) check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_beit) def test_beit_layer_decay_optimizer_constructor(): - # paramwise_cfg with BEiTExampleModel backbone = ToyBEiT() model = PseudoDataParallel(ToySegmentor(backbone)) - optimizer_cfg = dict( - type='AdamW', lr=1, betas=(0.9, 0.999), weight_decay=0.05) + optimizer_cfg = dict(type="AdamW", lr=1, betas=(0.9, 0.999), weight_decay=0.05) paramwise_cfg = dict(layer_decay_rate=2, num_layers=3) - optim_constructor = LayerDecayOptimizerConstructor(optimizer_cfg, - paramwise_cfg) + optim_constructor = LayerDecayOptimizerConstructor(optimizer_cfg, paramwise_cfg) optimizer = optim_constructor(model) check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_beit) diff --git a/mmsegmentation/tests/test_core/test_optimizer.py b/mmsegmentation/tests/test_core/test_optimizer.py index 247f9fe..c17cb73 100644 --- a/mmsegmentation/tests/test_core/test_optimizer.py +++ b/mmsegmentation/tests/test_core/test_optimizer.py @@ -4,12 +4,14 @@ import torch.nn as nn from mmcv.runner import DefaultOptimizerConstructor -from mmseg.core.builder import (OPTIMIZER_BUILDERS, build_optimizer, - build_optimizer_constructor) +from mmseg.core.builder import ( + OPTIMIZER_BUILDERS, + build_optimizer, + build_optimizer_constructor, +) class ExampleModel(nn.Module): - def __init__(self): super().__init__() self.param1 = nn.Parameter(torch.ones(1)) @@ -28,9 +30,11 @@ def forward(self, x): def test_build_optimizer_constructor(): optimizer_cfg = dict( - type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum) + type="SGD", lr=base_lr, weight_decay=base_wd, momentum=momentum + ) optim_constructor_cfg = dict( - type='DefaultOptimizerConstructor', optimizer_cfg=optimizer_cfg) + type="DefaultOptimizerConstructor", optimizer_cfg=optimizer_cfg + ) optim_constructor = build_optimizer_constructor(optim_constructor_cfg) # Test whether optimizer constructor can be built from parent. assert type(optim_constructor) is DefaultOptimizerConstructor @@ -40,20 +44,22 @@ class MyOptimizerConstructor(DefaultOptimizerConstructor): pass optim_constructor_cfg = dict( - type='MyOptimizerConstructor', optimizer_cfg=optimizer_cfg) + type="MyOptimizerConstructor", optimizer_cfg=optimizer_cfg + ) optim_constructor = build_optimizer_constructor(optim_constructor_cfg) # Test optimizer constructor can be built from child registry. assert type(optim_constructor) is MyOptimizerConstructor # Test unregistered constructor cannot be built with pytest.raises(KeyError): - build_optimizer_constructor(dict(type='A')) + build_optimizer_constructor(dict(type="A")) def test_build_optimizer(): model = ExampleModel() optimizer_cfg = dict( - type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum) + type="SGD", lr=base_lr, weight_decay=base_wd, momentum=momentum + ) optimizer = build_optimizer(model, optimizer_cfg) # test whether optimizer is successfully built from parent. assert isinstance(optimizer, torch.optim.SGD) diff --git a/mmsegmentation/tests/test_data/test_dataset.py b/mmsegmentation/tests/test_data/test_dataset.py index 6ea6eb9..3414331 100644 --- a/mmsegmentation/tests/test_data/test_dataset.py +++ b/mmsegmentation/tests/test_data/test_dataset.py @@ -12,51 +12,64 @@ from PIL import Image from mmseg.core.evaluation import get_classes, get_palette -from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset, - COCOStuffDataset, ConcatDataset, CustomDataset, - ISPRSDataset, LoveDADataset, MultiImageMixDataset, - PascalVOCDataset, PotsdamDataset, RepeatDataset, - build_dataset, iSAIDDataset) +from mmseg.datasets import ( + DATASETS, + ADE20KDataset, + CityscapesDataset, + COCOStuffDataset, + ConcatDataset, + CustomDataset, + ISPRSDataset, + LoveDADataset, + MultiImageMixDataset, + PascalVOCDataset, + PotsdamDataset, + RepeatDataset, + build_dataset, + iSAIDDataset, +) def test_classes(): - assert list(CityscapesDataset.CLASSES) == get_classes('cityscapes') - assert list(PascalVOCDataset.CLASSES) == get_classes('voc') == get_classes( - 'pascal_voc') - assert list( - ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k') - assert list(COCOStuffDataset.CLASSES) == get_classes('cocostuff') - assert list(LoveDADataset.CLASSES) == get_classes('loveda') - assert list(PotsdamDataset.CLASSES) == get_classes('potsdam') - assert list(ISPRSDataset.CLASSES) == get_classes('vaihingen') - assert list(iSAIDDataset.CLASSES) == get_classes('isaid') + assert list(CityscapesDataset.CLASSES) == get_classes("cityscapes") + assert ( + list(PascalVOCDataset.CLASSES) + == get_classes("voc") + == get_classes("pascal_voc") + ) + assert list(ADE20KDataset.CLASSES) == get_classes("ade") == get_classes("ade20k") + assert list(COCOStuffDataset.CLASSES) == get_classes("cocostuff") + assert list(LoveDADataset.CLASSES) == get_classes("loveda") + assert list(PotsdamDataset.CLASSES) == get_classes("potsdam") + assert list(ISPRSDataset.CLASSES) == get_classes("vaihingen") + assert list(iSAIDDataset.CLASSES) == get_classes("isaid") with pytest.raises(ValueError): - get_classes('unsupported') + get_classes("unsupported") def test_classes_file_path(): tmp_file = tempfile.NamedTemporaryFile() - classes_path = f'{tmp_file.name}.txt' - train_pipeline = [dict(type='LoadImageFromFile')] - kwargs = dict(pipeline=train_pipeline, img_dir='./', classes=classes_path) + classes_path = f"{tmp_file.name}.txt" + train_pipeline = [dict(type="LoadImageFromFile")] + kwargs = dict(pipeline=train_pipeline, img_dir="./", classes=classes_path) # classes.txt with full categories - categories = get_classes('cityscapes') - with open(classes_path, 'w') as f: - f.write('\n'.join(categories)) + categories = get_classes("cityscapes") + with open(classes_path, "w") as f: + f.write("\n".join(categories)) assert list(CityscapesDataset(**kwargs).CLASSES) == categories # classes.txt with sub categories - categories = ['road', 'sidewalk', 'building'] - with open(classes_path, 'w') as f: - f.write('\n'.join(categories)) + categories = ["road", "sidewalk", "building"] + with open(classes_path, "w") as f: + f.write("\n".join(categories)) assert list(CityscapesDataset(**kwargs).CLASSES) == categories # classes.txt with unknown categories - categories = ['road', 'sidewalk', 'unknown'] - with open(classes_path, 'w') as f: - f.write('\n'.join(categories)) + categories = ["road", "sidewalk", "unknown"] + with open(classes_path, "w") as f: + f.write("\n".join(categories)) with pytest.raises(ValueError): CityscapesDataset(**kwargs) @@ -67,22 +80,22 @@ def test_classes_file_path(): def test_palette(): - assert CityscapesDataset.PALETTE == get_palette('cityscapes') - assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette( - 'pascal_voc') - assert ADE20KDataset.PALETTE == get_palette('ade') == get_palette('ade20k') - assert LoveDADataset.PALETTE == get_palette('loveda') - assert PotsdamDataset.PALETTE == get_palette('potsdam') - assert COCOStuffDataset.PALETTE == get_palette('cocostuff') - assert iSAIDDataset.PALETTE == get_palette('isaid') + assert CityscapesDataset.PALETTE == get_palette("cityscapes") + assert PascalVOCDataset.PALETTE == get_palette("voc") == get_palette("pascal_voc") + assert ADE20KDataset.PALETTE == get_palette("ade") == get_palette("ade20k") + assert LoveDADataset.PALETTE == get_palette("loveda") + assert PotsdamDataset.PALETTE == get_palette("potsdam") + assert COCOStuffDataset.PALETTE == get_palette("cocostuff") + assert iSAIDDataset.PALETTE == get_palette("isaid") with pytest.raises(ValueError): - get_palette('unsupported') + get_palette("unsupported") -@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) -@patch('mmseg.datasets.CustomDataset.__getitem__', - MagicMock(side_effect=lambda idx: idx)) +@patch("mmseg.datasets.CustomDataset.load_annotations", MagicMock) +@patch( + "mmseg.datasets.CustomDataset.__getitem__", MagicMock(side_effect=lambda idx: idx) +) def test_dataset_wrapper(): # CustomDataset.load_annotations = MagicMock() # CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx) @@ -108,9 +121,9 @@ def test_dataset_wrapper(): img_scale = (60, 60) pipeline = [ - dict(type='RandomMosaic', prob=1, img_scale=img_scale), - dict(type='RandomFlip', prob=0.5), - dict(type='Resize', img_scale=img_scale, keep_ratio=False), + dict(type="RandomMosaic", prob=1, img_scale=img_scale), + dict(type="RandomFlip", prob=0.5), + dict(type="Resize", img_scale=img_scale, keep_ratio=False), ] CustomDataset.load_annotations = MagicMock() @@ -122,7 +135,7 @@ def test_dataset_wrapper(): gt_semantic_seg = np.random.randint(5, size=(height, weight)) results.append(dict(gt_semantic_seg=gt_semantic_seg, img=img)) - classes = ['0', '1', '2', '3', '4'] + classes = ["0", "1", "2", "3", "4"] palette = [(0, 0, 0), (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)] CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: results[idx]) dataset_a = CustomDataset( @@ -130,7 +143,8 @@ def test_dataset_wrapper(): pipeline=[], test_mode=True, classes=classes, - palette=palette) + palette=palette, + ) len_a = 2 dataset_a.img_infos = MagicMock() dataset_a.img_infos.__len__.return_value = len_a @@ -143,106 +157,114 @@ def test_dataset_wrapper(): # test skip_type_keys multi_image_mix_dataset = MultiImageMixDataset( - dataset_a, pipeline, skip_type_keys=('RandomFlip')) + dataset_a, pipeline, skip_type_keys=("RandomFlip") + ) for idx in range(len_a): results_ = multi_image_mix_dataset[idx] - assert results_['img'].shape == (img_scale[0], img_scale[1], 3) + assert results_["img"].shape == (img_scale[0], img_scale[1], 3) - skip_type_keys = ('RandomFlip', 'Resize') + skip_type_keys = ("RandomFlip", "Resize") multi_image_mix_dataset.update_skip_type_keys(skip_type_keys) for idx in range(len_a): results_ = multi_image_mix_dataset[idx] - assert results_['img'].shape[:2] != img_scale + assert results_["img"].shape[:2] != img_scale # test pipeline with pytest.raises(TypeError): - pipeline = [['Resize']] + pipeline = [["Resize"]] multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline) def test_custom_dataset(): img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - to_rgb=True) + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True + ) crop_size = (512, 1024) train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations'), - dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)), - dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), - dict(type='RandomFlip', prob=0.5), - dict(type='PhotoMetricDistortion'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), - dict(type='DefaultFormatBundle'), - dict(type='Collect', keys=['img', 'gt_semantic_seg']), + dict(type="LoadImageFromFile"), + dict(type="LoadAnnotations"), + dict(type="Resize", img_scale=(128, 256), ratio_range=(0.5, 2.0)), + dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75), + dict(type="RandomFlip", prob=0.5), + dict(type="PhotoMetricDistortion"), + dict(type="Normalize", **img_norm_cfg), + dict(type="Pad", size=crop_size, pad_val=0, seg_pad_val=255), + dict(type="DefaultFormatBundle"), + dict(type="Collect", keys=["img", "gt_semantic_seg"]), ] test_pipeline = [ - dict(type='LoadImageFromFile'), + dict(type="LoadImageFromFile"), dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=(128, 256), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ - dict(type='Resize', keep_ratio=True), - dict(type='RandomFlip'), - dict(type='Normalize', **img_norm_cfg), - dict(type='ImageToTensor', keys=['img']), - dict(type='Collect', keys=['img']), - ]) + dict(type="Resize", keep_ratio=True), + dict(type="RandomFlip"), + dict(type="Normalize", **img_norm_cfg), + dict(type="ImageToTensor", keys=["img"]), + dict(type="Collect", keys=["img"]), + ], + ), ] # with img_dir and ann_dir train_dataset = CustomDataset( train_pipeline, - data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'), - img_dir='imgs/', - ann_dir='gts/', - img_suffix='img.jpg', - seg_map_suffix='gt.png') + data_root=osp.join(osp.dirname(__file__), "../data/pseudo_dataset"), + img_dir="imgs/", + ann_dir="gts/", + img_suffix="img.jpg", + seg_map_suffix="gt.png", + ) assert len(train_dataset) == 5 # with img_dir, ann_dir, split train_dataset = CustomDataset( train_pipeline, - data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'), - img_dir='imgs/', - ann_dir='gts/', - img_suffix='img.jpg', - seg_map_suffix='gt.png', - split='splits/train.txt') + data_root=osp.join(osp.dirname(__file__), "../data/pseudo_dataset"), + img_dir="imgs/", + ann_dir="gts/", + img_suffix="img.jpg", + seg_map_suffix="gt.png", + split="splits/train.txt", + ) assert len(train_dataset) == 4 # no data_root train_dataset = CustomDataset( train_pipeline, - img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'), - ann_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts'), - img_suffix='img.jpg', - seg_map_suffix='gt.png') + img_dir=osp.join(osp.dirname(__file__), "../data/pseudo_dataset/imgs"), + ann_dir=osp.join(osp.dirname(__file__), "../data/pseudo_dataset/gts"), + img_suffix="img.jpg", + seg_map_suffix="gt.png", + ) assert len(train_dataset) == 5 # with data_root but img_dir/ann_dir are abs path train_dataset = CustomDataset( train_pipeline, - data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'), + data_root=osp.join(osp.dirname(__file__), "../data/pseudo_dataset"), img_dir=osp.abspath( - osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')), + osp.join(osp.dirname(__file__), "../data/pseudo_dataset/imgs") + ), ann_dir=osp.abspath( - osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts')), - img_suffix='img.jpg', - seg_map_suffix='gt.png') + osp.join(osp.dirname(__file__), "../data/pseudo_dataset/gts") + ), + img_suffix="img.jpg", + seg_map_suffix="gt.png", + ) assert len(train_dataset) == 5 # test_mode=True test_dataset = CustomDataset( test_pipeline, - img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'), - img_suffix='img.jpg', + img_dir=osp.join(osp.dirname(__file__), "../data/pseudo_dataset/imgs"), + img_suffix="img.jpg", test_mode=True, - classes=('pseudo_class', )) + classes=("pseudo_class",), + ) assert len(test_dataset) == 5 # training data get @@ -261,7 +283,7 @@ def test_custom_dataset(): # format_results not implemented with pytest.raises(NotImplementedError): - test_dataset.format_results([], '') + test_dataset.format_results([], "") pseudo_results = [] for gt_seg_map in gt_seg_maps: @@ -270,135 +292,137 @@ def test_custom_dataset(): # test past evaluation without CLASSES with pytest.raises(TypeError): - eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU']) + eval_results = train_dataset.evaluate(pseudo_results, metric=["mIoU"]) with pytest.raises(TypeError): - eval_results = train_dataset.evaluate(pseudo_results, metric='mDice') + eval_results = train_dataset.evaluate(pseudo_results, metric="mDice") with pytest.raises(TypeError): - eval_results = train_dataset.evaluate( - pseudo_results, metric=['mDice', 'mIoU']) + eval_results = train_dataset.evaluate(pseudo_results, metric=["mDice", "mIoU"]) # test past evaluation with CLASSES - train_dataset.CLASSES = tuple(['a'] * 7) - eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU') + train_dataset.CLASSES = tuple(["a"] * 7) + eval_results = train_dataset.evaluate(pseudo_results, metric="mIoU") assert isinstance(eval_results, dict) - assert 'mIoU' in eval_results - assert 'mAcc' in eval_results - assert 'aAcc' in eval_results + assert "mIoU" in eval_results + assert "mAcc" in eval_results + assert "aAcc" in eval_results - eval_results = train_dataset.evaluate(pseudo_results, metric='mDice') + eval_results = train_dataset.evaluate(pseudo_results, metric="mDice") assert isinstance(eval_results, dict) - assert 'mDice' in eval_results - assert 'mAcc' in eval_results - assert 'aAcc' in eval_results + assert "mDice" in eval_results + assert "mAcc" in eval_results + assert "aAcc" in eval_results - eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore') + eval_results = train_dataset.evaluate(pseudo_results, metric="mFscore") assert isinstance(eval_results, dict) - assert 'mRecall' in eval_results - assert 'mPrecision' in eval_results - assert 'mFscore' in eval_results - assert 'aAcc' in eval_results + assert "mRecall" in eval_results + assert "mPrecision" in eval_results + assert "mFscore" in eval_results + assert "aAcc" in eval_results eval_results = train_dataset.evaluate( - pseudo_results, metric=['mIoU', 'mDice', 'mFscore']) + pseudo_results, metric=["mIoU", "mDice", "mFscore"] + ) assert isinstance(eval_results, dict) - assert 'mIoU' in eval_results - assert 'mDice' in eval_results - assert 'mAcc' in eval_results - assert 'aAcc' in eval_results - assert 'mFscore' in eval_results - assert 'mPrecision' in eval_results - assert 'mRecall' in eval_results - - assert not np.isnan(eval_results['mIoU']) - assert not np.isnan(eval_results['mDice']) - assert not np.isnan(eval_results['mAcc']) - assert not np.isnan(eval_results['aAcc']) - assert not np.isnan(eval_results['mFscore']) - assert not np.isnan(eval_results['mPrecision']) - assert not np.isnan(eval_results['mRecall']) + assert "mIoU" in eval_results + assert "mDice" in eval_results + assert "mAcc" in eval_results + assert "aAcc" in eval_results + assert "mFscore" in eval_results + assert "mPrecision" in eval_results + assert "mRecall" in eval_results + + assert not np.isnan(eval_results["mIoU"]) + assert not np.isnan(eval_results["mDice"]) + assert not np.isnan(eval_results["mAcc"]) + assert not np.isnan(eval_results["aAcc"]) + assert not np.isnan(eval_results["mFscore"]) + assert not np.isnan(eval_results["mPrecision"]) + assert not np.isnan(eval_results["mRecall"]) # test evaluation with pre-eval and the dataset.CLASSES is necessary - train_dataset.CLASSES = tuple(['a'] * 7) + train_dataset.CLASSES = tuple(["a"] * 7) pseudo_results = [] for idx in range(len(train_dataset)): h, w = gt_seg_maps[idx].shape pseudo_result = np.random.randint(low=0, high=7, size=(h, w)) pseudo_results.extend(train_dataset.pre_eval(pseudo_result, idx)) - eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU']) + eval_results = train_dataset.evaluate(pseudo_results, metric=["mIoU"]) assert isinstance(eval_results, dict) - assert 'mIoU' in eval_results - assert 'mAcc' in eval_results - assert 'aAcc' in eval_results + assert "mIoU" in eval_results + assert "mAcc" in eval_results + assert "aAcc" in eval_results - eval_results = train_dataset.evaluate(pseudo_results, metric='mDice') + eval_results = train_dataset.evaluate(pseudo_results, metric="mDice") assert isinstance(eval_results, dict) - assert 'mDice' in eval_results - assert 'mAcc' in eval_results - assert 'aAcc' in eval_results + assert "mDice" in eval_results + assert "mAcc" in eval_results + assert "aAcc" in eval_results - eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore') + eval_results = train_dataset.evaluate(pseudo_results, metric="mFscore") assert isinstance(eval_results, dict) - assert 'mRecall' in eval_results - assert 'mPrecision' in eval_results - assert 'mFscore' in eval_results - assert 'aAcc' in eval_results + assert "mRecall" in eval_results + assert "mPrecision" in eval_results + assert "mFscore" in eval_results + assert "aAcc" in eval_results eval_results = train_dataset.evaluate( - pseudo_results, metric=['mIoU', 'mDice', 'mFscore']) + pseudo_results, metric=["mIoU", "mDice", "mFscore"] + ) assert isinstance(eval_results, dict) - assert 'mIoU' in eval_results - assert 'mDice' in eval_results - assert 'mAcc' in eval_results - assert 'aAcc' in eval_results - assert 'mFscore' in eval_results - assert 'mPrecision' in eval_results - assert 'mRecall' in eval_results - - assert not np.isnan(eval_results['mIoU']) - assert not np.isnan(eval_results['mDice']) - assert not np.isnan(eval_results['mAcc']) - assert not np.isnan(eval_results['aAcc']) - assert not np.isnan(eval_results['mFscore']) - assert not np.isnan(eval_results['mPrecision']) - assert not np.isnan(eval_results['mRecall']) - - -@pytest.mark.parametrize('separate_eval', [True, False]) + assert "mIoU" in eval_results + assert "mDice" in eval_results + assert "mAcc" in eval_results + assert "aAcc" in eval_results + assert "mFscore" in eval_results + assert "mPrecision" in eval_results + assert "mRecall" in eval_results + + assert not np.isnan(eval_results["mIoU"]) + assert not np.isnan(eval_results["mDice"]) + assert not np.isnan(eval_results["mAcc"]) + assert not np.isnan(eval_results["aAcc"]) + assert not np.isnan(eval_results["mFscore"]) + assert not np.isnan(eval_results["mPrecision"]) + assert not np.isnan(eval_results["mRecall"]) + + +@pytest.mark.parametrize("separate_eval", [True, False]) def test_eval_concat_custom_dataset(separate_eval): img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - to_rgb=True) + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True + ) test_pipeline = [ - dict(type='LoadImageFromFile'), + dict(type="LoadImageFromFile"), dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=(128, 256), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ - dict(type='Resize', keep_ratio=True), - dict(type='RandomFlip'), - dict(type='Normalize', **img_norm_cfg), - dict(type='ImageToTensor', keys=['img']), - dict(type='Collect', keys=['img']), - ]) + dict(type="Resize", keep_ratio=True), + dict(type="RandomFlip"), + dict(type="Normalize", **img_norm_cfg), + dict(type="ImageToTensor", keys=["img"]), + dict(type="Collect", keys=["img"]), + ], + ), ] - data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset') - img_dir = 'imgs/' - ann_dir = 'gts/' + data_root = osp.join(osp.dirname(__file__), "../data/pseudo_dataset") + img_dir = "imgs/" + ann_dir = "gts/" cfg1 = dict( - type='CustomDataset', + type="CustomDataset", pipeline=test_pipeline, data_root=data_root, img_dir=img_dir, ann_dir=ann_dir, - img_suffix='img.jpg', - seg_map_suffix='gt.png', - classes=tuple(['a'] * 7)) + img_suffix="img.jpg", + seg_map_suffix="gt.png", + classes=tuple(["a"] * 7), + ) dataset1 = build_dataset(cfg1) assert len(dataset1) == 5 # get gt seg map @@ -413,50 +437,68 @@ def test_eval_concat_custom_dataset(separate_eval): h, w = gt_seg_map.shape pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) eval_results1 = dataset1.evaluate( - pseudo_results, metric=['mIoU', 'mDice', 'mFscore']) + pseudo_results, metric=["mIoU", "mDice", "mFscore"] + ) # We use same dir twice for simplicity # with ann_dir cfg2 = dict( - type='CustomDataset', + type="CustomDataset", pipeline=test_pipeline, data_root=data_root, img_dir=[img_dir, img_dir], ann_dir=[ann_dir, ann_dir], - img_suffix='img.jpg', - seg_map_suffix='gt.png', - classes=tuple(['a'] * 7), - separate_eval=separate_eval) + img_suffix="img.jpg", + seg_map_suffix="gt.png", + classes=tuple(["a"] * 7), + separate_eval=separate_eval, + ) dataset2 = build_dataset(cfg2) assert isinstance(dataset2, ConcatDataset) assert len(dataset2) == 10 eval_results2 = dataset2.evaluate( - pseudo_results * 2, metric=['mIoU', 'mDice', 'mFscore']) + pseudo_results * 2, metric=["mIoU", "mDice", "mFscore"] + ) if separate_eval: - assert eval_results1['mIoU'] == eval_results2[ - '0_mIoU'] == eval_results2['1_mIoU'] - assert eval_results1['mDice'] == eval_results2[ - '0_mDice'] == eval_results2['1_mDice'] - assert eval_results1['mAcc'] == eval_results2[ - '0_mAcc'] == eval_results2['1_mAcc'] - assert eval_results1['aAcc'] == eval_results2[ - '0_aAcc'] == eval_results2['1_aAcc'] - assert eval_results1['mFscore'] == eval_results2[ - '0_mFscore'] == eval_results2['1_mFscore'] - assert eval_results1['mPrecision'] == eval_results2[ - '0_mPrecision'] == eval_results2['1_mPrecision'] - assert eval_results1['mRecall'] == eval_results2[ - '0_mRecall'] == eval_results2['1_mRecall'] + assert ( + eval_results1["mIoU"] == eval_results2["0_mIoU"] == eval_results2["1_mIoU"] + ) + assert ( + eval_results1["mDice"] + == eval_results2["0_mDice"] + == eval_results2["1_mDice"] + ) + assert ( + eval_results1["mAcc"] == eval_results2["0_mAcc"] == eval_results2["1_mAcc"] + ) + assert ( + eval_results1["aAcc"] == eval_results2["0_aAcc"] == eval_results2["1_aAcc"] + ) + assert ( + eval_results1["mFscore"] + == eval_results2["0_mFscore"] + == eval_results2["1_mFscore"] + ) + assert ( + eval_results1["mPrecision"] + == eval_results2["0_mPrecision"] + == eval_results2["1_mPrecision"] + ) + assert ( + eval_results1["mRecall"] + == eval_results2["0_mRecall"] + == eval_results2["1_mRecall"] + ) else: - assert eval_results1['mIoU'] == eval_results2['mIoU'] - assert eval_results1['mDice'] == eval_results2['mDice'] - assert eval_results1['mAcc'] == eval_results2['mAcc'] - assert eval_results1['aAcc'] == eval_results2['aAcc'] - assert eval_results1['mFscore'] == eval_results2['mFscore'] - assert eval_results1['mPrecision'] == eval_results2['mPrecision'] - assert eval_results1['mRecall'] == eval_results2['mRecall'] + assert eval_results1["mIoU"] == eval_results2["mIoU"] + assert eval_results1["mDice"] == eval_results2["mDice"] + assert eval_results1["mAcc"] == eval_results2["mAcc"] + assert eval_results1["aAcc"] == eval_results2["aAcc"] + assert eval_results1["mFscore"] == eval_results2["mFscore"] + assert eval_results1["mPrecision"] == eval_results2["mPrecision"] + assert eval_results1["mRecall"] == eval_results2["mRecall"] # test get dataset_idx and sample_idx from ConcateDataset dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(3) @@ -479,7 +521,8 @@ def test_eval_concat_custom_dataset(separate_eval): indice = -6 dataset_idx1, sample_idx1 = dataset2.get_dataset_idx_and_sample_idx(indice) dataset_idx2, sample_idx2 = dataset2.get_dataset_idx_and_sample_idx( - len(dataset2) + indice) + len(dataset2) + indice + ) assert dataset_idx1 == dataset_idx2 assert sample_idx1 == sample_idx2 @@ -498,7 +541,8 @@ def test_eval_concat_custom_dataset(separate_eval): assert isinstance(eval_results1[0][0], torch.Tensor) eval_results1 = dataset1.evaluate( - eval_results1, metric=['mIoU', 'mDice', 'mFscore']) + eval_results1, metric=["mIoU", "mDice", "mFscore"] + ) pseudo_results = pseudo_results * 2 eval_results2 = [] @@ -511,35 +555,50 @@ def test_eval_concat_custom_dataset(separate_eval): assert isinstance(eval_results2[0][0], torch.Tensor) eval_results2 = dataset2.evaluate( - eval_results2, metric=['mIoU', 'mDice', 'mFscore']) + eval_results2, metric=["mIoU", "mDice", "mFscore"] + ) if separate_eval: - assert eval_results1['mIoU'] == eval_results2[ - '0_mIoU'] == eval_results2['1_mIoU'] - assert eval_results1['mDice'] == eval_results2[ - '0_mDice'] == eval_results2['1_mDice'] - assert eval_results1['mAcc'] == eval_results2[ - '0_mAcc'] == eval_results2['1_mAcc'] - assert eval_results1['aAcc'] == eval_results2[ - '0_aAcc'] == eval_results2['1_aAcc'] - assert eval_results1['mFscore'] == eval_results2[ - '0_mFscore'] == eval_results2['1_mFscore'] - assert eval_results1['mPrecision'] == eval_results2[ - '0_mPrecision'] == eval_results2['1_mPrecision'] - assert eval_results1['mRecall'] == eval_results2[ - '0_mRecall'] == eval_results2['1_mRecall'] + assert ( + eval_results1["mIoU"] == eval_results2["0_mIoU"] == eval_results2["1_mIoU"] + ) + assert ( + eval_results1["mDice"] + == eval_results2["0_mDice"] + == eval_results2["1_mDice"] + ) + assert ( + eval_results1["mAcc"] == eval_results2["0_mAcc"] == eval_results2["1_mAcc"] + ) + assert ( + eval_results1["aAcc"] == eval_results2["0_aAcc"] == eval_results2["1_aAcc"] + ) + assert ( + eval_results1["mFscore"] + == eval_results2["0_mFscore"] + == eval_results2["1_mFscore"] + ) + assert ( + eval_results1["mPrecision"] + == eval_results2["0_mPrecision"] + == eval_results2["1_mPrecision"] + ) + assert ( + eval_results1["mRecall"] + == eval_results2["0_mRecall"] + == eval_results2["1_mRecall"] + ) else: - assert eval_results1['mIoU'] == eval_results2['mIoU'] - assert eval_results1['mDice'] == eval_results2['mDice'] - assert eval_results1['mAcc'] == eval_results2['mAcc'] - assert eval_results1['aAcc'] == eval_results2['aAcc'] - assert eval_results1['mFscore'] == eval_results2['mFscore'] - assert eval_results1['mPrecision'] == eval_results2['mPrecision'] - assert eval_results1['mRecall'] == eval_results2['mRecall'] + assert eval_results1["mIoU"] == eval_results2["mIoU"] + assert eval_results1["mDice"] == eval_results2["mDice"] + assert eval_results1["mAcc"] == eval_results2["mAcc"] + assert eval_results1["aAcc"] == eval_results2["aAcc"] + assert eval_results1["mFscore"] == eval_results2["mFscore"] + assert eval_results1["mPrecision"] == eval_results2["mPrecision"] + assert eval_results1["mRecall"] == eval_results2["mRecall"] # test batch_indices for pre eval - eval_results2 = dataset2.pre_eval(pseudo_results, - list(range(len(pseudo_results)))) + eval_results2 = dataset2.pre_eval(pseudo_results, list(range(len(pseudo_results)))) assert len(eval_results2) == len(dataset2) assert isinstance(eval_results2[0], tuple) @@ -547,37 +606,54 @@ def test_eval_concat_custom_dataset(separate_eval): assert isinstance(eval_results2[0][0], torch.Tensor) eval_results2 = dataset2.evaluate( - eval_results2, metric=['mIoU', 'mDice', 'mFscore']) + eval_results2, metric=["mIoU", "mDice", "mFscore"] + ) if separate_eval: - assert eval_results1['mIoU'] == eval_results2[ - '0_mIoU'] == eval_results2['1_mIoU'] - assert eval_results1['mDice'] == eval_results2[ - '0_mDice'] == eval_results2['1_mDice'] - assert eval_results1['mAcc'] == eval_results2[ - '0_mAcc'] == eval_results2['1_mAcc'] - assert eval_results1['aAcc'] == eval_results2[ - '0_aAcc'] == eval_results2['1_aAcc'] - assert eval_results1['mFscore'] == eval_results2[ - '0_mFscore'] == eval_results2['1_mFscore'] - assert eval_results1['mPrecision'] == eval_results2[ - '0_mPrecision'] == eval_results2['1_mPrecision'] - assert eval_results1['mRecall'] == eval_results2[ - '0_mRecall'] == eval_results2['1_mRecall'] + assert ( + eval_results1["mIoU"] == eval_results2["0_mIoU"] == eval_results2["1_mIoU"] + ) + assert ( + eval_results1["mDice"] + == eval_results2["0_mDice"] + == eval_results2["1_mDice"] + ) + assert ( + eval_results1["mAcc"] == eval_results2["0_mAcc"] == eval_results2["1_mAcc"] + ) + assert ( + eval_results1["aAcc"] == eval_results2["0_aAcc"] == eval_results2["1_aAcc"] + ) + assert ( + eval_results1["mFscore"] + == eval_results2["0_mFscore"] + == eval_results2["1_mFscore"] + ) + assert ( + eval_results1["mPrecision"] + == eval_results2["0_mPrecision"] + == eval_results2["1_mPrecision"] + ) + assert ( + eval_results1["mRecall"] + == eval_results2["0_mRecall"] + == eval_results2["1_mRecall"] + ) else: - assert eval_results1['mIoU'] == eval_results2['mIoU'] - assert eval_results1['mDice'] == eval_results2['mDice'] - assert eval_results1['mAcc'] == eval_results2['mAcc'] - assert eval_results1['aAcc'] == eval_results2['aAcc'] - assert eval_results1['mFscore'] == eval_results2['mFscore'] - assert eval_results1['mPrecision'] == eval_results2['mPrecision'] - assert eval_results1['mRecall'] == eval_results2['mRecall'] + assert eval_results1["mIoU"] == eval_results2["mIoU"] + assert eval_results1["mDice"] == eval_results2["mDice"] + assert eval_results1["mAcc"] == eval_results2["mAcc"] + assert eval_results1["aAcc"] == eval_results2["aAcc"] + assert eval_results1["mFscore"] == eval_results2["mFscore"] + assert eval_results1["mPrecision"] == eval_results2["mPrecision"] + assert eval_results1["mRecall"] == eval_results2["mRecall"] def test_ade(): test_dataset = ADE20KDataset( pipeline=[], - img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')) + img_dir=osp.join(osp.dirname(__file__), "../data/pseudo_dataset/imgs"), + ) assert len(test_dataset) == 5 # Test format_results @@ -586,23 +662,25 @@ def test_ade(): h, w = (2, 2) pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) - file_paths = test_dataset.format_results(pseudo_results, '.format_ade') + file_paths = test_dataset.format_results(pseudo_results, ".format_ade") assert len(file_paths) == len(test_dataset) temp = np.array(Image.open(file_paths[0])) assert np.allclose(temp, pseudo_results[0] + 1) - shutil.rmtree('.format_ade') + shutil.rmtree(".format_ade") -@pytest.mark.parametrize('separate_eval', [True, False]) +@pytest.mark.parametrize("separate_eval", [True, False]) def test_concat_ade(separate_eval): test_dataset = ADE20KDataset( pipeline=[], - img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')) + img_dir=osp.join(osp.dirname(__file__), "../data/pseudo_dataset/imgs"), + ) assert len(test_dataset) == 5 - concat_dataset = ConcatDataset([test_dataset, test_dataset], - separate_eval=separate_eval) + concat_dataset = ConcatDataset( + [test_dataset, test_dataset], separate_eval=separate_eval + ) assert len(concat_dataset) == 10 # Test format_results pseudo_results = [] @@ -614,32 +692,35 @@ def test_concat_ade(separate_eval): file_paths = [] for i in range(len(pseudo_results)): file_paths.extend( - concat_dataset.format_results([pseudo_results[i]], - '.format_ade', - indices=[i])) + concat_dataset.format_results( + [pseudo_results[i]], ".format_ade", indices=[i] + ) + ) assert len(file_paths) == len(concat_dataset) temp = np.array(Image.open(file_paths[0])) assert np.allclose(temp, pseudo_results[0] + 1) - shutil.rmtree('.format_ade') + shutil.rmtree(".format_ade") # test default argument - file_paths = concat_dataset.format_results(pseudo_results, '.format_ade') + file_paths = concat_dataset.format_results(pseudo_results, ".format_ade") assert len(file_paths) == len(concat_dataset) temp = np.array(Image.open(file_paths[0])) assert np.allclose(temp, pseudo_results[0] + 1) - shutil.rmtree('.format_ade') + shutil.rmtree(".format_ade") def test_cityscapes(): test_dataset = CityscapesDataset( pipeline=[], img_dir=osp.join( - osp.dirname(__file__), - '../data/pseudo_cityscapes_dataset/leftImg8bit'), + osp.dirname(__file__), "../data/pseudo_cityscapes_dataset/leftImg8bit" + ), ann_dir=osp.join( - osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine')) + osp.dirname(__file__), "../data/pseudo_cityscapes_dataset/gtFine" + ), + ) assert len(test_dataset) == 1 gt_seg_maps = list(test_dataset.get_gt_seg_maps()) @@ -650,49 +731,55 @@ def test_cityscapes(): h, w = gt_seg_maps[idx].shape pseudo_results.append(np.random.randint(low=0, high=19, size=(h, w))) - file_paths = test_dataset.format_results(pseudo_results, '.format_city') + file_paths = test_dataset.format_results(pseudo_results, ".format_city") assert len(file_paths) == len(test_dataset) temp = np.array(Image.open(file_paths[0])) - assert np.allclose(temp, - test_dataset._convert_to_label_id(pseudo_results[0])) + assert np.allclose(temp, test_dataset._convert_to_label_id(pseudo_results[0])) # Test cityscapes evaluate test_dataset.evaluate( - pseudo_results, metric='cityscapes', imgfile_prefix='.format_city') + pseudo_results, metric="cityscapes", imgfile_prefix=".format_city" + ) - shutil.rmtree('.format_city') + shutil.rmtree(".format_city") -@pytest.mark.parametrize('separate_eval', [True, False]) +@pytest.mark.parametrize("separate_eval", [True, False]) def test_concat_cityscapes(separate_eval): cityscape_dataset = CityscapesDataset( pipeline=[], img_dir=osp.join( - osp.dirname(__file__), - '../data/pseudo_cityscapes_dataset/leftImg8bit'), + osp.dirname(__file__), "../data/pseudo_cityscapes_dataset/leftImg8bit" + ), ann_dir=osp.join( - osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine')) + osp.dirname(__file__), "../data/pseudo_cityscapes_dataset/gtFine" + ), + ) assert len(cityscape_dataset) == 1 with pytest.raises(NotImplementedError): - _ = ConcatDataset([cityscape_dataset, cityscape_dataset], - separate_eval=separate_eval) + _ = ConcatDataset( + [cityscape_dataset, cityscape_dataset], separate_eval=separate_eval + ) ade_dataset = ADE20KDataset( pipeline=[], - img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')) + img_dir=osp.join(osp.dirname(__file__), "../data/pseudo_dataset/imgs"), + ) assert len(ade_dataset) == 5 with pytest.raises(NotImplementedError): - _ = ConcatDataset([cityscape_dataset, ade_dataset], - separate_eval=separate_eval) + _ = ConcatDataset([cityscape_dataset, ade_dataset], separate_eval=separate_eval) def test_loveda(): test_dataset = LoveDADataset( pipeline=[], img_dir=osp.join( - osp.dirname(__file__), '../data/pseudo_loveda_dataset/img_dir'), + osp.dirname(__file__), "../data/pseudo_loveda_dataset/img_dir" + ), ann_dir=osp.join( - osp.dirname(__file__), '../data/pseudo_loveda_dataset/ann_dir')) + osp.dirname(__file__), "../data/pseudo_loveda_dataset/ann_dir" + ), + ) assert len(test_dataset) == 3 gt_seg_maps = list(test_dataset.get_gt_seg_maps()) @@ -702,23 +789,27 @@ def test_loveda(): for idx in range(len(test_dataset)): h, w = gt_seg_maps[idx].shape pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) - file_paths = test_dataset.format_results(pseudo_results, '.format_loveda') + file_paths = test_dataset.format_results(pseudo_results, ".format_loveda") assert len(file_paths) == len(test_dataset) # Test loveda evaluate test_dataset.evaluate( - pseudo_results, metric='mIoU', imgfile_prefix='.format_loveda') + pseudo_results, metric="mIoU", imgfile_prefix=".format_loveda" + ) - shutil.rmtree('.format_loveda') + shutil.rmtree(".format_loveda") def test_potsdam(): test_dataset = PotsdamDataset( pipeline=[], img_dir=osp.join( - osp.dirname(__file__), '../data/pseudo_potsdam_dataset/img_dir'), + osp.dirname(__file__), "../data/pseudo_potsdam_dataset/img_dir" + ), ann_dir=osp.join( - osp.dirname(__file__), '../data/pseudo_potsdam_dataset/ann_dir')) + osp.dirname(__file__), "../data/pseudo_potsdam_dataset/ann_dir" + ), + ) assert len(test_dataset) == 1 @@ -726,44 +817,48 @@ def test_vaihingen(): test_dataset = ISPRSDataset( pipeline=[], img_dir=osp.join( - osp.dirname(__file__), '../data/pseudo_vaihingen_dataset/img_dir'), + osp.dirname(__file__), "../data/pseudo_vaihingen_dataset/img_dir" + ), ann_dir=osp.join( - osp.dirname(__file__), '../data/pseudo_vaihingen_dataset/ann_dir')) + osp.dirname(__file__), "../data/pseudo_vaihingen_dataset/ann_dir" + ), + ) assert len(test_dataset) == 1 def test_isaid(): test_dataset = iSAIDDataset( pipeline=[], - img_dir=osp.join( - osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'), - ann_dir=osp.join( - osp.dirname(__file__), '../data/pseudo_isaid_dataset/ann_dir')) + img_dir=osp.join(osp.dirname(__file__), "../data/pseudo_isaid_dataset/img_dir"), + ann_dir=osp.join(osp.dirname(__file__), "../data/pseudo_isaid_dataset/ann_dir"), + ) assert len(test_dataset) == 2 isaid_info = test_dataset.load_annotations( - img_dir=osp.join( - osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'), - img_suffix='.png', - ann_dir=osp.join( - osp.dirname(__file__), '../data/pseudo_isaid_dataset/ann_dir'), - seg_map_suffix='.png', + img_dir=osp.join(osp.dirname(__file__), "../data/pseudo_isaid_dataset/img_dir"), + img_suffix=".png", + ann_dir=osp.join(osp.dirname(__file__), "../data/pseudo_isaid_dataset/ann_dir"), + seg_map_suffix=".png", split=osp.join( - osp.dirname(__file__), - '../data/pseudo_isaid_dataset/splits/train.txt')) + osp.dirname(__file__), "../data/pseudo_isaid_dataset/splits/train.txt" + ), + ) assert len(isaid_info) == 1 -@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) -@patch('mmseg.datasets.CustomDataset.__getitem__', - MagicMock(side_effect=lambda idx: idx)) -@pytest.mark.parametrize('dataset, classes', [ - ('ADE20KDataset', ('wall', 'building')), - ('CityscapesDataset', ('road', 'sidewalk')), - ('CustomDataset', ('bus', 'car')), - ('PascalVOCDataset', ('aeroplane', 'bicycle')), -]) +@patch("mmseg.datasets.CustomDataset.load_annotations", MagicMock) +@patch( + "mmseg.datasets.CustomDataset.__getitem__", MagicMock(side_effect=lambda idx: idx) +) +@pytest.mark.parametrize( + "dataset, classes", + [ + ("ADE20KDataset", ("wall", "building")), + ("CityscapesDataset", ("road", "sidewalk")), + ("CustomDataset", ("bus", "car")), + ("PascalVOCDataset", ("aeroplane", "bicycle")), + ], +) def test_custom_classes_override_default(dataset, classes): - dataset_class = DATASETS.get(dataset) original_classes = dataset_class.CLASSES @@ -774,7 +869,8 @@ def test_custom_classes_override_default(dataset, classes): img_dir=MagicMock(), split=MagicMock(), classes=classes, - test_mode=True) + test_mode=True, + ) assert custom_dataset.CLASSES != original_classes assert custom_dataset.CLASSES == classes @@ -785,7 +881,8 @@ def test_custom_classes_override_default(dataset, classes): img_dir=MagicMock(), split=MagicMock(), classes=list(classes), - test_mode=True) + test_mode=True, + ) assert custom_dataset.CLASSES != original_classes assert custom_dataset.CLASSES == list(classes) @@ -796,7 +893,8 @@ def test_custom_classes_override_default(dataset, classes): img_dir=MagicMock(), split=MagicMock(), classes=[classes[0]], - test_mode=True) + test_mode=True, + ) assert custom_dataset.CLASSES != original_classes assert custom_dataset.CLASSES == [classes[0]] @@ -809,43 +907,49 @@ def test_custom_classes_override_default(dataset, classes): img_dir=MagicMock(), split=MagicMock(), classes=None, - test_mode=True) + test_mode=True, + ) else: custom_dataset = dataset_class( pipeline=[], img_dir=MagicMock(), split=MagicMock(), classes=None, - test_mode=True) + test_mode=True, + ) assert custom_dataset.CLASSES == original_classes -@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) -@patch('mmseg.datasets.CustomDataset.__getitem__', - MagicMock(side_effect=lambda idx: idx)) +@patch("mmseg.datasets.CustomDataset.load_annotations", MagicMock) +@patch( + "mmseg.datasets.CustomDataset.__getitem__", MagicMock(side_effect=lambda idx: idx) +) def test_custom_dataset_random_palette_is_generated(): dataset = CustomDataset( pipeline=[], img_dir=MagicMock(), split=MagicMock(), - classes=('bus', 'car'), - test_mode=True) + classes=("bus", "car"), + test_mode=True, + ) assert len(dataset.PALETTE) == 2 for class_color in dataset.PALETTE: assert len(class_color) == 3 assert all(x >= 0 and x <= 255 for x in class_color) -@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) -@patch('mmseg.datasets.CustomDataset.__getitem__', - MagicMock(side_effect=lambda idx: idx)) +@patch("mmseg.datasets.CustomDataset.load_annotations", MagicMock) +@patch( + "mmseg.datasets.CustomDataset.__getitem__", MagicMock(side_effect=lambda idx: idx) +) def test_custom_dataset_custom_palette(): dataset = CustomDataset( pipeline=[], img_dir=MagicMock(), split=MagicMock(), - classes=('bus', 'car'), + classes=("bus", "car"), palette=[[100, 100, 100], [200, 200, 200]], - test_mode=True) + test_mode=True, + ) assert tuple(dataset.PALETTE) == tuple([[100, 100, 100], [200, 200, 200]]) diff --git a/mmsegmentation/tests/test_data/test_dataset_builder.py b/mmsegmentation/tests/test_data/test_dataset_builder.py index 30910b0..f38c7f6 100644 --- a/mmsegmentation/tests/test_data/test_dataset_builder.py +++ b/mmsegmentation/tests/test_data/test_dataset_builder.py @@ -3,16 +3,19 @@ import os.path as osp import pytest -from torch.utils.data import (DistributedSampler, RandomSampler, - SequentialSampler) +from torch.utils.data import DistributedSampler, RandomSampler, SequentialSampler -from mmseg.datasets import (DATASETS, ConcatDataset, MultiImageMixDataset, - build_dataloader, build_dataset) +from mmseg.datasets import ( + DATASETS, + ConcatDataset, + MultiImageMixDataset, + build_dataloader, + build_dataset, +) @DATASETS.register_module() -class ToyDataset(object): - +class ToyDataset: def __init__(self, cnt=0): self.cnt = cnt @@ -24,7 +27,7 @@ def __len__(self): def test_build_dataset(): - cfg = dict(type='ToyDataset') + cfg = dict(type="ToyDataset") dataset = build_dataset(cfg) assert isinstance(dataset, ToyDataset) assert dataset.cnt == 0 @@ -32,72 +35,77 @@ def test_build_dataset(): assert isinstance(dataset, ToyDataset) assert dataset.cnt == 1 - data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset') - img_dir = 'imgs/' - ann_dir = 'gts/' + data_root = osp.join(osp.dirname(__file__), "../data/pseudo_dataset") + img_dir = "imgs/" + ann_dir = "gts/" # We use same dir twice for simplicity # with ann_dir cfg = dict( - type='CustomDataset', + type="CustomDataset", pipeline=[], data_root=data_root, img_dir=[img_dir, img_dir], - ann_dir=[ann_dir, ann_dir]) + ann_dir=[ann_dir, ann_dir], + ) dataset = build_dataset(cfg) assert isinstance(dataset, ConcatDataset) assert len(dataset) == 10 - cfg = dict(type='MultiImageMixDataset', dataset=cfg, pipeline=[]) + cfg = dict(type="MultiImageMixDataset", dataset=cfg, pipeline=[]) dataset = build_dataset(cfg) assert isinstance(dataset, MultiImageMixDataset) assert len(dataset) == 10 # with ann_dir, split cfg = dict( - type='CustomDataset', + type="CustomDataset", pipeline=[], data_root=data_root, img_dir=img_dir, ann_dir=ann_dir, - split=['splits/train.txt', 'splits/val.txt']) + split=["splits/train.txt", "splits/val.txt"], + ) dataset = build_dataset(cfg) assert isinstance(dataset, ConcatDataset) assert len(dataset) == 5 # with ann_dir, split cfg = dict( - type='CustomDataset', + type="CustomDataset", pipeline=[], data_root=data_root, img_dir=img_dir, ann_dir=[ann_dir, ann_dir], - split=['splits/train.txt', 'splits/val.txt']) + split=["splits/train.txt", "splits/val.txt"], + ) dataset = build_dataset(cfg) assert isinstance(dataset, ConcatDataset) assert len(dataset) == 5 # test mode cfg = dict( - type='CustomDataset', + type="CustomDataset", pipeline=[], data_root=data_root, img_dir=[img_dir, img_dir], test_mode=True, - classes=('pseudo_class', )) + classes=("pseudo_class",), + ) dataset = build_dataset(cfg) assert isinstance(dataset, ConcatDataset) assert len(dataset) == 10 # test mode with splits cfg = dict( - type='CustomDataset', + type="CustomDataset", pipeline=[], data_root=data_root, img_dir=[img_dir, img_dir], - split=['splits/val.txt', 'splits/val.txt'], + split=["splits/val.txt", "splits/val.txt"], test_mode=True, - classes=('pseudo_class', )) + classes=("pseudo_class",), + ) dataset = build_dataset(cfg) assert isinstance(dataset, ConcatDataset) assert len(dataset) == 2 @@ -105,33 +113,36 @@ def test_build_dataset(): # len(ann_dir) should be zero or len(img_dir) when len(img_dir) > 1 with pytest.raises(AssertionError): cfg = dict( - type='CustomDataset', + type="CustomDataset", pipeline=[], data_root=data_root, img_dir=[img_dir, img_dir], - ann_dir=[ann_dir, ann_dir, ann_dir]) + ann_dir=[ann_dir, ann_dir, ann_dir], + ) build_dataset(cfg) # len(splits) should be zero or len(img_dir) when len(img_dir) > 1 with pytest.raises(AssertionError): cfg = dict( - type='CustomDataset', + type="CustomDataset", pipeline=[], data_root=data_root, img_dir=[img_dir, img_dir], - split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt']) + split=["splits/val.txt", "splits/val.txt", "splits/val.txt"], + ) build_dataset(cfg) # len(splits) == len(ann_dir) when only len(img_dir) == 1 and len( # ann_dir) > 1 with pytest.raises(AssertionError): cfg = dict( - type='CustomDataset', + type="CustomDataset", pipeline=[], data_root=data_root, img_dir=img_dir, ann_dir=[ann_dir, ann_dir], - split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt']) + split=["splits/val.txt", "splits/val.txt", "splits/val.txt"], + ) build_dataset(cfg) @@ -140,7 +151,8 @@ def test_build_dataloader(): samples_per_gpu = 3 # dist=True, shuffle=True, 1GPU dataloader = build_dataloader( - dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=2) + dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=2 + ) assert dataloader.batch_size == samples_per_gpu assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) assert isinstance(dataloader.sampler, DistributedSampler) @@ -148,10 +160,8 @@ def test_build_dataloader(): # dist=True, shuffle=False, 1GPU dataloader = build_dataloader( - dataset, - samples_per_gpu=samples_per_gpu, - workers_per_gpu=2, - shuffle=False) + dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=2, shuffle=False + ) assert dataloader.batch_size == samples_per_gpu assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) assert isinstance(dataloader.sampler, DistributedSampler) @@ -159,20 +169,16 @@ def test_build_dataloader(): # dist=True, shuffle=True, 8GPU dataloader = build_dataloader( - dataset, - samples_per_gpu=samples_per_gpu, - workers_per_gpu=2, - num_gpus=8) + dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=2, num_gpus=8 + ) assert dataloader.batch_size == samples_per_gpu assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) assert dataloader.num_workers == 2 # dist=False, shuffle=True, 1GPU dataloader = build_dataloader( - dataset, - samples_per_gpu=samples_per_gpu, - workers_per_gpu=2, - dist=False) + dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=2, dist=False + ) assert dataloader.batch_size == samples_per_gpu assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) assert isinstance(dataloader.sampler, RandomSampler) @@ -180,11 +186,8 @@ def test_build_dataloader(): # dist=False, shuffle=False, 1GPU dataloader = build_dataloader( - dataset, - samples_per_gpu=3, - workers_per_gpu=2, - shuffle=False, - dist=False) + dataset, samples_per_gpu=3, workers_per_gpu=2, shuffle=False, dist=False + ) assert dataloader.batch_size == samples_per_gpu assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) assert isinstance(dataloader.sampler, SequentialSampler) @@ -192,9 +195,9 @@ def test_build_dataloader(): # dist=False, shuffle=True, 8GPU dataloader = build_dataloader( - dataset, samples_per_gpu=3, workers_per_gpu=2, num_gpus=8, dist=False) + dataset, samples_per_gpu=3, workers_per_gpu=2, num_gpus=8, dist=False + ) assert dataloader.batch_size == samples_per_gpu * 8 - assert len(dataloader) == int( - math.ceil(len(dataset) / samples_per_gpu / 8)) + assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu / 8)) assert isinstance(dataloader.sampler, RandomSampler) assert dataloader.num_workers == 16 diff --git a/mmsegmentation/tests/test_data/test_loading.py b/mmsegmentation/tests/test_data/test_loading.py index fdda93e..9dca889 100644 --- a/mmsegmentation/tests/test_data/test_loading.py +++ b/mmsegmentation/tests/test_data/test_loading.py @@ -9,101 +9,101 @@ from mmseg.datasets.pipelines import LoadAnnotations, LoadImageFromFile -class TestLoading(object): - +class TestLoading: @classmethod def setup_class(cls): - cls.data_prefix = osp.join(osp.dirname(__file__), '../data') + cls.data_prefix = osp.join(osp.dirname(__file__), "../data") def test_load_img(self): - results = dict( - img_prefix=self.data_prefix, img_info=dict(filename='color.jpg')) + results = dict(img_prefix=self.data_prefix, img_info=dict(filename="color.jpg")) transform = LoadImageFromFile() results = transform(copy.deepcopy(results)) - assert results['filename'] == osp.join(self.data_prefix, 'color.jpg') - assert results['ori_filename'] == 'color.jpg' - assert results['img'].shape == (288, 512, 3) - assert results['img'].dtype == np.uint8 - assert results['img_shape'] == (288, 512, 3) - assert results['ori_shape'] == (288, 512, 3) - assert results['pad_shape'] == (288, 512, 3) - assert results['scale_factor'] == 1.0 - np.testing.assert_equal(results['img_norm_cfg']['mean'], - np.zeros(3, dtype=np.float32)) - assert repr(transform) == transform.__class__.__name__ + \ - "(to_float32=False,color_type='color',imdecode_backend='cv2')" + assert results["filename"] == osp.join(self.data_prefix, "color.jpg") + assert results["ori_filename"] == "color.jpg" + assert results["img"].shape == (288, 512, 3) + assert results["img"].dtype == np.uint8 + assert results["img_shape"] == (288, 512, 3) + assert results["ori_shape"] == (288, 512, 3) + assert results["pad_shape"] == (288, 512, 3) + assert results["scale_factor"] == 1.0 + np.testing.assert_equal( + results["img_norm_cfg"]["mean"], np.zeros(3, dtype=np.float32) + ) + assert ( + repr(transform) + == transform.__class__.__name__ + + "(to_float32=False,color_type='color',imdecode_backend='cv2')" + ) # no img_prefix - results = dict( - img_prefix=None, img_info=dict(filename='tests/data/color.jpg')) + results = dict(img_prefix=None, img_info=dict(filename="tests/data/color.jpg")) transform = LoadImageFromFile() results = transform(copy.deepcopy(results)) - assert results['filename'] == 'tests/data/color.jpg' - assert results['ori_filename'] == 'tests/data/color.jpg' - assert results['img'].shape == (288, 512, 3) + assert results["filename"] == "tests/data/color.jpg" + assert results["ori_filename"] == "tests/data/color.jpg" + assert results["img"].shape == (288, 512, 3) # to_float32 transform = LoadImageFromFile(to_float32=True) results = transform(copy.deepcopy(results)) - assert results['img'].dtype == np.float32 + assert results["img"].dtype == np.float32 # gray image - results = dict( - img_prefix=self.data_prefix, img_info=dict(filename='gray.jpg')) + results = dict(img_prefix=self.data_prefix, img_info=dict(filename="gray.jpg")) transform = LoadImageFromFile() results = transform(copy.deepcopy(results)) - assert results['img'].shape == (288, 512, 3) - assert results['img'].dtype == np.uint8 + assert results["img"].shape == (288, 512, 3) + assert results["img"].dtype == np.uint8 - transform = LoadImageFromFile(color_type='unchanged') + transform = LoadImageFromFile(color_type="unchanged") results = transform(copy.deepcopy(results)) - assert results['img'].shape == (288, 512) - assert results['img'].dtype == np.uint8 - np.testing.assert_equal(results['img_norm_cfg']['mean'], - np.zeros(1, dtype=np.float32)) + assert results["img"].shape == (288, 512) + assert results["img"].dtype == np.uint8 + np.testing.assert_equal( + results["img_norm_cfg"]["mean"], np.zeros(1, dtype=np.float32) + ) def test_load_seg(self): results = dict( - seg_prefix=self.data_prefix, - ann_info=dict(seg_map='seg.png'), - seg_fields=[]) + seg_prefix=self.data_prefix, ann_info=dict(seg_map="seg.png"), seg_fields=[] + ) transform = LoadAnnotations() results = transform(copy.deepcopy(results)) - assert results['seg_fields'] == ['gt_semantic_seg'] - assert results['gt_semantic_seg'].shape == (288, 512) - assert results['gt_semantic_seg'].dtype == np.uint8 - assert repr(transform) == transform.__class__.__name__ + \ - "(reduce_zero_label=False,imdecode_backend='pillow')" + assert results["seg_fields"] == ["gt_semantic_seg"] + assert results["gt_semantic_seg"].shape == (288, 512) + assert results["gt_semantic_seg"].dtype == np.uint8 + assert ( + repr(transform) + == transform.__class__.__name__ + + "(reduce_zero_label=False,imdecode_backend='pillow')" + ) # no img_prefix results = dict( - seg_prefix=None, - ann_info=dict(seg_map='tests/data/seg.png'), - seg_fields=[]) + seg_prefix=None, ann_info=dict(seg_map="tests/data/seg.png"), seg_fields=[] + ) transform = LoadAnnotations() results = transform(copy.deepcopy(results)) - assert results['gt_semantic_seg'].shape == (288, 512) - assert results['gt_semantic_seg'].dtype == np.uint8 + assert results["gt_semantic_seg"].shape == (288, 512) + assert results["gt_semantic_seg"].dtype == np.uint8 # reduce_zero_label transform = LoadAnnotations(reduce_zero_label=True) results = transform(copy.deepcopy(results)) - assert results['gt_semantic_seg'].shape == (288, 512) - assert results['gt_semantic_seg'].dtype == np.uint8 + assert results["gt_semantic_seg"].shape == (288, 512) + assert results["gt_semantic_seg"].dtype == np.uint8 # mmcv backend results = dict( - seg_prefix=self.data_prefix, - ann_info=dict(seg_map='seg.png'), - seg_fields=[]) - transform = LoadAnnotations(imdecode_backend='pillow') + seg_prefix=self.data_prefix, ann_info=dict(seg_map="seg.png"), seg_fields=[] + ) + transform = LoadAnnotations(imdecode_backend="pillow") results = transform(copy.deepcopy(results)) # this image is saved by PIL - assert results['gt_semantic_seg'].shape == (288, 512) - assert results['gt_semantic_seg'].dtype == np.uint8 + assert results["gt_semantic_seg"].shape == (288, 512) + assert results["gt_semantic_seg"].dtype == np.uint8 def test_load_seg_custom_classes(self): - test_img = np.random.rand(10, 10) test_gt = np.zeros_like(test_img) test_gt[2:4, 2:4] = 1 @@ -112,8 +112,8 @@ def test_load_seg_custom_classes(self): test_gt[6:8, 6:8] = 4 tmp_dir = tempfile.TemporaryDirectory() - img_path = osp.join(tmp_dir.name, 'img.jpg') - gt_path = osp.join(tmp_dir.name, 'gt.png') + img_path = osp.join(tmp_dir.name, "img.jpg") + gt_path = osp.join(tmp_dir.name, "gt.png") mmcv.imwrite(test_img, img_path) mmcv.imwrite(test_gt, gt_path) @@ -122,14 +122,9 @@ def test_load_seg_custom_classes(self): results = dict( img_info=dict(filename=img_path), ann_info=dict(seg_map=gt_path), - label_map={ - 0: 0, - 1: 0, - 2: 0, - 3: 1, - 4: 0 - }, - seg_fields=[]) + label_map={0: 0, 1: 0, 2: 0, 3: 1, 4: 0}, + seg_fields=[], + ) load_imgs = LoadImageFromFile() results = load_imgs(copy.deepcopy(results)) @@ -137,12 +132,12 @@ def test_load_seg_custom_classes(self): load_anns = LoadAnnotations() results = load_anns(copy.deepcopy(results)) - gt_array = results['gt_semantic_seg'] + gt_array = results["gt_semantic_seg"] true_mask = np.zeros_like(gt_array) true_mask[6:8, 2:4] = 1 - assert results['seg_fields'] == ['gt_semantic_seg'] + assert results["seg_fields"] == ["gt_semantic_seg"] assert gt_array.shape == (10, 10) assert gt_array.dtype == np.uint8 np.testing.assert_array_equal(gt_array, true_mask) @@ -151,14 +146,9 @@ def test_load_seg_custom_classes(self): results = dict( img_info=dict(filename=img_path), ann_info=dict(seg_map=gt_path), - label_map={ - 0: 0, - 1: 0, - 2: 0, - 3: 2, - 4: 1 - }, - seg_fields=[]) + label_map={0: 0, 1: 0, 2: 0, 3: 2, 4: 1}, + seg_fields=[], + ) load_imgs = LoadImageFromFile() results = load_imgs(copy.deepcopy(results)) @@ -166,13 +156,13 @@ def test_load_seg_custom_classes(self): load_anns = LoadAnnotations() results = load_anns(copy.deepcopy(results)) - gt_array = results['gt_semantic_seg'] + gt_array = results["gt_semantic_seg"] true_mask = np.zeros_like(gt_array) true_mask[6:8, 2:4] = 2 true_mask[6:8, 6:8] = 1 - assert results['seg_fields'] == ['gt_semantic_seg'] + assert results["seg_fields"] == ["gt_semantic_seg"] assert gt_array.shape == (10, 10) assert gt_array.dtype == np.uint8 np.testing.assert_array_equal(gt_array, true_mask) @@ -181,7 +171,8 @@ def test_load_seg_custom_classes(self): results = dict( img_info=dict(filename=img_path), ann_info=dict(seg_map=gt_path), - seg_fields=[]) + seg_fields=[], + ) load_imgs = LoadImageFromFile() results = load_imgs(copy.deepcopy(results)) @@ -189,9 +180,9 @@ def test_load_seg_custom_classes(self): load_anns = LoadAnnotations() results = load_anns(copy.deepcopy(results)) - gt_array = results['gt_semantic_seg'] + gt_array = results["gt_semantic_seg"] - assert results['seg_fields'] == ['gt_semantic_seg'] + assert results["seg_fields"] == ["gt_semantic_seg"] assert gt_array.shape == (10, 10) assert gt_array.dtype == np.uint8 np.testing.assert_array_equal(gt_array, test_gt) diff --git a/mmsegmentation/tests/test_data/test_transform.py b/mmsegmentation/tests/test_data/test_transform.py index fcc46e7..a4d6abf 100644 --- a/mmsegmentation/tests/test_data/test_transform.py +++ b/mmsegmentation/tests/test_data/test_transform.py @@ -12,242 +12,236 @@ def test_resize_to_multiple(): - transform = dict(type='ResizeToMultiple', size_divisor=32) + transform = dict(type="ResizeToMultiple", size_divisor=32) transform = build_from_cfg(transform, PIPELINES) img = np.random.randn(213, 232, 3) seg = np.random.randint(0, 19, (213, 232)) results = dict() - results['img'] = img - results['gt_semantic_seg'] = seg - results['seg_fields'] = ['gt_semantic_seg'] - results['img_shape'] = img.shape - results['pad_shape'] = img.shape + results["img"] = img + results["gt_semantic_seg"] = seg + results["seg_fields"] = ["gt_semantic_seg"] + results["img_shape"] = img.shape + results["pad_shape"] = img.shape results = transform(results) - assert results['img'].shape == (224, 256, 3) - assert results['gt_semantic_seg'].shape == (224, 256) - assert results['img_shape'] == (224, 256, 3) - assert results['pad_shape'] == (224, 256, 3) + assert results["img"].shape == (224, 256, 3) + assert results["gt_semantic_seg"].shape == (224, 256) + assert results["img_shape"] == (224, 256, 3) + assert results["pad_shape"] == (224, 256, 3) def test_resize(): # test assertion if img_scale is a list with pytest.raises(AssertionError): - transform = dict(type='Resize', img_scale=[1333, 800], keep_ratio=True) + transform = dict(type="Resize", img_scale=[1333, 800], keep_ratio=True) build_from_cfg(transform, PIPELINES) # test assertion if len(img_scale) while ratio_range is not None with pytest.raises(AssertionError): transform = dict( - type='Resize', + type="Resize", img_scale=[(1333, 800), (1333, 600)], ratio_range=(0.9, 1.1), - keep_ratio=True) + keep_ratio=True, + ) build_from_cfg(transform, PIPELINES) # test assertion for invalid multiscale_mode with pytest.raises(AssertionError): transform = dict( - type='Resize', + type="Resize", img_scale=[(1333, 800), (1333, 600)], keep_ratio=True, - multiscale_mode='2333') + multiscale_mode="2333", + ) build_from_cfg(transform, PIPELINES) - transform = dict(type='Resize', img_scale=(1333, 800), keep_ratio=True) + transform = dict(type="Resize", img_scale=(1333, 800), keep_ratio=True) resize_module = build_from_cfg(transform, PIPELINES) results = dict() # (288, 512, 3) - img = mmcv.imread( - osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') - results['img'] = img - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + img = mmcv.imread(osp.join(osp.dirname(__file__), "../data/color.jpg"), "color") + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 resized_results = resize_module(results.copy()) - assert resized_results['img_shape'] == (750, 1333, 3) + assert resized_results["img_shape"] == (750, 1333, 3) # test keep_ratio=False transform = dict( - type='Resize', - img_scale=(1280, 800), - multiscale_mode='value', - keep_ratio=False) + type="Resize", img_scale=(1280, 800), multiscale_mode="value", keep_ratio=False + ) resize_module = build_from_cfg(transform, PIPELINES) resized_results = resize_module(results.copy()) - assert resized_results['img_shape'] == (800, 1280, 3) + assert resized_results["img_shape"] == (800, 1280, 3) # test multiscale_mode='range' transform = dict( - type='Resize', + type="Resize", img_scale=[(1333, 400), (1333, 1200)], - multiscale_mode='range', - keep_ratio=True) + multiscale_mode="range", + keep_ratio=True, + ) resize_module = build_from_cfg(transform, PIPELINES) resized_results = resize_module(results.copy()) - assert max(resized_results['img_shape'][:2]) <= 1333 - assert min(resized_results['img_shape'][:2]) >= 400 - assert min(resized_results['img_shape'][:2]) <= 1200 + assert max(resized_results["img_shape"][:2]) <= 1333 + assert min(resized_results["img_shape"][:2]) >= 400 + assert min(resized_results["img_shape"][:2]) <= 1200 # test multiscale_mode='value' transform = dict( - type='Resize', + type="Resize", img_scale=[(1333, 800), (1333, 400)], - multiscale_mode='value', - keep_ratio=True) + multiscale_mode="value", + keep_ratio=True, + ) resize_module = build_from_cfg(transform, PIPELINES) resized_results = resize_module(results.copy()) - assert resized_results['img_shape'] in [(750, 1333, 3), (400, 711, 3)] + assert resized_results["img_shape"] in [(750, 1333, 3), (400, 711, 3)] # test multiscale_mode='range' transform = dict( - type='Resize', - img_scale=(1333, 800), - ratio_range=(0.9, 1.1), - keep_ratio=True) + type="Resize", img_scale=(1333, 800), ratio_range=(0.9, 1.1), keep_ratio=True + ) resize_module = build_from_cfg(transform, PIPELINES) resized_results = resize_module(results.copy()) - assert max(resized_results['img_shape'][:2]) <= 1333 * 1.1 + assert max(resized_results["img_shape"][:2]) <= 1333 * 1.1 # test img_scale=None and ratio_range is tuple. # img shape: (288, 512, 3) transform = dict( - type='Resize', img_scale=None, ratio_range=(0.5, 2.0), keep_ratio=True) + type="Resize", img_scale=None, ratio_range=(0.5, 2.0), keep_ratio=True + ) resize_module = build_from_cfg(transform, PIPELINES) resized_results = resize_module(results.copy()) - assert int(288 * 0.5) <= resized_results['img_shape'][0] <= 288 * 2.0 - assert int(512 * 0.5) <= resized_results['img_shape'][1] <= 512 * 2.0 + assert int(288 * 0.5) <= resized_results["img_shape"][0] <= 288 * 2.0 + assert int(512 * 0.5) <= resized_results["img_shape"][1] <= 512 * 2.0 # test min_size=640 - transform = dict(type='Resize', img_scale=(2560, 640), min_size=640) + transform = dict(type="Resize", img_scale=(2560, 640), min_size=640) resize_module = build_from_cfg(transform, PIPELINES) resized_results = resize_module(results.copy()) - assert resized_results['img_shape'] == (640, 1138, 3) + assert resized_results["img_shape"] == (640, 1138, 3) # test min_size=640 and img_scale=(512, 640) - transform = dict(type='Resize', img_scale=(512, 640), min_size=640) + transform = dict(type="Resize", img_scale=(512, 640), min_size=640) resize_module = build_from_cfg(transform, PIPELINES) resized_results = resize_module(results.copy()) - assert resized_results['img_shape'] == (640, 1138, 3) + assert resized_results["img_shape"] == (640, 1138, 3) # test h > w img = np.random.randn(512, 288, 3) - results['img'] = img - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 - transform = dict(type='Resize', img_scale=(2560, 640), min_size=640) + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 + transform = dict(type="Resize", img_scale=(2560, 640), min_size=640) resize_module = build_from_cfg(transform, PIPELINES) resized_results = resize_module(results.copy()) - assert resized_results['img_shape'] == (1138, 640, 3) + assert resized_results["img_shape"] == (1138, 640, 3) def test_flip(): # test assertion for invalid prob with pytest.raises(AssertionError): - transform = dict(type='RandomFlip', prob=1.5) + transform = dict(type="RandomFlip", prob=1.5) build_from_cfg(transform, PIPELINES) # test assertion for invalid direction with pytest.raises(AssertionError): - transform = dict(type='RandomFlip', prob=1, direction='horizonta') + transform = dict(type="RandomFlip", prob=1, direction="horizonta") build_from_cfg(transform, PIPELINES) - transform = dict(type='RandomFlip', prob=1) + transform = dict(type="RandomFlip", prob=1) flip_module = build_from_cfg(transform, PIPELINES) results = dict() - img = mmcv.imread( - osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + img = mmcv.imread(osp.join(osp.dirname(__file__), "../data/color.jpg"), "color") original_img = copy.deepcopy(img) - seg = np.array( - Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) + seg = np.array(Image.open(osp.join(osp.dirname(__file__), "../data/seg.png"))) original_seg = copy.deepcopy(seg) - results['img'] = img - results['gt_semantic_seg'] = seg - results['seg_fields'] = ['gt_semantic_seg'] - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + results["img"] = img + results["gt_semantic_seg"] = seg + results["seg_fields"] = ["gt_semantic_seg"] + results["img_shape"] = img.shape + results["ori_shape"] = img.shape # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 results = flip_module(results) flip_module = build_from_cfg(transform, PIPELINES) results = flip_module(results) - assert np.equal(original_img, results['img']).all() - assert np.equal(original_seg, results['gt_semantic_seg']).all() + assert np.equal(original_img, results["img"]).all() + assert np.equal(original_seg, results["gt_semantic_seg"]).all() def test_random_crop(): # test assertion for invalid random crop with pytest.raises(AssertionError): - transform = dict(type='RandomCrop', crop_size=(-1, 0)) + transform = dict(type="RandomCrop", crop_size=(-1, 0)) build_from_cfg(transform, PIPELINES) results = dict() - img = mmcv.imread( - osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') - seg = np.array( - Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) - results['img'] = img - results['gt_semantic_seg'] = seg - results['seg_fields'] = ['gt_semantic_seg'] - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + img = mmcv.imread(osp.join(osp.dirname(__file__), "../data/color.jpg"), "color") + seg = np.array(Image.open(osp.join(osp.dirname(__file__), "../data/seg.png"))) + results["img"] = img + results["gt_semantic_seg"] = seg + results["seg_fields"] = ["gt_semantic_seg"] + results["img_shape"] = img.shape + results["ori_shape"] = img.shape # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 h, w, _ = img.shape - transform = dict(type='RandomCrop', crop_size=(h - 20, w - 20)) + transform = dict(type="RandomCrop", crop_size=(h - 20, w - 20)) crop_module = build_from_cfg(transform, PIPELINES) results = crop_module(results) - assert results['img'].shape[:2] == (h - 20, w - 20) - assert results['img_shape'][:2] == (h - 20, w - 20) - assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20) + assert results["img"].shape[:2] == (h - 20, w - 20) + assert results["img_shape"][:2] == (h - 20, w - 20) + assert results["gt_semantic_seg"].shape[:2] == (h - 20, w - 20) def test_pad(): # test assertion if both size_divisor and size is None with pytest.raises(AssertionError): - transform = dict(type='Pad') + transform = dict(type="Pad") build_from_cfg(transform, PIPELINES) - transform = dict(type='Pad', size_divisor=32) + transform = dict(type="Pad", size_divisor=32) transform = build_from_cfg(transform, PIPELINES) results = dict() - img = mmcv.imread( - osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + img = mmcv.imread(osp.join(osp.dirname(__file__), "../data/color.jpg"), "color") original_img = copy.deepcopy(img) - results['img'] = img - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 results = transform(results) # original img already divisible by 32 - assert np.equal(results['img'], original_img).all() - img_shape = results['img'].shape + assert np.equal(results["img"], original_img).all() + img_shape = results["img"].shape assert img_shape[0] % 32 == 0 assert img_shape[1] % 32 == 0 - resize_transform = dict( - type='Resize', img_scale=(1333, 800), keep_ratio=True) + resize_transform = dict(type="Resize", img_scale=(1333, 800), keep_ratio=True) resize_module = build_from_cfg(resize_transform, PIPELINES) results = resize_module(results) results = transform(results) - img_shape = results['img'].shape + img_shape = results["img"].shape assert img_shape[0] % 32 == 0 assert img_shape[1] % 32 == 0 @@ -255,205 +249,201 @@ def test_pad(): def test_rotate(): # test assertion degree should be tuple[float] or float with pytest.raises(AssertionError): - transform = dict(type='RandomRotate', prob=0.5, degree=-10) + transform = dict(type="RandomRotate", prob=0.5, degree=-10) build_from_cfg(transform, PIPELINES) # test assertion degree should be tuple[float] or float with pytest.raises(AssertionError): - transform = dict(type='RandomRotate', prob=0.5, degree=(10., 20., 30.)) + transform = dict(type="RandomRotate", prob=0.5, degree=(10.0, 20.0, 30.0)) build_from_cfg(transform, PIPELINES) - transform = dict(type='RandomRotate', degree=10., prob=1.) + transform = dict(type="RandomRotate", degree=10.0, prob=1.0) transform = build_from_cfg(transform, PIPELINES) - assert str(transform) == f'RandomRotate(' \ - f'prob={1.}, ' \ - f'degree=({-10.}, {10.}), ' \ - f'pad_val={0}, ' \ - f'seg_pad_val={255}, ' \ - f'center={None}, ' \ - f'auto_bound={False})' + assert ( + str(transform) == f"RandomRotate(" + f"prob={1.}, " + f"degree=({-10.}, {10.}), " + f"pad_val={0}, " + f"seg_pad_val={255}, " + f"center={None}, " + f"auto_bound={False})" + ) results = dict() - img = mmcv.imread( - osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + img = mmcv.imread(osp.join(osp.dirname(__file__), "../data/color.jpg"), "color") h, w, _ = img.shape - seg = np.array( - Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) - results['img'] = img - results['gt_semantic_seg'] = seg - results['seg_fields'] = ['gt_semantic_seg'] - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + seg = np.array(Image.open(osp.join(osp.dirname(__file__), "../data/seg.png"))) + results["img"] = img + results["gt_semantic_seg"] = seg + results["seg_fields"] = ["gt_semantic_seg"] + results["img_shape"] = img.shape + results["ori_shape"] = img.shape # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 results = transform(results) - assert results['img'].shape[:2] == (h, w) - assert results['gt_semantic_seg'].shape[:2] == (h, w) + assert results["img"].shape[:2] == (h, w) + assert results["gt_semantic_seg"].shape[:2] == (h, w) def test_normalize(): img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - to_rgb=True) - transform = dict(type='Normalize', **img_norm_cfg) + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True + ) + transform = dict(type="Normalize", **img_norm_cfg) transform = build_from_cfg(transform, PIPELINES) results = dict() - img = mmcv.imread( - osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + img = mmcv.imread(osp.join(osp.dirname(__file__), "../data/color.jpg"), "color") original_img = copy.deepcopy(img) - results['img'] = img - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 results = transform(results) - mean = np.array(img_norm_cfg['mean']) - std = np.array(img_norm_cfg['std']) + mean = np.array(img_norm_cfg["mean"]) + std = np.array(img_norm_cfg["std"]) converted_img = (original_img[..., ::-1] - mean) / std - assert np.allclose(results['img'], converted_img) + assert np.allclose(results["img"], converted_img) def test_rgb2gray(): # test assertion out_channels should be greater than 0 with pytest.raises(AssertionError): - transform = dict(type='RGB2Gray', out_channels=-1) + transform = dict(type="RGB2Gray", out_channels=-1) build_from_cfg(transform, PIPELINES) # test assertion weights should be tuple[float] with pytest.raises(AssertionError): - transform = dict(type='RGB2Gray', out_channels=1, weights=1.1) + transform = dict(type="RGB2Gray", out_channels=1, weights=1.1) build_from_cfg(transform, PIPELINES) # test out_channels is None - transform = dict(type='RGB2Gray') + transform = dict(type="RGB2Gray") transform = build_from_cfg(transform, PIPELINES) - assert str(transform) == f'RGB2Gray(' \ - f'out_channels={None}, ' \ - f'weights={(0.299, 0.587, 0.114)})' + assert ( + str(transform) == f"RGB2Gray(" + f"out_channels={None}, " + f"weights={(0.299, 0.587, 0.114)})" + ) results = dict() - img = mmcv.imread( - osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + img = mmcv.imread(osp.join(osp.dirname(__file__), "../data/color.jpg"), "color") h, w, c = img.shape - seg = np.array( - Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) - results['img'] = img - results['gt_semantic_seg'] = seg - results['seg_fields'] = ['gt_semantic_seg'] - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + seg = np.array(Image.open(osp.join(osp.dirname(__file__), "../data/seg.png"))) + results["img"] = img + results["gt_semantic_seg"] = seg + results["seg_fields"] = ["gt_semantic_seg"] + results["img_shape"] = img.shape + results["ori_shape"] = img.shape # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 results = transform(results) - assert results['img'].shape == (h, w, c) - assert results['img_shape'] == (h, w, c) - assert results['ori_shape'] == (h, w, c) + assert results["img"].shape == (h, w, c) + assert results["img_shape"] == (h, w, c) + assert results["ori_shape"] == (h, w, c) # test out_channels = 2 - transform = dict(type='RGB2Gray', out_channels=2) + transform = dict(type="RGB2Gray", out_channels=2) transform = build_from_cfg(transform, PIPELINES) - assert str(transform) == f'RGB2Gray(' \ - f'out_channels={2}, ' \ - f'weights={(0.299, 0.587, 0.114)})' + assert ( + str(transform) == f"RGB2Gray(" + f"out_channels={2}, " + f"weights={(0.299, 0.587, 0.114)})" + ) results = dict() - img = mmcv.imread( - osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + img = mmcv.imread(osp.join(osp.dirname(__file__), "../data/color.jpg"), "color") h, w, c = img.shape - seg = np.array( - Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) - results['img'] = img - results['gt_semantic_seg'] = seg - results['seg_fields'] = ['gt_semantic_seg'] - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + seg = np.array(Image.open(osp.join(osp.dirname(__file__), "../data/seg.png"))) + results["img"] = img + results["gt_semantic_seg"] = seg + results["seg_fields"] = ["gt_semantic_seg"] + results["img_shape"] = img.shape + results["ori_shape"] = img.shape # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 results = transform(results) - assert results['img'].shape == (h, w, 2) - assert results['img_shape'] == (h, w, 2) - assert results['ori_shape'] == (h, w, c) + assert results["img"].shape == (h, w, 2) + assert results["img_shape"] == (h, w, 2) + assert results["ori_shape"] == (h, w, c) def test_adjust_gamma(): # test assertion if gamma <= 0 with pytest.raises(AssertionError): - transform = dict(type='AdjustGamma', gamma=0) + transform = dict(type="AdjustGamma", gamma=0) build_from_cfg(transform, PIPELINES) # test assertion if gamma is list with pytest.raises(AssertionError): - transform = dict(type='AdjustGamma', gamma=[1.2]) + transform = dict(type="AdjustGamma", gamma=[1.2]) build_from_cfg(transform, PIPELINES) # test with gamma = 1.2 - transform = dict(type='AdjustGamma', gamma=1.2) + transform = dict(type="AdjustGamma", gamma=1.2) transform = build_from_cfg(transform, PIPELINES) results = dict() - img = mmcv.imread( - osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + img = mmcv.imread(osp.join(osp.dirname(__file__), "../data/color.jpg"), "color") original_img = copy.deepcopy(img) - results['img'] = img - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 results = transform(results) inv_gamma = 1.0 / 1.2 - table = np.array([((i / 255.0)**inv_gamma) * 255 - for i in np.arange(0, 256)]).astype('uint8') - converted_img = mmcv.lut_transform( - np.array(original_img, dtype=np.uint8), table) - assert np.allclose(results['img'], converted_img) - assert str(transform) == f'AdjustGamma(gamma={1.2})' + table = np.array( + [((i / 255.0) ** inv_gamma) * 255 for i in np.arange(0, 256)] + ).astype("uint8") + converted_img = mmcv.lut_transform(np.array(original_img, dtype=np.uint8), table) + assert np.allclose(results["img"], converted_img) + assert str(transform) == f"AdjustGamma(gamma={1.2})" def test_rerange(): # test assertion if min_value or max_value is illegal with pytest.raises(AssertionError): - transform = dict(type='Rerange', min_value=[0], max_value=[255]) + transform = dict(type="Rerange", min_value=[0], max_value=[255]) build_from_cfg(transform, PIPELINES) # test assertion if min_value >= max_value with pytest.raises(AssertionError): - transform = dict(type='Rerange', min_value=1, max_value=1) + transform = dict(type="Rerange", min_value=1, max_value=1) build_from_cfg(transform, PIPELINES) # test assertion if img_min_value == img_max_value with pytest.raises(AssertionError): - transform = dict(type='Rerange', min_value=0, max_value=1) + transform = dict(type="Rerange", min_value=0, max_value=1) transform = build_from_cfg(transform, PIPELINES) results = dict() - results['img'] = np.array([[1, 1], [1, 1]]) + results["img"] = np.array([[1, 1], [1, 1]]) transform(results) img_rerange_cfg = dict() - transform = dict(type='Rerange', **img_rerange_cfg) + transform = dict(type="Rerange", **img_rerange_cfg) transform = build_from_cfg(transform, PIPELINES) results = dict() - img = mmcv.imread( - osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + img = mmcv.imread(osp.join(osp.dirname(__file__), "../data/color.jpg"), "color") original_img = copy.deepcopy(img) - results['img'] = img - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 results = transform(results) @@ -461,230 +451,224 @@ def test_rerange(): max_value = np.max(original_img) converted_img = (original_img - min_value) / (max_value - min_value) * 255 - assert np.allclose(results['img'], converted_img) - assert str(transform) == f'Rerange(min_value={0}, max_value={255})' + assert np.allclose(results["img"], converted_img) + assert str(transform) == f"Rerange(min_value={0}, max_value={255})" def test_CLAHE(): # test assertion if clip_limit is None with pytest.raises(AssertionError): - transform = dict(type='CLAHE', clip_limit=None) + transform = dict(type="CLAHE", clip_limit=None) build_from_cfg(transform, PIPELINES) # test assertion if tile_grid_size is illegal with pytest.raises(AssertionError): - transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0)) + transform = dict(type="CLAHE", tile_grid_size=(8.0, 8.0)) build_from_cfg(transform, PIPELINES) # test assertion if tile_grid_size is illegal with pytest.raises(AssertionError): - transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9)) + transform = dict(type="CLAHE", tile_grid_size=(9, 9, 9)) build_from_cfg(transform, PIPELINES) - transform = dict(type='CLAHE', clip_limit=2) + transform = dict(type="CLAHE", clip_limit=2) transform = build_from_cfg(transform, PIPELINES) results = dict() - img = mmcv.imread( - osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + img = mmcv.imread(osp.join(osp.dirname(__file__), "../data/color.jpg"), "color") original_img = copy.deepcopy(img) - results['img'] = img - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 results = transform(results) converted_img = np.empty(original_img.shape) for i in range(original_img.shape[2]): converted_img[:, :, i] = mmcv.clahe( - np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8)) + np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8) + ) - assert np.allclose(results['img'], converted_img) - assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})' + assert np.allclose(results["img"], converted_img) + assert str(transform) == f"CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})" def test_seg_rescale(): results = dict() - seg = np.array( - Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) - results['gt_semantic_seg'] = seg - results['seg_fields'] = ['gt_semantic_seg'] + seg = np.array(Image.open(osp.join(osp.dirname(__file__), "../data/seg.png"))) + results["gt_semantic_seg"] = seg + results["seg_fields"] = ["gt_semantic_seg"] h, w = seg.shape - transform = dict(type='SegRescale', scale_factor=1. / 2) + transform = dict(type="SegRescale", scale_factor=1.0 / 2) rescale_module = build_from_cfg(transform, PIPELINES) rescale_results = rescale_module(results.copy()) - assert rescale_results['gt_semantic_seg'].shape == (h // 2, w // 2) + assert rescale_results["gt_semantic_seg"].shape == (h // 2, w // 2) - transform = dict(type='SegRescale', scale_factor=1) + transform = dict(type="SegRescale", scale_factor=1) rescale_module = build_from_cfg(transform, PIPELINES) rescale_results = rescale_module(results.copy()) - assert rescale_results['gt_semantic_seg'].shape == (h, w) + assert rescale_results["gt_semantic_seg"].shape == (h, w) def test_cutout(): # test prob with pytest.raises(AssertionError): - transform = dict(type='RandomCutOut', prob=1.5, n_holes=1) + transform = dict(type="RandomCutOut", prob=1.5, n_holes=1) build_from_cfg(transform, PIPELINES) # test n_holes with pytest.raises(AssertionError): transform = dict( - type='RandomCutOut', prob=0.5, n_holes=(5, 3), cutout_shape=(8, 8)) + type="RandomCutOut", prob=0.5, n_holes=(5, 3), cutout_shape=(8, 8) + ) build_from_cfg(transform, PIPELINES) with pytest.raises(AssertionError): transform = dict( - type='RandomCutOut', - prob=0.5, - n_holes=(3, 4, 5), - cutout_shape=(8, 8)) + type="RandomCutOut", prob=0.5, n_holes=(3, 4, 5), cutout_shape=(8, 8) + ) build_from_cfg(transform, PIPELINES) # test cutout_shape and cutout_ratio with pytest.raises(AssertionError): - transform = dict( - type='RandomCutOut', prob=0.5, n_holes=1, cutout_shape=8) + transform = dict(type="RandomCutOut", prob=0.5, n_holes=1, cutout_shape=8) build_from_cfg(transform, PIPELINES) with pytest.raises(AssertionError): - transform = dict( - type='RandomCutOut', prob=0.5, n_holes=1, cutout_ratio=0.2) + transform = dict(type="RandomCutOut", prob=0.5, n_holes=1, cutout_ratio=0.2) build_from_cfg(transform, PIPELINES) # either of cutout_shape and cutout_ratio should be given with pytest.raises(AssertionError): - transform = dict(type='RandomCutOut', prob=0.5, n_holes=1) + transform = dict(type="RandomCutOut", prob=0.5, n_holes=1) build_from_cfg(transform, PIPELINES) with pytest.raises(AssertionError): transform = dict( - type='RandomCutOut', + type="RandomCutOut", prob=0.5, n_holes=1, cutout_shape=(2, 2), - cutout_ratio=(0.4, 0.4)) + cutout_ratio=(0.4, 0.4), + ) build_from_cfg(transform, PIPELINES) # test seg_fill_in with pytest.raises(AssertionError): transform = dict( - type='RandomCutOut', + type="RandomCutOut", prob=0.5, n_holes=1, cutout_shape=(8, 8), - seg_fill_in='a') + seg_fill_in="a", + ) build_from_cfg(transform, PIPELINES) with pytest.raises(AssertionError): transform = dict( - type='RandomCutOut', + type="RandomCutOut", prob=0.5, n_holes=1, cutout_shape=(8, 8), - seg_fill_in=256) + seg_fill_in=256, + ) build_from_cfg(transform, PIPELINES) results = dict() - img = mmcv.imread( - osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + img = mmcv.imread(osp.join(osp.dirname(__file__), "../data/color.jpg"), "color") - seg = np.array( - Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) + seg = np.array(Image.open(osp.join(osp.dirname(__file__), "../data/seg.png"))) - results['img'] = img - results['gt_semantic_seg'] = seg - results['seg_fields'] = ['gt_semantic_seg'] - results['img_shape'] = img.shape - results['ori_shape'] = img.shape - results['pad_shape'] = img.shape - results['img_fields'] = ['img'] + results["img"] = img + results["gt_semantic_seg"] = seg + results["seg_fields"] = ["gt_semantic_seg"] + results["img_shape"] = img.shape + results["ori_shape"] = img.shape + results["pad_shape"] = img.shape + results["img_fields"] = ["img"] - transform = dict( - type='RandomCutOut', prob=1, n_holes=1, cutout_shape=(10, 10)) + transform = dict(type="RandomCutOut", prob=1, n_holes=1, cutout_shape=(10, 10)) cutout_module = build_from_cfg(transform, PIPELINES) - assert 'cutout_shape' in repr(cutout_module) + assert "cutout_shape" in repr(cutout_module) cutout_result = cutout_module(copy.deepcopy(results)) - assert cutout_result['img'].sum() < img.sum() + assert cutout_result["img"].sum() < img.sum() - transform = dict( - type='RandomCutOut', prob=1, n_holes=1, cutout_ratio=(0.8, 0.8)) + transform = dict(type="RandomCutOut", prob=1, n_holes=1, cutout_ratio=(0.8, 0.8)) cutout_module = build_from_cfg(transform, PIPELINES) - assert 'cutout_ratio' in repr(cutout_module) + assert "cutout_ratio" in repr(cutout_module) cutout_result = cutout_module(copy.deepcopy(results)) - assert cutout_result['img'].sum() < img.sum() + assert cutout_result["img"].sum() < img.sum() - transform = dict( - type='RandomCutOut', prob=0, n_holes=1, cutout_ratio=(0.8, 0.8)) + transform = dict(type="RandomCutOut", prob=0, n_holes=1, cutout_ratio=(0.8, 0.8)) cutout_module = build_from_cfg(transform, PIPELINES) cutout_result = cutout_module(copy.deepcopy(results)) - assert cutout_result['img'].sum() == img.sum() - assert cutout_result['gt_semantic_seg'].sum() == seg.sum() + assert cutout_result["img"].sum() == img.sum() + assert cutout_result["gt_semantic_seg"].sum() == seg.sum() transform = dict( - type='RandomCutOut', + type="RandomCutOut", prob=1, n_holes=(2, 4), cutout_shape=[(10, 10), (15, 15)], fill_in=(255, 255, 255), - seg_fill_in=None) + seg_fill_in=None, + ) cutout_module = build_from_cfg(transform, PIPELINES) cutout_result = cutout_module(copy.deepcopy(results)) - assert cutout_result['img'].sum() > img.sum() - assert cutout_result['gt_semantic_seg'].sum() == seg.sum() + assert cutout_result["img"].sum() > img.sum() + assert cutout_result["gt_semantic_seg"].sum() == seg.sum() transform = dict( - type='RandomCutOut', + type="RandomCutOut", prob=1, n_holes=1, cutout_ratio=(0.8, 0.8), fill_in=(255, 255, 255), - seg_fill_in=255) + seg_fill_in=255, + ) cutout_module = build_from_cfg(transform, PIPELINES) cutout_result = cutout_module(copy.deepcopy(results)) - assert cutout_result['img'].sum() > img.sum() - assert cutout_result['gt_semantic_seg'].sum() > seg.sum() + assert cutout_result["img"].sum() > img.sum() + assert cutout_result["gt_semantic_seg"].sum() > seg.sum() def test_mosaic(): # test prob with pytest.raises(AssertionError): - transform = dict(type='RandomMosaic', prob=1.5) + transform = dict(type="RandomMosaic", prob=1.5) build_from_cfg(transform, PIPELINES) # test assertion for invalid img_scale with pytest.raises(AssertionError): - transform = dict(type='RandomMosaic', prob=1, img_scale=640) + transform = dict(type="RandomMosaic", prob=1, img_scale=640) build_from_cfg(transform, PIPELINES) results = dict() - img = mmcv.imread( - osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') - seg = np.array( - Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) + img = mmcv.imread(osp.join(osp.dirname(__file__), "../data/color.jpg"), "color") + seg = np.array(Image.open(osp.join(osp.dirname(__file__), "../data/seg.png"))) - results['img'] = img - results['gt_semantic_seg'] = seg - results['seg_fields'] = ['gt_semantic_seg'] + results["img"] = img + results["gt_semantic_seg"] = seg + results["seg_fields"] = ["gt_semantic_seg"] - transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12)) + transform = dict(type="RandomMosaic", prob=1, img_scale=(10, 12)) mosaic_module = build_from_cfg(transform, PIPELINES) - assert 'Mosaic' in repr(mosaic_module) + assert "Mosaic" in repr(mosaic_module) # test assertion for invalid mix_results with pytest.raises(AssertionError): mosaic_module(results) - results['mix_results'] = [copy.deepcopy(results)] * 3 + results["mix_results"] = [copy.deepcopy(results)] * 3 results = mosaic_module(results) - assert results['img'].shape[:2] == (20, 24) + assert results["img"].shape[:2] == (20, 24) results = dict() - results['img'] = img[:, :, 0] - results['gt_semantic_seg'] = seg - results['seg_fields'] = ['gt_semantic_seg'] + results["img"] = img[:, :, 0] + results["gt_semantic_seg"] = seg + results["seg_fields"] = ["gt_semantic_seg"] - transform = dict(type='RandomMosaic', prob=0, img_scale=(10, 12)) + transform = dict(type="RandomMosaic", prob=0, img_scale=(10, 12)) mosaic_module = build_from_cfg(transform, PIPELINES) - results['mix_results'] = [copy.deepcopy(results)] * 3 + results["mix_results"] = [copy.deepcopy(results)] * 3 results = mosaic_module(results) - assert results['img'].shape[:2] == img.shape[:2] + assert results["img"].shape[:2] == img.shape[:2] - transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12)) + transform = dict(type="RandomMosaic", prob=1, img_scale=(10, 12)) mosaic_module = build_from_cfg(transform, PIPELINES) results = mosaic_module(results) - assert results['img'].shape[:2] == (20, 24) + assert results["img"].shape[:2] == (20, 24) diff --git a/mmsegmentation/tests/test_data/test_tta.py b/mmsegmentation/tests/test_data/test_tta.py index 9373e2b..91ed14c 100644 --- a/mmsegmentation/tests/test_data/test_tta.py +++ b/mmsegmentation/tests/test_data/test_tta.py @@ -12,178 +12,213 @@ def test_multi_scale_flip_aug(): # test assertion if img_scale=None, img_ratios=1 (not float). with pytest.raises(AssertionError): tta_transform = dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=None, img_ratios=1, - transforms=[dict(type='Resize', keep_ratio=False)], + transforms=[dict(type="Resize", keep_ratio=False)], ) build_from_cfg(tta_transform, PIPELINES) # test assertion if img_scale=None, img_ratios=None. with pytest.raises(AssertionError): tta_transform = dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=None, img_ratios=None, - transforms=[dict(type='Resize', keep_ratio=False)], + transforms=[dict(type="Resize", keep_ratio=False)], ) build_from_cfg(tta_transform, PIPELINES) # test assertion if img_scale=(512, 512), img_ratios=1 (not float). with pytest.raises(AssertionError): tta_transform = dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=(512, 512), img_ratios=1, - transforms=[dict(type='Resize', keep_ratio=False)], + transforms=[dict(type="Resize", keep_ratio=False)], ) build_from_cfg(tta_transform, PIPELINES) tta_transform = dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=(512, 512), img_ratios=[0.5, 1.0, 2.0], flip=False, - transforms=[dict(type='Resize', keep_ratio=False)], + transforms=[dict(type="Resize", keep_ratio=False)], ) tta_module = build_from_cfg(tta_transform, PIPELINES) results = dict() # (288, 512, 3) - img = mmcv.imread( - osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') - results['img'] = img - results['img_shape'] = img.shape - results['ori_shape'] = img.shape + img = mmcv.imread(osp.join(osp.dirname(__file__), "../data/color.jpg"), "color") + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 tta_results = tta_module(results.copy()) - assert tta_results['scale'] == [(256, 256), (512, 512), (1024, 1024)] - assert tta_results['flip'] == [False, False, False] + assert tta_results["scale"] == [(256, 256), (512, 512), (1024, 1024)] + assert tta_results["flip"] == [False, False, False] tta_transform = dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=(512, 512), img_ratios=[0.5, 1.0, 2.0], flip=True, - transforms=[dict(type='Resize', keep_ratio=False)], + transforms=[dict(type="Resize", keep_ratio=False)], ) tta_module = build_from_cfg(tta_transform, PIPELINES) tta_results = tta_module(results.copy()) - assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512), - (512, 512), (1024, 1024), (1024, 1024)] - assert tta_results['flip'] == [False, True, False, True, False, True] + assert tta_results["scale"] == [ + (256, 256), + (256, 256), + (512, 512), + (512, 512), + (1024, 1024), + (1024, 1024), + ] + assert tta_results["flip"] == [False, True, False, True, False, True] tta_transform = dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=(512, 512), img_ratios=1.0, flip=False, - transforms=[dict(type='Resize', keep_ratio=False)], + transforms=[dict(type="Resize", keep_ratio=False)], ) tta_module = build_from_cfg(tta_transform, PIPELINES) tta_results = tta_module(results.copy()) - assert tta_results['scale'] == [(512, 512)] - assert tta_results['flip'] == [False] + assert tta_results["scale"] == [(512, 512)] + assert tta_results["flip"] == [False] tta_transform = dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=(512, 512), img_ratios=1.0, flip=True, - transforms=[dict(type='Resize', keep_ratio=False)], + transforms=[dict(type="Resize", keep_ratio=False)], ) tta_module = build_from_cfg(tta_transform, PIPELINES) tta_results = tta_module(results.copy()) - assert tta_results['scale'] == [(512, 512), (512, 512)] - assert tta_results['flip'] == [False, True] + assert tta_results["scale"] == [(512, 512), (512, 512)] + assert tta_results["flip"] == [False, True] tta_transform = dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=None, img_ratios=[0.5, 1.0, 2.0], flip=False, - transforms=[dict(type='Resize', keep_ratio=False)], + transforms=[dict(type="Resize", keep_ratio=False)], ) tta_module = build_from_cfg(tta_transform, PIPELINES) tta_results = tta_module(results.copy()) - assert tta_results['scale'] == [(256, 144), (512, 288), (1024, 576)] - assert tta_results['flip'] == [False, False, False] + assert tta_results["scale"] == [(256, 144), (512, 288), (1024, 576)] + assert tta_results["flip"] == [False, False, False] tta_transform = dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=None, img_ratios=[0.5, 1.0, 2.0], flip=True, - transforms=[dict(type='Resize', keep_ratio=False)], + transforms=[dict(type="Resize", keep_ratio=False)], ) tta_module = build_from_cfg(tta_transform, PIPELINES) tta_results = tta_module(results.copy()) - assert tta_results['scale'] == [(256, 144), (256, 144), (512, 288), - (512, 288), (1024, 576), (1024, 576)] - assert tta_results['flip'] == [False, True, False, True, False, True] + assert tta_results["scale"] == [ + (256, 144), + (256, 144), + (512, 288), + (512, 288), + (1024, 576), + (1024, 576), + ] + assert tta_results["flip"] == [False, True, False, True, False, True] tta_transform = dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=[(256, 256), (512, 512), (1024, 1024)], img_ratios=None, flip=False, - transforms=[dict(type='Resize', keep_ratio=False)], + transforms=[dict(type="Resize", keep_ratio=False)], ) tta_module = build_from_cfg(tta_transform, PIPELINES) tta_results = tta_module(results.copy()) - assert tta_results['scale'] == [(256, 256), (512, 512), (1024, 1024)] - assert tta_results['flip'] == [False, False, False] + assert tta_results["scale"] == [(256, 256), (512, 512), (1024, 1024)] + assert tta_results["flip"] == [False, False, False] tta_transform = dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=[(256, 256), (512, 512), (1024, 1024)], img_ratios=None, flip=True, - transforms=[dict(type='Resize', keep_ratio=False)], + transforms=[dict(type="Resize", keep_ratio=False)], ) tta_module = build_from_cfg(tta_transform, PIPELINES) tta_results = tta_module(results.copy()) - assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512), - (512, 512), (1024, 1024), (1024, 1024)] - assert tta_results['flip'] == [False, True, False, True, False, True] + assert tta_results["scale"] == [ + (256, 256), + (256, 256), + (512, 512), + (512, 512), + (1024, 1024), + (1024, 1024), + ] + assert tta_results["flip"] == [False, True, False, True, False, True] # test assertion if flip is True and Pad executed before RandomFlip with pytest.raises(AssertionError): tta_transform = dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=[(256, 256), (512, 512), (1024, 1024)], img_ratios=None, flip=True, transforms=[ - dict(type='Resize', keep_ratio=False), - dict(type='Pad', size_divisor=32), - dict(type='RandomFlip'), - ]) + dict(type="Resize", keep_ratio=False), + dict(type="Pad", size_divisor=32), + dict(type="RandomFlip"), + ], + ) tta_module = build_from_cfg(tta_transform, PIPELINES) tta_transform = dict( - type='MultiScaleFlipAug', + type="MultiScaleFlipAug", img_scale=[(256, 256), (512, 512), (1024, 1024)], img_ratios=None, flip=True, transforms=[ - dict(type='Resize', keep_ratio=True), - dict(type='RandomFlip'), - dict(type='Pad', size_divisor=32), - ]) + dict(type="Resize", keep_ratio=True), + dict(type="RandomFlip"), + dict(type="Pad", size_divisor=32), + ], + ) tta_module = build_from_cfg(tta_transform, PIPELINES) tta_results = tta_module(results.copy()) - assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512), - (512, 512), (1024, 1024), (1024, 1024)] - assert tta_results['flip'] == [False, True, False, True, False, True] - assert tta_results['img_shape'] == [(144, 256, 3), (144, 256, 3), - (288, 512, 3), (288, 512, 3), - (576, 1024, 3), (576, 1024, 3)] - assert tta_results['pad_shape'] == [(160, 256, 3), (160, 256, 3), - (288, 512, 3), (288, 512, 3), - (576, 1024, 3), (576, 1024, 3)] - for i in range(len(tta_results['img'])): - assert tta_results['img'][i].shape == tta_results['pad_shape'][i] + assert tta_results["scale"] == [ + (256, 256), + (256, 256), + (512, 512), + (512, 512), + (1024, 1024), + (1024, 1024), + ] + assert tta_results["flip"] == [False, True, False, True, False, True] + assert tta_results["img_shape"] == [ + (144, 256, 3), + (144, 256, 3), + (288, 512, 3), + (288, 512, 3), + (576, 1024, 3), + (576, 1024, 3), + ] + assert tta_results["pad_shape"] == [ + (160, 256, 3), + (160, 256, 3), + (288, 512, 3), + (288, 512, 3), + (576, 1024, 3), + (576, 1024, 3), + ] + for i in range(len(tta_results["img"])): + assert tta_results["img"][i].shape == tta_results["pad_shape"][i] diff --git a/mmsegmentation/tests/test_digit_version.py b/mmsegmentation/tests/test_digit_version.py index 45daf09..f5df808 100644 --- a/mmsegmentation/tests/test_digit_version.py +++ b/mmsegmentation/tests/test_digit_version.py @@ -3,19 +3,19 @@ def test_digit_version(): - assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0) - assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0) - assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0) - assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1) - assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0) - assert digit_version('1.0') == digit_version('1.0.0') - assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5') - assert digit_version('1.0.0dev') < digit_version('1.0.0a') - assert digit_version('1.0.0a') < digit_version('1.0.0a1') - assert digit_version('1.0.0a') < digit_version('1.0.0b') - assert digit_version('1.0.0b') < digit_version('1.0.0rc') - assert digit_version('1.0.0rc1') < digit_version('1.0.0') - assert digit_version('1.0.0') < digit_version('1.0.0post') - assert digit_version('1.0.0post') < digit_version('1.0.0post1') - assert digit_version('v1') == (1, 0, 0, 0, 0, 0) - assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0) + assert digit_version("0.2.16") == (0, 2, 16, 0, 0, 0) + assert digit_version("1.2.3") == (1, 2, 3, 0, 0, 0) + assert digit_version("1.2.3rc0") == (1, 2, 3, 0, -1, 0) + assert digit_version("1.2.3rc1") == (1, 2, 3, 0, -1, 1) + assert digit_version("1.0rc0") == (1, 0, 0, 0, -1, 0) + assert digit_version("1.0") == digit_version("1.0.0") + assert digit_version("1.5.0+cuda90_cudnn7.6.3_lms") == digit_version("1.5") + assert digit_version("1.0.0dev") < digit_version("1.0.0a") + assert digit_version("1.0.0a") < digit_version("1.0.0a1") + assert digit_version("1.0.0a") < digit_version("1.0.0b") + assert digit_version("1.0.0b") < digit_version("1.0.0rc") + assert digit_version("1.0.0rc1") < digit_version("1.0.0") + assert digit_version("1.0.0") < digit_version("1.0.0post") + assert digit_version("1.0.0post") < digit_version("1.0.0post1") + assert digit_version("v1") == (1, 0, 0, 0, 0, 0) + assert digit_version("v1.1.5") == (1, 1, 5, 0, 0, 0) diff --git a/mmsegmentation/tests/test_eval_hook.py b/mmsegmentation/tests/test_eval_hook.py index 5267438..d8da985 100644 --- a/mmsegmentation/tests/test_eval_hook.py +++ b/mmsegmentation/tests/test_eval_hook.py @@ -15,7 +15,6 @@ class ExampleDataset(Dataset): - def __getitem__(self, idx): results = dict(img=torch.tensor([1]), img_metas=dict()) return results @@ -25,9 +24,8 @@ def __len__(self): class ExampleModel(nn.Module): - def __init__(self): - super(ExampleModel, self).__init__() + super().__init__() self.test_cfg = None self.conv = nn.Conv2d(3, 3, 3) @@ -44,24 +42,21 @@ def test_iter_eval_hook(): test_dataset = ExampleModel() data_loader = [ DataLoader( - test_dataset, - batch_size=1, - sampler=None, - num_worker=0, - shuffle=False) + test_dataset, batch_size=1, sampler=None, num_worker=0, shuffle=False + ) ] EvalHook(data_loader) test_dataset = ExampleDataset() test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])]) - test_dataset.evaluate = MagicMock(return_value=dict(test='success')) + test_dataset.evaluate = MagicMock(return_value=dict(test="success")) loader = DataLoader(test_dataset, batch_size=1) model = ExampleModel() data_loader = DataLoader( - test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False) - optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) - optimizer = obj_from_dict(optim_cfg, torch.optim, - dict(params=model.parameters())) + test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False + ) + optim_cfg = dict(type="SGD", lr=0.01, momentum=0.9, weight_decay=0.0005) + optimizer = obj_from_dict(optim_cfg, torch.optim, dict(params=model.parameters())) # test EvalHook with tempfile.TemporaryDirectory() as tmpdir: @@ -70,11 +65,13 @@ def test_iter_eval_hook(): model=model, optimizer=optimizer, work_dir=tmpdir, - logger=logging.getLogger()) + logger=logging.getLogger(), + ) runner.register_hook(eval_hook) - runner.run([loader], [('train', 1)], 1) - test_dataset.evaluate.assert_called_with([torch.tensor([1])], - logger=runner.logger) + runner.run([loader], [("train", 1)], 1) + test_dataset.evaluate.assert_called_with( + [torch.tensor([1])], logger=runner.logger + ) def test_epoch_eval_hook(): @@ -82,24 +79,21 @@ def test_epoch_eval_hook(): test_dataset = ExampleModel() data_loader = [ DataLoader( - test_dataset, - batch_size=1, - sampler=None, - num_worker=0, - shuffle=False) + test_dataset, batch_size=1, sampler=None, num_worker=0, shuffle=False + ) ] EvalHook(data_loader, by_epoch=True) test_dataset = ExampleDataset() test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])]) - test_dataset.evaluate = MagicMock(return_value=dict(test='success')) + test_dataset.evaluate = MagicMock(return_value=dict(test="success")) loader = DataLoader(test_dataset, batch_size=1) model = ExampleModel() data_loader = DataLoader( - test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False) - optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) - optimizer = obj_from_dict(optim_cfg, torch.optim, - dict(params=model.parameters())) + test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False + ) + optim_cfg = dict(type="SGD", lr=0.01, momentum=0.9, weight_decay=0.0005) + optimizer = obj_from_dict(optim_cfg, torch.optim, dict(params=model.parameters())) # test EvalHook with interval with tempfile.TemporaryDirectory() as tmpdir: @@ -108,87 +102,80 @@ def test_epoch_eval_hook(): model=model, optimizer=optimizer, work_dir=tmpdir, - logger=logging.getLogger()) + logger=logging.getLogger(), + ) runner.register_hook(eval_hook) - runner.run([loader], [('train', 1)], 2) - test_dataset.evaluate.assert_called_once_with([torch.tensor([1])], - logger=runner.logger) + runner.run([loader], [("train", 1)], 2) + test_dataset.evaluate.assert_called_once_with( + [torch.tensor([1])], logger=runner.logger + ) -def multi_gpu_test(model, - data_loader, - tmpdir=None, - gpu_collect=False, - pre_eval=False): +def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False, pre_eval=False): # Pre eval is set by default when training. results = single_gpu_test(model, data_loader, pre_eval=True) return results -@patch('mmseg.apis.multi_gpu_test', multi_gpu_test) +@patch("mmseg.apis.multi_gpu_test", multi_gpu_test) def test_dist_eval_hook(): with pytest.raises(TypeError): test_dataset = ExampleModel() data_loader = [ DataLoader( - test_dataset, - batch_size=1, - sampler=None, - num_worker=0, - shuffle=False) + test_dataset, batch_size=1, sampler=None, num_worker=0, shuffle=False + ) ] DistEvalHook(data_loader) test_dataset = ExampleDataset() test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])]) - test_dataset.evaluate = MagicMock(return_value=dict(test='success')) + test_dataset.evaluate = MagicMock(return_value=dict(test="success")) loader = DataLoader(test_dataset, batch_size=1) model = ExampleModel() data_loader = DataLoader( - test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False) - optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) - optimizer = obj_from_dict(optim_cfg, torch.optim, - dict(params=model.parameters())) + test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False + ) + optim_cfg = dict(type="SGD", lr=0.01, momentum=0.9, weight_decay=0.0005) + optimizer = obj_from_dict(optim_cfg, torch.optim, dict(params=model.parameters())) # test DistEvalHook with tempfile.TemporaryDirectory() as tmpdir: - eval_hook = DistEvalHook( - data_loader, by_epoch=False, efficient_test=True) + eval_hook = DistEvalHook(data_loader, by_epoch=False, efficient_test=True) runner = mmcv.runner.IterBasedRunner( model=model, optimizer=optimizer, work_dir=tmpdir, - logger=logging.getLogger()) + logger=logging.getLogger(), + ) runner.register_hook(eval_hook) - runner.run([loader], [('train', 1)], 1) - test_dataset.evaluate.assert_called_with([torch.tensor([1])], - logger=runner.logger) + runner.run([loader], [("train", 1)], 1) + test_dataset.evaluate.assert_called_with( + [torch.tensor([1])], logger=runner.logger + ) -@patch('mmseg.apis.multi_gpu_test', multi_gpu_test) +@patch("mmseg.apis.multi_gpu_test", multi_gpu_test) def test_dist_eval_hook_epoch(): with pytest.raises(TypeError): test_dataset = ExampleModel() data_loader = [ DataLoader( - test_dataset, - batch_size=1, - sampler=None, - num_worker=0, - shuffle=False) + test_dataset, batch_size=1, sampler=None, num_worker=0, shuffle=False + ) ] DistEvalHook(data_loader) test_dataset = ExampleDataset() test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])]) - test_dataset.evaluate = MagicMock(return_value=dict(test='success')) + test_dataset.evaluate = MagicMock(return_value=dict(test="success")) loader = DataLoader(test_dataset, batch_size=1) model = ExampleModel() data_loader = DataLoader( - test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False) - optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) - optimizer = obj_from_dict(optim_cfg, torch.optim, - dict(params=model.parameters())) + test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False + ) + optim_cfg = dict(type="SGD", lr=0.01, momentum=0.9, weight_decay=0.0005) + optimizer = obj_from_dict(optim_cfg, torch.optim, dict(params=model.parameters())) # test DistEvalHook with tempfile.TemporaryDirectory() as tmpdir: @@ -197,8 +184,10 @@ def test_dist_eval_hook_epoch(): model=model, optimizer=optimizer, work_dir=tmpdir, - logger=logging.getLogger()) + logger=logging.getLogger(), + ) runner.register_hook(eval_hook) - runner.run([loader], [('train', 1)], 2) - test_dataset.evaluate.assert_called_with([torch.tensor([1])], - logger=runner.logger) + runner.run([loader], [("train", 1)], 2) + test_dataset.evaluate.assert_called_with( + [torch.tensor([1])], logger=runner.logger + ) diff --git a/mmsegmentation/tests/test_inference.py b/mmsegmentation/tests/test_inference.py index f71a7ea..c754a0e 100644 --- a/mmsegmentation/tests/test_inference.py +++ b/mmsegmentation/tests/test_inference.py @@ -7,13 +7,13 @@ def test_test_time_augmentation_on_cpu(): - config_file = 'configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py' + config_file = "configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py" config = mmcv.Config.fromfile(config_file) # Remove pretrain model download for testing config.model.pretrained = None # Replace SyncBN with BN to inference on CPU - norm_cfg = dict(type='BN', requires_grad=True) + norm_cfg = dict(type="BN", requires_grad=True) config.model.backbone.norm_cfg = norm_cfg config.model.decode_head.norm_cfg = norm_cfg config.model.auxiliary_head.norm_cfg = norm_cfg @@ -22,9 +22,8 @@ def test_test_time_augmentation_on_cpu(): config.data.test.pipeline[1].flip = True checkpoint_file = None - model = init_segmentor(config, checkpoint_file, device='cpu') + model = init_segmentor(config, checkpoint_file, device="cpu") - img = mmcv.imread( - osp.join(osp.dirname(__file__), 'data/color.jpg'), 'color') + img = mmcv.imread(osp.join(osp.dirname(__file__), "data/color.jpg"), "color") result = inference_segmentor(model, img) assert result[0].shape == (288, 512) diff --git a/mmsegmentation/tests/test_metrics.py b/mmsegmentation/tests/test_metrics.py index adb09ae..c660dda 100644 --- a/mmsegmentation/tests/test_metrics.py +++ b/mmsegmentation/tests/test_metrics.py @@ -1,21 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np -from mmseg.core.evaluation import (eval_metrics, mean_dice, mean_fscore, - mean_iou) +from mmseg.core.evaluation import eval_metrics, mean_dice, mean_fscore, mean_iou from mmseg.core.evaluation.metrics import f_score def get_confusion_matrix(pred_label, label, num_classes, ignore_index): """Intersection over Union - Args: - pred_label (np.ndarray): 2D predict map - label (np.ndarray): label 2D label map - num_classes (int): number of categories - ignore_index (int): index ignore in evaluation - """ - - mask = (label != ignore_index) + Args: + pred_label (np.ndarray): 2D predict map + label (np.ndarray): label 2D label map + num_classes (int): number of categories + ignore_index (int): index ignore in evaluation + """ + + mask = label != ignore_index pred_label = pred_label[mask] label = label[mask] @@ -34,12 +33,14 @@ def legacy_mean_iou(results, gt_seg_maps, num_classes, ignore_index): total_mat = np.zeros((num_classes, num_classes), dtype=np.float32) for i in range(num_imgs): mat = get_confusion_matrix( - results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index) + results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index + ) total_mat += mat all_acc = np.diag(total_mat).sum() / total_mat.sum() acc = np.diag(total_mat) / total_mat.sum(axis=1) iou = np.diag(total_mat) / ( - total_mat.sum(axis=1) + total_mat.sum(axis=0) - np.diag(total_mat)) + total_mat.sum(axis=1) + total_mat.sum(axis=0) - np.diag(total_mat) + ) return all_acc, acc, iou @@ -51,28 +52,25 @@ def legacy_mean_dice(results, gt_seg_maps, num_classes, ignore_index): total_mat = np.zeros((num_classes, num_classes), dtype=np.float32) for i in range(num_imgs): mat = get_confusion_matrix( - results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index) + results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index + ) total_mat += mat all_acc = np.diag(total_mat).sum() / total_mat.sum() acc = np.diag(total_mat) / total_mat.sum(axis=1) - dice = 2 * np.diag(total_mat) / ( - total_mat.sum(axis=1) + total_mat.sum(axis=0)) + dice = 2 * np.diag(total_mat) / (total_mat.sum(axis=1) + total_mat.sum(axis=0)) return all_acc, acc, dice # This func is deprecated since it's not memory efficient -def legacy_mean_fscore(results, - gt_seg_maps, - num_classes, - ignore_index, - beta=1): +def legacy_mean_fscore(results, gt_seg_maps, num_classes, ignore_index, beta=1): num_imgs = len(results) assert len(gt_seg_maps) == num_imgs total_mat = np.zeros((num_classes, num_classes), dtype=np.float32) for i in range(num_imgs): mat = get_confusion_matrix( - results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index) + results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index + ) total_mat += mat all_acc = np.diag(total_mat).sum() / total_mat.sum() recall = np.diag(total_mat) / total_mat.sum(axis=1) @@ -95,46 +93,54 @@ def test_metrics(): # Test the correctness of the implementation of mIoU calculation. ret_metrics = eval_metrics( - results, label, num_classes, ignore_index, metrics='mIoU') - all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[ - 'IoU'] - all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes, - ignore_index) + results, label, num_classes, ignore_index, metrics="mIoU" + ) + all_acc, acc, iou = ret_metrics["aAcc"], ret_metrics["Acc"], ret_metrics["IoU"] + all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes, ignore_index) assert np.allclose(all_acc, all_acc_l) assert np.allclose(acc, acc_l) assert np.allclose(iou, iou_l) # Test the correctness of the implementation of mDice calculation. ret_metrics = eval_metrics( - results, label, num_classes, ignore_index, metrics='mDice') - all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[ - 'Dice'] - all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes, - ignore_index) + results, label, num_classes, ignore_index, metrics="mDice" + ) + all_acc, acc, dice = ret_metrics["aAcc"], ret_metrics["Acc"], ret_metrics["Dice"] + all_acc_l, acc_l, dice_l = legacy_mean_dice( + results, label, num_classes, ignore_index + ) assert np.allclose(all_acc, all_acc_l) assert np.allclose(acc, acc_l) assert np.allclose(dice, dice_l) # Test the correctness of the implementation of mDice calculation. ret_metrics = eval_metrics( - results, label, num_classes, ignore_index, metrics='mFscore') - all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[ - 'Recall'], ret_metrics['Precision'], ret_metrics['Fscore'] + results, label, num_classes, ignore_index, metrics="mFscore" + ) + all_acc, recall, precision, fscore = ( + ret_metrics["aAcc"], + ret_metrics["Recall"], + ret_metrics["Precision"], + ret_metrics["Fscore"], + ) all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore( - results, label, num_classes, ignore_index) + results, label, num_classes, ignore_index + ) assert np.allclose(all_acc, all_acc_l) assert np.allclose(recall, recall_l) assert np.allclose(precision, precision_l) assert np.allclose(fscore, fscore_l) # Test the correctness of the implementation of joint calculation. ret_metrics = eval_metrics( - results, - label, - num_classes, - ignore_index, - metrics=['mIoU', 'mDice', 'mFscore']) - all_acc, acc, iou, dice, precision, recall, fscore = ret_metrics[ - 'aAcc'], ret_metrics['Acc'], ret_metrics['IoU'], ret_metrics[ - 'Dice'], ret_metrics['Precision'], ret_metrics[ - 'Recall'], ret_metrics['Fscore'] + results, label, num_classes, ignore_index, metrics=["mIoU", "mDice", "mFscore"] + ) + all_acc, acc, iou, dice, precision, recall, fscore = ( + ret_metrics["aAcc"], + ret_metrics["Acc"], + ret_metrics["IoU"], + ret_metrics["Dice"], + ret_metrics["Precision"], + ret_metrics["Recall"], + ret_metrics["Fscore"], + ) assert np.allclose(all_acc, all_acc_l) assert np.allclose(acc, acc_l) assert np.allclose(iou, iou_l) @@ -148,38 +154,28 @@ def test_metrics(): results = np.random.randint(0, 5, size=pred_size) label = np.random.randint(0, 4, size=pred_size) ret_metrics = eval_metrics( - results, - label, - num_classes, - ignore_index=255, - metrics='mIoU', - nan_to_num=-1) - all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[ - 'IoU'] + results, label, num_classes, ignore_index=255, metrics="mIoU", nan_to_num=-1 + ) + all_acc, acc, iou = ret_metrics["aAcc"], ret_metrics["Acc"], ret_metrics["IoU"] assert acc[-1] == -1 assert iou[-1] == -1 ret_metrics = eval_metrics( - results, - label, - num_classes, - ignore_index=255, - metrics='mDice', - nan_to_num=-1) - all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[ - 'Dice'] + results, label, num_classes, ignore_index=255, metrics="mDice", nan_to_num=-1 + ) + all_acc, acc, dice = ret_metrics["aAcc"], ret_metrics["Acc"], ret_metrics["Dice"] assert acc[-1] == -1 assert dice[-1] == -1 ret_metrics = eval_metrics( - results, - label, - num_classes, - ignore_index=255, - metrics='mFscore', - nan_to_num=-1) - all_acc, precision, recall, fscore = ret_metrics['aAcc'], ret_metrics[ - 'Precision'], ret_metrics['Recall'], ret_metrics['Fscore'] + results, label, num_classes, ignore_index=255, metrics="mFscore", nan_to_num=-1 + ) + all_acc, precision, recall, fscore = ( + ret_metrics["aAcc"], + ret_metrics["Precision"], + ret_metrics["Recall"], + ret_metrics["Fscore"], + ) assert precision[-1] == -1 assert recall[-1] == -1 assert fscore[-1] == -1 @@ -189,12 +185,18 @@ def test_metrics(): label, num_classes, ignore_index=255, - metrics=['mDice', 'mIoU', 'mFscore'], - nan_to_num=-1) - all_acc, acc, iou, dice, precision, recall, fscore = ret_metrics[ - 'aAcc'], ret_metrics['Acc'], ret_metrics['IoU'], ret_metrics[ - 'Dice'], ret_metrics['Precision'], ret_metrics[ - 'Recall'], ret_metrics['Fscore'] + metrics=["mDice", "mIoU", "mFscore"], + nan_to_num=-1, + ) + all_acc, acc, iou, dice, precision, recall, fscore = ( + ret_metrics["aAcc"], + ret_metrics["Acc"], + ret_metrics["IoU"], + ret_metrics["Dice"], + ret_metrics["Precision"], + ret_metrics["Recall"], + ret_metrics["Fscore"], + ) assert acc[-1] == -1 assert dice[-1] == -1 assert iou[-1] == -1 @@ -210,9 +212,9 @@ def test_metrics(): label = np.array([np.arange(59)]) num_classes = 59 ret_metrics = eval_metrics( - results, label, num_classes, ignore_index=255, metrics='mIoU') - all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[ - 'IoU'] + results, label, num_classes, ignore_index=255, metrics="mIoU" + ) + all_acc, acc, iou = ret_metrics["aAcc"], ret_metrics["Acc"], ret_metrics["IoU"] assert not np.any(np.isnan(iou)) @@ -224,20 +226,16 @@ def test_mean_iou(): label = np.random.randint(0, num_classes, size=pred_size) label[:, 2, 5:10] = ignore_index ret_metrics = mean_iou(results, label, num_classes, ignore_index) - all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[ - 'IoU'] - all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes, - ignore_index) + all_acc, acc, iou = ret_metrics["aAcc"], ret_metrics["Acc"], ret_metrics["IoU"] + all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes, ignore_index) assert np.allclose(all_acc, all_acc_l) assert np.allclose(acc, acc_l) assert np.allclose(iou, iou_l) results = np.random.randint(0, 5, size=pred_size) label = np.random.randint(0, 4, size=pred_size) - ret_metrics = mean_iou( - results, label, num_classes, ignore_index=255, nan_to_num=-1) - all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[ - 'IoU'] + ret_metrics = mean_iou(results, label, num_classes, ignore_index=255, nan_to_num=-1) + all_acc, acc, iou = ret_metrics["aAcc"], ret_metrics["Acc"], ret_metrics["IoU"] assert acc[-1] == -1 assert acc[-1] == -1 @@ -250,10 +248,10 @@ def test_mean_dice(): label = np.random.randint(0, num_classes, size=pred_size) label[:, 2, 5:10] = ignore_index ret_metrics = mean_dice(results, label, num_classes, ignore_index) - all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[ - 'Dice'] - all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes, - ignore_index) + all_acc, acc, iou = ret_metrics["aAcc"], ret_metrics["Acc"], ret_metrics["Dice"] + all_acc_l, acc_l, dice_l = legacy_mean_dice( + results, label, num_classes, ignore_index + ) assert np.allclose(all_acc, all_acc_l) assert np.allclose(acc, acc_l) assert np.allclose(iou, dice_l) @@ -261,9 +259,9 @@ def test_mean_dice(): results = np.random.randint(0, 5, size=pred_size) label = np.random.randint(0, 4, size=pred_size) ret_metrics = mean_dice( - results, label, num_classes, ignore_index=255, nan_to_num=-1) - all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[ - 'Dice'] + results, label, num_classes, ignore_index=255, nan_to_num=-1 + ) + all_acc, acc, dice = ret_metrics["aAcc"], ret_metrics["Acc"], ret_metrics["Dice"] assert acc[-1] == -1 assert dice[-1] == -1 @@ -276,21 +274,30 @@ def test_mean_fscore(): label = np.random.randint(0, num_classes, size=pred_size) label[:, 2, 5:10] = ignore_index ret_metrics = mean_fscore(results, label, num_classes, ignore_index) - all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[ - 'Recall'], ret_metrics['Precision'], ret_metrics['Fscore'] + all_acc, recall, precision, fscore = ( + ret_metrics["aAcc"], + ret_metrics["Recall"], + ret_metrics["Precision"], + ret_metrics["Fscore"], + ) all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore( - results, label, num_classes, ignore_index) + results, label, num_classes, ignore_index + ) assert np.allclose(all_acc, all_acc_l) assert np.allclose(recall, recall_l) assert np.allclose(precision, precision_l) assert np.allclose(fscore, fscore_l) - ret_metrics = mean_fscore( - results, label, num_classes, ignore_index, beta=2) - all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[ - 'Recall'], ret_metrics['Precision'], ret_metrics['Fscore'] + ret_metrics = mean_fscore(results, label, num_classes, ignore_index, beta=2) + all_acc, recall, precision, fscore = ( + ret_metrics["aAcc"], + ret_metrics["Recall"], + ret_metrics["Precision"], + ret_metrics["Fscore"], + ) all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore( - results, label, num_classes, ignore_index, beta=2) + results, label, num_classes, ignore_index, beta=2 + ) assert np.allclose(all_acc, all_acc_l) assert np.allclose(recall, recall_l) assert np.allclose(precision, precision_l) @@ -299,9 +306,14 @@ def test_mean_fscore(): results = np.random.randint(0, 5, size=pred_size) label = np.random.randint(0, 4, size=pred_size) ret_metrics = mean_fscore( - results, label, num_classes, ignore_index=255, nan_to_num=-1) - all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[ - 'Recall'], ret_metrics['Precision'], ret_metrics['Fscore'] + results, label, num_classes, ignore_index=255, nan_to_num=-1 + ) + all_acc, recall, precision, fscore = ( + ret_metrics["aAcc"], + ret_metrics["Recall"], + ret_metrics["Precision"], + ret_metrics["Fscore"], + ) assert recall[-1] == -1 assert precision[-1] == -1 assert fscore[-1] == -1 @@ -314,9 +326,9 @@ def test_filename_inputs(): def save_arr(input_arrays: list, title: str, is_image: bool, dir: str): filenames = [] - SUFFIX = '.png' if is_image else '.npy' + SUFFIX = ".png" if is_image else ".npy" for idx, arr in enumerate(input_arrays): - filename = '{}/{}-{}{}'.format(dir, title, idx, SUFFIX) + filename = f"{dir}/{title}-{idx}{SUFFIX}" if is_image: cv2.imwrite(filename, arr) else: @@ -332,20 +344,16 @@ def save_arr(input_arrays: list, title: str, is_image: bool, dir: str): labels[:, 2, 5:10] = ignore_index with tempfile.TemporaryDirectory() as temp_dir: - - result_files = save_arr(results, 'pred', False, temp_dir) - label_files = save_arr(labels, 'label', True, temp_dir) + result_files = save_arr(results, "pred", False, temp_dir) + label_files = save_arr(labels, "label", True, temp_dir) ret_metrics = eval_metrics( - result_files, - label_files, - num_classes, - ignore_index, - metrics='mIoU') - all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics[ - 'Acc'], ret_metrics['IoU'] - all_acc_l, acc_l, iou_l = legacy_mean_iou(results, labels, num_classes, - ignore_index) + result_files, label_files, num_classes, ignore_index, metrics="mIoU" + ) + all_acc, acc, iou = ret_metrics["aAcc"], ret_metrics["Acc"], ret_metrics["IoU"] + all_acc_l, acc_l, iou_l = legacy_mean_iou( + results, labels, num_classes, ignore_index + ) assert np.allclose(all_acc, all_acc_l) assert np.allclose(acc, acc_l) assert np.allclose(iou, iou_l) diff --git a/mmsegmentation/tests/test_models/test_backbones/__init__.py b/mmsegmentation/tests/test_models/test_backbones/__init__.py index 8b673fa..d78e0f2 100644 --- a/mmsegmentation/tests/test_models/test_backbones/__init__.py +++ b/mmsegmentation/tests/test_models/test_backbones/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .utils import all_zeros, check_norm_state, is_block, is_norm -__all__ = ['is_norm', 'is_block', 'all_zeros', 'check_norm_state'] +__all__ = ["is_norm", "is_block", "all_zeros", "check_norm_state"] diff --git a/mmsegmentation/tests/test_models/test_backbones/test_beit.py b/mmsegmentation/tests/test_models/test_backbones/test_beit.py index cf39608..118d1f3 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_beit.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_beit.py @@ -18,7 +18,7 @@ def test_beit_backbone(): with pytest.raises(TypeError): # out_indices must be int ,list or tuple - model = BEiT(out_indices=1.) + model = BEiT(out_indices=1.0) with pytest.raises(AssertionError): # The length of img_size tuple must be lower than 3. @@ -30,7 +30,7 @@ def test_beit_backbone(): # Test img_size isinstance tuple imgs = torch.randn(1, 3, 224, 224) - model = BEiT(img_size=(224, )) + model = BEiT(img_size=(224,)) model.init_weights() model(imgs) @@ -115,7 +115,7 @@ def test_beit_backbone(): def test_beit_init(): - path = 'PATH_THAT_DO_NOT_EXIST' + path = "PATH_THAT_DO_NOT_EXIST" # Test all combinations of pretrained and init_cfg # pretrained=None, init_cfg=None model = BEiT(pretrained=None, init_cfg=None) @@ -124,9 +124,8 @@ def test_beit_init(): # pretrained=None # init_cfg loads pretrain from an non-existent file - model = BEiT( - pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path)) - assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + model = BEiT(pretrained=None, init_cfg=dict(type="Pretrained", checkpoint=path)) + assert model.init_cfg == dict(type="Pretrained", checkpoint=path) # Test loading a checkpoint from an non-existent file with pytest.raises(OSError): model.init_weights() @@ -134,9 +133,9 @@ def test_beit_init(): # test resize_rel_pos_embed value = torch.randn(732, 16) ckpt = { - 'state_dict': { - 'layers.0.attn.relative_position_index': 0, - 'layers.0.attn.relative_position_bias_table': value + "state_dict": { + "layers.0.attn.relative_position_index": 0, + "layers.0.attn.relative_position_bias_table": value, } } model = BEiT(img_size=(512, 512)) @@ -152,7 +151,7 @@ def test_beit_init(): # pretrained loads pretrain from an non-existent file # init_cfg=None model = BEiT(pretrained=path, init_cfg=None) - assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + assert model.init_cfg == dict(type="Pretrained", checkpoint=path) # Test loading a checkpoint from an non-existent file with pytest.raises(OSError): model.init_weights() @@ -160,8 +159,7 @@ def test_beit_init(): # pretrained loads pretrain from an non-existent file # init_cfg loads pretrain from an non-existent file with pytest.raises(AssertionError): - model = BEiT( - pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path)) + model = BEiT(pretrained=path, init_cfg=dict(type="Pretrained", checkpoint=path)) with pytest.raises(AssertionError): model = BEiT(pretrained=path, init_cfg=123) @@ -173,8 +171,7 @@ def test_beit_init(): # pretrain=123, whose type is unsupported # init_cfg loads pretrain from an non-existent file with pytest.raises(AssertionError): - model = BEiT( - pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path)) + model = BEiT(pretrained=123, init_cfg=dict(type="Pretrained", checkpoint=path)) # pretrain=123, whose type is unsupported # init_cfg=123, whose type is unsupported diff --git a/mmsegmentation/tests/test_models/test_backbones/test_bisenetv1.py b/mmsegmentation/tests/test_models/test_backbones/test_bisenetv1.py index c067749..436f945 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_bisenetv1.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_bisenetv1.py @@ -3,15 +3,18 @@ import torch from mmseg.models.backbones import BiSeNetV1 -from mmseg.models.backbones.bisenetv1 import (AttentionRefinementModule, - ContextPath, FeatureFusionModule, - SpatialPath) +from mmseg.models.backbones.bisenetv1 import ( + AttentionRefinementModule, + ContextPath, + FeatureFusionModule, + SpatialPath, +) def test_bisenetv1_backbone(): # Test BiSeNetV1 Standard Forward backbone_cfg = dict( - type='ResNet', + type="ResNet", in_channels=3, depth=18, num_stages=4, @@ -19,8 +22,9 @@ def test_bisenetv1_backbone(): dilations=(1, 1, 1, 1), strides=(1, 2, 2, 2), norm_eval=False, - style='pytorch', - contract_dilation=True) + style="pytorch", + contract_dilation=True, + ) model = BiSeNetV1(in_channels=3, backbone_cfg=backbone_cfg) model.init_weights() model.train() @@ -45,16 +49,14 @@ def test_bisenetv1_backbone(): with pytest.raises(AssertionError): # BiSeNetV1 spatial path channel constraints. BiSeNetV1( - backbone_cfg=backbone_cfg, - in_channels=3, - spatial_channels=(16, 16, 16)) + backbone_cfg=backbone_cfg, in_channels=3, spatial_channels=(16, 16, 16) + ) with pytest.raises(AssertionError): # BiSeNetV1 context path constraints. BiSeNetV1( - backbone_cfg=backbone_cfg, - in_channels=3, - context_channels=(16, 32, 64, 128)) + backbone_cfg=backbone_cfg, in_channels=3, context_channels=(16, 32, 64, 128) + ) def test_bisenetv1_spatial_path(): @@ -65,7 +67,7 @@ def test_bisenetv1_spatial_path(): def test_bisenetv1_context_path(): backbone_cfg = dict( - type='ResNet', + type="ResNet", in_channels=3, depth=50, num_stages=4, @@ -73,13 +75,13 @@ def test_bisenetv1_context_path(): dilations=(1, 1, 1, 1), strides=(1, 2, 2, 2), norm_eval=False, - style='pytorch', - contract_dilation=True) + style="pytorch", + contract_dilation=True, + ) with pytest.raises(AssertionError): # BiSeNetV1 context path constraints. - ContextPath( - backbone_cfg=backbone_cfg, context_channels=(16, 32, 64, 128)) + ContextPath(backbone_cfg=backbone_cfg, context_channels=(16, 32, 64, 128)) def test_bisenetv1_attention_refinement_module(): diff --git a/mmsegmentation/tests/test_models/test_backbones/test_bisenetv2.py b/mmsegmentation/tests/test_models/test_backbones/test_bisenetv2.py index cf2dfb3..81ead85 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_bisenetv2.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_bisenetv2.py @@ -3,8 +3,7 @@ from mmcv.cnn import ConvModule from mmseg.models.backbones import BiSeNetV2 -from mmseg.models.backbones.bisenetv2 import (BGALayer, DetailBranch, - SemanticBranch) +from mmseg.models.backbones.bisenetv2 import BGALayer, DetailBranch, SemanticBranch def test_bisenetv2_backbone(): diff --git a/mmsegmentation/tests/test_models/test_backbones/test_blocks.py b/mmsegmentation/tests/test_models/test_backbones/test_blocks.py index 77c8564..4fcfcf1 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_blocks.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_blocks.py @@ -4,8 +4,12 @@ import torch from mmcv.utils import TORCH_VERSION, digit_version -from mmseg.models.utils import (InvertedResidual, InvertedResidualV3, SELayer, - make_divisible) +from mmseg.models.utils import ( + InvertedResidual, + InvertedResidualV3, + SELayer, + make_divisible, +) def test_make_divisible(): @@ -78,7 +82,7 @@ def test_inv_residualv3(): assert inv_module.with_res_shortcut is True assert inv_module.with_se is False assert inv_module.with_expand_conv is False - assert not hasattr(inv_module, 'expand_conv') + assert not hasattr(inv_module, "expand_conv") assert isinstance(inv_module.depthwise_conv.conv, torch.nn.Conv2d) assert inv_module.depthwise_conv.conv.kernel_size == (3, 3) assert inv_module.depthwise_conv.conv.stride == (1, 1) @@ -98,11 +102,10 @@ def test_inv_residualv3(): se_cfg = dict( channels=16, ratio=4, - act_cfg=(dict(type='ReLU'), - dict(type='HSigmoid', bias=3.0, divisor=6.0))) - act_cfg = dict(type='HSwish') - inv_module = InvertedResidualV3( - 32, 40, 16, 3, 2, se_cfg=se_cfg, act_cfg=act_cfg) + act_cfg=(dict(type="ReLU"), dict(type="HSigmoid", bias=3.0, divisor=6.0)), + ) + act_cfg = dict(type="HSwish") + inv_module = InvertedResidualV3(32, 40, 16, 3, 2, se_cfg=se_cfg, act_cfg=act_cfg) assert inv_module.with_res_shortcut is False assert inv_module.with_se is True assert inv_module.with_expand_conv is True @@ -110,8 +113,9 @@ def test_inv_residualv3(): assert inv_module.expand_conv.conv.stride == (1, 1) assert inv_module.expand_conv.conv.padding == (0, 0) - assert isinstance(inv_module.depthwise_conv.conv, - mmcv.cnn.bricks.Conv2dAdaptivePadding) + assert isinstance( + inv_module.depthwise_conv.conv, mmcv.cnn.bricks.Conv2dAdaptivePadding + ) assert inv_module.depthwise_conv.conv.kernel_size == (3, 3) assert inv_module.depthwise_conv.conv.stride == (2, 2) assert inv_module.depthwise_conv.conv.padding == (0, 0) @@ -122,8 +126,9 @@ def test_inv_residualv3(): assert inv_module.linear_conv.conv.padding == (0, 0) assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d) - if (TORCH_VERSION == 'parrots' - or digit_version(TORCH_VERSION) < digit_version('1.7')): + if TORCH_VERSION == "parrots" or digit_version(TORCH_VERSION) < digit_version( + "1.7" + ): # Note: Use PyTorch official HSwish # when torch>=1.7 after MMCV >= 1.4.5. # Hardswish is not supported when PyTorch version < 1.6. @@ -134,8 +139,7 @@ def test_inv_residualv3(): assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish) else: assert isinstance(inv_module.expand_conv.activate, torch.nn.Hardswish) - assert isinstance(inv_module.depthwise_conv.activate, - torch.nn.Hardswish) + assert isinstance(inv_module.depthwise_conv.activate, torch.nn.Hardswish) x = torch.rand(1, 32, 64, 64) output = inv_module(x) @@ -143,7 +147,8 @@ def test_inv_residualv3(): # test with checkpoint forward inv_module = InvertedResidualV3( - 32, 40, 16, 3, 2, se_cfg=se_cfg, act_cfg=act_cfg, with_cp=True) + 32, 40, 16, 3, 2, se_cfg=se_cfg, act_cfg=act_cfg, with_cp=True + ) assert inv_module.with_cp x = torch.randn(2, 32, 64, 64, requires_grad=True) output = inv_module(x) @@ -153,7 +158,7 @@ def test_inv_residualv3(): def test_se_layer(): with pytest.raises(AssertionError): # test act_cfg assertion. - SELayer(32, act_cfg=(dict(type='ReLU'), )) + SELayer(32, act_cfg=(dict(type="ReLU"),)) # test config with channels = 16. se_layer = SELayer(16) @@ -171,7 +176,7 @@ def test_se_layer(): assert output.shape == (1, 16, 64, 64) # test config with channels = 16, act_cfg = dict(type='ReLU'). - se_layer = SELayer(16, act_cfg=dict(type='ReLU')) + se_layer = SELayer(16, act_cfg=dict(type="ReLU")) assert se_layer.conv1.conv.kernel_size == (1, 1) assert se_layer.conv1.conv.stride == (1, 1) assert se_layer.conv1.conv.padding == (0, 0) diff --git a/mmsegmentation/tests/test_models/test_backbones/test_cgnet.py b/mmsegmentation/tests/test_models/test_backbones/test_cgnet.py index f938525..e63ac24 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_cgnet.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_cgnet.py @@ -3,8 +3,7 @@ import torch from mmseg.models.backbones import CGNet -from mmseg.models.backbones.cgnet import (ContextGuidedBlock, - GlobalContextExtractor) +from mmseg.models.backbones.cgnet import ContextGuidedBlock, GlobalContextExtractor def test_cgnet_GlobalContextExtractor(): @@ -21,8 +20,7 @@ def test_cgnet_context_guided_block(): ContextGuidedBlock(8, 8) # test cgnet ContextGuidedBlock with checkpoint forward - block = ContextGuidedBlock( - 16, 16, act_cfg=dict(type='PReLU'), with_cp=True) + block = ContextGuidedBlock(16, 16, act_cfg=dict(type="PReLU"), with_cp=True) assert block.with_cp x = torch.randn(2, 16, 64, 64, requires_grad=True) x_out = block(x) diff --git a/mmsegmentation/tests/test_models/test_backbones/test_erfnet.py b/mmsegmentation/tests/test_models/test_backbones/test_erfnet.py index 6ae7345..1687aa7 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_erfnet.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_erfnet.py @@ -3,8 +3,11 @@ import torch from mmseg.models.backbones import ERFNet -from mmseg.models.backbones.erfnet import (DownsamplerBlock, NonBottleneck1d, - UpsamplerBlock) +from mmseg.models.backbones.erfnet import ( + DownsamplerBlock, + NonBottleneck1d, + UpsamplerBlock, +) def test_erfnet_backbone(): diff --git a/mmsegmentation/tests/test_models/test_backbones/test_fast_scnn.py b/mmsegmentation/tests/test_models/test_backbones/test_fast_scnn.py index 7ee638b..df294b9 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_fast_scnn.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_fast_scnn.py @@ -9,11 +9,15 @@ def test_fastscnn_backbone(): with pytest.raises(AssertionError): # Fast-SCNN channel constraints. FastSCNN( - 3, (32, 48), - 64, (64, 96, 128), (2, 2, 1), + 3, + (32, 48), + 64, + (64, 96, 128), + (2, 2, 1), global_out_channels=127, higher_in_channels=64, - lower_in_channels=128) + lower_in_channels=128, + ) # Test FastSCNN Standard Forward model = FastSCNN( diff --git a/mmsegmentation/tests/test_models/test_backbones/test_hrnet.py b/mmsegmentation/tests/test_models/test_backbones/test_hrnet.py index 8329c84..feffc5d 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_hrnet.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_hrnet.py @@ -7,7 +7,7 @@ from mmseg.models.backbones.resnet import BasicBlock, Bottleneck -@pytest.mark.parametrize('block', [BasicBlock, Bottleneck]) +@pytest.mark.parametrize("block", [BasicBlock, Bottleneck]) def test_hrmodule(block): # Test multiscale forward num_channles = (32, 64) @@ -22,7 +22,7 @@ def test_hrmodule(block): feats = [ torch.randn(1, in_channels[0], 64, 64), - torch.randn(1, in_channels[1], 32, 32) + torch.randn(1, in_channels[1], 32, 32), ] feats = hrmodule(feats) @@ -44,7 +44,7 @@ def test_hrmodule(block): feats = [ torch.randn(1, in_channels[0], 64, 64), - torch.randn(1, in_channels[1], 32, 32) + torch.randn(1, in_channels[1], 32, 32), ] feats = hrmodule(feats) @@ -58,37 +58,42 @@ def test_hrnet_backbone(): stage1=dict( num_modules=1, num_branches=1, - block='BOTTLENECK', - num_blocks=(4, ), - num_channels=(64, )), + block="BOTTLENECK", + num_blocks=(4,), + num_channels=(64,), + ), stage2=dict( num_modules=1, num_branches=2, - block='BASIC', + block="BASIC", num_blocks=(4, 4), - num_channels=(32, 64)), + num_channels=(32, 64), + ), stage3=dict( num_modules=4, num_branches=3, - block='BASIC', + block="BASIC", num_blocks=(4, 4, 4), - num_channels=(32, 64, 128))) + num_channels=(32, 64, 128), + ), + ) with pytest.raises(AssertionError): # HRNet now only support 4 stages HRNet(extra=extra) - extra['stage4'] = dict( + extra["stage4"] = dict( num_modules=3, num_branches=3, # should be 4 - block='BASIC', + block="BASIC", num_blocks=(4, 4, 4, 4), - num_channels=(32, 64, 128, 256)) + num_channels=(32, 64, 128, 256), + ) with pytest.raises(AssertionError): # len(num_blocks) should equal num_branches HRNet(extra=extra) - extra['stage4']['num_branches'] = 4 + extra["stage4"]["num_branches"] = 4 # Test hrnetv2p_w32 model = HRNet(extra=extra) @@ -123,13 +128,13 @@ def test_hrnet_backbone(): assert param.requires_grad is False for i in range(1, frozen_stages + 1): if i == 1: - layer = getattr(model, f'layer{i}') - transition = getattr(model, f'transition{i}') + layer = getattr(model, f"layer{i}") + transition = getattr(model, f"transition{i}") elif i == 4: - layer = getattr(model, f'stage{i}') + layer = getattr(model, f"stage{i}") else: - layer = getattr(model, f'stage{i}') - transition = getattr(model, f'transition{i}') + layer = getattr(model, f"stage{i}") + transition = getattr(model, f"transition{i}") for mod in layer.modules(): if isinstance(mod, _BatchNorm): diff --git a/mmsegmentation/tests/test_models/test_backbones/test_icnet.py b/mmsegmentation/tests/test_models/test_backbones/test_icnet.py index a96d8d8..10f0450 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_icnet.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_icnet.py @@ -14,26 +14,29 @@ def test_icnet_backbone(): light_branch_middle_channels=8, psp_out_channels=128, out_channels=(16, 128, 128), - backbone_cfg=None) + backbone_cfg=None, + ) # Test ICNet Standard Forward model = ICNet( layer_channels=(128, 512), backbone_cfg=dict( - type='ResNetV1c', + type="ResNetV1c", in_channels=3, depth=18, num_stages=4, out_indices=(0, 1, 2, 3), dilations=(1, 1, 2, 4), strides=(1, 2, 1, 1), - norm_cfg=dict(type='BN', requires_grad=True), + norm_cfg=dict(type="BN", requires_grad=True), norm_eval=False, - style='pytorch', - contract_dilation=True), + style="pytorch", + contract_dilation=True, + ), + ) + assert ( + hasattr(model.backbone, "maxpool") and model.backbone.maxpool.ceil_mode is True ) - assert hasattr(model.backbone, - 'maxpool') and model.backbone.maxpool.ceil_mode is True model.init_weights() model.train() batch_size = 2 diff --git a/mmsegmentation/tests/test_models/test_backbones/test_mae.py b/mmsegmentation/tests/test_models/test_backbones/test_mae.py index 562d067..a5c2e5f 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_mae.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_mae.py @@ -18,7 +18,7 @@ def test_mae_backbone(): with pytest.raises(TypeError): # out_indices must be int ,list or tuple - model = MAE(out_indices=1.) + model = MAE(out_indices=1.0) with pytest.raises(AssertionError): # The length of img_size tuple must be lower than 3. @@ -30,7 +30,7 @@ def test_mae_backbone(): # Test img_size isinstance tuple imgs = torch.randn(1, 3, 224, 224) - model = MAE(img_size=(224, )) + model = MAE(img_size=(224,)) model.init_weights() model(imgs) @@ -111,7 +111,7 @@ def test_mae_backbone(): def test_mae_init(): - path = 'PATH_THAT_DO_NOT_EXIST' + path = "PATH_THAT_DO_NOT_EXIST" # Test all combinations of pretrained and init_cfg # pretrained=None, init_cfg=None model = MAE(pretrained=None, init_cfg=None) @@ -120,9 +120,8 @@ def test_mae_init(): # pretrained=None # init_cfg loads pretrain from an non-existent file - model = MAE( - pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path)) - assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + model = MAE(pretrained=None, init_cfg=dict(type="Pretrained", checkpoint=path)) + assert model.init_cfg == dict(type="Pretrained", checkpoint=path) # Test loading a checkpoint from an non-existent file with pytest.raises(OSError): model.init_weights() @@ -131,10 +130,10 @@ def test_mae_init(): value = torch.randn(732, 16) abs_pos_embed_value = torch.rand(1, 17, 768) ckpt = { - 'state_dict': { - 'layers.0.attn.relative_position_index': 0, - 'layers.0.attn.relative_position_bias_table': value, - 'pos_embed': abs_pos_embed_value + "state_dict": { + "layers.0.attn.relative_position_index": 0, + "layers.0.attn.relative_position_bias_table": value, + "pos_embed": abs_pos_embed_value, } } model = MAE(img_size=(512, 512)) @@ -142,7 +141,7 @@ def test_mae_init(): model.resize_rel_pos_embed(ckpt) # test resize abs pos embed - ckpt = model.resize_abs_pos_embed(ckpt['state_dict']) + ckpt = model.resize_abs_pos_embed(ckpt["state_dict"]) # pretrained=None # init_cfg=123, whose type is unsupported @@ -153,7 +152,7 @@ def test_mae_init(): # pretrained loads pretrain from an non-existent file # init_cfg=None model = MAE(pretrained=path, init_cfg=None) - assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + assert model.init_cfg == dict(type="Pretrained", checkpoint=path) # Test loading a checkpoint from an non-existent file with pytest.raises(OSError): model.init_weights() @@ -161,8 +160,7 @@ def test_mae_init(): # pretrained loads pretrain from an non-existent file # init_cfg loads pretrain from an non-existent file with pytest.raises(AssertionError): - model = MAE( - pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path)) + model = MAE(pretrained=path, init_cfg=dict(type="Pretrained", checkpoint=path)) with pytest.raises(AssertionError): model = MAE(pretrained=path, init_cfg=123) @@ -174,8 +172,7 @@ def test_mae_init(): # pretrain=123, whose type is unsupported # init_cfg loads pretrain from an non-existent file with pytest.raises(AssertionError): - model = MAE( - pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path)) + model = MAE(pretrained=123, init_cfg=dict(type="Pretrained", checkpoint=path)) # pretrain=123, whose type is unsupported # init_cfg=123, whose type is unsupported diff --git a/mmsegmentation/tests/test_models/test_backbones/test_mit.py b/mmsegmentation/tests/test_models/test_backbones/test_mit.py index 72f74fe..5b00610 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_mit.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_mit.py @@ -3,8 +3,11 @@ import torch from mmseg.models.backbones import MixVisionTransformer -from mmseg.models.backbones.mit import (EfficientMultiheadAttention, MixFFN, - TransformerEncoderLayer) +from mmseg.models.backbones.mit import ( + EfficientMultiheadAttention, + MixFFN, + TransformerEncoderLayer, +) def test_mit(): @@ -16,7 +19,8 @@ def test_mit(): H, W = (224, 224) temp = torch.randn((1, 3, H, W)) model = MixVisionTransformer( - embed_dims=32, num_heads=[1, 2, 5, 8], out_indices=(0, 1, 2, 3)) + embed_dims=32, num_heads=[1, 2, 5, 8], out_indices=(0, 1, 2, 3) + ) model.init_weights() outs = model(temp) assert outs[0].shape == (1, 32, H // 4, W // 4) @@ -59,7 +63,8 @@ def test_mit(): # Test TransformerEncoderLayer with checkpoint forward block = TransformerEncoderLayer( - embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True) + embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True + ) assert block.with_cp x = torch.randn(1, 56 * 56, 64) x_out = block(x, (56, 56)) @@ -67,7 +72,7 @@ def test_mit(): def test_mit_init(): - path = 'PATH_THAT_DO_NOT_EXIST' + path = "PATH_THAT_DO_NOT_EXIST" # Test all combinations of pretrained and init_cfg # pretrained=None, init_cfg=None model = MixVisionTransformer(pretrained=None, init_cfg=None) @@ -77,8 +82,9 @@ def test_mit_init(): # pretrained=None # init_cfg loads pretrain from an non-existent file model = MixVisionTransformer( - pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path)) - assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + pretrained=None, init_cfg=dict(type="Pretrained", checkpoint=path) + ) + assert model.init_cfg == dict(type="Pretrained", checkpoint=path) # Test loading a checkpoint from an non-existent file with pytest.raises(OSError): model.init_weights() @@ -92,7 +98,7 @@ def test_mit_init(): # pretrained loads pretrain from an non-existent file # init_cfg=None model = MixVisionTransformer(pretrained=path, init_cfg=None) - assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + assert model.init_cfg == dict(type="Pretrained", checkpoint=path) # Test loading a checkpoint from an non-existent file with pytest.raises(OSError): model.init_weights() @@ -101,7 +107,8 @@ def test_mit_init(): # init_cfg loads pretrain from an non-existent file with pytest.raises(AssertionError): MixVisionTransformer( - pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path)) + pretrained=path, init_cfg=dict(type="Pretrained", checkpoint=path) + ) with pytest.raises(AssertionError): MixVisionTransformer(pretrained=path, init_cfg=123) @@ -114,7 +121,8 @@ def test_mit_init(): # init_cfg loads pretrain from an non-existent file with pytest.raises(AssertionError): MixVisionTransformer( - pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path)) + pretrained=123, init_cfg=dict(type="Pretrained", checkpoint=path) + ) # pretrain=123, whose type is unsupported # init_cfg=123, whose type is unsupported diff --git a/mmsegmentation/tests/test_models/test_backbones/test_mobilenet_v3.py b/mmsegmentation/tests/test_models/test_backbones/test_mobilenet_v3.py index 769ee14..04e574c 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_mobilenet_v3.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_mobilenet_v3.py @@ -8,7 +8,7 @@ def test_mobilenet_v3(): with pytest.raises(AssertionError): # check invalid arch - MobileNetV3('big') + MobileNetV3("big") with pytest.raises(AssertionError): # check invalid reduction_factor @@ -40,7 +40,7 @@ def test_mobilenet_v3(): assert feat[2].shape == (2, 576, 7, 7) # Test MobileNetV3 with arch = 'large' - model = MobileNetV3(arch='large', out_indices=(1, 3, 16)) + model = MobileNetV3(arch="large", out_indices=(1, 3, 16)) model.init_weights() model.train() diff --git a/mmsegmentation/tests/test_models/test_backbones/test_resnest.py b/mmsegmentation/tests/test_models/test_backbones/test_resnest.py index 3013f34..07848ec 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_resnest.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_resnest.py @@ -9,11 +9,10 @@ def test_resnest_bottleneck(): with pytest.raises(AssertionError): # Style must be in ['pytorch', 'caffe'] - BottleneckS(64, 64, radix=2, reduction_factor=4, style='tensorflow') + BottleneckS(64, 64, radix=2, reduction_factor=4, style="tensorflow") # Test ResNeSt Bottleneck structure - block = BottleneckS( - 64, 256, radix=2, reduction_factor=4, stride=2, style='pytorch') + block = BottleneckS(64, 256, radix=2, reduction_factor=4, stride=2, style="pytorch") assert block.avd_layer.stride == 2 assert block.conv2.channels == 256 @@ -30,8 +29,7 @@ def test_resnest_backbone(): ResNeSt(depth=18) # Test ResNeSt with radix 2, reduction_factor 4 - model = ResNeSt( - depth=50, radix=2, reduction_factor=4, out_indices=(0, 1, 2, 3)) + model = ResNeSt(depth=50, radix=2, reduction_factor=4, out_indices=(0, 1, 2, 3)) model.init_weights() model.train() diff --git a/mmsegmentation/tests/test_models/test_backbones/test_resnet.py b/mmsegmentation/tests/test_models/test_backbones/test_resnet.py index fa632f5..b7525e5 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_resnet.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_resnet.py @@ -14,15 +14,13 @@ def test_resnet_basic_block(): with pytest.raises(AssertionError): # Not implemented yet. - dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) + dcn = dict(type="DCN", deform_groups=1, fallback_on_stride=False) BasicBlock(64, 64, dcn=dcn) with pytest.raises(AssertionError): # Not implemented yet. plugins = [ - dict( - cfg=dict(type='ContextBlock', ratio=1. / 16), - position='after_conv3') + dict(cfg=dict(type="ContextBlock", ratio=1.0 / 16), position="after_conv3") ] BasicBlock(64, 64, plugins=plugins) @@ -31,12 +29,14 @@ def test_resnet_basic_block(): plugins = [ dict( cfg=dict( - type='GeneralizedAttention', + type="GeneralizedAttention", spatial_range=-1, num_heads=8, - attention_type='0010', - kv_stride=2), - position='after_conv2') + attention_type="0010", + kv_stride=2, + ), + position="after_conv2", + ) ] BasicBlock(64, 64, plugins=plugins) @@ -63,32 +63,26 @@ def test_resnet_basic_block(): def test_resnet_bottleneck(): with pytest.raises(AssertionError): # Style must be in ['pytorch', 'caffe'] - Bottleneck(64, 64, style='tensorflow') + Bottleneck(64, 64, style="tensorflow") with pytest.raises(AssertionError): # Allowed positions are 'after_conv1', 'after_conv2', 'after_conv3' plugins = [ - dict( - cfg=dict(type='ContextBlock', ratio=1. / 16), - position='after_conv4') + dict(cfg=dict(type="ContextBlock", ratio=1.0 / 16), position="after_conv4") ] Bottleneck(64, 16, plugins=plugins) with pytest.raises(AssertionError): # Need to specify different postfix to avoid duplicate plugin name plugins = [ - dict( - cfg=dict(type='ContextBlock', ratio=1. / 16), - position='after_conv3'), - dict( - cfg=dict(type='ContextBlock', ratio=1. / 16), - position='after_conv3') + dict(cfg=dict(type="ContextBlock", ratio=1.0 / 16), position="after_conv3"), + dict(cfg=dict(type="ContextBlock", ratio=1.0 / 16), position="after_conv3"), ] Bottleneck(64, 16, plugins=plugins) with pytest.raises(KeyError): # Plugin type is not supported - plugins = [dict(cfg=dict(type='WrongPlugin'), position='after_conv3')] + plugins = [dict(cfg=dict(type="WrongPlugin"), position="after_conv3")] Bottleneck(64, 16, plugins=plugins) # Test Bottleneck with checkpoint forward @@ -99,17 +93,17 @@ def test_resnet_bottleneck(): assert x_out.shape == torch.Size([1, 64, 56, 56]) # Test Bottleneck style - block = Bottleneck(64, 64, stride=2, style='pytorch') + block = Bottleneck(64, 64, stride=2, style="pytorch") assert block.conv1.stride == (1, 1) assert block.conv2.stride == (2, 2) - block = Bottleneck(64, 64, stride=2, style='caffe') + block = Bottleneck(64, 64, stride=2, style="caffe") assert block.conv1.stride == (2, 2) assert block.conv2.stride == (1, 1) # Test Bottleneck DCN - dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) + dcn = dict(type="DCN", deform_groups=1, fallback_on_stride=False) with pytest.raises(AssertionError): - Bottleneck(64, 64, dcn=dcn, conv_cfg=dict(type='Conv')) + Bottleneck(64, 64, dcn=dcn, conv_cfg=dict(type="Conv")) block = Bottleneck(64, 64, dcn=dcn) assert isinstance(block.conv2, DeformConv2dPack) @@ -121,9 +115,7 @@ def test_resnet_bottleneck(): # Test Bottleneck with 1 ContextBlock after conv3 plugins = [ - dict( - cfg=dict(type='ContextBlock', ratio=1. / 16), - position='after_conv3') + dict(cfg=dict(type="ContextBlock", ratio=1.0 / 16), position="after_conv3") ] block = Bottleneck(64, 16, plugins=plugins) assert block.context_block.in_channels == 64 @@ -135,12 +127,14 @@ def test_resnet_bottleneck(): plugins = [ dict( cfg=dict( - type='GeneralizedAttention', + type="GeneralizedAttention", spatial_range=-1, num_heads=8, - attention_type='0010', - kv_stride=2), - position='after_conv2') + attention_type="0010", + kv_stride=2, + ), + position="after_conv2", + ) ] block = Bottleneck(64, 16, plugins=plugins) assert block.gen_attention_block.in_channels == 16 @@ -153,16 +147,16 @@ def test_resnet_bottleneck(): plugins = [ dict( cfg=dict( - type='GeneralizedAttention', + type="GeneralizedAttention", spatial_range=-1, num_heads=8, - attention_type='0010', - kv_stride=2), - position='after_conv2'), - dict(cfg=dict(type='NonLocal2d'), position='after_conv2'), - dict( - cfg=dict(type='ContextBlock', ratio=1. / 16), - position='after_conv3') + attention_type="0010", + kv_stride=2, + ), + position="after_conv2", + ), + dict(cfg=dict(type="NonLocal2d"), position="after_conv2"), + dict(cfg=dict(type="ContextBlock", ratio=1.0 / 16), position="after_conv3"), ] block = Bottleneck(64, 16, plugins=plugins) assert block.gen_attention_block.in_channels == 16 @@ -176,14 +170,17 @@ def test_resnet_bottleneck(): # conv3 plugins = [ dict( - cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=1), - position='after_conv2'), + cfg=dict(type="ContextBlock", ratio=1.0 / 16, postfix=1), + position="after_conv2", + ), dict( - cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=2), - position='after_conv3'), + cfg=dict(type="ContextBlock", ratio=1.0 / 16, postfix=2), + position="after_conv3", + ), dict( - cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=3), - position='after_conv3') + cfg=dict(type="ContextBlock", ratio=1.0 / 16, postfix=3), + position="after_conv3", + ), ] block = Bottleneck(64, 16, plugins=plugins) assert block.context_block1.in_channels == 16 @@ -278,16 +275,17 @@ def test_resnet_backbone(): with pytest.raises(AssertionError): # len(stage_with_dcn) == num_stages - dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) - ResNet(50, dcn=dcn, stage_with_dcn=(True, )) + dcn = dict(type="DCN", deform_groups=1, fallback_on_stride=False) + ResNet(50, dcn=dcn, stage_with_dcn=(True,)) with pytest.raises(AssertionError): # len(stage_with_plugin) == num_stages plugins = [ dict( - cfg=dict(type='ContextBlock', ratio=1. / 16), + cfg=dict(type="ContextBlock", ratio=1.0 / 16), stages=(False, True, True), - position='after_conv3') + position="after_conv3", + ) ] ResNet(50, plugins=plugins) @@ -297,7 +295,7 @@ def test_resnet_backbone(): with pytest.raises(AssertionError): # len(strides) == len(dilations) == num_stages - ResNet(18, strides=(1, ), dilations=(1, 1), num_stages=3) + ResNet(18, strides=(1,), dilations=(1, 1), num_stages=3) with pytest.raises(TypeError): # pretrained must be a string path @@ -306,7 +304,7 @@ def test_resnet_backbone(): with pytest.raises(AssertionError): # Style must be in ['pytorch', 'caffe'] - ResNet(50, style='tensorflow') + ResNet(50, style="tensorflow") # Test ResNet18 norm_eval=True model = ResNet(18, norm_eval=True) @@ -315,8 +313,7 @@ def test_resnet_backbone(): assert check_norm_state(model.modules(), False) # Test ResNet18 with torchvision pretrained weight - model = ResNet( - depth=18, norm_eval=True, pretrained='torchvision://resnet18') + model = ResNet(depth=18, norm_eval=True, pretrained="torchvision://resnet18") model.init_weights() model.train() assert check_norm_state(model.modules(), False) @@ -331,7 +328,7 @@ def test_resnet_backbone(): for param in layer.parameters(): assert param.requires_grad is False for i in range(1, frozen_stages + 1): - layer = getattr(model, 'layer{}'.format(i)) + layer = getattr(model, f"layer{i}") for mod in layer.modules(): if isinstance(mod, _BatchNorm): assert mod.training is False @@ -347,7 +344,7 @@ def test_resnet_backbone(): for param in model.stem.parameters(): assert param.requires_grad is False for i in range(1, frozen_stages + 1): - layer = getattr(model, 'layer{}'.format(i)) + layer = getattr(model, f"layer{i}") for mod in layer.modules(): if isinstance(mod, _BatchNorm): assert mod.training is False @@ -428,8 +425,7 @@ def test_resnet_backbone(): assert feat[3].shape == torch.Size([1, 512, 7, 7]) # Test ResNet18 with GroupNorm forward - model = ResNet( - 18, norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)) + model = ResNet(18, norm_cfg=dict(type="GN", num_groups=32, requires_grad=True)) for m in model.modules(): if is_norm(m): assert isinstance(m, GroupNorm) @@ -449,24 +445,27 @@ def test_resnet_backbone(): plugins = [ dict( cfg=dict( - type='GeneralizedAttention', + type="GeneralizedAttention", spatial_range=-1, num_heads=8, - attention_type='0010', - kv_stride=2), + attention_type="0010", + kv_stride=2, + ), stages=(False, True, True, True), - position='after_conv2'), - dict(cfg=dict(type='NonLocal2d'), position='after_conv2'), + position="after_conv2", + ), + dict(cfg=dict(type="NonLocal2d"), position="after_conv2"), dict( - cfg=dict(type='ContextBlock', ratio=1. / 16), + cfg=dict(type="ContextBlock", ratio=1.0 / 16), stages=(False, True, True, False), - position='after_conv3') + position="after_conv3", + ), ] model = ResNet(50, plugins=plugins) for m in model.layer1.modules(): if is_block(m): - assert not hasattr(m, 'context_block') - assert not hasattr(m, 'gen_attention_block') + assert not hasattr(m, "context_block") + assert not hasattr(m, "gen_attention_block") assert m.nonlocal_block.in_channels == 64 for m in model.layer2.modules(): if is_block(m): @@ -484,7 +483,7 @@ def test_resnet_backbone(): if is_block(m): assert m.nonlocal_block.in_channels == 512 assert m.gen_attention_block.in_channels == 512 - assert not hasattr(m, 'context_block') + assert not hasattr(m, "context_block") model.init_weights() model.train() @@ -500,38 +499,40 @@ def test_resnet_backbone(): # conv3 in layers 2, 3, 4 plugins = [ dict( - cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=1), + cfg=dict(type="ContextBlock", ratio=1.0 / 16, postfix=1), stages=(False, True, True, False), - position='after_conv3'), + position="after_conv3", + ), dict( - cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=2), + cfg=dict(type="ContextBlock", ratio=1.0 / 16, postfix=2), stages=(False, True, True, False), - position='after_conv3') + position="after_conv3", + ), ] model = ResNet(50, plugins=plugins) for m in model.layer1.modules(): if is_block(m): - assert not hasattr(m, 'context_block') - assert not hasattr(m, 'context_block1') - assert not hasattr(m, 'context_block2') + assert not hasattr(m, "context_block") + assert not hasattr(m, "context_block1") + assert not hasattr(m, "context_block2") for m in model.layer2.modules(): if is_block(m): - assert not hasattr(m, 'context_block') + assert not hasattr(m, "context_block") assert m.context_block1.in_channels == 512 assert m.context_block2.in_channels == 512 for m in model.layer3.modules(): if is_block(m): - assert not hasattr(m, 'context_block') + assert not hasattr(m, "context_block") assert m.context_block1.in_channels == 1024 assert m.context_block2.in_channels == 1024 for m in model.layer4.modules(): if is_block(m): - assert not hasattr(m, 'context_block') - assert not hasattr(m, 'context_block1') - assert not hasattr(m, 'context_block2') + assert not hasattr(m, "context_block") + assert not hasattr(m, "context_block1") + assert not hasattr(m, "context_block2") model.init_weights() model.train() diff --git a/mmsegmentation/tests/test_models/test_backbones/test_resnext.py b/mmsegmentation/tests/test_models/test_backbones/test_resnext.py index 2aecaf0..36493db 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_resnext.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_resnext.py @@ -10,26 +10,21 @@ def test_renext_bottleneck(): with pytest.raises(AssertionError): # Style must be in ['pytorch', 'caffe'] - BottleneckX(64, 64, groups=32, base_width=4, style='tensorflow') + BottleneckX(64, 64, groups=32, base_width=4, style="tensorflow") # Test ResNeXt Bottleneck structure - block = BottleneckX( - 64, 64, groups=32, base_width=4, stride=2, style='pytorch') + block = BottleneckX(64, 64, groups=32, base_width=4, stride=2, style="pytorch") assert block.conv2.stride == (2, 2) assert block.conv2.groups == 32 assert block.conv2.out_channels == 128 # Test ResNeXt Bottleneck with DCN - dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) + dcn = dict(type="DCN", deform_groups=1, fallback_on_stride=False) with pytest.raises(AssertionError): # conv_cfg must be None if dcn is not None BottleneckX( - 64, - 64, - groups=32, - base_width=4, - dcn=dcn, - conv_cfg=dict(type='Conv')) + 64, 64, groups=32, base_width=4, dcn=dcn, conv_cfg=dict(type="Conv") + ) BottleneckX(64, 64, dcn=dcn) # Test ResNeXt Bottleneck forward diff --git a/mmsegmentation/tests/test_models/test_backbones/test_stdc.py b/mmsegmentation/tests/test_models/test_backbones/test_stdc.py index 1e3862b..9a50a14 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_stdc.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_stdc.py @@ -3,27 +3,32 @@ import torch from mmseg.models.backbones import STDCContextPathNet -from mmseg.models.backbones.stdc import (AttentionRefinementModule, - FeatureFusionModule, STDCModule, - STDCNet) +from mmseg.models.backbones.stdc import ( + AttentionRefinementModule, + FeatureFusionModule, + STDCModule, + STDCNet, +) def test_stdc_context_path_net(): # Test STDCContextPathNet Standard Forward model = STDCContextPathNet( backbone_cfg=dict( - type='STDCNet', - stdc_type='STDCNet1', + type="STDCNet", + stdc_type="STDCNet1", in_channels=3, channels=(32, 64, 256, 512, 1024), - bottleneck_type='cat', + bottleneck_type="cat", num_convs=4, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='ReLU'), - with_final_conv=True), + norm_cfg=dict(type="BN", requires_grad=True), + act_cfg=dict(type="ReLU"), + with_final_conv=True, + ), last_in_channels=(1024, 512), out_channels=128, - ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4)) + ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4), + ) model.init_weights() model.train() batch_size = 2 @@ -45,18 +50,20 @@ def test_stdc_context_path_net(): imgs = torch.randn(batch_size, 3, 527, 279) model = STDCContextPathNet( backbone_cfg=dict( - type='STDCNet', - stdc_type='STDCNet1', + type="STDCNet", + stdc_type="STDCNet1", in_channels=3, channels=(32, 64, 256, 512, 1024), - bottleneck_type='add', + bottleneck_type="add", num_convs=4, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='ReLU'), - with_final_conv=False), + norm_cfg=dict(type="BN", requires_grad=True), + act_cfg=dict(type="ReLU"), + with_final_conv=False, + ), last_in_channels=(1024, 512), out_channels=128, - ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4)) + ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4), + ) model.init_weights() model.train() feat = model(imgs) @@ -67,38 +74,41 @@ def test_stdcnet(): with pytest.raises(AssertionError): # STDC backbone constraints. STDCNet( - stdc_type='STDCNet3', + stdc_type="STDCNet3", in_channels=3, channels=(32, 64, 256, 512, 1024), - bottleneck_type='cat', + bottleneck_type="cat", num_convs=4, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='ReLU'), - with_final_conv=False) + norm_cfg=dict(type="BN", requires_grad=True), + act_cfg=dict(type="ReLU"), + with_final_conv=False, + ) with pytest.raises(AssertionError): # STDC bottleneck type constraints. STDCNet( - stdc_type='STDCNet1', + stdc_type="STDCNet1", in_channels=3, channels=(32, 64, 256, 512, 1024), - bottleneck_type='dog', + bottleneck_type="dog", num_convs=4, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='ReLU'), - with_final_conv=False) + norm_cfg=dict(type="BN", requires_grad=True), + act_cfg=dict(type="ReLU"), + with_final_conv=False, + ) with pytest.raises(AssertionError): # STDC channels length constraints. STDCNet( - stdc_type='STDCNet1', + stdc_type="STDCNet1", in_channels=3, channels=(16, 32, 64, 256, 512, 1024), - bottleneck_type='cat', + bottleneck_type="cat", num_convs=4, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='ReLU'), - with_final_conv=False) + norm_cfg=dict(type="BN", requires_grad=True), + act_cfg=dict(type="ReLU"), + with_final_conv=False, + ) def test_feature_fusion_module(): diff --git a/mmsegmentation/tests/test_models/test_backbones/test_swin.py b/mmsegmentation/tests/test_models/test_backbones/test_swin.py index 8d14d47..0df6385 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_swin.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_swin.py @@ -17,7 +17,8 @@ def test_swin_block(): # Test BasicBlock with checkpoint forward block = SwinBlock( - embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True) + embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True + ) assert block.with_cp x = torch.randn(1, 56 * 56, 64) x_out = block(x, (56, 56)) diff --git a/mmsegmentation/tests/test_models/test_backbones/test_timm_backbone.py b/mmsegmentation/tests/test_models/test_backbones/test_timm_backbone.py index 85ef9aa..2072698 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_timm_backbone.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_timm_backbone.py @@ -15,26 +15,26 @@ def test_timm_backbone(): # Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN' # Test resnet18 from timm, norm_layer='BN2d' model = TIMMBackbone( - model_name='resnet18', + model_name="resnet18", features_only=True, pretrained=False, output_stride=32, - norm_layer='BN2d') + norm_layer="BN2d", + ) # Test resnet18 from timm, norm_layer='SyncBN' model = TIMMBackbone( - model_name='resnet18', + model_name="resnet18", features_only=True, pretrained=False, output_stride=32, - norm_layer='SyncBN') + norm_layer="SyncBN", + ) # Test resnet18 from timm, features_only=True, output_stride=32 model = TIMMBackbone( - model_name='resnet18', - features_only=True, - pretrained=False, - output_stride=32) + model_name="resnet18", features_only=True, pretrained=False, output_stride=32 + ) model.init_weights() model.train() assert check_norm_state(model.modules(), True) @@ -51,10 +51,8 @@ def test_timm_backbone(): # Test resnet18 from timm, features_only=True, output_stride=16 model = TIMMBackbone( - model_name='resnet18', - features_only=True, - pretrained=False, - output_stride=16) + model_name="resnet18", features_only=True, pretrained=False, output_stride=16 + ) imgs = torch.randn(1, 3, 224, 224) feats = model(imgs) feats = [feat.shape for feat in feats] @@ -67,10 +65,8 @@ def test_timm_backbone(): # Test resnet18 from timm, features_only=True, output_stride=8 model = TIMMBackbone( - model_name='resnet18', - features_only=True, - pretrained=False, - output_stride=8) + model_name="resnet18", features_only=True, pretrained=False, output_stride=8 + ) imgs = torch.randn(1, 3, 224, 224) feats = model(imgs) feats = [feat.shape for feat in feats] @@ -82,14 +78,15 @@ def test_timm_backbone(): assert feats[4] == torch.Size((1, 512, 28, 28)) # Test efficientnet_b1 with pretrained weights - model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True) + model = TIMMBackbone(model_name="efficientnet_b1", pretrained=True) # Test resnetv2_50x1_bitm from timm, features_only=True, output_stride=8 model = TIMMBackbone( - model_name='resnetv2_50x1_bitm', + model_name="resnetv2_50x1_bitm", features_only=True, pretrained=False, - output_stride=8) + output_stride=8, + ) imgs = torch.randn(1, 3, 8, 8) feats = model(imgs) feats = [feat.shape for feat in feats] @@ -102,10 +99,11 @@ def test_timm_backbone(): # Test resnetv2_50x3_bitm from timm, features_only=True, output_stride=8 model = TIMMBackbone( - model_name='resnetv2_50x3_bitm', + model_name="resnetv2_50x3_bitm", features_only=True, pretrained=False, - output_stride=8) + output_stride=8, + ) imgs = torch.randn(1, 3, 8, 8) feats = model(imgs) feats = [feat.shape for feat in feats] @@ -118,10 +116,11 @@ def test_timm_backbone(): # Test resnetv2_101x1_bitm from timm, features_only=True, output_stride=8 model = TIMMBackbone( - model_name='resnetv2_101x1_bitm', + model_name="resnetv2_101x1_bitm", features_only=True, pretrained=False, - output_stride=8) + output_stride=8, + ) imgs = torch.randn(1, 3, 8, 8) feats = model(imgs) feats = [feat.shape for feat in feats] diff --git a/mmsegmentation/tests/test_models/test_backbones/test_twins.py b/mmsegmentation/tests/test_models/test_backbones/test_twins.py index aa3eaf9..30ce13b 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_twins.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_twins.py @@ -2,9 +2,12 @@ import pytest import torch -from mmseg.models.backbones.twins import (PCPVT, SVT, - ConditionalPositionEncoding, - LocallyGroupedSelfAttention) +from mmseg.models.backbones.twins import ( + PCPVT, + SVT, + ConditionalPositionEncoding, + LocallyGroupedSelfAttention, +) def test_pcpvt(): @@ -18,7 +21,8 @@ def test_pcpvt(): qkv_bias=True, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], - norm_after_stage=False) + norm_after_stage=False, + ) model.init_weights() outs = model(temp) assert outs[0].shape == (1, 32, H // 4, W // 4) @@ -38,7 +42,8 @@ def test_svt(): qkv_bias=False, depths=[4, 4, 4], windiow_sizes=[7, 7, 7], - norm_after_stage=True) + norm_after_stage=True, + ) model.init_weights() outs = model(temp) @@ -48,7 +53,7 @@ def test_svt(): def test_svt_init(): - path = 'PATH_THAT_DO_NOT_EXIST' + path = "PATH_THAT_DO_NOT_EXIST" # Test all combinations of pretrained and init_cfg # pretrained=None, init_cfg=None model = SVT(pretrained=None, init_cfg=None) @@ -57,9 +62,8 @@ def test_svt_init(): # pretrained=None # init_cfg loads pretrain from an non-existent file - model = SVT( - pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path)) - assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + model = SVT(pretrained=None, init_cfg=dict(type="Pretrained", checkpoint=path)) + assert model.init_cfg == dict(type="Pretrained", checkpoint=path) # Test loading a checkpoint from an non-existent file with pytest.raises(OSError): model.init_weights() @@ -73,7 +77,7 @@ def test_svt_init(): # pretrained loads pretrain from an non-existent file # init_cfg=None model = SVT(pretrained=path, init_cfg=None) - assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + assert model.init_cfg == dict(type="Pretrained", checkpoint=path) # Test loading a checkpoint from an non-existent file with pytest.raises(OSError): model.init_weights() @@ -81,8 +85,7 @@ def test_svt_init(): # pretrained loads pretrain from an non-existent file # init_cfg loads pretrain from an non-existent file with pytest.raises(AssertionError): - model = SVT( - pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path)) + model = SVT(pretrained=path, init_cfg=dict(type="Pretrained", checkpoint=path)) with pytest.raises(AssertionError): model = SVT(pretrained=path, init_cfg=123) @@ -94,8 +97,7 @@ def test_svt_init(): # pretrain=123, whose type is unsupported # init_cfg loads pretrain from an non-existent file with pytest.raises(AssertionError): - model = SVT( - pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path)) + model = SVT(pretrained=123, init_cfg=dict(type="Pretrained", checkpoint=path)) # pretrain=123, whose type is unsupported # init_cfg=123, whose type is unsupported @@ -104,7 +106,7 @@ def test_svt_init(): def test_pcpvt_init(): - path = 'PATH_THAT_DO_NOT_EXIST' + path = "PATH_THAT_DO_NOT_EXIST" # Test all combinations of pretrained and init_cfg # pretrained=None, init_cfg=None model = PCPVT(pretrained=None, init_cfg=None) @@ -113,9 +115,8 @@ def test_pcpvt_init(): # pretrained=None # init_cfg loads pretrain from an non-existent file - model = PCPVT( - pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path)) - assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + model = PCPVT(pretrained=None, init_cfg=dict(type="Pretrained", checkpoint=path)) + assert model.init_cfg == dict(type="Pretrained", checkpoint=path) # Test loading a checkpoint from an non-existent file with pytest.raises(OSError): model.init_weights() @@ -129,7 +130,7 @@ def test_pcpvt_init(): # pretrained loads pretrain from an non-existent file # init_cfg=None model = PCPVT(pretrained=path, init_cfg=None) - assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + assert model.init_cfg == dict(type="Pretrained", checkpoint=path) # Test loading a checkpoint from an non-existent file with pytest.raises(OSError): model.init_weights() @@ -138,7 +139,8 @@ def test_pcpvt_init(): # init_cfg loads pretrain from an non-existent file with pytest.raises(AssertionError): model = PCPVT( - pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path)) + pretrained=path, init_cfg=dict(type="Pretrained", checkpoint=path) + ) with pytest.raises(AssertionError): model = PCPVT(pretrained=path, init_cfg=123) @@ -150,8 +152,7 @@ def test_pcpvt_init(): # pretrain=123, whose type is unsupported # init_cfg loads pretrain from an non-existent file with pytest.raises(AssertionError): - model = PCPVT( - pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path)) + model = PCPVT(pretrained=123, init_cfg=dict(type="Pretrained", checkpoint=path)) # pretrain=123, whose type is unsupported # init_cfg=123, whose type is unsupported diff --git a/mmsegmentation/tests/test_models/test_backbones/test_unet.py b/mmsegmentation/tests/test_models/test_backbones/test_unet.py index 9beb727..995645c 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_unet.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_unet.py @@ -3,8 +3,13 @@ import torch from mmcv.cnn import ConvModule -from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule, - InterpConv, UNet, UpConvBlock) +from mmseg.models.backbones.unet import ( + BasicConvBlock, + DeconvModule, + InterpConv, + UNet, + UpConvBlock, +) from mmseg.ops import Upsample from .utils import check_norm_state @@ -12,15 +17,13 @@ def test_unet_basic_conv_block(): with pytest.raises(AssertionError): # Not implemented yet. - dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) + dcn = dict(type="DCN", deform_groups=1, fallback_on_stride=False) BasicConvBlock(64, 64, dcn=dcn) with pytest.raises(AssertionError): # Not implemented yet. plugins = [ - dict( - cfg=dict(type='ContextBlock', ratio=1. / 16), - position='after_conv3') + dict(cfg=dict(type="ContextBlock", ratio=1.0 / 16), position="after_conv3") ] BasicConvBlock(64, 64, plugins=plugins) @@ -29,12 +32,14 @@ def test_unet_basic_conv_block(): plugins = [ dict( cfg=dict( - type='GeneralizedAttention', + type="GeneralizedAttention", spatial_range=-1, num_heads=8, - attention_type='0010', - kv_stride=2), - position='after_conv2') + attention_type="0010", + kv_stride=2, + ), + position="after_conv2", + ) ] BasicConvBlock(64, 64, plugins=plugins) @@ -163,41 +168,37 @@ def test_interp_conv(): 64, 32, conv_first=False, - upsample_cfg=dict( - scale_factor=2, mode='bilinear', align_corners=False)) + upsample_cfg=dict(scale_factor=2, mode="bilinear", align_corners=False), + ) x = torch.randn(1, 64, 128, 128) x_out = block(x) assert isinstance(block.interp_upsample[0], Upsample) assert isinstance(block.interp_upsample[1], ConvModule) assert x_out.shape == torch.Size([1, 32, 256, 256]) - assert block.interp_upsample[0].mode == 'bilinear' + assert block.interp_upsample[0].mode == "bilinear" # test InterpConv with nearest upsample for upsample 2X. block = InterpConv( - 64, - 32, - conv_first=False, - upsample_cfg=dict(scale_factor=2, mode='nearest')) + 64, 32, conv_first=False, upsample_cfg=dict(scale_factor=2, mode="nearest") + ) x = torch.randn(1, 64, 128, 128) x_out = block(x) assert isinstance(block.interp_upsample[0], Upsample) assert isinstance(block.interp_upsample[1], ConvModule) assert x_out.shape == torch.Size([1, 32, 256, 256]) - assert block.interp_upsample[0].mode == 'nearest' + assert block.interp_upsample[0].mode == "nearest" def test_up_conv_block(): with pytest.raises(AssertionError): # Not implemented yet. - dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) + dcn = dict(type="DCN", deform_groups=1, fallback_on_stride=False) UpConvBlock(BasicConvBlock, 64, 32, 32, dcn=dcn) with pytest.raises(AssertionError): # Not implemented yet. plugins = [ - dict( - cfg=dict(type='ContextBlock', ratio=1. / 16), - position='after_conv3') + dict(cfg=dict(type="ContextBlock", ratio=1.0 / 16), position="after_conv3") ] UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins) @@ -206,12 +207,14 @@ def test_up_conv_block(): plugins = [ dict( cfg=dict( - type='GeneralizedAttention', + type="GeneralizedAttention", spatial_range=-1, num_heads=8, - attention_type='0010', - kv_stride=2), - position='after_conv2') + attention_type="0010", + kv_stride=2, + ), + position="after_conv2", + ) ] UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins) @@ -225,7 +228,8 @@ def test_up_conv_block(): # test UpConvBlock with upsample=True for upsample 2X. The spatial size of # skip_x is 2X larger than x. block = UpConvBlock( - BasicConvBlock, 64, 32, 32, upsample_cfg=dict(type='InterpConv')) + BasicConvBlock, 64, 32, 32, upsample_cfg=dict(type="InterpConv") + ) skip_x = torch.randn(1, 32, 256, 256) x = torch.randn(1, 64, 128, 128) x_out = block(skip_x, x) @@ -247,9 +251,10 @@ def test_up_conv_block(): 32, 32, upsample_cfg=dict( - type='InterpConv', - upsample_cfg=dict( - scale_factor=2, mode='bilinear', align_corners=False))) + type="InterpConv", + upsample_cfg=dict(scale_factor=2, mode="bilinear", align_corners=False), + ), + ) skip_x = torch.randn(1, 32, 256, 256) x = torch.randn(1, 64, 128, 128) x_out = block(skip_x, x) @@ -262,7 +267,8 @@ def test_up_conv_block(): 64, 32, 32, - upsample_cfg=dict(type='DeconvModule', kernel_size=4, scale_factor=2)) + upsample_cfg=dict(type="DeconvModule", kernel_size=4, scale_factor=2), + ) skip_x = torch.randn(1, 32, 256, 256) x = torch.randn(1, 64, 128, 128) x_out = block(skip_x, x) @@ -277,9 +283,10 @@ def test_up_conv_block(): num_convs=3, dilation=3, upsample_cfg=dict( - type='InterpConv', - upsample_cfg=dict( - scale_factor=2, mode='bilinear', align_corners=False))) + type="InterpConv", + upsample_cfg=dict(scale_factor=2, mode="bilinear", align_corners=False), + ), + ) skip_x = torch.randn(1, 32, 256, 256) x = torch.randn(1, 64, 128, 128) x_out = block(skip_x, x) @@ -313,15 +320,13 @@ def test_up_conv_block(): def test_unet(): with pytest.raises(AssertionError): # Not implemented yet. - dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) + dcn = dict(type="DCN", deform_groups=1, fallback_on_stride=False) UNet(3, 64, 5, dcn=dcn) with pytest.raises(AssertionError): # Not implemented yet. plugins = [ - dict( - cfg=dict(type='ContextBlock', ratio=1. / 16), - position='after_conv3') + dict(cfg=dict(type="ContextBlock", ratio=1.0 / 16), position="after_conv3") ] UNet(3, 64, 5, plugins=plugins) @@ -330,12 +335,14 @@ def test_unet(): plugins = [ dict( cfg=dict( - type='GeneralizedAttention', + type="GeneralizedAttention", spatial_range=-1, num_heads=8, - attention_type='0010', - kv_stride=2), - position='after_conv2') + attention_type="0010", + kv_stride=2, + ), + position="after_conv2", + ) ] UNet(3, 64, 5, plugins=plugins) @@ -352,7 +359,8 @@ def test_unet(): dec_num_convs=(2, 2, 2), downsamples=(True, True, True), enc_dilations=(1, 1, 1, 1), - dec_dilations=(1, 1, 1)) + dec_dilations=(1, 1, 1), + ) x = torch.randn(2, 3, 65, 65) unet(x) @@ -369,7 +377,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 65, 65) unet(x) @@ -386,7 +395,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, False), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 65, 65) unet(x) @@ -403,7 +413,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, False), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 65, 65) unet(x) @@ -420,7 +431,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2, 2), downsamples=(True, True, True, True, True), enc_dilations=(1, 1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1, 1), + ) x = torch.randn(2, 3, 65, 65) unet(x) @@ -435,7 +447,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 64, 64) unet(x) @@ -450,7 +463,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 64, 64) unet(x) @@ -465,7 +479,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2, 2), downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 64, 64) unet(x) @@ -480,7 +495,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 64, 64) unet(x) @@ -495,7 +511,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 64, 64) unet(x) @@ -510,7 +527,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1, 1), + ) x = torch.randn(2, 3, 64, 64) unet(x) @@ -525,7 +543,8 @@ def test_unet(): downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1, 1), dec_dilations=(1, 1, 1, 1), - norm_eval=True) + norm_eval=True, + ) unet.train() assert check_norm_state(unet.modules(), False) @@ -540,7 +559,8 @@ def test_unet(): downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1, 1), dec_dilations=(1, 1, 1, 1), - norm_eval=False) + norm_eval=False, + ) unet.train() assert check_norm_state(unet.modules(), True) @@ -554,7 +574,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) @@ -574,7 +595,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, False), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) @@ -594,7 +616,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, False), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) @@ -614,7 +637,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, False, False), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) @@ -634,7 +658,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, False, False), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) @@ -654,7 +679,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, False), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) @@ -674,7 +700,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, False, False), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) @@ -694,7 +721,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, False, False, False), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) @@ -714,7 +742,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(False, False, False, False), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) @@ -734,7 +763,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, True), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) assert x_outs[0].shape == torch.Size([2, 64, 8, 8]) @@ -753,7 +783,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, False), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) assert x_outs[0].shape == torch.Size([2, 64, 16, 16]) @@ -772,7 +803,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, True, False), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) assert x_outs[0].shape == torch.Size([2, 64, 16, 16]) @@ -791,7 +823,8 @@ def test_unet(): dec_num_convs=(2, 2, 2, 2), downsamples=(True, True, False, False), enc_dilations=(1, 1, 1, 1, 1), - dec_dilations=(1, 1, 1, 1)) + dec_dilations=(1, 1, 1, 1), + ) x = torch.randn(2, 3, 128, 128) x_outs = unet(x) assert x_outs[0].shape == torch.Size([2, 64, 32, 32]) @@ -811,7 +844,8 @@ def test_unet(): downsamples=(True, True, False, False), enc_dilations=(1, 1, 1, 1, 1), dec_dilations=(1, 1, 1, 1), - pretrained=None) + pretrained=None, + ) unet.init_weights() x = torch.randn(2, 3, 128, 128) x_outs = unet(x) diff --git a/mmsegmentation/tests/test_models/test_backbones/test_vit.py b/mmsegmentation/tests/test_models/test_backbones/test_vit.py index 0d1ba70..4a96e2f 100644 --- a/mmsegmentation/tests/test_models/test_backbones/test_vit.py +++ b/mmsegmentation/tests/test_models/test_backbones/test_vit.py @@ -2,8 +2,7 @@ import pytest import torch -from mmseg.models.backbones.vit import (TransformerEncoderLayer, - VisionTransformer) +from mmseg.models.backbones.vit import TransformerEncoderLayer, VisionTransformer from .utils import check_norm_state @@ -19,12 +18,12 @@ def test_vit_backbone(): with pytest.raises(TypeError): # out_indices must be int ,list or tuple - model = VisionTransformer(out_indices=1.) + model = VisionTransformer(out_indices=1.0) with pytest.raises(TypeError): # test upsample_pos_embed function x = torch.randn(1, 196) - VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear') + VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, "bilinear") with pytest.raises(AssertionError): # The length of img_size tuple must be lower than 3. @@ -40,7 +39,7 @@ def test_vit_backbone(): # Test img_size isinstance tuple imgs = torch.randn(1, 3, 224, 224) - model = VisionTransformer(img_size=(224, )) + model = VisionTransformer(img_size=(224,)) model.init_weights() model(imgs) @@ -122,7 +121,8 @@ def test_vit_backbone(): # Test TransformerEncoderLayer with checkpoint forward block = TransformerEncoderLayer( - embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True) + embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True + ) assert block.with_cp x = torch.randn(1, 56 * 56, 64) x_out = block(x) @@ -130,7 +130,7 @@ def test_vit_backbone(): def test_vit_init(): - path = 'PATH_THAT_DO_NOT_EXIST' + path = "PATH_THAT_DO_NOT_EXIST" # Test all combinations of pretrained and init_cfg # pretrained=None, init_cfg=None model = VisionTransformer(pretrained=None, init_cfg=None) @@ -140,8 +140,9 @@ def test_vit_init(): # pretrained=None # init_cfg loads pretrain from an non-existent file model = VisionTransformer( - pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path)) - assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + pretrained=None, init_cfg=dict(type="Pretrained", checkpoint=path) + ) + assert model.init_cfg == dict(type="Pretrained", checkpoint=path) # Test loading a checkpoint from an non-existent file with pytest.raises(OSError): model.init_weights() @@ -155,7 +156,7 @@ def test_vit_init(): # pretrained loads pretrain from an non-existent file # init_cfg=None model = VisionTransformer(pretrained=path, init_cfg=None) - assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + assert model.init_cfg == dict(type="Pretrained", checkpoint=path) # Test loading a checkpoint from an non-existent file with pytest.raises(OSError): model.init_weights() @@ -164,7 +165,8 @@ def test_vit_init(): # init_cfg loads pretrain from an non-existent file with pytest.raises(AssertionError): model = VisionTransformer( - pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path)) + pretrained=path, init_cfg=dict(type="Pretrained", checkpoint=path) + ) with pytest.raises(AssertionError): model = VisionTransformer(pretrained=path, init_cfg=123) @@ -177,7 +179,8 @@ def test_vit_init(): # init_cfg loads pretrain from an non-existent file with pytest.raises(AssertionError): model = VisionTransformer( - pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path)) + pretrained=123, init_cfg=dict(type="Pretrained", checkpoint=path) + ) # pretrain=123, whose type is unsupported # init_cfg=123, whose type is unsupported diff --git a/mmsegmentation/tests/test_models/test_backbones/utils.py b/mmsegmentation/tests/test_models/test_backbones/utils.py index 54b6404..4ecd7b7 100644 --- a/mmsegmentation/tests/test_models/test_backbones/utils.py +++ b/mmsegmentation/tests/test_models/test_backbones/utils.py @@ -23,11 +23,13 @@ def is_norm(modules): def all_zeros(modules): """Check if the weight(and bias) is all zero.""" - weight_zero = torch.allclose(modules.weight.data, - torch.zeros_like(modules.weight.data)) - if hasattr(modules, 'bias'): - bias_zero = torch.allclose(modules.bias.data, - torch.zeros_like(modules.bias.data)) + weight_zero = torch.allclose( + modules.weight.data, torch.zeros_like(modules.weight.data) + ) + if hasattr(modules, "bias"): + bias_zero = torch.allclose( + modules.bias.data, torch.zeros_like(modules.bias.data) + ) else: bias_zero = True diff --git a/mmsegmentation/tests/test_models/test_forward.py b/mmsegmentation/tests/test_models/test_forward.py index ee707b3..253b21d 100644 --- a/mmsegmentation/tests/test_models/test_forward.py +++ b/mmsegmentation/tests/test_models/test_forward.py @@ -26,23 +26,25 @@ def _demo_mm_inputs(input_shape=(2, 3, 8, 16), num_classes=10): rng = np.random.RandomState(0) imgs = rng.rand(*input_shape) - segs = rng.randint( - low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) - - img_metas = [{ - 'img_shape': (H, W, C), - 'ori_shape': (H, W, C), - 'pad_shape': (H, W, C), - 'filename': '.png', - 'scale_factor': 1.0, - 'flip': False, - 'flip_direction': 'horizontal' - } for _ in range(N)] + segs = rng.randint(low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) + + img_metas = [ + { + "img_shape": (H, W, C), + "ori_shape": (H, W, C), + "pad_shape": (H, W, C), + "filename": ".png", + "scale_factor": 1.0, + "flip": False, + "flip_direction": "horizontal", + } + for _ in range(N) + ] mm_inputs = { - 'imgs': torch.FloatTensor(imgs), - 'img_metas': img_metas, - 'gt_semantic_seg': torch.LongTensor(segs) + "imgs": torch.FloatTensor(imgs), + "img_metas": img_metas, + "gt_semantic_seg": torch.LongTensor(segs), } return mm_inputs @@ -55,16 +57,18 @@ def _get_config_directory(): except NameError: # For IPython development when this __file__ is not defined import mmseg + repo_dpath = dirname(dirname(dirname(mmseg.__file__))) - config_dpath = join(repo_dpath, 'configs') + config_dpath = join(repo_dpath, "configs") if not exists(config_dpath): - raise Exception('Cannot find config path') + raise Exception("Cannot find config path") return config_dpath def _get_config_module(fname): """Load a configuration as a python module.""" from mmcv import Config + config_dpath = _get_config_directory() config_fpath = join(config_dpath, fname) config_mod = Config.fromfile(config_fpath) @@ -83,105 +87,96 @@ def _get_segmentor_cfg(fname): def test_pspnet_forward(): - _test_encoder_decoder_forward( - 'pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py') + _test_encoder_decoder_forward("pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py") def test_fcn_forward(): - _test_encoder_decoder_forward('fcn/fcn_r50-d8_512x1024_40k_cityscapes.py') + _test_encoder_decoder_forward("fcn/fcn_r50-d8_512x1024_40k_cityscapes.py") def test_deeplabv3_forward(): _test_encoder_decoder_forward( - 'deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py') + "deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py" + ) def test_deeplabv3plus_forward(): _test_encoder_decoder_forward( - 'deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py') + "deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py" + ) def test_gcnet_forward(): - _test_encoder_decoder_forward( - 'gcnet/gcnet_r50-d8_512x1024_40k_cityscapes.py') + _test_encoder_decoder_forward("gcnet/gcnet_r50-d8_512x1024_40k_cityscapes.py") def test_ann_forward(): - _test_encoder_decoder_forward('ann/ann_r50-d8_512x1024_40k_cityscapes.py') + _test_encoder_decoder_forward("ann/ann_r50-d8_512x1024_40k_cityscapes.py") def test_ccnet_forward(): if not torch.cuda.is_available(): - pytest.skip('CCNet requires CUDA') - _test_encoder_decoder_forward( - 'ccnet/ccnet_r50-d8_512x1024_40k_cityscapes.py') + pytest.skip("CCNet requires CUDA") + _test_encoder_decoder_forward("ccnet/ccnet_r50-d8_512x1024_40k_cityscapes.py") def test_danet_forward(): - _test_encoder_decoder_forward( - 'danet/danet_r50-d8_512x1024_40k_cityscapes.py') + _test_encoder_decoder_forward("danet/danet_r50-d8_512x1024_40k_cityscapes.py") def test_nonlocal_net_forward(): _test_encoder_decoder_forward( - 'nonlocal_net/nonlocal_r50-d8_512x1024_40k_cityscapes.py') + "nonlocal_net/nonlocal_r50-d8_512x1024_40k_cityscapes.py" + ) def test_upernet_forward(): - _test_encoder_decoder_forward( - 'upernet/upernet_r50_512x1024_40k_cityscapes.py') + _test_encoder_decoder_forward("upernet/upernet_r50_512x1024_40k_cityscapes.py") def test_hrnet_forward(): - _test_encoder_decoder_forward('hrnet/fcn_hr18s_512x1024_40k_cityscapes.py') + _test_encoder_decoder_forward("hrnet/fcn_hr18s_512x1024_40k_cityscapes.py") def test_ocrnet_forward(): - _test_encoder_decoder_forward( - 'ocrnet/ocrnet_hr18s_512x1024_40k_cityscapes.py') + _test_encoder_decoder_forward("ocrnet/ocrnet_hr18s_512x1024_40k_cityscapes.py") def test_psanet_forward(): - _test_encoder_decoder_forward( - 'psanet/psanet_r50-d8_512x1024_40k_cityscapes.py') + _test_encoder_decoder_forward("psanet/psanet_r50-d8_512x1024_40k_cityscapes.py") def test_encnet_forward(): - _test_encoder_decoder_forward( - 'encnet/encnet_r50-d8_512x1024_40k_cityscapes.py') + _test_encoder_decoder_forward("encnet/encnet_r50-d8_512x1024_40k_cityscapes.py") def test_sem_fpn_forward(): - _test_encoder_decoder_forward('sem_fpn/fpn_r50_512x1024_80k_cityscapes.py') + _test_encoder_decoder_forward("sem_fpn/fpn_r50_512x1024_80k_cityscapes.py") def test_point_rend_forward(): - _test_encoder_decoder_forward( - 'point_rend/pointrend_r50_512x1024_80k_cityscapes.py') + _test_encoder_decoder_forward("point_rend/pointrend_r50_512x1024_80k_cityscapes.py") def test_mobilenet_v2_forward(): _test_encoder_decoder_forward( - 'mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py') + "mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py" + ) def test_dnlnet_forward(): - _test_encoder_decoder_forward( - 'dnlnet/dnl_r50-d8_512x1024_40k_cityscapes.py') + _test_encoder_decoder_forward("dnlnet/dnl_r50-d8_512x1024_40k_cityscapes.py") def test_emanet_forward(): - _test_encoder_decoder_forward( - 'emanet/emanet_r50-d8_512x1024_80k_cityscapes.py') + _test_encoder_decoder_forward("emanet/emanet_r50-d8_512x1024_80k_cityscapes.py") def test_isanet_forward(): - _test_encoder_decoder_forward( - 'isanet/isanet_r50-d8_512x1024_40k_cityscapes.py') + _test_encoder_decoder_forward("isanet/isanet_r50-d8_512x1024_40k_cityscapes.py") def get_world_size(process_group): - return 1 @@ -189,15 +184,15 @@ def _check_input_dim(self, inputs): pass -@patch('torch.nn.modules.batchnorm._BatchNorm._check_input_dim', - _check_input_dim) -@patch('torch.distributed.get_world_size', get_world_size) +@patch("torch.nn.modules.batchnorm._BatchNorm._check_input_dim", _check_input_dim) +@patch("torch.distributed.get_world_size", get_world_size) def _test_encoder_decoder_forward(cfg_file): model = _get_segmentor_cfg(cfg_file) - model['pretrained'] = None - model['test_cfg']['mode'] = 'whole' + model["pretrained"] = None + model["test_cfg"]["mode"] = "whole" from mmseg.models import build_segmentor + segmentor = build_segmentor(model) segmentor.init_weights() @@ -209,9 +204,9 @@ def _test_encoder_decoder_forward(cfg_file): input_shape = (2, 3, 32, 32) mm_inputs = _demo_mm_inputs(input_shape, num_classes=num_classes) - imgs = mm_inputs.pop('imgs') - img_metas = mm_inputs.pop('img_metas') - gt_semantic_seg = mm_inputs['gt_semantic_seg'] + imgs = mm_inputs.pop("imgs") + img_metas = mm_inputs.pop("img_metas") + gt_semantic_seg = mm_inputs["gt_semantic_seg"] # convert to cuda Tensor if applicable if torch.cuda.is_available(): @@ -223,7 +218,8 @@ def _test_encoder_decoder_forward(cfg_file): # Test forward train losses = segmentor.forward( - imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True) + imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True + ) assert isinstance(losses, dict) # Test forward test diff --git a/mmsegmentation/tests/test_models/test_heads/test_ann_head.py b/mmsegmentation/tests/test_models/test_heads/test_ann_head.py index c1e44bc..28b1189 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_ann_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_ann_head.py @@ -6,14 +6,14 @@ def test_ann_head(): - inputs = [torch.randn(1, 4, 45, 45), torch.randn(1, 8, 21, 21)] head = ANNHead( in_channels=[4, 8], channels=2, num_classes=19, in_index=[-2, -1], - project_channels=8) + project_channels=8, + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) outputs = head(inputs) diff --git a/mmsegmentation/tests/test_models/test_heads/test_apc_head.py b/mmsegmentation/tests/test_models/test_heads/test_apc_head.py index dc55ccc..28b8227 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_apc_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_apc_head.py @@ -7,7 +7,6 @@ def test_apc_head(): - with pytest.raises(AssertionError): # pool_scales must be list|tuple APCHead(in_channels=8, channels=2, num_classes=19, pool_scales=1) @@ -18,20 +17,15 @@ def test_apc_head(): # test with norm_cfg head = APCHead( - in_channels=8, - channels=2, - num_classes=19, - norm_cfg=dict(type='SyncBN')) + in_channels=8, channels=2, num_classes=19, norm_cfg=dict(type="SyncBN") + ) assert _conv_has_norm(head, sync_bn=True) # fusion=True inputs = [torch.randn(1, 8, 45, 45)] head = APCHead( - in_channels=8, - channels=2, - num_classes=19, - pool_scales=(1, 2, 3), - fusion=True) + in_channels=8, channels=2, num_classes=19, pool_scales=(1, 2, 3), fusion=True + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) assert head.fusion is True @@ -44,11 +38,8 @@ def test_apc_head(): # fusion=False inputs = [torch.randn(1, 8, 45, 45)] head = APCHead( - in_channels=8, - channels=2, - num_classes=19, - pool_scales=(1, 2, 3), - fusion=False) + in_channels=8, channels=2, num_classes=19, pool_scales=(1, 2, 3), fusion=False + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) assert head.fusion is False diff --git a/mmsegmentation/tests/test_models/test_heads/test_aspp_head.py b/mmsegmentation/tests/test_models/test_heads/test_aspp_head.py index db9e893..11a9abb 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_aspp_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_aspp_head.py @@ -7,7 +7,6 @@ def test_aspp_head(): - with pytest.raises(AssertionError): # pool_scales must be list|tuple ASPPHead(in_channels=8, channels=4, num_classes=19, dilations=1) @@ -18,15 +17,12 @@ def test_aspp_head(): # test with norm_cfg head = ASPPHead( - in_channels=8, - channels=4, - num_classes=19, - norm_cfg=dict(type='SyncBN')) + in_channels=8, channels=4, num_classes=19, norm_cfg=dict(type="SyncBN") + ) assert _conv_has_norm(head, sync_bn=True) inputs = [torch.randn(1, 8, 45, 45)] - head = ASPPHead( - in_channels=8, channels=4, num_classes=19, dilations=(1, 12, 24)) + head = ASPPHead(in_channels=8, channels=4, num_classes=19, dilations=(1, 12, 24)) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) assert head.aspp_modules[0].conv.dilation == (1, 1) @@ -37,7 +33,6 @@ def test_aspp_head(): def test_dw_aspp_head(): - # test w.o. c1 inputs = [torch.randn(1, 8, 45, 45)] head = DepthwiseSeparableASPPHead( @@ -46,7 +41,8 @@ def test_dw_aspp_head(): in_channels=8, channels=4, num_classes=19, - dilations=(1, 12, 24)) + dilations=(1, 12, 24), + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) assert head.c1_bottleneck is None @@ -64,7 +60,8 @@ def test_dw_aspp_head(): in_channels=16, channels=8, num_classes=19, - dilations=(1, 12, 24)) + dilations=(1, 12, 24), + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) assert head.c1_bottleneck.in_channels == 4 diff --git a/mmsegmentation/tests/test_models/test_heads/test_cc_head.py b/mmsegmentation/tests/test_models/test_heads/test_cc_head.py index 0630417..ab4c486 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_cc_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_cc_head.py @@ -9,9 +9,9 @@ def test_cc_head(): head = CCHead(in_channels=16, channels=8, num_classes=19) assert len(head.convs) == 2 - assert hasattr(head, 'cca') + assert hasattr(head, "cca") if not torch.cuda.is_available(): - pytest.skip('CCHead requires CUDA') + pytest.skip("CCHead requires CUDA") inputs = [torch.randn(1, 16, 23, 23)] head, inputs = to_cuda(head, inputs) outputs = head(inputs) diff --git a/mmsegmentation/tests/test_models/test_heads/test_da_head.py b/mmsegmentation/tests/test_models/test_heads/test_da_head.py index 7ab4a96..ec2202b 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_da_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_da_head.py @@ -6,7 +6,6 @@ def test_da_head(): - inputs = [torch.randn(1, 16, 23, 23)] head = DAHead(in_channels=16, channels=8, num_classes=19, pam_channels=8) if torch.cuda.is_available(): diff --git a/mmsegmentation/tests/test_models/test_heads/test_decode_head.py b/mmsegmentation/tests/test_models/test_heads/test_decode_head.py index 87cadbc..b526784 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_decode_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_decode_head.py @@ -10,7 +10,6 @@ @patch.multiple(BaseDecodeHead, __abstractmethods__=set()) def test_decode_head(): - with pytest.raises(AssertionError): # default input_transform doesn't accept multiple inputs BaseDecodeHead([32, 16], 16, num_classes=19) @@ -21,27 +20,23 @@ def test_decode_head(): with pytest.raises(AssertionError): # supported mode is resize_concat only - BaseDecodeHead(32, 16, num_classes=19, input_transform='concat') + BaseDecodeHead(32, 16, num_classes=19, input_transform="concat") with pytest.raises(AssertionError): # in_channels should be list|tuple - BaseDecodeHead(32, 16, num_classes=19, input_transform='resize_concat') + BaseDecodeHead(32, 16, num_classes=19, input_transform="resize_concat") with pytest.raises(AssertionError): # in_index should be list|tuple - BaseDecodeHead([32], - 16, - in_index=-1, - num_classes=19, - input_transform='resize_concat') + BaseDecodeHead( + [32], 16, in_index=-1, num_classes=19, input_transform="resize_concat" + ) with pytest.raises(AssertionError): # len(in_index) should equal len(in_channels) - BaseDecodeHead([32, 16], - 16, - num_classes=19, - in_index=[-1], - input_transform='resize_concat') + BaseDecodeHead( + [32, 16], 16, num_classes=19, in_index=[-1], input_transform="resize_concat" + ) with pytest.raises(ValueError): # out_channels should be equal to num_classes @@ -57,11 +52,11 @@ def test_decode_head(): # test default dropout head = BaseDecodeHead(32, 16, num_classes=19) - assert hasattr(head, 'dropout') and head.dropout.p == 0.1 + assert hasattr(head, "dropout") and head.dropout.p == 0.1 # test set dropout head = BaseDecodeHead(32, 16, num_classes=19, dropout_ratio=0.2) - assert hasattr(head, 'dropout') and head.dropout.p == 0.2 + assert hasattr(head, "dropout") and head.dropout.p == 0.2 # test no input_transform inputs = [torch.randn(1, 32, 45, 45)] @@ -75,22 +70,20 @@ def test_decode_head(): # test input_transform = resize_concat inputs = [torch.randn(1, 32, 45, 45), torch.randn(1, 16, 21, 21)] - head = BaseDecodeHead([32, 16], - 16, - num_classes=19, - in_index=[0, 1], - input_transform='resize_concat') + head = BaseDecodeHead( + [32, 16], 16, num_classes=19, in_index=[0, 1], input_transform="resize_concat" + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) assert head.in_channels == 48 - assert head.input_transform == 'resize_concat' + assert head.input_transform == "resize_concat" transformed_inputs = head._transform_inputs(inputs) assert transformed_inputs.shape == (1, 48, 45, 45) # test multi-loss, loss_decode is dict with pytest.raises(TypeError): # loss_decode must be a dict or sequence of dict. - BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss']) + BaseDecodeHead(3, 16, num_classes=19, loss_decode=["CrossEntropyLoss"]) inputs = torch.randn(2, 19, 8, 8).float() target = torch.ones(2, 1, 64, 64).long() @@ -98,13 +91,13 @@ def test_decode_head(): 3, 16, num_classes=19, - loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + loss_decode=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) head, target = to_cuda(head, target) loss = head.losses(seg_logit=inputs, seg_label=target) - assert 'loss_ce' in loss + assert "loss_ce" in loss # test multi-loss, loss_decode is list of dict inputs = torch.randn(2, 19, 8, 8).float() @@ -114,19 +107,20 @@ def test_decode_head(): 16, num_classes=19, loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_1'), - dict(type='CrossEntropyLoss', loss_name='loss_2') - ]) + dict(type="CrossEntropyLoss", loss_name="loss_1"), + dict(type="CrossEntropyLoss", loss_name="loss_2"), + ], + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) head, target = to_cuda(head, target) loss = head.losses(seg_logit=inputs, seg_label=target) - assert 'loss_1' in loss - assert 'loss_2' in loss + assert "loss_1" in loss + assert "loss_2" in loss # 'loss_decode' must be a dict or sequence of dict with pytest.raises(TypeError): - BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss']) + BaseDecodeHead(3, 16, num_classes=19, loss_decode=["CrossEntropyLoss"]) with pytest.raises(TypeError): BaseDecodeHead(3, 16, num_classes=19, loss_decode=0) @@ -137,16 +131,19 @@ def test_decode_head(): 3, 16, num_classes=19, - loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'), - dict(type='CrossEntropyLoss', loss_name='loss_2'), - dict(type='CrossEntropyLoss', loss_name='loss_3'))) + loss_decode=( + dict(type="CrossEntropyLoss", loss_name="loss_1"), + dict(type="CrossEntropyLoss", loss_name="loss_2"), + dict(type="CrossEntropyLoss", loss_name="loss_3"), + ), + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) head, target = to_cuda(head, target) loss = head.losses(seg_logit=inputs, seg_label=target) - assert 'loss_1' in loss - assert 'loss_2' in loss - assert 'loss_3' in loss + assert "loss_1" in loss + assert "loss_2" in loss + assert "loss_3" in loss # test multi-loss, loss_decode is list of dict, names of them are identical inputs = torch.randn(2, 19, 8, 8).float() @@ -155,9 +152,12 @@ def test_decode_head(): 3, 16, num_classes=19, - loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'), - dict(type='CrossEntropyLoss', loss_name='loss_ce'), - dict(type='CrossEntropyLoss', loss_name='loss_ce'))) + loss_decode=( + dict(type="CrossEntropyLoss", loss_name="loss_ce"), + dict(type="CrossEntropyLoss", loss_name="loss_ce"), + dict(type="CrossEntropyLoss", loss_name="loss_ce"), + ), + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) head, target = to_cuda(head, target) @@ -167,11 +167,12 @@ def test_decode_head(): 3, 16, num_classes=19, - loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'))) + loss_decode=(dict(type="CrossEntropyLoss", loss_name="loss_ce")), + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) head, target = to_cuda(head, target) loss = head.losses(seg_logit=inputs, seg_label=target) - assert 'loss_ce' in loss - assert 'loss_ce' in loss_3 - assert loss_3['loss_ce'] == 3 * loss['loss_ce'] + assert "loss_ce" in loss + assert "loss_ce" in loss_3 + assert loss_3["loss_ce"] == 3 * loss["loss_ce"] diff --git a/mmsegmentation/tests/test_models/test_heads/test_dm_head.py b/mmsegmentation/tests/test_models/test_heads/test_dm_head.py index a922ff7..510ee10 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_dm_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_dm_head.py @@ -7,7 +7,6 @@ def test_dm_head(): - with pytest.raises(AssertionError): # filter_sizes must be list|tuple DMHead(in_channels=8, channels=4, num_classes=19, filter_sizes=1) @@ -18,20 +17,15 @@ def test_dm_head(): # test with norm_cfg head = DMHead( - in_channels=8, - channels=4, - num_classes=19, - norm_cfg=dict(type='SyncBN')) + in_channels=8, channels=4, num_classes=19, norm_cfg=dict(type="SyncBN") + ) assert _conv_has_norm(head, sync_bn=True) # fusion=True inputs = [torch.randn(1, 8, 23, 23)] head = DMHead( - in_channels=8, - channels=4, - num_classes=19, - filter_sizes=(1, 3, 5), - fusion=True) + in_channels=8, channels=4, num_classes=19, filter_sizes=(1, 3, 5), fusion=True + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) assert head.fusion is True @@ -44,11 +38,8 @@ def test_dm_head(): # fusion=False inputs = [torch.randn(1, 8, 23, 23)] head = DMHead( - in_channels=8, - channels=4, - num_classes=19, - filter_sizes=(1, 3, 5), - fusion=False) + in_channels=8, channels=4, num_classes=19, filter_sizes=(1, 3, 5), fusion=False + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) assert head.fusion is False diff --git a/mmsegmentation/tests/test_models/test_heads/test_dnl_head.py b/mmsegmentation/tests/test_models/test_heads/test_dnl_head.py index 720cb07..9564f91 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_dnl_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_dnl_head.py @@ -9,7 +9,7 @@ def test_dnl_head(): # DNL with 'embedded_gaussian' mode head = DNLHead(in_channels=8, channels=4, num_classes=19) assert len(head.convs) == 2 - assert hasattr(head, 'dnl_block') + assert hasattr(head, "dnl_block") assert head.dnl_block.temperature == 0.05 inputs = [torch.randn(1, 8, 23, 23)] if torch.cuda.is_available(): @@ -18,8 +18,7 @@ def test_dnl_head(): assert outputs.shape == (1, head.num_classes, 23, 23) # NonLocal2d with 'dot_product' mode - head = DNLHead( - in_channels=8, channels=4, num_classes=19, mode='dot_product') + head = DNLHead(in_channels=8, channels=4, num_classes=19, mode="dot_product") inputs = [torch.randn(1, 8, 23, 23)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) @@ -27,7 +26,7 @@ def test_dnl_head(): assert outputs.shape == (1, head.num_classes, 23, 23) # NonLocal2d with 'gaussian' mode - head = DNLHead(in_channels=8, channels=4, num_classes=19, mode='gaussian') + head = DNLHead(in_channels=8, channels=4, num_classes=19, mode="gaussian") inputs = [torch.randn(1, 8, 23, 23)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) @@ -35,8 +34,7 @@ def test_dnl_head(): assert outputs.shape == (1, head.num_classes, 23, 23) # NonLocal2d with 'concatenation' mode - head = DNLHead( - in_channels=8, channels=4, num_classes=19, mode='concatenation') + head = DNLHead(in_channels=8, channels=4, num_classes=19, mode="concatenation") inputs = [torch.randn(1, 8, 23, 23)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) diff --git a/mmsegmentation/tests/test_models/test_heads/test_dpt_head.py b/mmsegmentation/tests/test_models/test_heads/test_dpt_head.py index 0a6af61..f6b61f8 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_dpt_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_dpt_head.py @@ -6,24 +6,24 @@ def test_dpt_head(): - with pytest.raises(AssertionError): # input_transform must be 'multiple_select' head = DPTHead( in_channels=[768, 768, 768, 768], channels=4, num_classes=19, - in_index=[0, 1, 2, 3]) + in_index=[0, 1, 2, 3], + ) head = DPTHead( in_channels=[768, 768, 768, 768], channels=4, num_classes=19, in_index=[0, 1, 2, 3], - input_transform='multiple_select') + input_transform="multiple_select", + ) - inputs = [[torch.randn(4, 768, 2, 2), - torch.randn(4, 768)] for _ in range(4)] + inputs = [[torch.randn(4, 768, 2, 2), torch.randn(4, 768)] for _ in range(4)] output = head(inputs) assert output.shape == torch.Size((4, 19, 16, 16)) @@ -33,8 +33,9 @@ def test_dpt_head(): channels=4, num_classes=19, in_index=[0, 1, 2, 3], - input_transform='multiple_select', - readout_type='add') + input_transform="multiple_select", + readout_type="add", + ) output = head(inputs) assert output.shape == torch.Size((4, 19, 16, 16)) @@ -43,7 +44,8 @@ def test_dpt_head(): channels=4, num_classes=19, in_index=[0, 1, 2, 3], - input_transform='multiple_select', - readout_type='project') + input_transform="multiple_select", + readout_type="project", + ) output = head(inputs) assert output.shape == torch.Size((4, 19, 16, 16)) diff --git a/mmsegmentation/tests/test_models/test_heads/test_ema_head.py b/mmsegmentation/tests/test_models/test_heads/test_ema_head.py index 1811cd2..9e0f6c0 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_ema_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_ema_head.py @@ -12,10 +12,11 @@ def test_emanet_head(): channels=2, num_stages=3, num_bases=2, - num_classes=19) + num_classes=19, + ) for param in head.ema_mid_conv.parameters(): assert not param.requires_grad - assert hasattr(head, 'ema_module') + assert hasattr(head, "ema_module") inputs = [torch.randn(1, 4, 23, 23)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) diff --git a/mmsegmentation/tests/test_models/test_heads/test_enc_head.py b/mmsegmentation/tests/test_models/test_heads/test_enc_head.py index 9c84c75..93b7720 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_enc_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_enc_head.py @@ -19,11 +19,8 @@ def test_enc_head(): # w.o se_loss, w.o. lateral inputs = [torch.randn(1, 8, 21, 21)] head = EncHead( - in_channels=[8], - channels=4, - use_se_loss=False, - num_classes=19, - in_index=[-1]) + in_channels=[8], channels=4, use_se_loss=False, num_classes=19, in_index=[-1] + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) outputs = head(inputs) @@ -36,7 +33,8 @@ def test_enc_head(): channels=4, add_lateral=True, num_classes=19, - in_index=[-2, -1]) + in_index=[-2, -1], + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) outputs = head(inputs) diff --git a/mmsegmentation/tests/test_models/test_heads/test_fcn_head.py b/mmsegmentation/tests/test_models/test_heads/test_fcn_head.py index 4e633fb..5e87760 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_fcn_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_fcn_head.py @@ -9,7 +9,6 @@ def test_fcn_head(): - with pytest.raises(AssertionError): # num_convs must be not less than 0 FCNHead(num_classes=19, num_convs=-1) @@ -22,29 +21,25 @@ def test_fcn_head(): # test with norm_cfg head = FCNHead( - in_channels=8, - channels=4, - num_classes=19, - norm_cfg=dict(type='SyncBN')) + in_channels=8, channels=4, num_classes=19, norm_cfg=dict(type="SyncBN") + ) for m in head.modules(): if isinstance(m, ConvModule): assert m.with_norm and isinstance(m.bn, SyncBatchNorm) # test concat_input=False inputs = [torch.randn(1, 8, 23, 23)] - head = FCNHead( - in_channels=8, channels=4, num_classes=19, concat_input=False) + head = FCNHead(in_channels=8, channels=4, num_classes=19, concat_input=False) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) assert len(head.convs) == 2 - assert not head.concat_input and not hasattr(head, 'conv_cat') + assert not head.concat_input and not hasattr(head, "conv_cat") outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 23, 23) # test concat_input=True inputs = [torch.randn(1, 8, 23, 23)] - head = FCNHead( - in_channels=8, channels=4, num_classes=19, concat_input=True) + head = FCNHead(in_channels=8, channels=4, num_classes=19, concat_input=True) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) assert len(head.convs) == 2 @@ -87,11 +82,8 @@ def test_fcn_head(): # test num_conv = 0 inputs = [torch.randn(1, 8, 23, 23)] head = FCNHead( - in_channels=8, - channels=8, - num_classes=19, - num_convs=0, - concat_input=False) + in_channels=8, channels=8, num_classes=19, num_convs=0, concat_input=False + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) assert isinstance(head.convs, torch.nn.Identity) @@ -107,7 +99,8 @@ def test_sep_fcn_head(): concat_input=False, num_classes=19, in_index=-1, - norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01)) + norm_cfg=dict(type="BN", requires_grad=True, momentum=0.01), + ) x = [torch.rand(2, 128, 8, 8)] output = head(x) assert output.shape == (2, head.num_classes, 8, 8) @@ -122,7 +115,8 @@ def test_sep_fcn_head(): concat_input=True, num_classes=19, in_index=-1, - norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01)) + norm_cfg=dict(type="BN", requires_grad=True, momentum=0.01), + ) x = [torch.rand(3, 64, 8, 8)] output = head(x) assert output.shape == (3, head.num_classes, 8, 8) diff --git a/mmsegmentation/tests/test_models/test_heads/test_gc_head.py b/mmsegmentation/tests/test_models/test_heads/test_gc_head.py index c62ac9a..99c382b 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_gc_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_gc_head.py @@ -8,7 +8,7 @@ def test_gc_head(): head = GCHead(in_channels=4, channels=4, num_classes=19) assert len(head.convs) == 2 - assert hasattr(head, 'gc_block') + assert hasattr(head, "gc_block") inputs = [torch.randn(1, 4, 23, 23)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) diff --git a/mmsegmentation/tests/test_models/test_heads/test_isa_head.py b/mmsegmentation/tests/test_models/test_heads/test_isa_head.py index b177f6d..ddd9f8b 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_isa_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_isa_head.py @@ -6,14 +6,10 @@ def test_isa_head(): - inputs = [torch.randn(1, 8, 23, 23)] isa_head = ISAHead( - in_channels=8, - channels=4, - num_classes=19, - isa_channels=4, - down_factor=(8, 8)) + in_channels=8, channels=4, num_classes=19, isa_channels=4, down_factor=(8, 8) + ) if torch.cuda.is_available(): isa_head, inputs = to_cuda(isa_head, inputs) output = isa_head(inputs) diff --git a/mmsegmentation/tests/test_models/test_heads/test_knet_head.py b/mmsegmentation/tests/test_models/test_heads/test_knet_head.py index e6845a6..4bbc50a 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_knet_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_knet_head.py @@ -1,22 +1,22 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmseg.models.decode_heads.knet_head import (IterativeDecodeHead, - KernelUpdateHead) +from mmseg.models.decode_heads.knet_head import IterativeDecodeHead, KernelUpdateHead from .utils import to_cuda num_stages = 3 conv_kernel_size = 1 kernel_updator_cfg = dict( - type='KernelUpdator', + type="KernelUpdator", in_channels=16, feat_channels=16, out_channels=16, gate_norm_act=True, activate_out=True, - act_cfg=dict(type='ReLU', inplace=True), - norm_cfg=dict(type='LN')) + act_cfg=dict(type="ReLU", inplace=True), + norm_cfg=dict(type="LN"), +) def test_knet_head(): @@ -31,18 +31,19 @@ def test_knet_head(): out_channels=32, dropout=0.0, conv_kernel_size=conv_kernel_size, - ffn_act_cfg=dict(type='ReLU', inplace=True), + ffn_act_cfg=dict(type="ReLU", inplace=True), with_ffn=True, - feat_transform_cfg=dict(conv_cfg=dict(type='Conv2d'), act_cfg=None), + feat_transform_cfg=dict(conv_cfg=dict(type="Conv2d"), act_cfg=None), kernel_init=True, - kernel_updator_cfg=kernel_updator_cfg) + kernel_updator_cfg=kernel_updator_cfg, + ) kernel_update_head.init_weights() head = IterativeDecodeHead( num_stages=num_stages, kernel_update_head=[ dict( - type='KernelUpdateHead', + type="KernelUpdateHead", num_classes=150, num_ffn_fcs=2, num_heads=8, @@ -52,16 +53,16 @@ def test_knet_head(): out_channels=32, dropout=0.0, conv_kernel_size=conv_kernel_size, - ffn_act_cfg=dict(type='ReLU', inplace=True), + ffn_act_cfg=dict(type="ReLU", inplace=True), with_ffn=True, - feat_transform_cfg=dict( - conv_cfg=dict(type='Conv2d'), act_cfg=None), + feat_transform_cfg=dict(conv_cfg=dict(type="Conv2d"), act_cfg=None), kernel_init=False, - kernel_updator_cfg=kernel_updator_cfg) + kernel_updator_cfg=kernel_updator_cfg, + ) for _ in range(num_stages) ], kernel_generate_head=dict( - type='FCNHead', + type="FCNHead", in_channels=128, in_index=3, channels=32, @@ -69,13 +70,15 @@ def test_knet_head(): concat_input=True, dropout_ratio=0.1, num_classes=150, - align_corners=False)) + align_corners=False, + ), + ) head.init_weights() inputs = [ torch.randn(1, 16, 27, 32), torch.randn(1, 32, 27, 16), torch.randn(1, 64, 27, 16), - torch.randn(1, 128, 27, 16) + torch.randn(1, 128, 27, 16), ] if torch.cuda.is_available(): @@ -95,7 +98,7 @@ def test_knet_head(): num_stages=num_stages, kernel_update_head=[ dict( - type='KernelUpdateHead', + type="KernelUpdateHead", num_classes=150, num_ffn_fcs=2, num_heads=8, @@ -105,14 +108,15 @@ def test_knet_head(): out_channels=32, dropout=0.0, conv_kernel_size=conv_kernel_size, - ffn_act_cfg=dict(type='ReLU', inplace=True), + ffn_act_cfg=dict(type="ReLU", inplace=True), with_ffn=True, feat_transform_cfg=None, - kernel_updator_cfg=kernel_updator_cfg) + kernel_updator_cfg=kernel_updator_cfg, + ) for _ in range(num_stages) ], kernel_generate_head=dict( - type='FCNHead', + type="FCNHead", in_channels=128, in_index=3, channels=32, @@ -120,14 +124,16 @@ def test_knet_head(): concat_input=True, dropout_ratio=0.1, num_classes=150, - align_corners=False)) + align_corners=False, + ), + ) head.init_weights() inputs = [ torch.randn(1, 16, 27, 32), torch.randn(1, 32, 27, 16), torch.randn(1, 64, 27, 16), - torch.randn(1, 128, 27, 16) + torch.randn(1, 128, 27, 16), ] if torch.cuda.is_available(): @@ -141,7 +147,7 @@ def test_knet_head(): num_stages=num_stages, kernel_update_head=[ dict( - type='KernelUpdateHead', + type="KernelUpdateHead", num_classes=150, num_ffn_fcs=2, num_heads=8, @@ -151,18 +157,18 @@ def test_knet_head(): out_channels=32, dropout=0.0, conv_kernel_size=conv_kernel_size, - ffn_act_cfg=dict(type='ReLU', inplace=True), + ffn_act_cfg=dict(type="ReLU", inplace=True), with_ffn=True, - feat_transform_cfg=dict( - conv_cfg=dict(type='Conv2d'), act_cfg=None), + feat_transform_cfg=dict(conv_cfg=dict(type="Conv2d"), act_cfg=None), kernel_init=False, mask_transform_stride=2, feat_gather_stride=1, - kernel_updator_cfg=kernel_updator_cfg) + kernel_updator_cfg=kernel_updator_cfg, + ) for _ in range(num_stages) ], kernel_generate_head=dict( - type='FCNHead', + type="FCNHead", in_channels=128, in_index=3, channels=32, @@ -170,14 +176,16 @@ def test_knet_head(): concat_input=True, dropout_ratio=0.1, num_classes=150, - align_corners=False)) + align_corners=False, + ), + ) head.init_weights() inputs = [ torch.randn(1, 16, 27, 32), torch.randn(1, 32, 27, 16), torch.randn(1, 64, 27, 16), - torch.randn(1, 128, 27, 16) + torch.randn(1, 128, 27, 16), ] if torch.cuda.is_available(): @@ -186,10 +194,9 @@ def test_knet_head(): assert outputs[-1].shape == (1, head.num_classes, 26, 16) # test loss function in K-Net - fake_label = torch.ones_like( - outputs[-1][:, 0:1, :, :], dtype=torch.int16).long() + fake_label = torch.ones_like(outputs[-1][:, 0:1, :, :], dtype=torch.int16).long() loss = head.losses(seg_logit=outputs, seg_label=fake_label) - assert loss['loss_ce.s0'] != torch.zeros_like(loss['loss_ce.s0']) - assert loss['loss_ce.s1'] != torch.zeros_like(loss['loss_ce.s1']) - assert loss['loss_ce.s2'] != torch.zeros_like(loss['loss_ce.s2']) - assert loss['loss_ce.s3'] != torch.zeros_like(loss['loss_ce.s3']) + assert loss["loss_ce.s0"] != torch.zeros_like(loss["loss_ce.s0"]) + assert loss["loss_ce.s1"] != torch.zeros_like(loss["loss_ce.s1"]) + assert loss["loss_ce.s2"] != torch.zeros_like(loss["loss_ce.s2"]) + assert loss["loss_ce.s3"] != torch.zeros_like(loss["loss_ce.s3"]) diff --git a/mmsegmentation/tests/test_models/test_heads/test_lraspp_head.py b/mmsegmentation/tests/test_models/test_heads/test_lraspp_head.py index a46e6a1..f637e95 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_lraspp_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_lraspp_head.py @@ -12,14 +12,16 @@ def test_lraspp_head(): in_channels=(4, 4, 123), in_index=(0, 1, 2), channels=32, - input_transform='resize_concat', + input_transform="resize_concat", dropout_ratio=0.1, num_classes=19, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0 + ), + ) with pytest.raises(AssertionError): # check invalid branch_channels @@ -28,32 +30,34 @@ def test_lraspp_head(): in_index=(0, 1, 2), channels=32, branch_channels=64, - input_transform='multiple_select', + input_transform="multiple_select", dropout_ratio=0.1, num_classes=19, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0 + ), + ) # test with default settings lraspp_head = LRASPPHead( in_channels=(4, 4, 123), in_index=(0, 1, 2), channels=32, - input_transform='multiple_select', + input_transform="multiple_select", dropout_ratio=0.1, num_classes=19, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), + norm_cfg=dict(type="BN"), + act_cfg=dict(type="ReLU"), align_corners=False, - loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + loss_decode=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), + ) inputs = [ torch.randn(2, 4, 45, 45), torch.randn(2, 4, 28, 28), - torch.randn(2, 123, 14, 14) + torch.randn(2, 123, 14, 14), ] with pytest.raises(RuntimeError): # check invalid inputs @@ -62,7 +66,7 @@ def test_lraspp_head(): inputs = [ torch.randn(2, 4, 111, 111), torch.randn(2, 4, 77, 77), - torch.randn(2, 123, 55, 55) + torch.randn(2, 123, 55, 55), ] output = lraspp_head(inputs) assert output.shape == (2, 19, 111, 111) diff --git a/mmsegmentation/tests/test_models/test_heads/test_nl_head.py b/mmsegmentation/tests/test_models/test_heads/test_nl_head.py index d4ef0b9..1d631f0 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_nl_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_nl_head.py @@ -8,7 +8,7 @@ def test_nl_head(): head = NLHead(in_channels=8, channels=4, num_classes=19) assert len(head.convs) == 2 - assert hasattr(head, 'nl_block') + assert hasattr(head, "nl_block") inputs = [torch.randn(1, 8, 23, 23)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) diff --git a/mmsegmentation/tests/test_models/test_heads/test_ocr_head.py b/mmsegmentation/tests/test_models/test_heads/test_ocr_head.py index 5e5d669..cd223ee 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_ocr_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_ocr_head.py @@ -6,10 +6,8 @@ def test_ocr_head(): - inputs = [torch.randn(1, 8, 23, 23)] - ocr_head = OCRHead( - in_channels=8, channels=4, num_classes=19, ocr_channels=8) + ocr_head = OCRHead(in_channels=8, channels=4, num_classes=19, ocr_channels=8) fcn_head = FCNHead(in_channels=8, channels=4, num_classes=19) if torch.cuda.is_available(): head, inputs = to_cuda(ocr_head, inputs) diff --git a/mmsegmentation/tests/test_models/test_heads/test_point_head.py b/mmsegmentation/tests/test_models/test_heads/test_point_head.py index 142ab16..588c4d2 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_point_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_point_head.py @@ -7,10 +7,8 @@ def test_point_head(): - inputs = [torch.randn(1, 32, 45, 45)] - point_head = PointHead( - in_channels=[32], in_index=[0], channels=16, num_classes=19) + point_head = PointHead(in_channels=[32], in_index=[0], channels=16, num_classes=19) assert len(point_head.fcs) == 3 fcn_head = FCNHead(in_channels=32, channels=16, num_classes=19) if torch.cuda.is_available(): @@ -18,7 +16,8 @@ def test_point_head(): head, inputs = to_cuda(fcn_head, inputs) prev_output = fcn_head(inputs) test_cfg = ConfigDict( - subdivision_steps=2, subdivision_num_points=8196, scale_factor=2) + subdivision_steps=2, subdivision_num_points=8196, scale_factor=2 + ) output = point_head.forward_test(inputs, prev_output, None, test_cfg) assert output.shape == (1, point_head.num_classes, 180, 180) @@ -30,26 +29,30 @@ def test_point_head(): channels=16, num_classes=19, loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_1'), - dict(type='CrossEntropyLoss', loss_name='loss_2') - ]) + dict(type="CrossEntropyLoss", loss_name="loss_1"), + dict(type="CrossEntropyLoss", loss_name="loss_2"), + ], + ) assert len(point_head_multiple_losses.fcs) == 3 fcn_head_multiple_losses = FCNHead( in_channels=32, channels=16, num_classes=19, loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_1'), - dict(type='CrossEntropyLoss', loss_name='loss_2') - ]) + dict(type="CrossEntropyLoss", loss_name="loss_1"), + dict(type="CrossEntropyLoss", loss_name="loss_2"), + ], + ) if torch.cuda.is_available(): head, inputs = to_cuda(point_head_multiple_losses, inputs) head, inputs = to_cuda(fcn_head_multiple_losses, inputs) prev_output = fcn_head_multiple_losses(inputs) test_cfg = ConfigDict( - subdivision_steps=2, subdivision_num_points=8196, scale_factor=2) - output = point_head_multiple_losses.forward_test(inputs, prev_output, None, - test_cfg) + subdivision_steps=2, subdivision_num_points=8196, scale_factor=2 + ) + output = point_head_multiple_losses.forward_test( + inputs, prev_output, None, test_cfg + ) assert output.shape == (1, point_head.num_classes, 180, 180) fake_label = torch.ones([1, 180, 180], dtype=torch.long) @@ -57,5 +60,5 @@ def test_point_head(): if torch.cuda.is_available(): fake_label = fake_label.cuda() loss = point_head_multiple_losses.losses(output, fake_label) - assert 'pointloss_1' in loss - assert 'pointloss_2' in loss + assert "pointloss_1" in loss + assert "pointloss_2" in loss diff --git a/mmsegmentation/tests/test_models/test_heads/test_psa_head.py b/mmsegmentation/tests/test_models/test_heads/test_psa_head.py index 34f592b..dc3f73b 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_psa_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_psa_head.py @@ -7,7 +7,6 @@ def test_psa_head(): - with pytest.raises(AssertionError): # psa_type must be in 'bi-direction', 'collect', 'distribute' PSAHead( @@ -15,11 +14,11 @@ def test_psa_head(): channels=2, num_classes=19, mask_size=(13, 13), - psa_type='gather') + psa_type="gather", + ) # test no norm_cfg - head = PSAHead( - in_channels=4, channels=2, num_classes=19, mask_size=(13, 13)) + head = PSAHead(in_channels=4, channels=2, num_classes=19, mask_size=(13, 13)) assert not _conv_has_norm(head, sync_bn=False) # test with norm_cfg @@ -28,13 +27,13 @@ def test_psa_head(): channels=2, num_classes=19, mask_size=(13, 13), - norm_cfg=dict(type='SyncBN')) + norm_cfg=dict(type="SyncBN"), + ) assert _conv_has_norm(head, sync_bn=True) # test 'bi-direction' psa_type inputs = [torch.randn(1, 4, 13, 13)] - head = PSAHead( - in_channels=4, channels=2, num_classes=19, mask_size=(13, 13)) + head = PSAHead(in_channels=4, channels=2, num_classes=19, mask_size=(13, 13)) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) outputs = head(inputs) @@ -43,11 +42,8 @@ def test_psa_head(): # test 'bi-direction' psa_type, shrink_factor=1 inputs = [torch.randn(1, 4, 13, 13)] head = PSAHead( - in_channels=4, - channels=2, - num_classes=19, - mask_size=(13, 13), - shrink_factor=1) + in_channels=4, channels=2, num_classes=19, mask_size=(13, 13), shrink_factor=1 + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) outputs = head(inputs) @@ -56,11 +52,8 @@ def test_psa_head(): # test 'bi-direction' psa_type with soft_max inputs = [torch.randn(1, 4, 13, 13)] head = PSAHead( - in_channels=4, - channels=2, - num_classes=19, - mask_size=(13, 13), - psa_softmax=True) + in_channels=4, channels=2, num_classes=19, mask_size=(13, 13), psa_softmax=True + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) outputs = head(inputs) @@ -73,7 +66,8 @@ def test_psa_head(): channels=2, num_classes=19, mask_size=(13, 13), - psa_type='collect') + psa_type="collect", + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) outputs = head(inputs) @@ -87,7 +81,8 @@ def test_psa_head(): num_classes=19, mask_size=(13, 13), shrink_factor=1, - psa_type='collect') + psa_type="collect", + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) outputs = head(inputs) @@ -100,9 +95,10 @@ def test_psa_head(): channels=2, num_classes=19, mask_size=(13, 13), - psa_type='collect', + psa_type="collect", shrink_factor=1, - compact=True) + compact=True, + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) outputs = head(inputs) @@ -115,7 +111,8 @@ def test_psa_head(): channels=2, num_classes=19, mask_size=(13, 13), - psa_type='distribute') + psa_type="distribute", + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) outputs = head(inputs) diff --git a/mmsegmentation/tests/test_models/test_heads/test_psp_head.py b/mmsegmentation/tests/test_models/test_heads/test_psp_head.py index fde4087..eb6c859 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_psp_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_psp_head.py @@ -7,7 +7,6 @@ def test_psp_head(): - with pytest.raises(AssertionError): # pool_scales must be list|tuple PSPHead(in_channels=4, channels=2, num_classes=19, pool_scales=1) @@ -18,15 +17,12 @@ def test_psp_head(): # test with norm_cfg head = PSPHead( - in_channels=4, - channels=2, - num_classes=19, - norm_cfg=dict(type='SyncBN')) + in_channels=4, channels=2, num_classes=19, norm_cfg=dict(type="SyncBN") + ) assert _conv_has_norm(head, sync_bn=True) inputs = [torch.randn(1, 4, 23, 23)] - head = PSPHead( - in_channels=4, channels=2, num_classes=19, pool_scales=(1, 2, 3)) + head = PSPHead(in_channels=4, channels=2, num_classes=19, pool_scales=(1, 2, 3)) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) assert head.psp_modules[0][0].output_size == 1 diff --git a/mmsegmentation/tests/test_models/test_heads/test_segformer_head.py b/mmsegmentation/tests/test_models/test_heads/test_segformer_head.py index 73afaba..28324dc 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_segformer_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_segformer_head.py @@ -8,18 +8,14 @@ def test_segformer_head(): with pytest.raises(AssertionError): # `in_channels` must have same length as `in_index` - SegformerHead( - in_channels=(1, 2, 3), in_index=(0, 1), channels=5, num_classes=2) + SegformerHead(in_channels=(1, 2, 3), in_index=(0, 1), channels=5, num_classes=2) H, W = (64, 64) in_channels = (32, 64, 160, 256) - shapes = [(H // 2**(i + 2), W // 2**(i + 2)) - for i in range(len(in_channels))] + shapes = [(H // 2 ** (i + 2), W // 2 ** (i + 2)) for i in range(len(in_channels))] model = SegformerHead( - in_channels=in_channels, - in_index=[0, 1, 2, 3], - channels=256, - num_classes=19) + in_channels=in_channels, in_index=[0, 1, 2, 3], channels=256, num_classes=19 + ) with pytest.raises(IndexError): # in_index must match the input feature maps. diff --git a/mmsegmentation/tests/test_models/test_heads/test_segmenter_mask_head.py b/mmsegmentation/tests/test_models/test_heads/test_segmenter_mask_head.py index 7b681ac..80b6ed0 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_segmenter_mask_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_segmenter_mask_head.py @@ -13,7 +13,8 @@ def test_segmenter_mask_transformer_head(): num_layers=2, num_heads=3, embed_dims=192, - dropout_ratio=0.0) + dropout_ratio=0.0, + ) assert _conv_has_norm(head, sync_bn=True) head.init_weights() diff --git a/mmsegmentation/tests/test_models/test_heads/test_setr_mla_head.py b/mmsegmentation/tests/test_models/test_heads/test_setr_mla_head.py index 301bc0b..47b2a45 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_setr_mla_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_setr_mla_head.py @@ -7,15 +7,13 @@ def test_setr_mla_head(capsys): - with pytest.raises(AssertionError): # MLA requires input multiple stage feature information. SETRMLAHead(in_channels=8, channels=4, num_classes=19, in_index=1) with pytest.raises(AssertionError): # multiple in_indexs requires multiple in_channels. - SETRMLAHead( - in_channels=8, channels=4, num_classes=19, in_index=(0, 1, 2, 3)) + SETRMLAHead(in_channels=8, channels=4, num_classes=19, in_index=(0, 1, 2, 3)) with pytest.raises(AssertionError): # channels should be len(in_channels) * mla_channels @@ -24,7 +22,8 @@ def test_setr_mla_head(capsys): channels=8, mla_channels=4, in_index=(0, 1, 2, 3), - num_classes=19) + num_classes=19, + ) # test inference of MLA head img_size = (8, 8) @@ -35,7 +34,8 @@ def test_setr_mla_head(capsys): mla_channels=4, in_index=(0, 1, 2, 3), num_classes=19, - norm_cfg=dict(type='BN')) + norm_cfg=dict(type="BN"), + ) h, w = img_size[0] // patch_size, img_size[1] // patch_size # Input square NCHW format feature information @@ -43,7 +43,7 @@ def test_setr_mla_head(capsys): torch.randn(1, 8, h, w), torch.randn(1, 8, h, w), torch.randn(1, 8, h, w), - torch.randn(1, 8, h, w) + torch.randn(1, 8, h, w), ] if torch.cuda.is_available(): head, x = to_cuda(head, x) @@ -55,7 +55,7 @@ def test_setr_mla_head(capsys): torch.randn(1, 8, h, w * 2), torch.randn(1, 8, h, w * 2), torch.randn(1, 8, h, w * 2), - torch.randn(1, 8, h, w * 2) + torch.randn(1, 8, h, w * 2), ] if torch.cuda.is_available(): head, x = to_cuda(head, x) diff --git a/mmsegmentation/tests/test_models/test_heads/test_setr_up_head.py b/mmsegmentation/tests/test_models/test_heads/test_setr_up_head.py index a051922..a91d8d0 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_setr_up_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_setr_up_head.py @@ -7,7 +7,6 @@ def test_setr_up_head(capsys): - with pytest.raises(AssertionError): # kernel_size must be [1/3] SETRUPHead(num_classes=19, kernel_size=2) @@ -21,9 +20,10 @@ def test_setr_up_head(capsys): head = SETRUPHead( in_channels=4, channels=2, - norm_cfg=dict(type='SyncBN'), + norm_cfg=dict(type="SyncBN"), num_classes=19, - init_cfg=dict(type='Kaiming')) + init_cfg=dict(type="Kaiming"), + ) super(SETRUPHead, head).init_weights() # test inference of Naive head @@ -37,7 +37,8 @@ def test_setr_up_head(capsys): num_convs=1, up_scale=4, kernel_size=1, - norm_cfg=dict(type='BN')) + norm_cfg=dict(type="BN"), + ) h, w = img_size[0] // patch_size, img_size[1] // patch_size diff --git a/mmsegmentation/tests/test_models/test_heads/test_stdc_head.py b/mmsegmentation/tests/test_models/test_heads/test_stdc_head.py index 1628209..328f1ff 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_stdc_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_stdc_head.py @@ -14,18 +14,17 @@ def test_stdc_head(): num_classes=2, in_index=-1, loss_decode=[ - dict( - type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), - dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.0) - ]) + dict(type="CrossEntropyLoss", loss_name="loss_ce", loss_weight=1.0), + dict(type="DiceLoss", loss_name="loss_dice", loss_weight=1.0), + ], + ) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) outputs = head(inputs) assert isinstance(outputs, torch.Tensor) and len(outputs) == 1 assert outputs.shape == torch.Size([1, head.num_classes, 21, 21]) - fake_label = torch.ones_like( - outputs[:, 0:1, :, :], dtype=torch.int16).long() + fake_label = torch.ones_like(outputs[:, 0:1, :, :], dtype=torch.int16).long() loss = head.losses(seg_logit=outputs, seg_label=fake_label) - assert loss['loss_ce'] != torch.zeros_like(loss['loss_ce']) - assert loss['loss_dice'] != torch.zeros_like(loss['loss_dice']) + assert loss["loss_ce"] != torch.zeros_like(loss["loss_ce"]) + assert loss["loss_dice"] != torch.zeros_like(loss["loss_dice"]) diff --git a/mmsegmentation/tests/test_models/test_heads/test_uper_head.py b/mmsegmentation/tests/test_models/test_heads/test_uper_head.py index 09456a8..8fa4ac0 100644 --- a/mmsegmentation/tests/test_models/test_heads/test_uper_head.py +++ b/mmsegmentation/tests/test_models/test_heads/test_uper_head.py @@ -7,14 +7,12 @@ def test_uper_head(): - with pytest.raises(AssertionError): # fpn_in_channels must be list|tuple UPerHead(in_channels=4, channels=2, num_classes=19) # test no norm_cfg - head = UPerHead( - in_channels=[4, 2], channels=2, num_classes=19, in_index=[-2, -1]) + head = UPerHead(in_channels=[4, 2], channels=2, num_classes=19, in_index=[-2, -1]) assert not _conv_has_norm(head, sync_bn=False) # test with norm_cfg @@ -22,13 +20,13 @@ def test_uper_head(): in_channels=[4, 2], channels=2, num_classes=19, - norm_cfg=dict(type='SyncBN'), - in_index=[-2, -1]) + norm_cfg=dict(type="SyncBN"), + in_index=[-2, -1], + ) assert _conv_has_norm(head, sync_bn=True) inputs = [torch.randn(1, 4, 45, 45), torch.randn(1, 2, 21, 21)] - head = UPerHead( - in_channels=[4, 2], channels=2, num_classes=19, in_index=[-2, -1]) + head = UPerHead(in_channels=[4, 2], channels=2, num_classes=19, in_index=[-2, -1]) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) outputs = head(inputs) diff --git a/mmsegmentation/tests/test_models/test_losses/test_ce_loss.py b/mmsegmentation/tests/test_models/test_losses/test_ce_loss.py index afa5706..4400b5c 100644 --- a/mmsegmentation/tests/test_models/test_losses/test_ce_loss.py +++ b/mmsegmentation/tests/test_models/test_losses/test_ce_loss.py @@ -5,38 +5,35 @@ from mmseg.models.losses.cross_entropy_loss import _expand_onehot_labels -@pytest.mark.parametrize('use_sigmoid', [True, False]) -@pytest.mark.parametrize('reduction', ('mean', 'sum', 'none')) -@pytest.mark.parametrize('avg_non_ignore', [True, False]) -@pytest.mark.parametrize('bce_input_same_dim', [True, False]) +@pytest.mark.parametrize("use_sigmoid", [True, False]) +@pytest.mark.parametrize("reduction", ("mean", "sum", "none")) +@pytest.mark.parametrize("avg_non_ignore", [True, False]) +@pytest.mark.parametrize("bce_input_same_dim", [True, False]) def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim): from mmseg.models import build_loss # use_mask and use_sigmoid cannot be true at the same time with pytest.raises(AssertionError): loss_cfg = dict( - type='CrossEntropyLoss', - use_mask=True, - use_sigmoid=True, - loss_weight=1.0) + type="CrossEntropyLoss", use_mask=True, use_sigmoid=True, loss_weight=1.0 + ) build_loss(loss_cfg) # test loss with simple case for ce/bce fake_pred = torch.Tensor([[100, -100]]) fake_label = torch.Tensor([1]).long() loss_cls_cfg = dict( - type='CrossEntropyLoss', + type="CrossEntropyLoss", use_sigmoid=use_sigmoid, loss_weight=1.0, avg_non_ignore=avg_non_ignore, - loss_name='loss_ce') + loss_name="loss_ce", + ) loss_cls = build_loss(loss_cls_cfg) if use_sigmoid: - assert torch.allclose( - loss_cls(fake_pred, fake_label), torch.tensor(100.)) + assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.0)) else: - assert torch.allclose( - loss_cls(fake_pred, fake_label), torch.tensor(200.)) + assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.0)) # test loss with complicated case for ce/bce # when avg_non_ignore is False, `avg_factor` would not be calculated @@ -52,24 +49,22 @@ def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim): fake_label[0, [1, 2, 5, 7]] = 255 # set ignore_index fake_label[1, [0, 5, 8, 9]] = 255 loss_cls = build_loss(loss_cls_cfg) - loss = loss_cls( - fake_pred, fake_label, weight=fake_weight, ignore_index=255) + loss = loss_cls(fake_pred, fake_label, weight=fake_weight, ignore_index=255) if use_sigmoid: if fake_pred.dim() != fake_label.dim(): fake_label, weight, valid_mask = _expand_onehot_labels( labels=fake_label, label_weights=None, target_shape=fake_pred.shape, - ignore_index=255) + ignore_index=255, + ) else: # should mask out the ignored elements valid_mask = ((fake_label >= 0) & (fake_label != 255)).float() weight = valid_mask torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( - fake_pred, - fake_label.float(), - reduction='none', - weight=fake_weight) + fake_pred, fake_label.float(), reduction="none", weight=fake_weight + ) if avg_non_ignore: avg_factor = valid_mask.sum().item() torch_loss = (torch_loss * weight).sum() / avg_factor @@ -78,11 +73,15 @@ def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim): else: if avg_non_ignore: torch_loss = torch.nn.functional.cross_entropy( - fake_pred, fake_label, reduction='mean', ignore_index=255) + fake_pred, fake_label, reduction="mean", ignore_index=255 + ) else: - torch_loss = torch.nn.functional.cross_entropy( - fake_pred, fake_label, reduction='sum', - ignore_index=255) / fake_label.numel() + torch_loss = ( + torch.nn.functional.cross_entropy( + fake_pred, fake_label, reduction="sum", ignore_index=255 + ) + / fake_label.numel() + ) assert torch.allclose(loss, torch_loss) if use_sigmoid: @@ -94,19 +93,20 @@ def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim): fake_weight = torch.rand(2, 8, 8) loss_cls = build_loss(loss_cls_cfg) - loss = loss_cls( - fake_pred, fake_label, weight=fake_weight, ignore_index=255) + loss = loss_cls(fake_pred, fake_label, weight=fake_weight, ignore_index=255) if use_sigmoid: fake_label, weight, valid_mask = _expand_onehot_labels( labels=fake_label, label_weights=None, target_shape=fake_pred.shape, - ignore_index=255) + ignore_index=255, + ) torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( fake_pred, fake_label.float(), - reduction='none', - weight=fake_weight.unsqueeze(1).expand(fake_pred.shape)) + reduction="none", + weight=fake_weight.unsqueeze(1).expand(fake_pred.shape), + ) if avg_non_ignore: avg_factor = valid_mask.sum().item() torch_loss = (torch_loss * weight).sum() / avg_factor @@ -122,56 +122,61 @@ def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim): import mmcv import numpy as np + tmp_file = tempfile.NamedTemporaryFile() - mmcv.dump([0.8, 0.2], f'{tmp_file.name}.pkl', 'pkl') # from pkl file + mmcv.dump([0.8, 0.2], f"{tmp_file.name}.pkl", "pkl") # from pkl file loss_cls_cfg = dict( - type='CrossEntropyLoss', + type="CrossEntropyLoss", use_sigmoid=False, - class_weight=f'{tmp_file.name}.pkl', + class_weight=f"{tmp_file.name}.pkl", loss_weight=1.0, - loss_name='loss_ce') + loss_name="loss_ce", + ) loss_cls = build_loss(loss_cls_cfg) - assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) + assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.0)) - np.save(f'{tmp_file.name}.npy', np.array([0.8, 0.2])) # from npy file + np.save(f"{tmp_file.name}.npy", np.array([0.8, 0.2])) # from npy file loss_cls_cfg = dict( - type='CrossEntropyLoss', + type="CrossEntropyLoss", use_sigmoid=False, - class_weight=f'{tmp_file.name}.npy', + class_weight=f"{tmp_file.name}.npy", loss_weight=1.0, - loss_name='loss_ce') + loss_name="loss_ce", + ) loss_cls = build_loss(loss_cls_cfg) - assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) + assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.0)) tmp_file.close() - os.remove(f'{tmp_file.name}.pkl') - os.remove(f'{tmp_file.name}.npy') + os.remove(f"{tmp_file.name}.pkl") + os.remove(f"{tmp_file.name}.npy") - loss_cls_cfg = dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) + loss_cls_cfg = dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0) loss_cls = build_loss(loss_cls_cfg) - assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.)) + assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.0)) # test `avg_non_ignore` without ignore index would not affect ce/bce loss # when reduction='sum'/'none'/'mean' loss_cls_cfg1 = dict( - type='CrossEntropyLoss', + type="CrossEntropyLoss", use_sigmoid=use_sigmoid, reduction=reduction, loss_weight=1.0, - avg_non_ignore=True) + avg_non_ignore=True, + ) loss_cls1 = build_loss(loss_cls_cfg1) loss_cls_cfg2 = dict( - type='CrossEntropyLoss', + type="CrossEntropyLoss", use_sigmoid=use_sigmoid, reduction=reduction, loss_weight=1.0, - avg_non_ignore=False) + avg_non_ignore=False, + ) loss_cls2 = build_loss(loss_cls_cfg2) assert torch.allclose( loss_cls1(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(), loss_cls2(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(), - atol=1e-4) + atol=1e-4, + ) # test ce/bce loss with ignore index and class weight # in 5-way classification @@ -189,33 +194,36 @@ def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim): labels=fake_label, label_weights=None, target_shape=fake_pred.shape, - ignore_index=-100) + ignore_index=-100, + ) torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( - fake_pred, - fake_label.float(), - reduction='mean', - pos_weight=class_weight) + fake_pred, fake_label.float(), reduction="mean", pos_weight=class_weight + ) else: fake_pred = torch.randn(2, 5, 10).float() # 5-way classification fake_label = torch.randint(0, 5, (2, 10)).long() class_weight = torch.rand(5) class_weight /= class_weight.sum() - torch_loss = torch.nn.functional.cross_entropy( - fake_pred, fake_label, reduction='sum', - weight=class_weight) / fake_label.numel() + torch_loss = ( + torch.nn.functional.cross_entropy( + fake_pred, fake_label, reduction="sum", weight=class_weight + ) + / fake_label.numel() + ) loss_cls_cfg = dict( - type='CrossEntropyLoss', + type="CrossEntropyLoss", use_sigmoid=use_sigmoid, - reduction='mean', + reduction="mean", class_weight=class_weight, loss_weight=1.0, - avg_non_ignore=avg_non_ignore) + avg_non_ignore=avg_non_ignore, + ) loss_cls = build_loss(loss_cls_cfg) # test cross entropy loss has name `loss_ce` - assert loss_cls.loss_name == 'loss_ce' + assert loss_cls.loss_name == "loss_ce" # test avg_non_ignore is in extra_repr - assert loss_cls.extra_repr() == f'avg_non_ignore={avg_non_ignore}' + assert loss_cls.extra_repr() == f"avg_non_ignore={avg_non_ignore}" loss = loss_cls(fake_pred, fake_label) assert torch.allclose(loss, torch_loss) @@ -229,33 +237,46 @@ def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim): fake_pred[fake_label != 10], fake_label[fake_label != 10].float(), pos_weight=class_weight[fake_label != 10], - reduction='mean') + reduction="mean", + ) else: - torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( - fake_pred[fake_label != 10], - fake_label[fake_label != 10].float(), - pos_weight=class_weight[fake_label != 10], - reduction='sum') / fake_label.numel() + torch_loss = ( + torch.nn.functional.binary_cross_entropy_with_logits( + fake_pred[fake_label != 10], + fake_label[fake_label != 10].float(), + pos_weight=class_weight[fake_label != 10], + reduction="sum", + ) + / fake_label.numel() + ) else: if avg_non_ignore: - torch_loss = torch.nn.functional.cross_entropy( - fake_pred, - fake_label, - ignore_index=10, - reduction='sum', - weight=class_weight) / fake_label[fake_label != 10].numel() + torch_loss = ( + torch.nn.functional.cross_entropy( + fake_pred, + fake_label, + ignore_index=10, + reduction="sum", + weight=class_weight, + ) + / fake_label[fake_label != 10].numel() + ) else: - torch_loss = torch.nn.functional.cross_entropy( - fake_pred, - fake_label, - ignore_index=10, - reduction='sum', - weight=class_weight) / fake_label.numel() + torch_loss = ( + torch.nn.functional.cross_entropy( + fake_pred, + fake_label, + ignore_index=10, + reduction="sum", + weight=class_weight, + ) + / fake_label.numel() + ) assert torch.allclose(loss, torch_loss) -@pytest.mark.parametrize('avg_non_ignore', [True, False]) -@pytest.mark.parametrize('with_weight', [True, False]) +@pytest.mark.parametrize("avg_non_ignore", [True, False]) +@pytest.mark.parametrize("with_weight", [True, False]) def test_binary_class_ce_loss(avg_non_ignore, with_weight): from mmseg.models import build_loss @@ -268,27 +289,29 @@ def test_binary_class_ce_loss(avg_non_ignore, with_weight): torch_loss = torch.nn.functional.binary_cross_entropy_with_logits( fake_pred, fake_label.unsqueeze(1).float(), - reduction='none', - weight=fake_weight.unsqueeze(1).float() if with_weight else None) + reduction="none", + weight=fake_weight.unsqueeze(1).float() if with_weight else None, + ) if avg_non_ignore: eps = torch.finfo(torch.float32).eps avg_factor = valid_mask.sum().item() - torch_loss = (torch_loss * weight.unsqueeze(1)).sum() / ( - avg_factor + eps) + torch_loss = (torch_loss * weight.unsqueeze(1)).sum() / (avg_factor + eps) else: torch_loss = (torch_loss * weight.unsqueeze(1)).mean() loss_cls_cfg = dict( - type='CrossEntropyLoss', + type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0, avg_non_ignore=avg_non_ignore, - reduction='mean', - loss_name='loss_ce') + reduction="mean", + loss_name="loss_ce", + ) loss_cls = build_loss(loss_cls_cfg) loss = loss_cls( fake_pred, fake_label, weight=fake_weight if with_weight else None, - ignore_index=255) + ignore_index=255, + ) assert torch.allclose(loss, torch_loss) diff --git a/mmsegmentation/tests/test_models/test_losses/test_dice_loss.py b/mmsegmentation/tests/test_models/test_losses/test_dice_loss.py index 3936f5d..f8322dc 100644 --- a/mmsegmentation/tests/test_models/test_losses/test_dice_loss.py +++ b/mmsegmentation/tests/test_models/test_losses/test_dice_loss.py @@ -7,12 +7,13 @@ def test_dice_lose(): # test dice loss with loss_type = 'multi_class' loss_cfg = dict( - type='DiceLoss', - reduction='none', + type="DiceLoss", + reduction="none", class_weight=[1.0, 2.0, 3.0], loss_weight=1.0, ignore_index=1, - loss_name='loss_dice') + loss_name="loss_dice", + ) dice_loss = build_loss(loss_cfg) logits = torch.rand(8, 3, 4, 4) labels = (torch.rand(8, 4, 4) * 3).long() @@ -24,42 +25,46 @@ def test_dice_lose(): import mmcv import numpy as np + tmp_file = tempfile.NamedTemporaryFile() - mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file + mmcv.dump([1.0, 2.0, 3.0], f"{tmp_file.name}.pkl", "pkl") # from pkl file loss_cfg = dict( - type='DiceLoss', - reduction='none', - class_weight=f'{tmp_file.name}.pkl', + type="DiceLoss", + reduction="none", + class_weight=f"{tmp_file.name}.pkl", loss_weight=1.0, ignore_index=1, - loss_name='loss_dice') + loss_name="loss_dice", + ) dice_loss = build_loss(loss_cfg) dice_loss(logits, labels, ignore_index=None) - np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file + np.save(f"{tmp_file.name}.npy", np.array([1.0, 2.0, 3.0])) # from npy file loss_cfg = dict( - type='DiceLoss', - reduction='none', - class_weight=f'{tmp_file.name}.pkl', + type="DiceLoss", + reduction="none", + class_weight=f"{tmp_file.name}.pkl", loss_weight=1.0, ignore_index=1, - loss_name='loss_dice') + loss_name="loss_dice", + ) dice_loss = build_loss(loss_cfg) dice_loss(logits, labels, ignore_index=None) tmp_file.close() - os.remove(f'{tmp_file.name}.pkl') - os.remove(f'{tmp_file.name}.npy') + os.remove(f"{tmp_file.name}.pkl") + os.remove(f"{tmp_file.name}.npy") # test dice loss with loss_type = 'binary' loss_cfg = dict( - type='DiceLoss', + type="DiceLoss", smooth=2, exponent=3, - reduction='sum', + reduction="sum", loss_weight=1.0, ignore_index=0, - loss_name='loss_dice') + loss_name="loss_dice", + ) dice_loss = build_loss(loss_cfg) logits = torch.rand(8, 2, 4, 4) labels = (torch.rand(8, 4, 4) * 2).long() @@ -67,12 +72,13 @@ def test_dice_lose(): # test dice loss has name `loss_dice` loss_cfg = dict( - type='DiceLoss', + type="DiceLoss", smooth=2, exponent=3, - reduction='sum', + reduction="sum", loss_weight=1.0, ignore_index=0, - loss_name='loss_dice') + loss_name="loss_dice", + ) dice_loss = build_loss(loss_cfg) - assert dice_loss.loss_name == 'loss_dice' + assert dice_loss.loss_name == "loss_dice" diff --git a/mmsegmentation/tests/test_models/test_losses/test_focal_loss.py b/mmsegmentation/tests/test_models/test_losses/test_focal_loss.py index 687312b..9e75542 100644 --- a/mmsegmentation/tests/test_models/test_losses/test_focal_loss.py +++ b/mmsegmentation/tests/test_models/test_losses/test_focal_loss.py @@ -10,12 +10,12 @@ def test_use_sigmoid(): # can't init with use_sigmoid=True with pytest.raises(AssertionError): - loss_cfg = dict(type='FocalLoss', use_sigmoid=False) + loss_cfg = dict(type="FocalLoss", use_sigmoid=False) build_loss(loss_cfg) # can't forward with use_sigmoid=True with pytest.raises(NotImplementedError): - loss_cfg = dict(type='FocalLoss', use_sigmoid=True) + loss_cfg = dict(type="FocalLoss", use_sigmoid=True) focal_loss = build_loss(loss_cfg) focal_loss.use_sigmoid = False fake_pred = torch.rand(3, 4, 5, 6) @@ -27,70 +27,71 @@ def test_use_sigmoid(): def test_wrong_reduction_type(): # can't init with wrong reduction with pytest.raises(AssertionError): - loss_cfg = dict(type='FocalLoss', reduction='test') + loss_cfg = dict(type="FocalLoss", reduction="test") build_loss(loss_cfg) # can't forward with wrong reduction override with pytest.raises(AssertionError): - loss_cfg = dict(type='FocalLoss') + loss_cfg = dict(type="FocalLoss") focal_loss = build_loss(loss_cfg) fake_pred = torch.rand(3, 4, 5, 6) fake_target = torch.randint(0, 4, (3, 5, 6)) - focal_loss(fake_pred, fake_target, reduction_override='test') + focal_loss(fake_pred, fake_target, reduction_override="test") # test focal loss can handle input parameters with # unacceptable types def test_unacceptable_parameters(): with pytest.raises(AssertionError): - loss_cfg = dict(type='FocalLoss', gamma='test') + loss_cfg = dict(type="FocalLoss", gamma="test") build_loss(loss_cfg) with pytest.raises(AssertionError): - loss_cfg = dict(type='FocalLoss', alpha='test') + loss_cfg = dict(type="FocalLoss", alpha="test") build_loss(loss_cfg) with pytest.raises(AssertionError): - loss_cfg = dict(type='FocalLoss', class_weight='test') + loss_cfg = dict(type="FocalLoss", class_weight="test") build_loss(loss_cfg) with pytest.raises(AssertionError): - loss_cfg = dict(type='FocalLoss', loss_weight='test') + loss_cfg = dict(type="FocalLoss", loss_weight="test") build_loss(loss_cfg) with pytest.raises(AssertionError): - loss_cfg = dict(type='FocalLoss', loss_name=123) + loss_cfg = dict(type="FocalLoss", loss_name=123) build_loss(loss_cfg) # test if focal loss can be correctly initialize def test_init_focal_loss(): loss_cfg = dict( - type='FocalLoss', + type="FocalLoss", use_sigmoid=True, gamma=3.0, alpha=3.0, class_weight=[1, 2, 3, 4], - reduction='sum') + reduction="sum", + ) focal_loss = build_loss(loss_cfg) assert focal_loss.use_sigmoid is True assert focal_loss.gamma == 3.0 assert focal_loss.alpha == 3.0 - assert focal_loss.reduction == 'sum' + assert focal_loss.reduction == "sum" assert focal_loss.class_weight == [1, 2, 3, 4] assert focal_loss.loss_weight == 1.0 - assert focal_loss.loss_name == 'loss_focal' + assert focal_loss.loss_name == "loss_focal" # test reduction override def test_reduction_override(): - loss_cfg = dict(type='FocalLoss', reduction='mean') + loss_cfg = dict(type="FocalLoss", reduction="mean") focal_loss = build_loss(loss_cfg) fake_pred = torch.rand(3, 4, 5, 6) fake_target = torch.randint(0, 4, (3, 5, 6)) - loss = focal_loss(fake_pred, fake_target, reduction_override='none') + loss = focal_loss(fake_pred, fake_target, reduction_override="none") assert loss.shape == fake_pred.shape # test wrong pred and target shape def test_wrong_pred_and_target_shape(): - loss_cfg = dict(type='FocalLoss') + loss_cfg = dict(type="FocalLoss") focal_loss = build_loss(loss_cfg) fake_pred = torch.rand(3, 4, 5, 6) fake_target = torch.randint(0, 4, (3, 2, 2)) @@ -102,7 +103,7 @@ def test_wrong_pred_and_target_shape(): # test forward with different shape of target def test_forward_with_different_shape_of_target(): - loss_cfg = dict(type='FocalLoss') + loss_cfg = dict(type="FocalLoss") focal_loss = build_loss(loss_cfg) fake_pred = torch.rand(3, 4, 5, 6) @@ -117,7 +118,7 @@ def test_forward_with_different_shape_of_target(): # test forward with weight def test_forward_with_weight(): - loss_cfg = dict(type='FocalLoss') + loss_cfg = dict(type="FocalLoss") focal_loss = build_loss(loss_cfg) fake_pred = torch.rand(3, 4, 5, 6) fake_target = torch.randint(0, 4, (3, 5, 6)) @@ -134,7 +135,7 @@ def test_forward_with_weight(): # test none reduction type def test_none_reduction_type(): - loss_cfg = dict(type='FocalLoss', reduction='none') + loss_cfg = dict(type="FocalLoss", reduction="none") focal_loss = build_loss(loss_cfg) fake_pred = torch.rand(3, 4, 5, 6) fake_target = torch.randint(0, 4, (3, 5, 6)) @@ -145,8 +146,9 @@ def test_none_reduction_type(): # test the usage of class weight def test_class_weight(): loss_cfg_cw = dict( - type='FocalLoss', reduction='none', class_weight=[1.0, 2.0, 3.0, 4.0]) - loss_cfg = dict(type='FocalLoss', reduction='none') + type="FocalLoss", reduction="none", class_weight=[1.0, 2.0, 3.0, 4.0] + ) + loss_cfg = dict(type="FocalLoss", reduction="none") focal_loss_cw = build_loss(loss_cfg_cw) focal_loss = build_loss(loss_cfg) fake_pred = torch.rand(3, 4, 5, 6) @@ -159,14 +161,14 @@ def test_class_weight(): # test ignore index def test_ignore_index(): - loss_cfg = dict(type='FocalLoss', reduction='none') + loss_cfg = dict(type="FocalLoss", reduction="none") # ignore_index within C classes focal_loss = build_loss(loss_cfg) fake_pred = torch.rand(3, 5, 5, 6) fake_target = torch.randint(0, 4, (3, 5, 6)) - dim1 = torch.randint(0, 3, (4, )) - dim2 = torch.randint(0, 5, (4, )) - dim3 = torch.randint(0, 6, (4, )) + dim1 = torch.randint(0, 3, (4,)) + dim2 = torch.randint(0, 5, (4,)) + dim3 = torch.randint(0, 6, (4,)) fake_target[dim1, dim2, dim3] = 4 loss1 = focal_loss(fake_pred, fake_target, ignore_index=4) one_hot_target = F.one_hot(fake_target, num_classes=5) @@ -190,9 +192,9 @@ def test_ignore_index(): # ignore index is not in prediction's classes fake_pred = torch.rand(3, 4, 5, 6) fake_target = torch.randint(0, 4, (3, 5, 6)) - dim1 = torch.randint(0, 3, (4, )) - dim2 = torch.randint(0, 5, (4, )) - dim3 = torch.randint(0, 6, (4, )) + dim1 = torch.randint(0, 3, (4,)) + dim2 = torch.randint(0, 5, (4,)) + dim3 = torch.randint(0, 6, (4,)) fake_target[dim1, dim2, dim3] = 255 loss1 = focal_loss(fake_pred, fake_target, ignore_index=255) assert (loss1[dim1, :, dim2, dim3] == 0).all() @@ -200,7 +202,7 @@ def test_ignore_index(): # test list alpha def test_alpha(): - loss_cfg = dict(type='FocalLoss') + loss_cfg = dict(type="FocalLoss") focal_loss = build_loss(loss_cfg) alpha_float = 0.4 alpha = [0.4, 0.4, 0.4, 0.4] diff --git a/mmsegmentation/tests/test_models/test_losses/test_lovasz_loss.py b/mmsegmentation/tests/test_models/test_losses/test_lovasz_loss.py index bea3f4b..ac78bfc 100644 --- a/mmsegmentation/tests/test_models/test_losses/test_lovasz_loss.py +++ b/mmsegmentation/tests/test_models/test_losses/test_lovasz_loss.py @@ -9,27 +9,25 @@ def test_lovasz_loss(): # loss_type should be 'binary' or 'multi_class' with pytest.raises(AssertionError): loss_cfg = dict( - type='LovaszLoss', - loss_type='Binary', - reduction='none', + type="LovaszLoss", + loss_type="Binary", + reduction="none", loss_weight=1.0, - loss_name='loss_lovasz') + loss_name="loss_lovasz", + ) build_loss(loss_cfg) # reduction should be 'none' when per_image is False. with pytest.raises(AssertionError): loss_cfg = dict( - type='LovaszLoss', - loss_type='multi_class', - loss_name='loss_lovasz') + type="LovaszLoss", loss_type="multi_class", loss_name="loss_lovasz" + ) build_loss(loss_cfg) # test lovasz loss with loss_type = 'multi_class' and per_image = False loss_cfg = dict( - type='LovaszLoss', - reduction='none', - loss_weight=1.0, - loss_name='loss_lovasz') + type="LovaszLoss", reduction="none", loss_weight=1.0, loss_name="loss_lovasz" + ) lovasz_loss = build_loss(loss_cfg) logits = torch.rand(1, 3, 4, 4) labels = (torch.rand(1, 4, 4) * 2).long() @@ -37,12 +35,13 @@ def test_lovasz_loss(): # test lovasz loss with loss_type = 'multi_class' and per_image = True loss_cfg = dict( - type='LovaszLoss', + type="LovaszLoss", per_image=True, - reduction='mean', + reduction="mean", class_weight=[1.0, 2.0, 3.0], loss_weight=1.0, - loss_name='loss_lovasz') + loss_name="loss_lovasz", + ) lovasz_loss = build_loss(loss_cfg) logits = torch.rand(1, 3, 4, 4) labels = (torch.rand(1, 4, 4) * 2).long() @@ -54,40 +53,44 @@ def test_lovasz_loss(): import mmcv import numpy as np + tmp_file = tempfile.NamedTemporaryFile() - mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file + mmcv.dump([1.0, 2.0, 3.0], f"{tmp_file.name}.pkl", "pkl") # from pkl file loss_cfg = dict( - type='LovaszLoss', + type="LovaszLoss", per_image=True, - reduction='mean', - class_weight=f'{tmp_file.name}.pkl', + reduction="mean", + class_weight=f"{tmp_file.name}.pkl", loss_weight=1.0, - loss_name='loss_lovasz') + loss_name="loss_lovasz", + ) lovasz_loss = build_loss(loss_cfg) lovasz_loss(logits, labels, ignore_index=None) - np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file + np.save(f"{tmp_file.name}.npy", np.array([1.0, 2.0, 3.0])) # from npy file loss_cfg = dict( - type='LovaszLoss', + type="LovaszLoss", per_image=True, - reduction='mean', - class_weight=f'{tmp_file.name}.npy', + reduction="mean", + class_weight=f"{tmp_file.name}.npy", loss_weight=1.0, - loss_name='loss_lovasz') + loss_name="loss_lovasz", + ) lovasz_loss = build_loss(loss_cfg) lovasz_loss(logits, labels, ignore_index=None) tmp_file.close() - os.remove(f'{tmp_file.name}.pkl') - os.remove(f'{tmp_file.name}.npy') + os.remove(f"{tmp_file.name}.pkl") + os.remove(f"{tmp_file.name}.npy") # test lovasz loss with loss_type = 'binary' and per_image = False loss_cfg = dict( - type='LovaszLoss', - loss_type='binary', - reduction='none', + type="LovaszLoss", + loss_type="binary", + reduction="none", loss_weight=1.0, - loss_name='loss_lovasz') + loss_name="loss_lovasz", + ) lovasz_loss = build_loss(loss_cfg) logits = torch.rand(2, 4, 4) labels = (torch.rand(2, 4, 4)).long() @@ -95,12 +98,13 @@ def test_lovasz_loss(): # test lovasz loss with loss_type = 'binary' and per_image = True loss_cfg = dict( - type='LovaszLoss', - loss_type='binary', + type="LovaszLoss", + loss_type="binary", per_image=True, - reduction='mean', + reduction="mean", loss_weight=1.0, - loss_name='loss_lovasz') + loss_name="loss_lovasz", + ) lovasz_loss = build_loss(loss_cfg) logits = torch.rand(2, 4, 4) labels = (torch.rand(2, 4, 4)).long() @@ -108,11 +112,12 @@ def test_lovasz_loss(): # test lovasz loss has name `loss_lovasz` loss_cfg = dict( - type='LovaszLoss', - loss_type='binary', + type="LovaszLoss", + loss_type="binary", per_image=True, - reduction='mean', + reduction="mean", loss_weight=1.0, - loss_name='loss_lovasz') + loss_name="loss_lovasz", + ) lovasz_loss = build_loss(loss_cfg) - assert lovasz_loss.loss_name == 'loss_lovasz' + assert lovasz_loss.loss_name == "loss_lovasz" diff --git a/mmsegmentation/tests/test_models/test_losses/test_tversky_loss.py b/mmsegmentation/tests/test_models/test_losses/test_tversky_loss.py index 24a4b57..7fb6c8b 100644 --- a/mmsegmentation/tests/test_models/test_losses/test_tversky_loss.py +++ b/mmsegmentation/tests/test_models/test_losses/test_tversky_loss.py @@ -9,12 +9,13 @@ def test_tversky_lose(): # test alpha + beta != 1 with pytest.raises(AssertionError): loss_cfg = dict( - type='TverskyLoss', + type="TverskyLoss", class_weight=[1.0, 2.0, 3.0], loss_weight=1.0, alpha=0.4, beta=0.7, - loss_name='loss_tversky') + loss_name="loss_tversky", + ) tversky_loss = build_loss(loss_cfg) logits = torch.rand(8, 3, 4, 4) labels = (torch.rand(8, 4, 4) * 3).long() @@ -22,11 +23,12 @@ def test_tversky_lose(): # test tversky loss loss_cfg = dict( - type='TverskyLoss', + type="TverskyLoss", class_weight=[1.0, 2.0, 3.0], loss_weight=1.0, ignore_index=1, - loss_name='loss_tversky') + loss_name="loss_tversky", + ) tversky_loss = build_loss(loss_cfg) logits = torch.rand(8, 3, 4, 4) labels = (torch.rand(8, 4, 4) * 3).long() @@ -38,39 +40,43 @@ def test_tversky_lose(): import mmcv import numpy as np + tmp_file = tempfile.NamedTemporaryFile() - mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file + mmcv.dump([1.0, 2.0, 3.0], f"{tmp_file.name}.pkl", "pkl") # from pkl file loss_cfg = dict( - type='TverskyLoss', - class_weight=f'{tmp_file.name}.pkl', + type="TverskyLoss", + class_weight=f"{tmp_file.name}.pkl", loss_weight=1.0, ignore_index=1, - loss_name='loss_tversky') + loss_name="loss_tversky", + ) tversky_loss = build_loss(loss_cfg) tversky_loss(logits, labels) - np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file + np.save(f"{tmp_file.name}.npy", np.array([1.0, 2.0, 3.0])) # from npy file loss_cfg = dict( - type='TverskyLoss', - class_weight=f'{tmp_file.name}.pkl', + type="TverskyLoss", + class_weight=f"{tmp_file.name}.pkl", loss_weight=1.0, ignore_index=1, - loss_name='loss_tversky') + loss_name="loss_tversky", + ) tversky_loss = build_loss(loss_cfg) tversky_loss(logits, labels) tmp_file.close() - os.remove(f'{tmp_file.name}.pkl') - os.remove(f'{tmp_file.name}.npy') + os.remove(f"{tmp_file.name}.pkl") + os.remove(f"{tmp_file.name}.npy") # test tversky loss has name `loss_tversky` loss_cfg = dict( - type='TverskyLoss', + type="TverskyLoss", smooth=2, loss_weight=1.0, ignore_index=1, alpha=0.3, beta=0.7, - loss_name='loss_tversky') + loss_name="loss_tversky", + ) tversky_loss = build_loss(loss_cfg) - assert tversky_loss.loss_name == 'loss_tversky' + assert tversky_loss.loss_name == "loss_tversky" diff --git a/mmsegmentation/tests/test_models/test_losses/test_utils.py b/mmsegmentation/tests/test_models/test_losses/test_utils.py index ab9927f..1bae2ac 100644 --- a/mmsegmentation/tests/test_models/test_losses/test_utils.py +++ b/mmsegmentation/tests/test_models/test_losses/test_utils.py @@ -12,33 +12,33 @@ def test_weight_reduce_loss(): weight[:, :, :2, :2] = 1 # test reduce_loss() - reduced = reduce_loss(loss, 'none') + reduced = reduce_loss(loss, "none") assert reduced is loss - reduced = reduce_loss(loss, 'mean') + reduced = reduce_loss(loss, "mean") np.testing.assert_almost_equal(reduced.numpy(), loss.mean()) - reduced = reduce_loss(loss, 'sum') + reduced = reduce_loss(loss, "sum") np.testing.assert_almost_equal(reduced.numpy(), loss.sum()) # test weight_reduce_loss() - reduced = weight_reduce_loss(loss, weight=None, reduction='none') + reduced = weight_reduce_loss(loss, weight=None, reduction="none") assert reduced is loss - reduced = weight_reduce_loss(loss, weight=weight, reduction='mean') + reduced = weight_reduce_loss(loss, weight=weight, reduction="mean") target = (loss * weight).mean() np.testing.assert_almost_equal(reduced.numpy(), target) - reduced = weight_reduce_loss(loss, weight=weight, reduction='sum') + reduced = weight_reduce_loss(loss, weight=weight, reduction="sum") np.testing.assert_almost_equal(reduced.numpy(), (loss * weight).sum()) with pytest.raises(AssertionError): weight_wrong = weight[0, 0, ...] - weight_reduce_loss(loss, weight=weight_wrong, reduction='mean') + weight_reduce_loss(loss, weight=weight_wrong, reduction="mean") with pytest.raises(AssertionError): weight_wrong = weight[:, 0:2, ...] - weight_reduce_loss(loss, weight=weight_wrong, reduction='mean') + weight_reduce_loss(loss, weight=weight_wrong, reduction="mean") def test_accuracy(): @@ -49,9 +49,15 @@ def test_accuracy(): acc = accuracy(pred, label) assert acc.item() == 0 - pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6], - [0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1], - [0.0, 0.0, 0.99, 0]]) + pred = torch.Tensor( + [ + [0.2, 0.3, 0.6, 0.5], + [0.1, 0.1, 0.2, 0.6], + [0.9, 0.0, 0.0, 0.1], + [0.4, 0.7, 0.1, 0.1], + [0.0, 0.0, 0.99, 0], + ] + ) # test for ignore_index true_label = torch.Tensor([2, 3, 0, 1, 2]).long() accuracy = Accuracy(topk=1, ignore_index=None) @@ -114,7 +120,7 @@ def test_accuracy(): # wrong topk type with pytest.raises(AssertionError): - accuracy = Accuracy(topk='wrong type') + accuracy = Accuracy(topk="wrong type") accuracy(pred, true_label) # label size is larger than required diff --git a/mmsegmentation/tests/test_models/test_necks/test_feature2pyramid.py b/mmsegmentation/tests/test_models/test_necks/test_feature2pyramid.py index 44fd02c..5cb17e0 100644 --- a/mmsegmentation/tests/test_models/test_necks/test_feature2pyramid.py +++ b/mmsegmentation/tests/test_models/test_necks/test_feature2pyramid.py @@ -12,7 +12,8 @@ def test_feature2pyramid(): inputs = [torch.randn(1, embed_dim, 32, 32) for i in range(len(rescales))] fpn = Feature2Pyramid( - embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True)) + embed_dim, rescales, norm_cfg=dict(type="BN", requires_grad=True) + ) outputs = fpn(inputs) assert outputs[0].shape == torch.Size([1, 64, 128, 128]) assert outputs[1].shape == torch.Size([1, 64, 64, 64]) @@ -24,7 +25,8 @@ def test_feature2pyramid(): inputs = [torch.randn(1, embed_dim, 32, 32) for i in range(len(rescales))] fpn = Feature2Pyramid( - embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True)) + embed_dim, rescales, norm_cfg=dict(type="BN", requires_grad=True) + ) outputs = fpn(inputs) assert outputs[0].shape == torch.Size([1, 64, 64, 64]) assert outputs[1].shape == torch.Size([1, 64, 32, 32]) @@ -35,4 +37,5 @@ def test_feature2pyramid(): rescales = [4, 2, 0.25, 0] with pytest.raises(KeyError): fpn = Feature2Pyramid( - embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True)) + embed_dim, rescales, norm_cfg=dict(type="BN", requires_grad=True) + ) diff --git a/mmsegmentation/tests/test_models/test_necks/test_fpn.py b/mmsegmentation/tests/test_models/test_necks/test_fpn.py index c294006..d2cf409 100644 --- a/mmsegmentation/tests/test_models/test_necks/test_fpn.py +++ b/mmsegmentation/tests/test_models/test_necks/test_fpn.py @@ -7,8 +7,7 @@ def test_fpn(): in_channels = [64, 128, 256, 512] inputs = [ - torch.randn(1, c, 56 // 2**i, 56 // 2**i) - for i, c in enumerate(in_channels) + torch.randn(1, c, 56 // 2**i, 56 // 2**i) for i, c in enumerate(in_channels) ] fpn = FPN(in_channels, 64, len(in_channels)) @@ -22,7 +21,8 @@ def test_fpn(): in_channels, 64, len(in_channels), - upsample_cfg=dict(mode='nearest', scale_factor=2.0)) + upsample_cfg=dict(mode="nearest", scale_factor=2.0), + ) outputs = fpn(inputs) assert outputs[0].shape == torch.Size([1, 64, 56, 56]) assert outputs[1].shape == torch.Size([1, 64, 28, 28]) diff --git a/mmsegmentation/tests/test_models/test_necks/test_ic_neck.py b/mmsegmentation/tests/test_models/test_necks/test_ic_neck.py index 3d13008..683c871 100644 --- a/mmsegmentation/tests/test_models/test_necks/test_ic_neck.py +++ b/mmsegmentation/tests/test_models/test_necks/test_ic_neck.py @@ -12,20 +12,22 @@ def test_ic_neck(): neck = ICNeck( in_channels=(4, 16, 16), out_channels=8, - norm_cfg=dict(type='SyncBN'), - align_corners=False) + norm_cfg=dict(type="SyncBN"), + align_corners=False, + ) assert _conv_has_norm(neck, sync_bn=True) inputs = [ torch.randn(1, 4, 32, 64), torch.randn(1, 16, 16, 32), - torch.randn(1, 16, 8, 16) + torch.randn(1, 16, 8, 16), ] neck = ICNeck( in_channels=(4, 16, 16), out_channels=4, - norm_cfg=dict(type='BN', requires_grad=True), - align_corners=False) + norm_cfg=dict(type="BN", requires_grad=True), + align_corners=False, + ) if torch.cuda.is_available(): neck, inputs = to_cuda(neck, inputs) @@ -49,5 +51,6 @@ def test_ic_neck_input_channels(): ICNeck( in_channels=(16, 64, 64, 64), out_channels=32, - norm_cfg=dict(type='BN', requires_grad=True), - align_corners=False) + norm_cfg=dict(type="BN", requires_grad=True), + align_corners=False, + ) diff --git a/mmsegmentation/tests/test_models/test_necks/test_jpu.py b/mmsegmentation/tests/test_models/test_necks/test_jpu.py index 4c3fa9f..aefa9d6 100644 --- a/mmsegmentation/tests/test_models/test_necks/test_jpu.py +++ b/mmsegmentation/tests/test_models/test_necks/test_jpu.py @@ -20,7 +20,7 @@ def test_fastfcn_neck(): input = [ torch.randn(batch_size, 64, 64, 128), torch.randn(batch_size, 128, 32, 64), - torch.randn(batch_size, 256, 16, 32) + torch.randn(batch_size, 256, 16, 32), ] feat = model(input) @@ -38,7 +38,7 @@ def test_fastfcn_neck(): input = [ torch.randn(batch_size, 64, 64, 128), torch.randn(batch_size, 128, 32, 64), - torch.randn(batch_size, 256, 16, 32) + torch.randn(batch_size, 256, 16, 32), ] feat = model(input) assert len(feat) == 2 diff --git a/mmsegmentation/tests/test_models/test_necks/test_multilevel_neck.py b/mmsegmentation/tests/test_models/test_necks/test_multilevel_neck.py index 9c71d51..15d2ef3 100644 --- a/mmsegmentation/tests/test_models/test_necks/test_multilevel_neck.py +++ b/mmsegmentation/tests/test_models/test_necks/test_multilevel_neck.py @@ -5,7 +5,6 @@ def test_multilevel_neck(): - # Test init_weights MultiLevelNeck([266], 32).init_weights() diff --git a/mmsegmentation/tests/test_models/test_segmentors/test_cascade_encoder_decoder.py b/mmsegmentation/tests/test_models/test_segmentors/test_cascade_encoder_decoder.py index 07ad5c3..94ae20a 100644 --- a/mmsegmentation/tests/test_models/test_segmentors/test_cascade_encoder_decoder.py +++ b/mmsegmentation/tests/test_models/test_segmentors/test_cascade_encoder_decoder.py @@ -6,52 +6,51 @@ def test_cascade_encoder_decoder(): - # test 1 decode head, w.o. aux head cfg = ConfigDict( - type='CascadeEncoderDecoder', + type="CascadeEncoderDecoder", num_stages=2, - backbone=dict(type='ExampleBackbone'), + backbone=dict(type="ExampleBackbone"), decode_head=[ - dict(type='ExampleDecodeHead'), - dict(type='ExampleCascadeDecodeHead') - ]) - cfg.test_cfg = ConfigDict(mode='whole') + dict(type="ExampleDecodeHead"), + dict(type="ExampleCascadeDecodeHead"), + ], + ) + cfg.test_cfg = ConfigDict(mode="whole") segmentor = build_segmentor(cfg) _segmentor_forward_train_test(segmentor) # test slide mode - cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2)) + cfg.test_cfg = ConfigDict(mode="slide", crop_size=(3, 3), stride=(2, 2)) segmentor = build_segmentor(cfg) _segmentor_forward_train_test(segmentor) # test 1 decode head, 1 aux head cfg = ConfigDict( - type='CascadeEncoderDecoder', + type="CascadeEncoderDecoder", num_stages=2, - backbone=dict(type='ExampleBackbone'), + backbone=dict(type="ExampleBackbone"), decode_head=[ - dict(type='ExampleDecodeHead'), - dict(type='ExampleCascadeDecodeHead') + dict(type="ExampleDecodeHead"), + dict(type="ExampleCascadeDecodeHead"), ], - auxiliary_head=dict(type='ExampleDecodeHead')) - cfg.test_cfg = ConfigDict(mode='whole') + auxiliary_head=dict(type="ExampleDecodeHead"), + ) + cfg.test_cfg = ConfigDict(mode="whole") segmentor = build_segmentor(cfg) _segmentor_forward_train_test(segmentor) # test 1 decode head, 2 aux head cfg = ConfigDict( - type='CascadeEncoderDecoder', + type="CascadeEncoderDecoder", num_stages=2, - backbone=dict(type='ExampleBackbone'), + backbone=dict(type="ExampleBackbone"), decode_head=[ - dict(type='ExampleDecodeHead'), - dict(type='ExampleCascadeDecodeHead') + dict(type="ExampleDecodeHead"), + dict(type="ExampleCascadeDecodeHead"), ], - auxiliary_head=[ - dict(type='ExampleDecodeHead'), - dict(type='ExampleDecodeHead') - ]) - cfg.test_cfg = ConfigDict(mode='whole') + auxiliary_head=[dict(type="ExampleDecodeHead"), dict(type="ExampleDecodeHead")], + ) + cfg.test_cfg = ConfigDict(mode="whole") segmentor = build_segmentor(cfg) _segmentor_forward_train_test(segmentor) diff --git a/mmsegmentation/tests/test_models/test_segmentors/test_encoder_decoder.py b/mmsegmentation/tests/test_models/test_segmentors/test_encoder_decoder.py index 2739b58..1b8e20b 100644 --- a/mmsegmentation/tests/test_models/test_segmentors/test_encoder_decoder.py +++ b/mmsegmentation/tests/test_models/test_segmentors/test_encoder_decoder.py @@ -6,15 +6,15 @@ def test_encoder_decoder(): - # test 1 decode head, w.o. aux head cfg = ConfigDict( - type='EncoderDecoder', - backbone=dict(type='ExampleBackbone'), - decode_head=dict(type='ExampleDecodeHead'), + type="EncoderDecoder", + backbone=dict(type="ExampleBackbone"), + decode_head=dict(type="ExampleDecodeHead"), train_cfg=None, - test_cfg=dict(mode='whole')) + test_cfg=dict(mode="whole"), + ) segmentor = build_segmentor(cfg) _segmentor_forward_train_test(segmentor) @@ -25,29 +25,28 @@ def test_encoder_decoder(): _segmentor_forward_train_test(segmentor) # test slide mode - cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2)) + cfg.test_cfg = ConfigDict(mode="slide", crop_size=(3, 3), stride=(2, 2)) segmentor = build_segmentor(cfg) _segmentor_forward_train_test(segmentor) # test 1 decode head, 1 aux head cfg = ConfigDict( - type='EncoderDecoder', - backbone=dict(type='ExampleBackbone'), - decode_head=dict(type='ExampleDecodeHead'), - auxiliary_head=dict(type='ExampleDecodeHead')) - cfg.test_cfg = ConfigDict(mode='whole') + type="EncoderDecoder", + backbone=dict(type="ExampleBackbone"), + decode_head=dict(type="ExampleDecodeHead"), + auxiliary_head=dict(type="ExampleDecodeHead"), + ) + cfg.test_cfg = ConfigDict(mode="whole") segmentor = build_segmentor(cfg) _segmentor_forward_train_test(segmentor) # test 1 decode head, 2 aux head cfg = ConfigDict( - type='EncoderDecoder', - backbone=dict(type='ExampleBackbone'), - decode_head=dict(type='ExampleDecodeHead'), - auxiliary_head=[ - dict(type='ExampleDecodeHead'), - dict(type='ExampleDecodeHead') - ]) - cfg.test_cfg = ConfigDict(mode='whole') + type="EncoderDecoder", + backbone=dict(type="ExampleBackbone"), + decode_head=dict(type="ExampleDecodeHead"), + auxiliary_head=[dict(type="ExampleDecodeHead"), dict(type="ExampleDecodeHead")], + ) + cfg.test_cfg = ConfigDict(mode="whole") segmentor = build_segmentor(cfg) _segmentor_forward_train_test(segmentor) diff --git a/mmsegmentation/tests/test_models/test_segmentors/utils.py b/mmsegmentation/tests/test_models/test_segmentors/utils.py index 1826dbf..9cd3de5 100644 --- a/mmsegmentation/tests/test_models/test_segmentors/utils.py +++ b/mmsegmentation/tests/test_models/test_segmentors/utils.py @@ -23,32 +23,33 @@ def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10): rng = np.random.RandomState(0) imgs = rng.rand(*input_shape) - segs = rng.randint( - low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) - - img_metas = [{ - 'img_shape': (H, W, C), - 'ori_shape': (H, W, C), - 'pad_shape': (H, W, C), - 'filename': '.png', - 'scale_factor': 1.0, - 'flip': False, - 'flip_direction': 'horizontal' - } for _ in range(N)] + segs = rng.randint(low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) + + img_metas = [ + { + "img_shape": (H, W, C), + "ori_shape": (H, W, C), + "pad_shape": (H, W, C), + "filename": ".png", + "scale_factor": 1.0, + "flip": False, + "flip_direction": "horizontal", + } + for _ in range(N) + ] mm_inputs = { - 'imgs': torch.FloatTensor(imgs), - 'img_metas': img_metas, - 'gt_semantic_seg': torch.LongTensor(segs) + "imgs": torch.FloatTensor(imgs), + "img_metas": img_metas, + "gt_semantic_seg": torch.LongTensor(segs), } return mm_inputs @BACKBONES.register_module() class ExampleBackbone(nn.Module): - def __init__(self): - super(ExampleBackbone, self).__init__() + super().__init__() self.conv = nn.Conv2d(3, 3, 3) def init_weights(self, pretrained=None): @@ -60,9 +61,8 @@ def forward(self, x): @HEADS.register_module() class ExampleDecodeHead(BaseDecodeHead): - def __init__(self): - super(ExampleDecodeHead, self).__init__(3, 3, num_classes=19) + super().__init__(3, 3, num_classes=19) def forward(self, inputs): return self.cls_seg(inputs[0]) @@ -70,9 +70,8 @@ def forward(self, inputs): @HEADS.register_module() class ExampleCascadeDecodeHead(BaseCascadeDecodeHead): - def __init__(self): - super(ExampleCascadeDecodeHead, self).__init__(3, 3, num_classes=19) + super().__init__(3, 3, num_classes=19) def forward(self, inputs, prev_out): return self.cls_seg(inputs[0]) @@ -86,9 +85,9 @@ def _segmentor_forward_train_test(segmentor): # batch_size=2 for BatchNorm mm_inputs = _demo_mm_inputs(num_classes=num_classes) - imgs = mm_inputs.pop('imgs') - img_metas = mm_inputs.pop('img_metas') - gt_semantic_seg = mm_inputs['gt_semantic_seg'] + imgs = mm_inputs.pop("imgs") + img_metas = mm_inputs.pop("img_metas") + gt_semantic_seg = mm_inputs["gt_semantic_seg"] # convert to cuda Tensor if applicable if torch.cuda.is_available(): @@ -98,28 +97,29 @@ def _segmentor_forward_train_test(segmentor): # Test forward train losses = segmentor.forward( - imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True) + imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True + ) assert isinstance(losses, dict) # Test train_step - data_batch = dict( - img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg) + data_batch = dict(img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg) outputs = segmentor.train_step(data_batch, None) assert isinstance(outputs, dict) - assert 'loss' in outputs - assert 'log_vars' in outputs - assert 'num_samples' in outputs + assert "loss" in outputs + assert "log_vars" in outputs + assert "num_samples" in outputs # Test val_step with torch.no_grad(): segmentor.eval() data_batch = dict( - img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg) + img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg + ) outputs = segmentor.val_step(data_batch, None) assert isinstance(outputs, dict) - assert 'loss' in outputs - assert 'log_vars' in outputs - assert 'num_samples' in outputs + assert "loss" in outputs + assert "log_vars" in outputs + assert "num_samples" in outputs # Test forward simple test with torch.no_grad(): diff --git a/mmsegmentation/tests/test_models/test_utils/test_embed.py b/mmsegmentation/tests/test_models/test_utils/test_embed.py index be20c97..daf7e7c 100644 --- a/mmsegmentation/tests/test_models/test_utils/test_embed.py +++ b/mmsegmentation/tests/test_models/test_utils/test_embed.py @@ -6,17 +6,14 @@ def test_adaptive_padding(): - - for padding in ('same', 'corner'): + for padding in ("same", "corner"): kernel_size = 16 stride = 16 dilation = 1 input = torch.rand(1, 1, 15, 17) adap_pool = AdaptivePadding( - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding) + kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding + ) out = adap_pool(input) # padding to divisible by 16 assert (out.shape[2], out.shape[3]) == (16, 32) @@ -30,10 +27,8 @@ def test_adaptive_padding(): dilation = (1, 1) adap_pad = AdaptivePadding( - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding) + kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding + ) input = torch.rand(1, 1, 11, 13) out = adap_pad(input) # padding to divisible by 2 @@ -44,10 +39,8 @@ def test_adaptive_padding(): dilation = (1, 1) adap_pad = AdaptivePadding( - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding) + kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding + ) input = torch.rand(1, 1, 10, 13) out = adap_pad(input) # no padding @@ -55,10 +48,8 @@ def test_adaptive_padding(): kernel_size = (11, 11) adap_pad = AdaptivePadding( - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding) + kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding + ) input = torch.rand(1, 1, 11, 13) out = adap_pad(input) # all padding @@ -71,19 +62,15 @@ def test_adaptive_padding(): dilation = (2, 2) # actually (7, 9) adap_pad = AdaptivePadding( - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding) + kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding + ) dilation_out = adap_pad(input) assert (dilation_out.shape[2], dilation_out.shape[3]) == (16, 21) kernel_size = (7, 9) dilation = (1, 1) adap_pad = AdaptivePadding( - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding) + kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding + ) kernel79_out = adap_pad(input) assert (kernel79_out.shape[2], kernel79_out.shape[3]) == (16, 21) assert kernel79_out.shape == dilation_out.shape @@ -91,10 +78,8 @@ def test_adaptive_padding(): # assert only support "same" "corner" with pytest.raises(AssertionError): AdaptivePadding( - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=1) + kernel_size=kernel_size, stride=stride, dilation=dilation, padding=1 + ) def test_patch_embed(): @@ -113,7 +98,8 @@ def test_patch_embed(): stride=stride, padding=0, dilation=1, - norm_cfg=None) + norm_cfg=None, + ) x1, shape = patch_merge_1(dummy_input) # test out shape @@ -162,8 +148,9 @@ def test_patch_embed(): stride=stride, padding=0, dilation=2, - norm_cfg=dict(type='LN'), - input_size=input_size) + norm_cfg=dict(type="LN"), + input_size=input_size, + ) x3, shape = patch_merge_3(dummy_input) # test out shape @@ -174,10 +161,8 @@ def test_patch_embed(): assert shape[0] * shape[1] == x3.shape[1] # test the init_out_size with nn.Unfold - assert patch_merge_3.init_out_size[1] == (input_size[0] - 2 * 4 - - 1) // 2 + 1 - assert patch_merge_3.init_out_size[0] == (input_size[0] - 2 * 4 - - 1) // 2 + 1 + assert patch_merge_3.init_out_size[1] == (input_size[0] - 2 * 4 - 1) // 2 + 1 + assert patch_merge_3.init_out_size[0] == (input_size[0] - 2 * 4 - 1) // 2 + 1 H = 11 W = 12 input_size = (H, W) @@ -190,8 +175,9 @@ def test_patch_embed(): stride=stride, padding=0, dilation=2, - norm_cfg=dict(type='LN'), - input_size=input_size) + norm_cfg=dict(type="LN"), + input_size=input_size, + ) _, shape = patch_merge_3(dummy_input) # when input_size equal to real input @@ -208,8 +194,9 @@ def test_patch_embed(): stride=stride, padding=0, dilation=2, - norm_cfg=dict(type='LN'), - input_size=input_size) + norm_cfg=dict(type="LN"), + input_size=input_size, + ) _, shape = patch_merge_3(dummy_input) # when input_size equal to real input @@ -217,7 +204,7 @@ def test_patch_embed(): assert shape == patch_merge_3.init_out_size # test adap padding - for padding in ('same', 'corner'): + for padding in ("same", "corner"): in_c = 2 embed_dims = 3 B = 2 @@ -237,7 +224,8 @@ def test_patch_embed(): stride=stride, padding=padding, dilation=dilation, - bias=bias) + bias=bias, + ) x_out, out_size = patch_embed(x) assert x_out.size() == (B, 25, 3) @@ -259,7 +247,8 @@ def test_patch_embed(): stride=stride, padding=padding, dilation=dilation, - bias=bias) + bias=bias, + ) x_out, out_size = patch_embed(x) assert x_out.size() == (B, 1, 3) @@ -281,7 +270,8 @@ def test_patch_embed(): stride=stride, padding=padding, dilation=dilation, - bias=bias) + bias=bias, + ) x_out, out_size = patch_embed(x) assert x_out.size() == (B, 2, 3) @@ -303,7 +293,8 @@ def test_patch_embed(): stride=stride, padding=padding, dilation=dilation, - bias=bias) + bias=bias, + ) x_out, out_size = patch_embed(x) assert x_out.size() == (B, 3, 3) @@ -312,7 +303,6 @@ def test_patch_embed(): def test_patch_merging(): - # Test the model with int padding in_c = 3 out_c = 4 @@ -329,7 +319,8 @@ def test_patch_merging(): stride=stride, padding=padding, dilation=dilation, - bias=bias) + bias=bias, + ) B, L, C = 1, 100, 3 input_size = (10, 10) x = torch.rand(B, L, C) @@ -352,7 +343,8 @@ def test_patch_merging(): stride=stride, padding=padding, dilation=dilation, - bias=bias) + bias=bias, + ) B, L, C = 1, 100, 4 input_size = (10, 10) x = torch.rand(B, L, C) @@ -363,7 +355,7 @@ def test_patch_merging(): assert x_out.size(1) == out_size[0] * out_size[1] # Test with adaptive padding - for padding in ('same', 'corner'): + for padding in ("same", "corner"): in_c = 2 out_c = 3 B = 2 @@ -384,7 +376,8 @@ def test_patch_merging(): stride=stride, padding=padding, dilation=dilation, - bias=bias) + bias=bias, + ) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (B, 25, 3) @@ -407,7 +400,8 @@ def test_patch_merging(): stride=stride, padding=padding, dilation=dilation, - bias=bias) + bias=bias, + ) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (B, 1, 3) @@ -430,7 +424,8 @@ def test_patch_merging(): stride=stride, padding=padding, dilation=dilation, - bias=bias) + bias=bias, + ) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (B, 2, 3) @@ -453,7 +448,8 @@ def test_patch_merging(): stride=stride, padding=padding, dilation=dilation, - bias=bias) + bias=bias, + ) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (B, 3, 3) diff --git a/mmsegmentation/tests/test_models/test_utils/test_shape_convert.py b/mmsegmentation/tests/test_models/test_utils/test_shape_convert.py index 60e87f3..dead1ef 100644 --- a/mmsegmentation/tests/test_models/test_utils/test_shape_convert.py +++ b/mmsegmentation/tests/test_models/test_utils/test_shape_convert.py @@ -1,8 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmseg.models.utils import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, - nlc_to_nchw) +from mmseg.models.utils import nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, nlc_to_nchw def test_nchw2nlc2nchw(): diff --git a/mmsegmentation/tests/test_sampler.py b/mmsegmentation/tests/test_sampler.py index 1409224..d97eac9 100644 --- a/mmsegmentation/tests/test_sampler.py +++ b/mmsegmentation/tests/test_sampler.py @@ -16,13 +16,13 @@ def _context_for_ohem_multiple_loss(): channels=16, num_classes=19, loss_decode=[ - dict(type='CrossEntropyLoss', loss_name='loss_1'), - dict(type='CrossEntropyLoss', loss_name='loss_2') - ]) + dict(type="CrossEntropyLoss", loss_name="loss_1"), + dict(type="CrossEntropyLoss", loss_name="loss_2"), + ], + ) def test_ohem_sampler(): - with pytest.raises(AssertionError): # seg_logit and seg_label must be of the same size sampler = OHEMPixelSampler(context=_context_for_ohem()) @@ -31,8 +31,7 @@ def test_ohem_sampler(): sampler.sample(seg_logit, seg_label) # test with thresh - sampler = OHEMPixelSampler( - context=_context_for_ohem(), thresh=0.7, min_kept=200) + sampler = OHEMPixelSampler(context=_context_for_ohem(), thresh=0.7, min_kept=200) seg_logit = torch.randn(1, 19, 45, 45) seg_label = torch.randint(0, 19, size=(1, 1, 45, 45)) seg_weight = sampler.sample(seg_logit, seg_label) @@ -59,7 +58,8 @@ def test_ohem_sampler(): # test with thresh in multiple losses case sampler = OHEMPixelSampler( - context=_context_for_ohem_multiple_loss(), thresh=0.7, min_kept=200) + context=_context_for_ohem_multiple_loss(), thresh=0.7, min_kept=200 + ) seg_logit = torch.randn(1, 19, 45, 45) seg_label = torch.randint(0, 19, size=(1, 1, 45, 45)) seg_weight = sampler.sample(seg_logit, seg_label) @@ -68,8 +68,7 @@ def test_ohem_sampler(): assert seg_weight.sum() > 200 # test w.o thresh in multiple losses case - sampler = OHEMPixelSampler( - context=_context_for_ohem_multiple_loss(), min_kept=200) + sampler = OHEMPixelSampler(context=_context_for_ohem_multiple_loss(), min_kept=200) seg_logit = torch.randn(1, 19, 45, 45) seg_label = torch.randint(0, 19, size=(1, 1, 45, 45)) seg_weight = sampler.sample(seg_logit, seg_label) diff --git a/mmsegmentation/tests/test_utils/test_misc.py b/mmsegmentation/tests/test_utils/test_misc.py index 7ce1fa6..29dfa24 100644 --- a/mmsegmentation/tests/test_utils/test_misc.py +++ b/mmsegmentation/tests/test_utils/test_misc.py @@ -13,28 +13,28 @@ def test_find_latest_checkpoint(): assert latest is None # The path doesn't exist - path = osp.join(tempdir, 'none') + path = osp.join(tempdir, "none") latest = find_latest_checkpoint(path) assert latest is None # test when latest.pth exists with tempfile.TemporaryDirectory() as tempdir: - with open(osp.join(tempdir, 'latest.pth'), 'w') as f: - f.write('latest') + with open(osp.join(tempdir, "latest.pth"), "w") as f: + f.write("latest") path = tempdir latest = find_latest_checkpoint(path) - assert latest == osp.join(tempdir, 'latest.pth') + assert latest == osp.join(tempdir, "latest.pth") with tempfile.TemporaryDirectory() as tempdir: for iter in range(1600, 160001, 1600): - with open(osp.join(tempdir, f'iter_{iter}.pth'), 'w') as f: - f.write(f'iter_{iter}.pth') + with open(osp.join(tempdir, f"iter_{iter}.pth"), "w") as f: + f.write(f"iter_{iter}.pth") latest = find_latest_checkpoint(tempdir) - assert latest == osp.join(tempdir, 'iter_160000.pth') + assert latest == osp.join(tempdir, "iter_160000.pth") with tempfile.TemporaryDirectory() as tempdir: for epoch in range(1, 21): - with open(osp.join(tempdir, f'epoch_{epoch}.pth'), 'w') as f: - f.write(f'epoch_{epoch}.pth') + with open(osp.join(tempdir, f"epoch_{epoch}.pth"), "w") as f: + f.write(f"epoch_{epoch}.pth") latest = find_latest_checkpoint(tempdir) - assert latest == osp.join(tempdir, 'epoch_20.pth') + assert latest == osp.join(tempdir, "epoch_20.pth") diff --git a/mmsegmentation/tests/test_utils/test_set_env.py b/mmsegmentation/tests/test_utils/test_set_env.py index 0af4424..af8ac9b 100644 --- a/mmsegmentation/tests/test_utils/test_set_env.py +++ b/mmsegmentation/tests/test_utils/test_set_env.py @@ -10,26 +10,37 @@ from mmseg.utils import setup_multi_processes -@pytest.mark.parametrize('workers_per_gpu', (0, 2)) -@pytest.mark.parametrize(('valid', 'env_cfg'), [(True, - dict( - mp_start_method='fork', - opencv_num_threads=0, - omp_num_threads=1, - mkl_num_threads=1)), - (False, - dict( - mp_start_method=1, - opencv_num_threads=0.1, - omp_num_threads='s', - mkl_num_threads='1'))]) +@pytest.mark.parametrize("workers_per_gpu", (0, 2)) +@pytest.mark.parametrize( + ("valid", "env_cfg"), + [ + ( + True, + dict( + mp_start_method="fork", + opencv_num_threads=0, + omp_num_threads=1, + mkl_num_threads=1, + ), + ), + ( + False, + dict( + mp_start_method=1, + opencv_num_threads=0.1, + omp_num_threads="s", + mkl_num_threads="1", + ), + ), + ], +) def test_setup_multi_processes(workers_per_gpu, valid, env_cfg): # temp save system setting sys_start_mehod = mp.get_start_method(allow_none=True) sys_cv_threads = cv2.getNumThreads() # pop and temp save system env vars - sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None) - sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None) + sys_omp_threads = os.environ.pop("OMP_NUM_THREADS", default=None) + sys_mkl_threads = os.environ.pop("MKL_NUM_THREADS", default=None) config = dict(data=dict(workers_per_gpu=workers_per_gpu)) config.update(env_cfg) @@ -41,45 +52,50 @@ def test_setup_multi_processes(workers_per_gpu, valid, env_cfg): if valid and workers_per_gpu > 0: # test config without setting env - assert os.getenv('OMP_NUM_THREADS') == str(env_cfg['omp_num_threads']) - assert os.getenv('MKL_NUM_THREADS') == str(env_cfg['mkl_num_threads']) + assert os.getenv("OMP_NUM_THREADS") == str(env_cfg["omp_num_threads"]) + assert os.getenv("MKL_NUM_THREADS") == str(env_cfg["mkl_num_threads"]) # when set to 0, the num threads will be 1 - assert cv2.getNumThreads() == env_cfg[ - 'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1 - if platform.system() != 'Windows': - assert mp.get_start_method() == env_cfg['mp_start_method'] + assert ( + cv2.getNumThreads() == env_cfg["opencv_num_threads"] + if env_cfg["opencv_num_threads"] > 0 + else 1 + ) + if platform.system() != "Windows": + assert mp.get_start_method() == env_cfg["mp_start_method"] # revert setting to avoid affecting other programs if sys_start_mehod: mp.set_start_method(sys_start_mehod, force=True) cv2.setNumThreads(sys_cv_threads) if sys_omp_threads: - os.environ['OMP_NUM_THREADS'] = sys_omp_threads + os.environ["OMP_NUM_THREADS"] = sys_omp_threads else: - os.environ.pop('OMP_NUM_THREADS') + os.environ.pop("OMP_NUM_THREADS") if sys_mkl_threads: - os.environ['MKL_NUM_THREADS'] = sys_mkl_threads + os.environ["MKL_NUM_THREADS"] = sys_mkl_threads else: - os.environ.pop('MKL_NUM_THREADS') + os.environ.pop("MKL_NUM_THREADS") elif valid and workers_per_gpu == 0: - - if platform.system() != 'Windows': - assert mp.get_start_method() == env_cfg['mp_start_method'] - assert cv2.getNumThreads() == env_cfg[ - 'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1 - assert 'OMP_NUM_THREADS' not in os.environ - assert 'MKL_NUM_THREADS' not in os.environ + if platform.system() != "Windows": + assert mp.get_start_method() == env_cfg["mp_start_method"] + assert ( + cv2.getNumThreads() == env_cfg["opencv_num_threads"] + if env_cfg["opencv_num_threads"] > 0 + else 1 + ) + assert "OMP_NUM_THREADS" not in os.environ + assert "MKL_NUM_THREADS" not in os.environ if sys_start_mehod: mp.set_start_method(sys_start_mehod, force=True) cv2.setNumThreads(sys_cv_threads) if sys_omp_threads: - os.environ['OMP_NUM_THREADS'] = sys_omp_threads + os.environ["OMP_NUM_THREADS"] = sys_omp_threads if sys_mkl_threads: - os.environ['MKL_NUM_THREADS'] = sys_mkl_threads + os.environ["MKL_NUM_THREADS"] = sys_mkl_threads else: assert mp.get_start_method() == sys_start_mehod assert cv2.getNumThreads() == sys_cv_threads - assert 'OMP_NUM_THREADS' not in os.environ - assert 'MKL_NUM_THREADS' not in os.environ + assert "OMP_NUM_THREADS" not in os.environ + assert "MKL_NUM_THREADS" not in os.environ diff --git a/mmsegmentation/tests/test_utils/test_util_distribution.py b/mmsegmentation/tests/test_utils/test_util_distribution.py index 5523879..6855b43 100644 --- a/mmsegmentation/tests/test_utils/test_util_distribution.py +++ b/mmsegmentation/tests/test_utils/test_util_distribution.py @@ -4,8 +4,7 @@ import mmcv import torch import torch.nn as nn -from mmcv.parallel import (MMDataParallel, MMDistributedDataParallel, - is_module_wrapper) +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel, is_module_wrapper from mmseg import digit_version from mmseg.utils import build_ddp, build_dp @@ -16,7 +15,6 @@ def mock(*args, **kwargs): class Model(nn.Module): - def __init__(self): super().__init__() self.conv = nn.Conv2d(2, 2, 1) @@ -25,44 +23,44 @@ def forward(self, x): return self.conv(x) -@patch('torch.distributed._broadcast_coalesced', mock) -@patch('torch.distributed.broadcast', mock) -@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock) +@patch("torch.distributed._broadcast_coalesced", mock) +@patch("torch.distributed.broadcast", mock) +@patch("torch.nn.parallel.DistributedDataParallel._ddp_init_helper", mock) def test_build_dp(): model = Model() assert not is_module_wrapper(model) - mmdp = build_dp(model, 'cpu') + mmdp = build_dp(model, "cpu") assert isinstance(mmdp, MMDataParallel) if torch.cuda.is_available(): - mmdp = build_dp(model, 'cuda') + mmdp = build_dp(model, "cuda") assert isinstance(mmdp, MMDataParallel) - if digit_version(mmcv.__version__) >= digit_version('1.5.0'): + if digit_version(mmcv.__version__) >= digit_version("1.5.0"): from mmcv.device.mlu import MLUDataParallel from mmcv.utils import IS_MLU_AVAILABLE + if IS_MLU_AVAILABLE: - mludp = build_dp(model, 'mlu') + mludp = build_dp(model, "mlu") assert isinstance(mludp, MLUDataParallel) -@patch('torch.distributed._broadcast_coalesced', mock) -@patch('torch.distributed.broadcast', mock) -@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock) +@patch("torch.distributed._broadcast_coalesced", mock) +@patch("torch.distributed.broadcast", mock) +@patch("torch.nn.parallel.DistributedDataParallel._ddp_init_helper", mock) def test_build_ddp(): model = Model() assert not is_module_wrapper(model) if torch.cuda.is_available(): - mmddp = build_ddp( - model, 'cuda', device_ids=[0], process_group=MagicMock()) + mmddp = build_ddp(model, "cuda", device_ids=[0], process_group=MagicMock()) assert isinstance(mmddp, MMDistributedDataParallel) - if digit_version(mmcv.__version__) >= digit_version('1.5.0'): + if digit_version(mmcv.__version__) >= digit_version("1.5.0"): from mmcv.device.mlu import MLUDistributedDataParallel from mmcv.utils import IS_MLU_AVAILABLE + if IS_MLU_AVAILABLE: - mluddp = build_ddp( - model, 'mlu', device_ids=[0], process_group=MagicMock()) + mluddp = build_ddp(model, "mlu", device_ids=[0], process_group=MagicMock()) assert isinstance(mluddp, MLUDistributedDataParallel) diff --git a/mmsegmentation/tools/analyze_logs.py b/mmsegmentation/tools/analyze_logs.py index e2127d4..85f217a 100644 --- a/mmsegmentation/tools/analyze_logs.py +++ b/mmsegmentation/tools/analyze_logs.py @@ -19,7 +19,7 @@ def plot_curve(log_dicts, args): legend = [] for json_log in args.json_logs: for metric in args.keys: - legend.append(f'{json_log}_{metric}') + legend.append(f"{json_log}_{metric}") assert len(legend) == (len(args.json_logs) * len(args.keys)) metrics = args.keys @@ -27,7 +27,7 @@ def plot_curve(log_dicts, args): for i, log_dict in enumerate(log_dicts): epochs = list(log_dict.keys()) for j, metric in enumerate(metrics): - print(f'plot curve of {args.json_logs[i]}, metric is {metric}') + print(f"plot curve of {args.json_logs[i]}, metric is {metric}") plot_epochs = [] plot_iters = [] plot_values = [] @@ -38,22 +38,22 @@ def plot_curve(log_dicts, args): epoch_logs = log_dict[epoch] if metric not in epoch_logs.keys(): continue - if metric in ['mIoU', 'mAcc', 'aAcc']: + if metric in ["mIoU", "mAcc", "aAcc"]: plot_epochs.append(epoch) plot_values.append(epoch_logs[metric][0]) else: for idx in range(len(epoch_logs[metric])): - if epoch_logs['mode'][idx] == 'train': - plot_iters.append(epoch_logs['iter'][idx]) + if epoch_logs["mode"][idx] == "train": + plot_iters.append(epoch_logs["iter"][idx]) plot_values.append(epoch_logs[metric][idx]) ax = plt.gca() label = legend[i * num_metrics + j] - if metric in ['mIoU', 'mAcc', 'aAcc']: + if metric in ["mIoU", "mAcc", "aAcc"]: ax.set_xticks(plot_epochs) - plt.xlabel('epoch') - plt.plot(plot_epochs, plot_values, label=label, marker='o') + plt.xlabel("epoch") + plt.plot(plot_epochs, plot_values, label=label, marker="o") else: - plt.xlabel('iter') + plt.xlabel("iter") plt.plot(plot_iters, plot_values, label=label, linewidth=0.5) plt.legend() if args.title is not None: @@ -61,36 +61,30 @@ def plot_curve(log_dicts, args): if args.out is None: plt.show() else: - print(f'save curve to: {args.out}') + print(f"save curve to: {args.out}") plt.savefig(args.out) plt.cla() def parse_args(): - parser = argparse.ArgumentParser(description='Analyze Json Log') + parser = argparse.ArgumentParser(description="Analyze Json Log") parser.add_argument( - 'json_logs', - type=str, - nargs='+', - help='path of train log in json format') - parser.add_argument( - '--keys', - type=str, - nargs='+', - default=['mIoU'], - help='the metric that you want to plot') - parser.add_argument('--title', type=str, help='title of figure') + "json_logs", type=str, nargs="+", help="path of train log in json format" + ) parser.add_argument( - '--legend', + "--keys", type=str, - nargs='+', - default=None, - help='legend of each plot') - parser.add_argument( - '--backend', type=str, default=None, help='backend of plt') + nargs="+", + default=["mIoU"], + help="the metric that you want to plot", + ) + parser.add_argument("--title", type=str, help="title of figure") parser.add_argument( - '--style', type=str, default='dark', help='style of plt') - parser.add_argument('--out', type=str, default=None) + "--legend", type=str, nargs="+", default=None, help="legend of each plot" + ) + parser.add_argument("--backend", type=str, default=None, help="backend of plt") + parser.add_argument("--style", type=str, default="dark", help="style of plt") + parser.add_argument("--out", type=str, default=None) args = parser.parse_args() return args @@ -101,13 +95,13 @@ def load_json_logs(json_logs): # value of sub dict is a list of corresponding values of all iterations log_dicts = [dict() for _ in json_logs] for json_log, log_dict in zip(json_logs, log_dicts): - with open(json_log, 'r') as log_file: + with open(json_log) as log_file: for line in log_file: log = json.loads(line.strip()) # skip lines without `epoch` field - if 'epoch' not in log: + if "epoch" not in log: continue - epoch = log.pop('epoch') + epoch = log.pop("epoch") if epoch not in log_dict: log_dict[epoch] = defaultdict(list) for k, v in log.items(): @@ -119,10 +113,10 @@ def main(): args = parse_args() json_logs = args.json_logs for json_log in json_logs: - assert json_log.endswith('.json') + assert json_log.endswith(".json") log_dicts = load_json_logs(json_logs) plot_curve(log_dicts, args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/benchmark.py b/mmsegmentation/tools/benchmark.py index f6d6888..7c19d28 100644 --- a/mmsegmentation/tools/benchmark.py +++ b/mmsegmentation/tools/benchmark.py @@ -15,16 +15,17 @@ def parse_args(): - parser = argparse.ArgumentParser(description='MMSeg benchmark a model') - parser.add_argument('config', help='test config file path') - parser.add_argument('checkpoint', help='checkpoint file') + parser = argparse.ArgumentParser(description="MMSeg benchmark a model") + parser.add_argument("config", help="test config file path") + parser.add_argument("checkpoint", help="checkpoint file") parser.add_argument( - '--log-interval', type=int, default=50, help='interval of logging') + "--log-interval", type=int, default=50, help="interval of logging" + ) parser.add_argument( - '--work-dir', - help=('if specified, the results will be dumped ' - 'into the directory as json')) - parser.add_argument('--repeat-times', type=int, default=1) + "--work-dir", + help=("if specified, the results will be dumped " "into the directory as json"), + ) + parser.add_argument("--repeat-times", type=int, default=1) args = parser.parse_args() return args @@ -33,16 +34,15 @@ def main(): args = parse_args() cfg = Config.fromfile(args.config) - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) if args.work_dir is not None: mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) - json_file = osp.join(args.work_dir, f'fps_{timestamp}.json') + json_file = osp.join(args.work_dir, f"fps_{timestamp}.json") else: # use config filename as default work_dir if cfg.work_dir is None - work_dir = osp.join('./work_dirs', - osp.splitext(osp.basename(args.config))[0]) + work_dir = osp.join("./work_dirs", osp.splitext(osp.basename(args.config))[0]) mmcv.mkdir_or_exist(osp.abspath(work_dir)) - json_file = osp.join(work_dir, f'fps_{timestamp}.json') + json_file = osp.join(work_dir, f"fps_{timestamp}.json") repeat_times = args.repeat_times # set cudnn_benchmark @@ -50,10 +50,10 @@ def main(): cfg.model.pretrained = None cfg.data.test.test_mode = True - benchmark_dict = dict(config=args.config, unit='img / s') + benchmark_dict = dict(config=args.config, unit="img / s") overall_fps_list = [] for time_index in range(repeat_times): - print(f'Run {time_index + 1}:') + print(f"Run {time_index + 1}:") # build the dataloader # TODO: support multiple images per gpu (only minor changes are needed) dataset = build_dataset(cfg.data.test) @@ -62,16 +62,17 @@ def main(): samples_per_gpu=1, workers_per_gpu=cfg.data.workers_per_gpu, dist=False, - shuffle=False) + shuffle=False, + ) # build the model and load checkpoint cfg.model.train_cfg = None - model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) - fp16_cfg = cfg.get('fp16', None) + model = build_segmentor(cfg.model, test_cfg=cfg.get("test_cfg")) + fp16_cfg = cfg.get("fp16", None) if fp16_cfg is not None: wrap_fp16_model(model) - if 'checkpoint' in args and osp.exists(args.checkpoint): - load_checkpoint(model, args.checkpoint, map_location='cpu') + if "checkpoint" in args and osp.exists(args.checkpoint): + load_checkpoint(model, args.checkpoint, map_location="cpu") model = MMDataParallel(model, device_ids=[0]) @@ -84,7 +85,6 @@ def main(): # benchmark with 200 image and take the average for i, data in enumerate(data_loader): - torch.cuda.synchronize() start_time = time.perf_counter() @@ -98,23 +98,29 @@ def main(): pure_inf_time += elapsed if (i + 1) % args.log_interval == 0: fps = (i + 1 - num_warmup) / pure_inf_time - print(f'Done image [{i + 1:<3}/ {total_iters}], ' - f'fps: {fps:.2f} img / s') + print( + f"Done image [{i + 1:<3}/ {total_iters}], " + f"fps: {fps:.2f} img / s" + ) if (i + 1) == total_iters: fps = (i + 1 - num_warmup) / pure_inf_time - print(f'Overall fps: {fps:.2f} img / s\n') - benchmark_dict[f'overall_fps_{time_index + 1}'] = round(fps, 2) + print(f"Overall fps: {fps:.2f} img / s\n") + benchmark_dict[f"overall_fps_{time_index + 1}"] = round(fps, 2) overall_fps_list.append(fps) break - benchmark_dict['average_fps'] = round(np.mean(overall_fps_list), 2) - benchmark_dict['fps_variance'] = round(np.var(overall_fps_list), 4) - print(f'Average fps of {repeat_times} evaluations: ' - f'{benchmark_dict["average_fps"]}') - print(f'The variance of {repeat_times} evaluations: ' - f'{benchmark_dict["fps_variance"]}') + benchmark_dict["average_fps"] = round(np.mean(overall_fps_list), 2) + benchmark_dict["fps_variance"] = round(np.var(overall_fps_list), 4) + print( + f"Average fps of {repeat_times} evaluations: " + f'{benchmark_dict["average_fps"]}' + ) + print( + f"The variance of {repeat_times} evaluations: " + f'{benchmark_dict["fps_variance"]}' + ) mmcv.dump(benchmark_dict, json_file, indent=4) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/browse_dataset.py b/mmsegmentation/tools/browse_dataset.py index 0aa9430..d568025 100644 --- a/mmsegmentation/tools/browse_dataset.py +++ b/mmsegmentation/tools/browse_dataset.py @@ -12,60 +12,62 @@ def parse_args(): - parser = argparse.ArgumentParser(description='Browse a dataset') - parser.add_argument('config', help='train config file path') + parser = argparse.ArgumentParser(description="Browse a dataset") + parser.add_argument("config", help="train config file path") parser.add_argument( - '--show-origin', + "--show-origin", default=False, - action='store_true', - help='if True, omit all augmentation in pipeline,' - ' show origin image and seg map') + action="store_true", + help="if True, omit all augmentation in pipeline," + " show origin image and seg map", + ) parser.add_argument( - '--skip-type', + "--skip-type", type=str, - nargs='+', - default=['DefaultFormatBundle', 'Normalize', 'Collect'], - help='skip some useless pipeline,if `show-origin` is true, ' - 'all pipeline except `Load` will be skipped') + nargs="+", + default=["DefaultFormatBundle", "Normalize", "Collect"], + help="skip some useless pipeline,if `show-origin` is true, " + "all pipeline except `Load` will be skipped", + ) parser.add_argument( - '--output-dir', - default='./output', + "--output-dir", + default="./output", type=str, - help='If there is no display interface, you can save it') - parser.add_argument('--show', default=False, action='store_true') + help="If there is no display interface, you can save it", + ) + parser.add_argument("--show", default=False, action="store_true") parser.add_argument( - '--show-interval', - type=int, - default=999, - help='the interval of show (ms)') + "--show-interval", type=int, default=999, help="the interval of show (ms)" + ) parser.add_argument( - '--opacity', - type=float, - default=0.5, - help='the opacity of semantic map') + "--opacity", type=float, default=0.5, help="the opacity of semantic map" + ) parser.add_argument( - '--cfg-options', - nargs='+', + "--cfg-options", + nargs="+", action=DictAction, - help='override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into config file. If the value to ' + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' - 'Note that the quotation marks are necessary and that no white space ' - 'is allowed.') + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) args = parser.parse_args() return args -def imshow_semantic(img, - seg, - class_names, - palette=None, - win_name='', - show=False, - wait_time=0, - out_file=None, - opacity=0.5): +def imshow_semantic( + img, + seg, + class_names, + palette=None, + win_name="", + show=False, + wait_time=0, + out_file=None, + opacity=0.5, +): """Draw `result` over `img`. Args: @@ -116,20 +118,20 @@ def imshow_semantic(img, mmcv.imwrite(img, out_file) if not (show or out_file): - warnings.warn('show==False and out_file is not specified, only ' - 'result image will be returned') + warnings.warn( + "show==False and out_file is not specified, only " + "result image will be returned" + ) return img def _retrieve_data_cfg(_data_cfg, skip_type, show_origin): if show_origin is True: # only keep pipeline of Loading data and ann - _data_cfg['pipeline'] = [ - x for x in _data_cfg.pipeline if 'Load' in x['type'] - ] + _data_cfg["pipeline"] = [x for x in _data_cfg.pipeline if "Load" in x["type"]] else: - _data_cfg['pipeline'] = [ - x for x in _data_cfg.pipeline if x['type'] not in skip_type + _data_cfg["pipeline"] = [ + x for x in _data_cfg.pipeline if x["type"] not in skip_type ] @@ -140,34 +142,40 @@ def retrieve_data_cfg(config_path, skip_type, cfg_options, show_origin=False): train_data_cfg = cfg.data.train if isinstance(train_data_cfg, list): for _data_cfg in train_data_cfg: - while 'dataset' in _data_cfg and _data_cfg[ - 'type'] != 'MultiImageMixDataset': - _data_cfg = _data_cfg['dataset'] - if 'pipeline' in _data_cfg: + while ( + "dataset" in _data_cfg and _data_cfg["type"] != "MultiImageMixDataset" + ): + _data_cfg = _data_cfg["dataset"] + if "pipeline" in _data_cfg: _retrieve_data_cfg(_data_cfg, skip_type, show_origin) else: raise ValueError else: - while 'dataset' in train_data_cfg and train_data_cfg[ - 'type'] != 'MultiImageMixDataset': - train_data_cfg = train_data_cfg['dataset'] + while ( + "dataset" in train_data_cfg + and train_data_cfg["type"] != "MultiImageMixDataset" + ): + train_data_cfg = train_data_cfg["dataset"] _retrieve_data_cfg(train_data_cfg, skip_type, show_origin) return cfg def main(): args = parse_args() - cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options, - args.show_origin) + cfg = retrieve_data_cfg( + args.config, args.skip_type, args.cfg_options, args.show_origin + ) dataset = build_dataset(cfg.data.train) progress_bar = mmcv.ProgressBar(len(dataset)) for item in dataset: - filename = os.path.join(args.output_dir, - Path(item['filename']).name - ) if args.output_dir is not None else None + filename = ( + os.path.join(args.output_dir, Path(item["filename"]).name) + if args.output_dir is not None + else None + ) imshow_semantic( - item['img'], - item['gt_semantic_seg'], + item["img"], + item["gt_semantic_seg"], dataset.CLASSES, dataset.PALETTE, show=args.show, @@ -178,5 +186,5 @@ def main(): progress_bar.update() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/confusion_matrix.py b/mmsegmentation/tools/confusion_matrix.py index 4166722..670c3ce 100644 --- a/mmsegmentation/tools/confusion_matrix.py +++ b/mmsegmentation/tools/confusion_matrix.py @@ -13,32 +13,35 @@ def parse_args(): parser = argparse.ArgumentParser( - description='Generate confusion matrix from segmentation results') - parser.add_argument('config', help='test config file path') + description="Generate confusion matrix from segmentation results" + ) + parser.add_argument("config", help="test config file path") parser.add_argument( - 'prediction_path', help='prediction path where test .pkl result') + "prediction_path", help="prediction path where test .pkl result" + ) parser.add_argument( - 'save_dir', help='directory where confusion matrix will be saved') + "save_dir", help="directory where confusion matrix will be saved" + ) + parser.add_argument("--show", action="store_true", help="show confusion matrix") parser.add_argument( - '--show', action='store_true', help='show confusion matrix') + "--color-theme", default="winter", help="theme of the matrix color map" + ) parser.add_argument( - '--color-theme', - default='winter', - help='theme of the matrix color map') + "--title", + default="Normalized Confusion Matrix", + help="title of the matrix color map", + ) parser.add_argument( - '--title', - default='Normalized Confusion Matrix', - help='title of the matrix color map') - parser.add_argument( - '--cfg-options', - nargs='+', + "--cfg-options", + nargs="+", action=DictAction, - help='override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into config file. If the value to ' + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' - 'Note that the quotation marks are necessary and that no white space ' - 'is allowed.') + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) args = parser.parse_args() return args @@ -68,12 +71,14 @@ def calculate_confusion_matrix(dataset, results): return confusion_matrix -def plot_confusion_matrix(confusion_matrix, - labels, - save_dir=None, - show=True, - title='Normalized Confusion Matrix', - color_theme='winter'): +def plot_confusion_matrix( + confusion_matrix, + labels, + save_dir=None, + show=True, + title="Normalized Confusion Matrix", + color_theme="winter", +): """Draw confusion matrix with matplotlib. Args: @@ -87,21 +92,19 @@ def plot_confusion_matrix(confusion_matrix, """ # normalize the confusion matrix per_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis] - confusion_matrix = \ - confusion_matrix.astype(np.float32) / per_label_sums * 100 + confusion_matrix = confusion_matrix.astype(np.float32) / per_label_sums * 100 num_classes = len(labels) - fig, ax = plt.subplots( - figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=180) + fig, ax = plt.subplots(figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=180) cmap = plt.get_cmap(color_theme) im = ax.imshow(confusion_matrix, cmap=cmap) plt.colorbar(mappable=im, ax=ax) - title_font = {'weight': 'bold', 'size': 12} + title_font = {"weight": "bold", "size": 12} ax.set_title(title, fontdict=title_font) - label_font = {'size': 10} - plt.ylabel('Ground Truth Label', fontdict=label_font) - plt.xlabel('Prediction Label', fontdict=label_font) + label_font = {"size": 10} + plt.ylabel("Ground Truth Label", fontdict=label_font) + plt.xlabel("Prediction Label", fontdict=label_font) # draw locator xmajor_locator = MultipleLocator(1) @@ -114,7 +117,7 @@ def plot_confusion_matrix(confusion_matrix, ax.yaxis.set_minor_locator(yminor_locator) # draw grid - ax.grid(True, which='minor', linestyle='-') + ax.grid(True, which="minor", linestyle="-") # draw label ax.set_xticks(np.arange(num_classes)) @@ -122,10 +125,8 @@ def plot_confusion_matrix(confusion_matrix, ax.set_xticklabels(labels) ax.set_yticklabels(labels) - ax.tick_params( - axis='x', bottom=False, top=True, labelbottom=False, labeltop=True) - plt.setp( - ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor') + ax.tick_params(axis="x", bottom=False, top=True, labelbottom=False, labeltop=True) + plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor") # draw confusion matrix value for i in range(num_classes): @@ -133,20 +134,22 @@ def plot_confusion_matrix(confusion_matrix, ax.text( j, i, - '{}%'.format( - round(confusion_matrix[i, j], 2 - ) if not np.isnan(confusion_matrix[i, j]) else -1), - ha='center', - va='center', - color='w', - size=7) + "{}%".format( + round(confusion_matrix[i, j], 2) + if not np.isnan(confusion_matrix[i, j]) + else -1 + ), + ha="center", + va="center", + color="w", + size=7, + ) ax.set_ylim(len(confusion_matrix) - 0.5, -0.5) # matplotlib>3.1.1 fig.tight_layout() if save_dir is not None: - plt.savefig( - os.path.join(save_dir, 'confusion_matrix.png'), format='png') + plt.savefig(os.path.join(save_dir, "confusion_matrix.png"), format="png") if show: plt.show() @@ -164,7 +167,7 @@ def main(): if isinstance(results[0], np.ndarray): pass else: - raise TypeError('invalid type of prediction results') + raise TypeError("invalid type of prediction results") if isinstance(cfg.data.test, dict): cfg.data.test.test_mode = True @@ -180,8 +183,9 @@ def main(): save_dir=args.save_dir, show=args.show, title=args.title, - color_theme=args.color_theme) + color_theme=args.color_theme, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/convert_datasets/chase_db1.py b/mmsegmentation/tools/convert_datasets/chase_db1.py index 580e6e7..9c79c7f 100644 --- a/mmsegmentation/tools/convert_datasets/chase_db1.py +++ b/mmsegmentation/tools/convert_datasets/chase_db1.py @@ -13,10 +13,11 @@ def parse_args(): parser = argparse.ArgumentParser( - description='Convert CHASE_DB1 dataset to mmsegmentation format') - parser.add_argument('dataset_path', help='path of CHASEDB1.zip') - parser.add_argument('--tmp_dir', help='path of the temporary directory') - parser.add_argument('-o', '--out_dir', help='output path') + description="Convert CHASE_DB1 dataset to mmsegmentation format" + ) + parser.add_argument("dataset_path", help="path of CHASEDB1.zip") + parser.add_argument("--tmp_dir", help="path of the temporary directory") + parser.add_argument("-o", "--out_dir", help="output path") args = parser.parse_args() return args @@ -25,36 +26,42 @@ def main(): args = parse_args() dataset_path = args.dataset_path if args.out_dir is None: - out_dir = osp.join('data', 'CHASE_DB1') + out_dir = osp.join("data", "CHASE_DB1") else: out_dir = args.out_dir - print('Making directories...') + print("Making directories...") mmcv.mkdir_or_exist(out_dir) - mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) + mmcv.mkdir_or_exist(osp.join(out_dir, "images")) + mmcv.mkdir_or_exist(osp.join(out_dir, "images", "training")) + mmcv.mkdir_or_exist(osp.join(out_dir, "images", "validation")) + mmcv.mkdir_or_exist(osp.join(out_dir, "annotations")) + mmcv.mkdir_or_exist(osp.join(out_dir, "annotations", "training")) + mmcv.mkdir_or_exist(osp.join(out_dir, "annotations", "validation")) with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: - print('Extracting CHASEDB1.zip...') + print("Extracting CHASEDB1.zip...") zip_file = zipfile.ZipFile(dataset_path) zip_file.extractall(tmp_dir) - print('Generating training dataset...') + print("Generating training dataset...") - assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \ - 'len(os.listdir(tmp_dir)) != {}'.format(CHASE_DB1_LEN) + assert ( + len(os.listdir(tmp_dir)) == CHASE_DB1_LEN + ), f"len(os.listdir(tmp_dir)) != {CHASE_DB1_LEN}" for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: img = mmcv.imread(osp.join(tmp_dir, img_name)) - if osp.splitext(img_name)[1] == '.jpg': + if osp.splitext(img_name)[1] == ".jpg": mmcv.imwrite( img, - osp.join(out_dir, 'images', 'training', - osp.splitext(img_name)[0] + '.png')) + osp.join( + out_dir, + "images", + "training", + osp.splitext(img_name)[0] + ".png", + ), + ) else: # The annotation img should be divided by 128, because some of # the annotation imgs are not standard. We should set a @@ -63,26 +70,41 @@ def main(): # else 0' mmcv.imwrite( img[:, :, 0] // 128, - osp.join(out_dir, 'annotations', 'training', - osp.splitext(img_name)[0] + '.png')) + osp.join( + out_dir, + "annotations", + "training", + osp.splitext(img_name)[0] + ".png", + ), + ) for img_name in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: img = mmcv.imread(osp.join(tmp_dir, img_name)) - if osp.splitext(img_name)[1] == '.jpg': + if osp.splitext(img_name)[1] == ".jpg": mmcv.imwrite( img, - osp.join(out_dir, 'images', 'validation', - osp.splitext(img_name)[0] + '.png')) + osp.join( + out_dir, + "images", + "validation", + osp.splitext(img_name)[0] + ".png", + ), + ) else: mmcv.imwrite( img[:, :, 0] // 128, - osp.join(out_dir, 'annotations', 'validation', - osp.splitext(img_name)[0] + '.png')) + osp.join( + out_dir, + "annotations", + "validation", + osp.splitext(img_name)[0] + ".png", + ), + ) - print('Removing the temporary files...') + print("Removing the temporary files...") - print('Done!') + print("Done!") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/convert_datasets/cityscapes.py b/mmsegmentation/tools/convert_datasets/cityscapes.py index 17b6168..9644130 100644 --- a/mmsegmentation/tools/convert_datasets/cityscapes.py +++ b/mmsegmentation/tools/convert_datasets/cityscapes.py @@ -7,18 +7,18 @@ def convert_json_to_label(json_file): - label_file = json_file.replace('_polygons.json', '_labelTrainIds.png') - json2labelImg(json_file, label_file, 'trainIds') + label_file = json_file.replace("_polygons.json", "_labelTrainIds.png") + json2labelImg(json_file, label_file, "trainIds") def parse_args(): parser = argparse.ArgumentParser( - description='Convert Cityscapes annotations to TrainIds') - parser.add_argument('cityscapes_path', help='cityscapes data path') - parser.add_argument('--gt-dir', default='gtFine', type=str) - parser.add_argument('-o', '--out-dir', help='output path') - parser.add_argument( - '--nproc', default=1, type=int, help='number of process') + description="Convert Cityscapes annotations to TrainIds" + ) + parser.add_argument("cityscapes_path", help="cityscapes data path") + parser.add_argument("--gt-dir", default="gtFine", type=str) + parser.add_argument("-o", "--out-dir", help="output path") + parser.add_argument("--nproc", default=1, type=int, help="number of process") args = parser.parse_args() return args @@ -32,25 +32,25 @@ def main(): gt_dir = osp.join(cityscapes_path, args.gt_dir) poly_files = [] - for poly in mmcv.scandir(gt_dir, '_polygons.json', recursive=True): + for poly in mmcv.scandir(gt_dir, "_polygons.json", recursive=True): poly_file = osp.join(gt_dir, poly) poly_files.append(poly_file) if args.nproc > 1: - mmcv.track_parallel_progress(convert_json_to_label, poly_files, - args.nproc) + mmcv.track_parallel_progress(convert_json_to_label, poly_files, args.nproc) else: mmcv.track_progress(convert_json_to_label, poly_files) - split_names = ['train', 'val', 'test'] + split_names = ["train", "val", "test"] for split in split_names: filenames = [] for poly in mmcv.scandir( - osp.join(gt_dir, split), '_polygons.json', recursive=True): - filenames.append(poly.replace('_gtFine_polygons.json', '')) - with open(osp.join(out_dir, f'{split}.txt'), 'w') as f: - f.writelines(f + '\n' for f in filenames) + osp.join(gt_dir, split), "_polygons.json", recursive=True + ): + filenames.append(poly.replace("_gtFine_polygons.json", "")) + with open(osp.join(out_dir, f"{split}.txt"), "w") as f: + f.writelines(f + "\n" for f in filenames) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/convert_datasets/coco_stuff10k.py b/mmsegmentation/tools/convert_datasets/coco_stuff10k.py index 374f819..02a55df 100644 --- a/mmsegmentation/tools/convert_datasets/coco_stuff10k.py +++ b/mmsegmentation/tools/convert_datasets/coco_stuff10k.py @@ -183,48 +183,55 @@ 179: 168, 180: 169, 181: 170, - 182: 171 + 182: 171, } -def convert_to_trainID(tuple_path, in_img_dir, in_ann_dir, out_img_dir, - out_mask_dir, is_train): +def convert_to_trainID( + tuple_path, in_img_dir, in_ann_dir, out_img_dir, out_mask_dir, is_train +): imgpath, maskpath = tuple_path shutil.copyfile( osp.join(in_img_dir, imgpath), - osp.join(out_img_dir, 'train2014', imgpath) if is_train else osp.join( - out_img_dir, 'test2014', imgpath)) + osp.join(out_img_dir, "train2014", imgpath) + if is_train + else osp.join(out_img_dir, "test2014", imgpath), + ) annotate = loadmat(osp.join(in_ann_dir, maskpath)) - mask = annotate['S'].astype(np.uint8) + mask = annotate["S"].astype(np.uint8) mask_copy = mask.copy() for clsID, trID in clsID_to_trID.items(): mask_copy[mask == clsID] = trID - seg_filename = osp.join(out_mask_dir, 'train2014', - maskpath.split('.')[0] + - '_labelTrainIds.png') if is_train else osp.join( - out_mask_dir, 'test2014', - maskpath.split('.')[0] + '_labelTrainIds.png') - Image.fromarray(mask_copy).save(seg_filename, 'PNG') + seg_filename = ( + osp.join( + out_mask_dir, "train2014", maskpath.split(".")[0] + "_labelTrainIds.png" + ) + if is_train + else osp.join( + out_mask_dir, "test2014", maskpath.split(".")[0] + "_labelTrainIds.png" + ) + ) + Image.fromarray(mask_copy).save(seg_filename, "PNG") def generate_coco_list(folder): - train_list = osp.join(folder, 'imageLists', 'train.txt') - test_list = osp.join(folder, 'imageLists', 'test.txt') + train_list = osp.join(folder, "imageLists", "train.txt") + test_list = osp.join(folder, "imageLists", "test.txt") train_paths = [] test_paths = [] with open(train_list) as f: for filename in f: basename = filename.strip() - imgpath = basename + '.jpg' - maskpath = basename + '.mat' + imgpath = basename + ".jpg" + maskpath = basename + ".mat" train_paths.append((imgpath, maskpath)) with open(test_list) as f: for filename in f: basename = filename.strip() - imgpath = basename + '.jpg' - maskpath = basename + '.mat' + imgpath = basename + ".jpg" + maskpath = basename + ".mat" test_paths.append((imgpath, maskpath)) return train_paths, test_paths @@ -232,12 +239,11 @@ def generate_coco_list(folder): def parse_args(): parser = argparse.ArgumentParser( - description=\ - 'Convert COCO Stuff 10k annotations to mmsegmentation format') # noqa - parser.add_argument('coco_path', help='coco stuff path') - parser.add_argument('-o', '--out_dir', help='output path') - parser.add_argument( - '--nproc', default=16, type=int, help='number of process') + description="Convert COCO Stuff 10k annotations to mmsegmentation format" + ) # noqa + parser.add_argument("coco_path", help="coco stuff path") + parser.add_argument("-o", "--out_dir", help="output path") + parser.add_argument("--nproc", default=16, type=int, help="number of process") args = parser.parse_args() return args @@ -248,60 +254,72 @@ def main(): nproc = args.nproc out_dir = args.out_dir or coco_path - out_img_dir = osp.join(out_dir, 'images') - out_mask_dir = osp.join(out_dir, 'annotations') + out_img_dir = osp.join(out_dir, "images") + out_mask_dir = osp.join(out_dir, "annotations") - mmcv.mkdir_or_exist(osp.join(out_img_dir, 'train2014')) - mmcv.mkdir_or_exist(osp.join(out_img_dir, 'test2014')) - mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2014')) - mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'test2014')) + mmcv.mkdir_or_exist(osp.join(out_img_dir, "train2014")) + mmcv.mkdir_or_exist(osp.join(out_img_dir, "test2014")) + mmcv.mkdir_or_exist(osp.join(out_mask_dir, "train2014")) + mmcv.mkdir_or_exist(osp.join(out_mask_dir, "test2014")) train_list, test_list = generate_coco_list(coco_path) - assert (len(train_list) + - len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( - len(train_list), len(test_list)) + assert ( + len(train_list) + len(test_list) + ) == COCO_LEN, "Wrong length of list {} & {}".format( + len(train_list), len(test_list) + ) if args.nproc > 1: mmcv.track_parallel_progress( partial( convert_to_trainID, - in_img_dir=osp.join(coco_path, 'images'), - in_ann_dir=osp.join(coco_path, 'annotations'), + in_img_dir=osp.join(coco_path, "images"), + in_ann_dir=osp.join(coco_path, "annotations"), out_img_dir=out_img_dir, out_mask_dir=out_mask_dir, - is_train=True), + is_train=True, + ), train_list, - nproc=nproc) + nproc=nproc, + ) mmcv.track_parallel_progress( partial( convert_to_trainID, - in_img_dir=osp.join(coco_path, 'images'), - in_ann_dir=osp.join(coco_path, 'annotations'), + in_img_dir=osp.join(coco_path, "images"), + in_ann_dir=osp.join(coco_path, "annotations"), out_img_dir=out_img_dir, out_mask_dir=out_mask_dir, - is_train=False), + is_train=False, + ), test_list, - nproc=nproc) + nproc=nproc, + ) else: mmcv.track_progress( partial( convert_to_trainID, - in_img_dir=osp.join(coco_path, 'images'), - in_ann_dir=osp.join(coco_path, 'annotations'), + in_img_dir=osp.join(coco_path, "images"), + in_ann_dir=osp.join(coco_path, "annotations"), out_img_dir=out_img_dir, out_mask_dir=out_mask_dir, - is_train=True), train_list) + is_train=True, + ), + train_list, + ) mmcv.track_progress( partial( convert_to_trainID, - in_img_dir=osp.join(coco_path, 'images'), - in_ann_dir=osp.join(coco_path, 'annotations'), + in_img_dir=osp.join(coco_path, "images"), + in_ann_dir=osp.join(coco_path, "annotations"), out_img_dir=out_img_dir, out_mask_dir=out_mask_dir, - is_train=False), test_list) + is_train=False, + ), + test_list, + ) - print('Done!') + print("Done!") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/convert_datasets/coco_stuff164k.py b/mmsegmentation/tools/convert_datasets/coco_stuff164k.py index 6d8e2f2..a1eff6a 100644 --- a/mmsegmentation/tools/convert_datasets/coco_stuff164k.py +++ b/mmsegmentation/tools/convert_datasets/coco_stuff164k.py @@ -183,7 +183,7 @@ 179: 168, 180: 169, 181: 170, - 255: 255 + 255: 255, } @@ -192,23 +192,29 @@ def convert_to_trainID(maskpath, out_mask_dir, is_train): mask_copy = mask.copy() for clsID, trID in clsID_to_trID.items(): mask_copy[mask == clsID] = trID - seg_filename = osp.join( - out_mask_dir, 'train2017', - osp.basename(maskpath).split('.')[0] + - '_labelTrainIds.png') if is_train else osp.join( - out_mask_dir, 'val2017', - osp.basename(maskpath).split('.')[0] + '_labelTrainIds.png') - Image.fromarray(mask_copy).save(seg_filename, 'PNG') + seg_filename = ( + osp.join( + out_mask_dir, + "train2017", + osp.basename(maskpath).split(".")[0] + "_labelTrainIds.png", + ) + if is_train + else osp.join( + out_mask_dir, + "val2017", + osp.basename(maskpath).split(".")[0] + "_labelTrainIds.png", + ) + ) + Image.fromarray(mask_copy).save(seg_filename, "PNG") def parse_args(): parser = argparse.ArgumentParser( - description=\ - 'Convert COCO Stuff 164k annotations to mmsegmentation format') # noqa - parser.add_argument('coco_path', help='coco stuff path') - parser.add_argument('-o', '--out_dir', help='output path') - parser.add_argument( - '--nproc', default=16, type=int, help='number of process') + description="Convert COCO Stuff 164k annotations to mmsegmentation format" + ) # noqa + parser.add_argument("coco_path", help="coco stuff path") + parser.add_argument("-o", "--out_dir", help="output path") + parser.add_argument("--nproc", default=16, type=int, help="number of process") args = parser.parse_args() return args @@ -219,46 +225,48 @@ def main(): nproc = args.nproc out_dir = args.out_dir or coco_path - out_img_dir = osp.join(out_dir, 'images') - out_mask_dir = osp.join(out_dir, 'annotations') + out_img_dir = osp.join(out_dir, "images") + out_mask_dir = osp.join(out_dir, "annotations") - mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017')) - mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017')) + mmcv.mkdir_or_exist(osp.join(out_mask_dir, "train2017")) + mmcv.mkdir_or_exist(osp.join(out_mask_dir, "val2017")) if out_dir != coco_path: - shutil.copytree(osp.join(coco_path, 'images'), out_img_dir) + shutil.copytree(osp.join(coco_path, "images"), out_img_dir) - train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png')) - train_list = [file for file in train_list if '_labelTrainIds' not in file] - test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png')) - test_list = [file for file in test_list if '_labelTrainIds' not in file] - assert (len(train_list) + - len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( - len(train_list), len(test_list)) + train_list = glob(osp.join(coco_path, "annotations", "train2017", "*.png")) + train_list = [file for file in train_list if "_labelTrainIds" not in file] + test_list = glob(osp.join(coco_path, "annotations", "val2017", "*.png")) + test_list = [file for file in test_list if "_labelTrainIds" not in file] + assert ( + len(train_list) + len(test_list) + ) == COCO_LEN, "Wrong length of list {} & {}".format( + len(train_list), len(test_list) + ) if args.nproc > 1: mmcv.track_parallel_progress( - partial( - convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), + partial(convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), train_list, - nproc=nproc) + nproc=nproc, + ) mmcv.track_parallel_progress( - partial( - convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), + partial(convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), test_list, - nproc=nproc) + nproc=nproc, + ) else: mmcv.track_progress( - partial( - convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), - train_list) + partial(convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), + train_list, + ) mmcv.track_progress( - partial( - convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), - test_list) + partial(convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), + test_list, + ) - print('Done!') + print("Done!") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/convert_datasets/drive.py b/mmsegmentation/tools/convert_datasets/drive.py index f547579..dfa4d70 100644 --- a/mmsegmentation/tools/convert_datasets/drive.py +++ b/mmsegmentation/tools/convert_datasets/drive.py @@ -11,13 +11,12 @@ def parse_args(): parser = argparse.ArgumentParser( - description='Convert DRIVE dataset to mmsegmentation format') - parser.add_argument( - 'training_path', help='the training part of DRIVE dataset') - parser.add_argument( - 'testing_path', help='the testing part of DRIVE dataset') - parser.add_argument('--tmp_dir', help='path of the temporary directory') - parser.add_argument('-o', '--out_dir', help='output path') + description="Convert DRIVE dataset to mmsegmentation format" + ) + parser.add_argument("training_path", help="the training part of DRIVE dataset") + parser.add_argument("testing_path", help="the testing part of DRIVE dataset") + parser.add_argument("--tmp_dir", help="path of the temporary directory") + parser.add_argument("-o", "--out_dir", help="output path") args = parser.parse_args() return args @@ -27,59 +26,71 @@ def main(): training_path = args.training_path testing_path = args.testing_path if args.out_dir is None: - out_dir = osp.join('data', 'DRIVE') + out_dir = osp.join("data", "DRIVE") else: out_dir = args.out_dir - print('Making directories...') + print("Making directories...") mmcv.mkdir_or_exist(out_dir) - mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) + mmcv.mkdir_or_exist(osp.join(out_dir, "images")) + mmcv.mkdir_or_exist(osp.join(out_dir, "images", "training")) + mmcv.mkdir_or_exist(osp.join(out_dir, "images", "validation")) + mmcv.mkdir_or_exist(osp.join(out_dir, "annotations")) + mmcv.mkdir_or_exist(osp.join(out_dir, "annotations", "training")) + mmcv.mkdir_or_exist(osp.join(out_dir, "annotations", "validation")) with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: - print('Extracting training.zip...') + print("Extracting training.zip...") zip_file = zipfile.ZipFile(training_path) zip_file.extractall(tmp_dir) - print('Generating training dataset...') - now_dir = osp.join(tmp_dir, 'training', 'images') + print("Generating training dataset...") + now_dir = osp.join(tmp_dir, "training", "images") for img_name in os.listdir(now_dir): img = mmcv.imread(osp.join(now_dir, img_name)) mmcv.imwrite( img, osp.join( - out_dir, 'images', 'training', - osp.splitext(img_name)[0].replace('_training', '') + - '.png')) - - now_dir = osp.join(tmp_dir, 'training', '1st_manual') + out_dir, + "images", + "training", + osp.splitext(img_name)[0].replace("_training", "") + ".png", + ), + ) + + now_dir = osp.join(tmp_dir, "training", "1st_manual") for img_name in os.listdir(now_dir): cap = cv2.VideoCapture(osp.join(now_dir, img_name)) ret, img = cap.read() mmcv.imwrite( img[:, :, 0] // 128, - osp.join(out_dir, 'annotations', 'training', - osp.splitext(img_name)[0] + '.png')) - - print('Extracting test.zip...') + osp.join( + out_dir, + "annotations", + "training", + osp.splitext(img_name)[0] + ".png", + ), + ) + + print("Extracting test.zip...") zip_file = zipfile.ZipFile(testing_path) zip_file.extractall(tmp_dir) - print('Generating validation dataset...') - now_dir = osp.join(tmp_dir, 'test', 'images') + print("Generating validation dataset...") + now_dir = osp.join(tmp_dir, "test", "images") for img_name in os.listdir(now_dir): img = mmcv.imread(osp.join(now_dir, img_name)) mmcv.imwrite( img, osp.join( - out_dir, 'images', 'validation', - osp.splitext(img_name)[0].replace('_test', '') + '.png')) - - now_dir = osp.join(tmp_dir, 'test', '1st_manual') + out_dir, + "images", + "validation", + osp.splitext(img_name)[0].replace("_test", "") + ".png", + ), + ) + + now_dir = osp.join(tmp_dir, "test", "1st_manual") if osp.exists(now_dir): for img_name in os.listdir(now_dir): cap = cv2.VideoCapture(osp.join(now_dir, img_name)) @@ -91,23 +102,33 @@ def main(): # else 0' mmcv.imwrite( img[:, :, 0] // 128, - osp.join(out_dir, 'annotations', 'validation', - osp.splitext(img_name)[0] + '.png')) - - now_dir = osp.join(tmp_dir, 'test', '2nd_manual') + osp.join( + out_dir, + "annotations", + "validation", + osp.splitext(img_name)[0] + ".png", + ), + ) + + now_dir = osp.join(tmp_dir, "test", "2nd_manual") if osp.exists(now_dir): for img_name in os.listdir(now_dir): cap = cv2.VideoCapture(osp.join(now_dir, img_name)) ret, img = cap.read() mmcv.imwrite( img[:, :, 0] // 128, - osp.join(out_dir, 'annotations', 'validation', - osp.splitext(img_name)[0] + '.png')) + osp.join( + out_dir, + "annotations", + "validation", + osp.splitext(img_name)[0] + ".png", + ), + ) - print('Removing the temporary files...') + print("Removing the temporary files...") - print('Done!') + print("Done!") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/convert_datasets/hrf.py b/mmsegmentation/tools/convert_datasets/hrf.py index 5e016e3..930f8bc 100644 --- a/mmsegmentation/tools/convert_datasets/hrf.py +++ b/mmsegmentation/tools/convert_datasets/hrf.py @@ -13,21 +13,25 @@ def parse_args(): parser = argparse.ArgumentParser( - description='Convert HRF dataset to mmsegmentation format') - parser.add_argument('healthy_path', help='the path of healthy.zip') + description="Convert HRF dataset to mmsegmentation format" + ) + parser.add_argument("healthy_path", help="the path of healthy.zip") parser.add_argument( - 'healthy_manualsegm_path', help='the path of healthy_manualsegm.zip') - parser.add_argument('glaucoma_path', help='the path of glaucoma.zip') + "healthy_manualsegm_path", help="the path of healthy_manualsegm.zip" + ) + parser.add_argument("glaucoma_path", help="the path of glaucoma.zip") parser.add_argument( - 'glaucoma_manualsegm_path', help='the path of glaucoma_manualsegm.zip') + "glaucoma_manualsegm_path", help="the path of glaucoma_manualsegm.zip" + ) parser.add_argument( - 'diabetic_retinopathy_path', - help='the path of diabetic_retinopathy.zip') + "diabetic_retinopathy_path", help="the path of diabetic_retinopathy.zip" + ) parser.add_argument( - 'diabetic_retinopathy_manualsegm_path', - help='the path of diabetic_retinopathy_manualsegm.zip') - parser.add_argument('--tmp_dir', help='path of the temporary directory') - parser.add_argument('-o', '--out_dir', help='output path') + "diabetic_retinopathy_manualsegm_path", + help="the path of diabetic_retinopathy_manualsegm.zip", + ) + parser.add_argument("--tmp_dir", help="path of the temporary directory") + parser.add_argument("-o", "--out_dir", help="output path") args = parser.parse_args() return args @@ -35,56 +39,71 @@ def parse_args(): def main(): args = parse_args() images_path = [ - args.healthy_path, args.glaucoma_path, args.diabetic_retinopathy_path + args.healthy_path, + args.glaucoma_path, + args.diabetic_retinopathy_path, ] annotations_path = [ - args.healthy_manualsegm_path, args.glaucoma_manualsegm_path, - args.diabetic_retinopathy_manualsegm_path + args.healthy_manualsegm_path, + args.glaucoma_manualsegm_path, + args.diabetic_retinopathy_manualsegm_path, ] if args.out_dir is None: - out_dir = osp.join('data', 'HRF') + out_dir = osp.join("data", "HRF") else: out_dir = args.out_dir - print('Making directories...') + print("Making directories...") mmcv.mkdir_or_exist(out_dir) - mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) - - print('Generating images...') + mmcv.mkdir_or_exist(osp.join(out_dir, "images")) + mmcv.mkdir_or_exist(osp.join(out_dir, "images", "training")) + mmcv.mkdir_or_exist(osp.join(out_dir, "images", "validation")) + mmcv.mkdir_or_exist(osp.join(out_dir, "annotations")) + mmcv.mkdir_or_exist(osp.join(out_dir, "annotations", "training")) + mmcv.mkdir_or_exist(osp.join(out_dir, "annotations", "validation")) + + print("Generating images...") for now_path in images_path: with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: zip_file = zipfile.ZipFile(now_path) zip_file.extractall(tmp_dir) - assert len(os.listdir(tmp_dir)) == HRF_LEN, \ - 'len(os.listdir(tmp_dir)) != {}'.format(HRF_LEN) + assert ( + len(os.listdir(tmp_dir)) == HRF_LEN + ), f"len(os.listdir(tmp_dir)) != {HRF_LEN}" for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: img = mmcv.imread(osp.join(tmp_dir, filename)) mmcv.imwrite( img, - osp.join(out_dir, 'images', 'training', - osp.splitext(filename)[0] + '.png')) + osp.join( + out_dir, + "images", + "training", + osp.splitext(filename)[0] + ".png", + ), + ) for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: img = mmcv.imread(osp.join(tmp_dir, filename)) mmcv.imwrite( img, - osp.join(out_dir, 'images', 'validation', - osp.splitext(filename)[0] + '.png')) - - print('Generating annotations...') + osp.join( + out_dir, + "images", + "validation", + osp.splitext(filename)[0] + ".png", + ), + ) + + print("Generating annotations...") for now_path in annotations_path: with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: zip_file = zipfile.ZipFile(now_path) zip_file.extractall(tmp_dir) - assert len(os.listdir(tmp_dir)) == HRF_LEN, \ - 'len(os.listdir(tmp_dir)) != {}'.format(HRF_LEN) + assert ( + len(os.listdir(tmp_dir)) == HRF_LEN + ), f"len(os.listdir(tmp_dir)) != {HRF_LEN}" for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: img = mmcv.imread(osp.join(tmp_dir, filename)) @@ -95,17 +114,27 @@ def main(): # else 0' mmcv.imwrite( img[:, :, 0] // 128, - osp.join(out_dir, 'annotations', 'training', - osp.splitext(filename)[0] + '.png')) + osp.join( + out_dir, + "annotations", + "training", + osp.splitext(filename)[0] + ".png", + ), + ) for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: img = mmcv.imread(osp.join(tmp_dir, filename)) mmcv.imwrite( img[:, :, 0] // 128, - osp.join(out_dir, 'annotations', 'validation', - osp.splitext(filename)[0] + '.png')) + osp.join( + out_dir, + "annotations", + "validation", + osp.splitext(filename)[0] + ".png", + ), + ) - print('Done!') + print("Done!") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/convert_datasets/isaid.py b/mmsegmentation/tools/convert_datasets/isaid.py index 314fb89..d089a41 100644 --- a/mmsegmentation/tools/convert_datasets/isaid.py +++ b/mmsegmentation/tools/convert_datasets/isaid.py @@ -11,25 +11,24 @@ import numpy as np from PIL import Image -iSAID_palette = \ - { - 0: (0, 0, 0), - 1: (0, 0, 63), - 2: (0, 63, 63), - 3: (0, 63, 0), - 4: (0, 63, 127), - 5: (0, 63, 191), - 6: (0, 63, 255), - 7: (0, 127, 63), - 8: (0, 127, 127), - 9: (0, 0, 127), - 10: (0, 0, 191), - 11: (0, 0, 255), - 12: (0, 191, 127), - 13: (0, 127, 191), - 14: (0, 127, 255), - 15: (0, 100, 155) - } +iSAID_palette = { + 0: (0, 0, 0), + 1: (0, 0, 63), + 2: (0, 63, 63), + 3: (0, 63, 0), + 4: (0, 63, 127), + 5: (0, 63, 191), + 6: (0, 63, 255), + 7: (0, 127, 63), + 8: (0, 127, 127), + 9: (0, 0, 127), + 10: (0, 0, 191), + 11: (0, 0, 255), + 12: (0, 191, 127), + 13: (0, 127, 191), + 14: (0, 127, 255), + 15: (0, 100, 155), +} iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()} @@ -46,24 +45,21 @@ def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette): def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap): - img = np.asarray(Image.open(src_path).convert('RGB')) + img = np.asarray(Image.open(src_path).convert("RGB")) img_H, img_W, _ = img.shape if img_H < patch_H and img_W > patch_W: - img = mmcv.impad(img, shape=(patch_H, img_W), pad_val=0) img_H, img_W, _ = img.shape elif img_H > patch_H and img_W < patch_W: - img = mmcv.impad(img, shape=(img_H, patch_W), pad_val=0) img_H, img_W, _ = img.shape elif img_H < patch_H and img_W < patch_W: - img = mmcv.impad(img, shape=(patch_H, patch_W), pad_val=0) img_H, img_W, _ = img.shape @@ -85,33 +81,39 @@ def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap): img_patch = img[y_str:y_end, x_str:x_end, :] img_patch = Image.fromarray(img_patch.astype(np.uint8)) - image = osp.basename(src_path).split('.')[0] + '_' + str( - y_str) + '_' + str(y_end) + '_' + str(x_str) + '_' + str( - x_end) + '.png' + image = ( + osp.basename(src_path).split(".")[0] + + "_" + + str(y_str) + + "_" + + str(y_end) + + "_" + + str(x_str) + + "_" + + str(x_end) + + ".png" + ) # print(image) - save_path_image = osp.join(out_dir, 'img_dir', mode, str(image)) + save_path_image = osp.join(out_dir, "img_dir", mode, str(image)) img_patch.save(save_path_image) def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap): - label = mmcv.imread(src_path, channel_order='rgb') + label = mmcv.imread(src_path, channel_order="rgb") label = iSAID_convert_from_color(label) img_H, img_W = label.shape if img_H < patch_H and img_W > patch_W: - label = mmcv.impad(label, shape=(patch_H, img_W), pad_val=255) img_H = patch_H elif img_H > patch_H and img_W < patch_W: - label = mmcv.impad(label, shape=(img_H, patch_W), pad_val=255) img_W = patch_W elif img_H < patch_H and img_W < patch_W: - label = mmcv.impad(label, shape=(patch_H, patch_W), pad_val=255) img_H = patch_H @@ -133,33 +135,42 @@ def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap): y_end = img_H lab_patch = label[y_str:y_end, x_str:x_end] - lab_patch = Image.fromarray(lab_patch.astype(np.uint8), mode='P') - - image = osp.basename(src_path).split('.')[0].split( - '_')[0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str( - x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png' - lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image))) + lab_patch = Image.fromarray(lab_patch.astype(np.uint8), mode="P") + + image = ( + osp.basename(src_path).split(".")[0].split("_")[0] + + "_" + + str(y_str) + + "_" + + str(y_end) + + "_" + + str(x_str) + + "_" + + str(x_end) + + "_instance_color_RGB" + + ".png" + ) + lab_patch.save(osp.join(out_dir, "ann_dir", mode, str(image))) def parse_args(): parser = argparse.ArgumentParser( - description='Convert iSAID dataset to mmsegmentation format') - parser.add_argument('dataset_path', help='iSAID folder path') - parser.add_argument('--tmp_dir', help='path of the temporary directory') - parser.add_argument('-o', '--out_dir', help='output path') + description="Convert iSAID dataset to mmsegmentation format" + ) + parser.add_argument("dataset_path", help="iSAID folder path") + parser.add_argument("--tmp_dir", help="path of the temporary directory") + parser.add_argument("-o", "--out_dir", help="output path") parser.add_argument( - '--patch_width', - default=896, - type=int, - help='Width of the cropped image patch') + "--patch_width", default=896, type=int, help="Width of the cropped image patch" + ) parser.add_argument( - '--patch_height', + "--patch_height", default=896, type=int, - help='Height of the cropped image patch') - parser.add_argument( - '--overlap_area', default=384, type=int, help='Overlap area') + help="Height of the cropped image patch", + ) + parser.add_argument("--overlap_area", default=384, type=int, help="Overlap area") args = parser.parse_args() return args @@ -173,73 +184,79 @@ def main(): overlap = args.overlap_area # overlap area if args.out_dir is None: - out_dir = osp.join('data', 'iSAID') + out_dir = osp.join("data", "iSAID") else: out_dir = args.out_dir - print('Making directories...') - mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test')) - - mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test')) - - assert os.path.exists(os.path.join(dataset_path, 'train')), \ - 'train is not in {}'.format(dataset_path) - assert os.path.exists(os.path.join(dataset_path, 'val')), \ - 'val is not in {}'.format(dataset_path) - assert os.path.exists(os.path.join(dataset_path, 'test')), \ - 'test is not in {}'.format(dataset_path) + print("Making directories...") + mmcv.mkdir_or_exist(osp.join(out_dir, "img_dir", "train")) + mmcv.mkdir_or_exist(osp.join(out_dir, "img_dir", "val")) + mmcv.mkdir_or_exist(osp.join(out_dir, "img_dir", "test")) + + mmcv.mkdir_or_exist(osp.join(out_dir, "ann_dir", "train")) + mmcv.mkdir_or_exist(osp.join(out_dir, "ann_dir", "val")) + mmcv.mkdir_or_exist(osp.join(out_dir, "ann_dir", "test")) + + assert os.path.exists( + os.path.join(dataset_path, "train") + ), f"train is not in {dataset_path}" + assert os.path.exists( + os.path.join(dataset_path, "val") + ), f"val is not in {dataset_path}" + assert os.path.exists( + os.path.join(dataset_path, "test") + ), f"test is not in {dataset_path}" with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: - for dataset_mode in ['train', 'val', 'test']: - + for dataset_mode in ["train", "val", "test"]: # for dataset_mode in [ 'test']: - print('Extracting {}ing.zip...'.format(dataset_mode)) + print(f"Extracting {dataset_mode}ing.zip...") img_zipp_list = glob.glob( - os.path.join(dataset_path, dataset_mode, 'images', '*.zip')) - print('Find the data', img_zipp_list) + os.path.join(dataset_path, dataset_mode, "images", "*.zip") + ) + print("Find the data", img_zipp_list) for img_zipp in img_zipp_list: zip_file = zipfile.ZipFile(img_zipp) - zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img')) + zip_file.extractall(os.path.join(tmp_dir, dataset_mode, "img")) src_path_list = glob.glob( - os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png')) + os.path.join(tmp_dir, dataset_mode, "img", "images", "*.png") + ) src_prog_bar = mmcv.ProgressBar(len(src_path_list)) for i, img_path in enumerate(src_path_list): - if dataset_mode != 'test': - slide_crop_image(img_path, out_dir, dataset_mode, patch_H, - patch_W, overlap) + if dataset_mode != "test": + slide_crop_image( + img_path, out_dir, dataset_mode, patch_H, patch_W, overlap + ) else: - shutil.move(img_path, - os.path.join(out_dir, 'img_dir', dataset_mode)) + shutil.move( + img_path, os.path.join(out_dir, "img_dir", dataset_mode) + ) src_prog_bar.update() - if dataset_mode != 'test': + if dataset_mode != "test": label_zipp_list = glob.glob( - os.path.join(dataset_path, dataset_mode, 'Semantic_masks', - '*.zip')) + os.path.join(dataset_path, dataset_mode, "Semantic_masks", "*.zip") + ) for label_zipp in label_zipp_list: zip_file = zipfile.ZipFile(label_zipp) - zip_file.extractall( - os.path.join(tmp_dir, dataset_mode, 'lab')) + zip_file.extractall(os.path.join(tmp_dir, dataset_mode, "lab")) lab_path_list = glob.glob( - os.path.join(tmp_dir, dataset_mode, 'lab', 'images', - '*.png')) + os.path.join(tmp_dir, dataset_mode, "lab", "images", "*.png") + ) lab_prog_bar = mmcv.ProgressBar(len(lab_path_list)) for i, lab_path in enumerate(lab_path_list): - slide_crop_label(lab_path, out_dir, dataset_mode, patch_H, - patch_W, overlap) + slide_crop_label( + lab_path, out_dir, dataset_mode, patch_H, patch_W, overlap + ) lab_prog_bar.update() - print('Removing the temporary files...') + print("Removing the temporary files...") - print('Done!') + print("Done!") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/convert_datasets/loveda.py b/mmsegmentation/tools/convert_datasets/loveda.py index 3a06268..59760eb 100644 --- a/mmsegmentation/tools/convert_datasets/loveda.py +++ b/mmsegmentation/tools/convert_datasets/loveda.py @@ -11,10 +11,11 @@ def parse_args(): parser = argparse.ArgumentParser( - description='Convert LoveDA dataset to mmsegmentation format') - parser.add_argument('dataset_path', help='LoveDA folder path') - parser.add_argument('--tmp_dir', help='path of the temporary directory') - parser.add_argument('-o', '--out_dir', help='output path') + description="Convert LoveDA dataset to mmsegmentation format" + ) + parser.add_argument("dataset_path", help="LoveDA folder path") + parser.add_argument("--tmp_dir", help="path of the temporary directory") + parser.add_argument("-o", "--out_dir", help="output path") args = parser.parse_args() return args @@ -23,51 +24,48 @@ def main(): args = parse_args() dataset_path = args.dataset_path if args.out_dir is None: - out_dir = osp.join('data', 'loveDA') + out_dir = osp.join("data", "loveDA") else: out_dir = args.out_dir - print('Making directories...') + print("Making directories...") mmcv.mkdir_or_exist(out_dir) - mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) + mmcv.mkdir_or_exist(osp.join(out_dir, "img_dir")) + mmcv.mkdir_or_exist(osp.join(out_dir, "img_dir", "train")) + mmcv.mkdir_or_exist(osp.join(out_dir, "img_dir", "val")) + mmcv.mkdir_or_exist(osp.join(out_dir, "img_dir", "test")) + mmcv.mkdir_or_exist(osp.join(out_dir, "ann_dir")) + mmcv.mkdir_or_exist(osp.join(out_dir, "ann_dir", "train")) + mmcv.mkdir_or_exist(osp.join(out_dir, "ann_dir", "val")) - assert 'Train.zip' in os.listdir(dataset_path), \ - 'Train.zip is not in {}'.format(dataset_path) - assert 'Val.zip' in os.listdir(dataset_path), \ - 'Val.zip is not in {}'.format(dataset_path) - assert 'Test.zip' in os.listdir(dataset_path), \ - 'Test.zip is not in {}'.format(dataset_path) + assert "Train.zip" in os.listdir( + dataset_path + ), f"Train.zip is not in {dataset_path}" + assert "Val.zip" in os.listdir(dataset_path), f"Val.zip is not in {dataset_path}" + assert "Test.zip" in os.listdir(dataset_path), f"Test.zip is not in {dataset_path}" with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: - for dataset in ['Train', 'Val', 'Test']: - zip_file = zipfile.ZipFile( - os.path.join(dataset_path, dataset + '.zip')) + for dataset in ["Train", "Val", "Test"]: + zip_file = zipfile.ZipFile(os.path.join(dataset_path, dataset + ".zip")) zip_file.extractall(tmp_dir) data_type = dataset.lower() - for location in ['Rural', 'Urban']: - for image_type in ['images_png', 'masks_png']: - if image_type == 'images_png': - dst = osp.join(out_dir, 'img_dir', data_type) + for location in ["Rural", "Urban"]: + for image_type in ["images_png", "masks_png"]: + if image_type == "images_png": + dst = osp.join(out_dir, "img_dir", data_type) else: - dst = osp.join(out_dir, 'ann_dir', data_type) - if dataset == 'Test' and image_type == 'masks_png': + dst = osp.join(out_dir, "ann_dir", data_type) + if dataset == "Test" and image_type == "masks_png": continue else: - src_dir = osp.join(tmp_dir, dataset, location, - image_type) + src_dir = osp.join(tmp_dir, dataset, location, image_type) src_lst = os.listdir(src_dir) for file in src_lst: shutil.move(osp.join(src_dir, file), dst) - print('Removing the temporary files...') + print("Removing the temporary files...") - print('Done!') + print("Done!") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/convert_datasets/pascal_context.py b/mmsegmentation/tools/convert_datasets/pascal_context.py index 03b79d5..9a3074d 100644 --- a/mmsegmentation/tools/convert_datasets/pascal_context.py +++ b/mmsegmentation/tools/convert_datasets/pascal_context.py @@ -9,38 +9,98 @@ from PIL import Image _mapping = np.sort( - np.array([ - 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284, - 158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59, - 440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355, - 85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115 - ])) -_key = np.array(range(len(_mapping))).astype('uint8') + np.array( + [ + 0, + 2, + 259, + 260, + 415, + 324, + 9, + 258, + 144, + 18, + 19, + 22, + 23, + 397, + 25, + 284, + 158, + 159, + 416, + 33, + 162, + 420, + 454, + 295, + 296, + 427, + 44, + 45, + 46, + 308, + 59, + 440, + 445, + 31, + 232, + 65, + 354, + 424, + 68, + 326, + 72, + 458, + 34, + 207, + 80, + 355, + 85, + 347, + 220, + 349, + 360, + 98, + 187, + 104, + 105, + 366, + 189, + 368, + 113, + 115, + ] + ) +) +_key = np.array(range(len(_mapping))).astype("uint8") def generate_labels(img_id, detail, out_dir): - def _class_to_index(mask, _mapping, _key): # assert the values values = np.unique(mask) for i in range(len(values)): - assert (values[i] in _mapping) + assert values[i] in _mapping index = np.digitize(mask.ravel(), _mapping, right=True) return _key[index].reshape(mask.shape) mask = Image.fromarray( - _class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key)) - filename = img_id['file_name'] - mask.save(osp.join(out_dir, filename.replace('jpg', 'png'))) + _class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key) + ) + filename = img_id["file_name"] + mask.save(osp.join(out_dir, filename.replace("jpg", "png"))) return osp.splitext(osp.basename(filename))[0] def parse_args(): parser = argparse.ArgumentParser( - description='Convert PASCAL VOC annotations to mmsegmentation format') - parser.add_argument('devkit_path', help='pascal voc devkit path') - parser.add_argument('json_path', help='annoation json filepath') - parser.add_argument('-o', '--out_dir', help='output path') + description="Convert PASCAL VOC annotations to mmsegmentation format" + ) + parser.add_argument("devkit_path", help="pascal voc devkit path") + parser.add_argument("json_path", help="annoation json filepath") + parser.add_argument("-o", "--out_dir", help="output path") args = parser.parse_args() return args @@ -49,39 +109,39 @@ def main(): args = parse_args() devkit_path = args.devkit_path if args.out_dir is None: - out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext') + out_dir = osp.join(devkit_path, "VOC2010", "SegmentationClassContext") else: out_dir = args.out_dir json_path = args.json_path mmcv.mkdir_or_exist(out_dir) - img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages') + img_dir = osp.join(devkit_path, "VOC2010", "JPEGImages") - train_detail = Detail(json_path, img_dir, 'train') + train_detail = Detail(json_path, img_dir, "train") train_ids = train_detail.getImgs() - val_detail = Detail(json_path, img_dir, 'val') + val_detail = Detail(json_path, img_dir, "val") val_ids = val_detail.getImgs() - mmcv.mkdir_or_exist( - osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext')) + mmcv.mkdir_or_exist(osp.join(devkit_path, "VOC2010/ImageSets/SegmentationContext")) train_list = mmcv.track_progress( - partial(generate_labels, detail=train_detail, out_dir=out_dir), - train_ids) + partial(generate_labels, detail=train_detail, out_dir=out_dir), train_ids + ) with open( - osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', - 'train.txt'), 'w') as f: - f.writelines(line + '\n' for line in sorted(train_list)) + osp.join(devkit_path, "VOC2010/ImageSets/SegmentationContext", "train.txt"), "w" + ) as f: + f.writelines(line + "\n" for line in sorted(train_list)) val_list = mmcv.track_progress( - partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids) + partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids + ) with open( - osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', - 'val.txt'), 'w') as f: - f.writelines(line + '\n' for line in sorted(val_list)) + osp.join(devkit_path, "VOC2010/ImageSets/SegmentationContext", "val.txt"), "w" + ) as f: + f.writelines(line + "\n" for line in sorted(val_list)) - print('Done!') + print("Done!") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/convert_datasets/potsdam.py b/mmsegmentation/tools/convert_datasets/potsdam.py index 87e67d5..14804ec 100644 --- a/mmsegmentation/tools/convert_datasets/potsdam.py +++ b/mmsegmentation/tools/convert_datasets/potsdam.py @@ -13,20 +13,23 @@ def parse_args(): parser = argparse.ArgumentParser( - description='Convert potsdam dataset to mmsegmentation format') - parser.add_argument('dataset_path', help='potsdam folder path') - parser.add_argument('--tmp_dir', help='path of the temporary directory') - parser.add_argument('-o', '--out_dir', help='output path') + description="Convert potsdam dataset to mmsegmentation format" + ) + parser.add_argument("dataset_path", help="potsdam folder path") + parser.add_argument("--tmp_dir", help="path of the temporary directory") + parser.add_argument("-o", "--out_dir", help="output path") parser.add_argument( - '--clip_size', + "--clip_size", type=int, - help='clipped size of image after preparation', - default=512) + help="clipped size of image after preparation", + default=512, + ) parser.add_argument( - '--stride_size', + "--stride_size", type=int, - help='stride of clipping original images', - default=256) + help="stride of clipping original images", + default=256, + ) args = parser.parse_args() return args @@ -44,14 +47,16 @@ def clip_big_image(image_path, clip_save_dir, args, to_label=False): clip_size = args.clip_size stride_size = args.stride_size - num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil( - (h - clip_size) / - stride_size) * stride_size + clip_size >= h else math.ceil( - (h - clip_size) / stride_size) + 1 - num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil( - (w - clip_size) / - stride_size) * stride_size + clip_size >= w else math.ceil( - (w - clip_size) / stride_size) + 1 + num_rows = ( + math.ceil((h - clip_size) / stride_size) + if math.ceil((h - clip_size) / stride_size) * stride_size + clip_size >= h + else math.ceil((h - clip_size) / stride_size) + 1 + ) + num_cols = ( + math.ceil((w - clip_size) / stride_size) + if math.ceil((w - clip_size) / stride_size) * stride_size + clip_size >= w + else math.ceil((w - clip_size) / stride_size) + 1 + ) x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1)) xmin = x * clip_size @@ -59,99 +64,145 @@ def clip_big_image(image_path, clip_save_dir, args, to_label=False): xmin = xmin.ravel() ymin = ymin.ravel() - xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size, - np.zeros_like(xmin)) - ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size, - np.zeros_like(ymin)) - boxes = np.stack([ - xmin + xmin_offset, ymin + ymin_offset, - np.minimum(xmin + clip_size, w), - np.minimum(ymin + clip_size, h) - ], - axis=1) + xmin_offset = np.where( + xmin + clip_size > w, w - xmin - clip_size, np.zeros_like(xmin) + ) + ymin_offset = np.where( + ymin + clip_size > h, h - ymin - clip_size, np.zeros_like(ymin) + ) + boxes = np.stack( + [ + xmin + xmin_offset, + ymin + ymin_offset, + np.minimum(xmin + clip_size, w), + np.minimum(ymin + clip_size, h), + ], + axis=1, + ) if to_label: - color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0], - [255, 255, 0], [0, 255, 0], [0, 255, 255], - [0, 0, 255]]) - flatten_v = np.matmul( - image.reshape(-1, c), - np.array([2, 3, 4]).reshape(3, 1)) + color_map = np.array( + [ + [0, 0, 0], + [255, 255, 255], + [255, 0, 0], + [255, 255, 0], + [0, 255, 0], + [0, 255, 255], + [0, 0, 255], + ] + ) + flatten_v = np.matmul(image.reshape(-1, c), np.array([2, 3, 4]).reshape(3, 1)) out = np.zeros_like(flatten_v) for idx, class_color in enumerate(color_map): - value_idx = np.matmul(class_color, - np.array([2, 3, 4]).reshape(3, 1)) + value_idx = np.matmul(class_color, np.array([2, 3, 4]).reshape(3, 1)) out[flatten_v == value_idx] = idx image = out.reshape(h, w) for box in boxes: start_x, start_y, end_x, end_y = box - clipped_image = image[start_y:end_y, - start_x:end_x] if to_label else image[ - start_y:end_y, start_x:end_x, :] - idx_i, idx_j = osp.basename(image_path).split('_')[2:4] + clipped_image = ( + image[start_y:end_y, start_x:end_x] + if to_label + else image[start_y:end_y, start_x:end_x, :] + ) + idx_i, idx_j = osp.basename(image_path).split("_")[2:4] mmcv.imwrite( clipped_image.astype(np.uint8), osp.join( clip_save_dir, - f'{idx_i}_{idx_j}_{start_x}_{start_y}_{end_x}_{end_y}.png')) + f"{idx_i}_{idx_j}_{start_x}_{start_y}_{end_x}_{end_y}.png", + ), + ) def main(): args = parse_args() splits = { - 'train': [ - '2_10', '2_11', '2_12', '3_10', '3_11', '3_12', '4_10', '4_11', - '4_12', '5_10', '5_11', '5_12', '6_10', '6_11', '6_12', '6_7', - '6_8', '6_9', '7_10', '7_11', '7_12', '7_7', '7_8', '7_9' + "train": [ + "2_10", + "2_11", + "2_12", + "3_10", + "3_11", + "3_12", + "4_10", + "4_11", + "4_12", + "5_10", + "5_11", + "5_12", + "6_10", + "6_11", + "6_12", + "6_7", + "6_8", + "6_9", + "7_10", + "7_11", + "7_12", + "7_7", + "7_8", + "7_9", + ], + "val": [ + "5_15", + "6_15", + "6_13", + "3_13", + "4_14", + "6_14", + "5_14", + "2_13", + "4_15", + "2_14", + "5_13", + "4_13", + "3_14", + "7_13", ], - 'val': [ - '5_15', '6_15', '6_13', '3_13', '4_14', '6_14', '5_14', '2_13', - '4_15', '2_14', '5_13', '4_13', '3_14', '7_13' - ] } dataset_path = args.dataset_path if args.out_dir is None: - out_dir = osp.join('data', 'potsdam') + out_dir = osp.join("data", "potsdam") else: out_dir = args.out_dir - print('Making directories...') - mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) + print("Making directories...") + mmcv.mkdir_or_exist(osp.join(out_dir, "img_dir", "train")) + mmcv.mkdir_or_exist(osp.join(out_dir, "img_dir", "val")) + mmcv.mkdir_or_exist(osp.join(out_dir, "ann_dir", "train")) + mmcv.mkdir_or_exist(osp.join(out_dir, "ann_dir", "val")) - zipp_list = glob.glob(os.path.join(dataset_path, '*.zip')) - print('Find the data', zipp_list) + zipp_list = glob.glob(os.path.join(dataset_path, "*.zip")) + print("Find the data", zipp_list) for zipp in zipp_list: with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: zip_file = zipfile.ZipFile(zipp) zip_file.extractall(tmp_dir) - src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif')) + src_path_list = glob.glob(os.path.join(tmp_dir, "*.tif")) if not len(src_path_list): sub_tmp_dir = os.path.join(tmp_dir, os.listdir(tmp_dir)[0]) - src_path_list = glob.glob(os.path.join(sub_tmp_dir, '*.tif')) + src_path_list = glob.glob(os.path.join(sub_tmp_dir, "*.tif")) prog_bar = mmcv.ProgressBar(len(src_path_list)) for i, src_path in enumerate(src_path_list): - idx_i, idx_j = osp.basename(src_path).split('_')[2:4] - data_type = 'train' if f'{idx_i}_{idx_j}' in splits[ - 'train'] else 'val' - if 'label' in src_path: - dst_dir = osp.join(out_dir, 'ann_dir', data_type) + idx_i, idx_j = osp.basename(src_path).split("_")[2:4] + data_type = "train" if f"{idx_i}_{idx_j}" in splits["train"] else "val" + if "label" in src_path: + dst_dir = osp.join(out_dir, "ann_dir", data_type) clip_big_image(src_path, dst_dir, args, to_label=True) else: - dst_dir = osp.join(out_dir, 'img_dir', data_type) + dst_dir = osp.join(out_dir, "img_dir", data_type) clip_big_image(src_path, dst_dir, args, to_label=False) prog_bar.update() - print('Removing the temporary files...') + print("Removing the temporary files...") - print('Done!') + print("Done!") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/convert_datasets/stare.py b/mmsegmentation/tools/convert_datasets/stare.py index 29b78c0..d6b3ba2 100644 --- a/mmsegmentation/tools/convert_datasets/stare.py +++ b/mmsegmentation/tools/convert_datasets/stare.py @@ -14,19 +14,20 @@ def un_gz(src, dst): g_file = gzip.GzipFile(src) - with open(dst, 'wb+') as f: + with open(dst, "wb+") as f: f.write(g_file.read()) g_file.close() def parse_args(): parser = argparse.ArgumentParser( - description='Convert STARE dataset to mmsegmentation format') - parser.add_argument('image_path', help='the path of stare-images.tar') - parser.add_argument('labels_ah', help='the path of labels-ah.tar') - parser.add_argument('labels_vk', help='the path of labels-vk.tar') - parser.add_argument('--tmp_dir', help='path of the temporary directory') - parser.add_argument('-o', '--out_dir', help='output path') + description="Convert STARE dataset to mmsegmentation format" + ) + parser.add_argument("image_path", help="the path of stare-images.tar") + parser.add_argument("labels_ah", help="the path of labels-ah.tar") + parser.add_argument("labels_vk", help="the path of labels-vk.tar") + parser.add_argument("--tmp_dir", help="path of the temporary directory") + parser.add_argument("-o", "--out_dir", help="output path") args = parser.parse_args() return args @@ -37,72 +38,78 @@ def main(): labels_ah = args.labels_ah labels_vk = args.labels_vk if args.out_dir is None: - out_dir = osp.join('data', 'STARE') + out_dir = osp.join("data", "STARE") else: out_dir = args.out_dir - print('Making directories...') + print("Making directories...") mmcv.mkdir_or_exist(out_dir) - mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) + mmcv.mkdir_or_exist(osp.join(out_dir, "images")) + mmcv.mkdir_or_exist(osp.join(out_dir, "images", "training")) + mmcv.mkdir_or_exist(osp.join(out_dir, "images", "validation")) + mmcv.mkdir_or_exist(osp.join(out_dir, "annotations")) + mmcv.mkdir_or_exist(osp.join(out_dir, "annotations", "training")) + mmcv.mkdir_or_exist(osp.join(out_dir, "annotations", "validation")) with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: - mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz')) - mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files')) + mmcv.mkdir_or_exist(osp.join(tmp_dir, "gz")) + mmcv.mkdir_or_exist(osp.join(tmp_dir, "files")) - print('Extracting stare-images.tar...') + print("Extracting stare-images.tar...") with tarfile.open(image_path) as f: - f.extractall(osp.join(tmp_dir, 'gz')) + f.extractall(osp.join(tmp_dir, "gz")) - for filename in os.listdir(osp.join(tmp_dir, 'gz')): + for filename in os.listdir(osp.join(tmp_dir, "gz")): un_gz( - osp.join(tmp_dir, 'gz', filename), - osp.join(tmp_dir, 'files', - osp.splitext(filename)[0])) + osp.join(tmp_dir, "gz", filename), + osp.join(tmp_dir, "files", osp.splitext(filename)[0]), + ) - now_dir = osp.join(tmp_dir, 'files') + now_dir = osp.join(tmp_dir, "files") - assert len(os.listdir(now_dir)) == STARE_LEN, \ - 'len(os.listdir(now_dir)) != {}'.format(STARE_LEN) + assert ( + len(os.listdir(now_dir)) == STARE_LEN + ), f"len(os.listdir(now_dir)) != {STARE_LEN}" for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: img = mmcv.imread(osp.join(now_dir, filename)) mmcv.imwrite( img, - osp.join(out_dir, 'images', 'training', - osp.splitext(filename)[0] + '.png')) + osp.join( + out_dir, "images", "training", osp.splitext(filename)[0] + ".png" + ), + ) for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: img = mmcv.imread(osp.join(now_dir, filename)) mmcv.imwrite( img, - osp.join(out_dir, 'images', 'validation', - osp.splitext(filename)[0] + '.png')) + osp.join( + out_dir, "images", "validation", osp.splitext(filename)[0] + ".png" + ), + ) - print('Removing the temporary files...') + print("Removing the temporary files...") with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: - mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz')) - mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files')) + mmcv.mkdir_or_exist(osp.join(tmp_dir, "gz")) + mmcv.mkdir_or_exist(osp.join(tmp_dir, "files")) - print('Extracting labels-ah.tar...') + print("Extracting labels-ah.tar...") with tarfile.open(labels_ah) as f: - f.extractall(osp.join(tmp_dir, 'gz')) + f.extractall(osp.join(tmp_dir, "gz")) - for filename in os.listdir(osp.join(tmp_dir, 'gz')): + for filename in os.listdir(osp.join(tmp_dir, "gz")): un_gz( - osp.join(tmp_dir, 'gz', filename), - osp.join(tmp_dir, 'files', - osp.splitext(filename)[0])) + osp.join(tmp_dir, "gz", filename), + osp.join(tmp_dir, "files", osp.splitext(filename)[0]), + ) - now_dir = osp.join(tmp_dir, 'files') + now_dir = osp.join(tmp_dir, "files") - assert len(os.listdir(now_dir)) == STARE_LEN, \ - 'len(os.listdir(now_dir)) != {}'.format(STARE_LEN) + assert ( + len(os.listdir(now_dir)) == STARE_LEN + ), f"len(os.listdir(now_dir)) != {STARE_LEN}" for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: img = mmcv.imread(osp.join(now_dir, filename)) @@ -112,55 +119,76 @@ def main(): # 128 equivalent to '1 if value >= 128 else 0' mmcv.imwrite( img[:, :, 0] // 128, - osp.join(out_dir, 'annotations', 'training', - osp.splitext(filename)[0] + '.png')) + osp.join( + out_dir, + "annotations", + "training", + osp.splitext(filename)[0] + ".png", + ), + ) for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: img = mmcv.imread(osp.join(now_dir, filename)) mmcv.imwrite( img[:, :, 0] // 128, - osp.join(out_dir, 'annotations', 'validation', - osp.splitext(filename)[0] + '.png')) + osp.join( + out_dir, + "annotations", + "validation", + osp.splitext(filename)[0] + ".png", + ), + ) - print('Removing the temporary files...') + print("Removing the temporary files...") with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: - mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz')) - mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files')) + mmcv.mkdir_or_exist(osp.join(tmp_dir, "gz")) + mmcv.mkdir_or_exist(osp.join(tmp_dir, "files")) - print('Extracting labels-vk.tar...') + print("Extracting labels-vk.tar...") with tarfile.open(labels_vk) as f: - f.extractall(osp.join(tmp_dir, 'gz')) + f.extractall(osp.join(tmp_dir, "gz")) - for filename in os.listdir(osp.join(tmp_dir, 'gz')): + for filename in os.listdir(osp.join(tmp_dir, "gz")): un_gz( - osp.join(tmp_dir, 'gz', filename), - osp.join(tmp_dir, 'files', - osp.splitext(filename)[0])) + osp.join(tmp_dir, "gz", filename), + osp.join(tmp_dir, "files", osp.splitext(filename)[0]), + ) - now_dir = osp.join(tmp_dir, 'files') + now_dir = osp.join(tmp_dir, "files") - assert len(os.listdir(now_dir)) == STARE_LEN, \ - 'len(os.listdir(now_dir)) != {}'.format(STARE_LEN) + assert ( + len(os.listdir(now_dir)) == STARE_LEN + ), f"len(os.listdir(now_dir)) != {STARE_LEN}" for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: img = mmcv.imread(osp.join(now_dir, filename)) mmcv.imwrite( img[:, :, 0] // 128, - osp.join(out_dir, 'annotations', 'training', - osp.splitext(filename)[0] + '.png')) + osp.join( + out_dir, + "annotations", + "training", + osp.splitext(filename)[0] + ".png", + ), + ) for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: img = mmcv.imread(osp.join(now_dir, filename)) mmcv.imwrite( img[:, :, 0] // 128, - osp.join(out_dir, 'annotations', 'validation', - osp.splitext(filename)[0] + '.png')) + osp.join( + out_dir, + "annotations", + "validation", + osp.splitext(filename)[0] + ".png", + ), + ) - print('Removing the temporary files...') + print("Removing the temporary files...") - print('Done!') + print("Done!") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/convert_datasets/vaihingen.py b/mmsegmentation/tools/convert_datasets/vaihingen.py index b025ae5..83ea17d 100644 --- a/mmsegmentation/tools/convert_datasets/vaihingen.py +++ b/mmsegmentation/tools/convert_datasets/vaihingen.py @@ -13,20 +13,23 @@ def parse_args(): parser = argparse.ArgumentParser( - description='Convert vaihingen dataset to mmsegmentation format') - parser.add_argument('dataset_path', help='vaihingen folder path') - parser.add_argument('--tmp_dir', help='path of the temporary directory') - parser.add_argument('-o', '--out_dir', help='output path') + description="Convert vaihingen dataset to mmsegmentation format" + ) + parser.add_argument("dataset_path", help="vaihingen folder path") + parser.add_argument("--tmp_dir", help="path of the temporary directory") + parser.add_argument("-o", "--out_dir", help="output path") parser.add_argument( - '--clip_size', + "--clip_size", type=int, - help='clipped size of image after preparation', - default=512) + help="clipped size of image after preparation", + default=512, + ) parser.add_argument( - '--stride_size', + "--stride_size", type=int, - help='stride of clipping original images', - default=256) + help="stride of clipping original images", + default=256, + ) args = parser.parse_args() return args @@ -44,10 +47,16 @@ def clip_big_image(image_path, clip_save_dir, to_label=False): cs = args.clip_size ss = args.stride_size - num_rows = math.ceil((h - cs) / ss) if math.ceil( - (h - cs) / ss) * ss + cs >= h else math.ceil((h - cs) / ss) + 1 - num_cols = math.ceil((w - cs) / ss) if math.ceil( - (w - cs) / ss) * ss + cs >= w else math.ceil((w - cs) / ss) + 1 + num_rows = ( + math.ceil((h - cs) / ss) + if math.ceil((h - cs) / ss) * ss + cs >= h + else math.ceil((h - cs) / ss) + 1 + ) + num_cols = ( + math.ceil((w - cs) / ss) + if math.ceil((w - cs) / ss) * ss + cs >= w + else math.ceil((w - cs) / ss) + 1 + ) x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1)) xmin = x * cs @@ -57,99 +66,141 @@ def clip_big_image(image_path, clip_save_dir, to_label=False): ymin = ymin.ravel() xmin_offset = np.where(xmin + cs > w, w - xmin - cs, np.zeros_like(xmin)) ymin_offset = np.where(ymin + cs > h, h - ymin - cs, np.zeros_like(ymin)) - boxes = np.stack([ - xmin + xmin_offset, ymin + ymin_offset, - np.minimum(xmin + cs, w), - np.minimum(ymin + cs, h) - ], - axis=1) + boxes = np.stack( + [ + xmin + xmin_offset, + ymin + ymin_offset, + np.minimum(xmin + cs, w), + np.minimum(ymin + cs, h), + ], + axis=1, + ) if to_label: - color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0], - [255, 255, 0], [0, 255, 0], [0, 255, 255], - [0, 0, 255]]) - flatten_v = np.matmul( - image.reshape(-1, c), - np.array([2, 3, 4]).reshape(3, 1)) + color_map = np.array( + [ + [0, 0, 0], + [255, 255, 255], + [255, 0, 0], + [255, 255, 0], + [0, 255, 0], + [0, 255, 255], + [0, 0, 255], + ] + ) + flatten_v = np.matmul(image.reshape(-1, c), np.array([2, 3, 4]).reshape(3, 1)) out = np.zeros_like(flatten_v) for idx, class_color in enumerate(color_map): - value_idx = np.matmul(class_color, - np.array([2, 3, 4]).reshape(3, 1)) + value_idx = np.matmul(class_color, np.array([2, 3, 4]).reshape(3, 1)) out[flatten_v == value_idx] = idx image = out.reshape(h, w) for box in boxes: start_x, start_y, end_x, end_y = box - clipped_image = image[start_y:end_y, - start_x:end_x] if to_label else image[ - start_y:end_y, start_x:end_x, :] - area_idx = osp.basename(image_path).split('_')[3].strip('.tif') + clipped_image = ( + image[start_y:end_y, start_x:end_x] + if to_label + else image[start_y:end_y, start_x:end_x, :] + ) + area_idx = osp.basename(image_path).split("_")[3].strip(".tif") mmcv.imwrite( clipped_image.astype(np.uint8), - osp.join(clip_save_dir, - f'{area_idx}_{start_x}_{start_y}_{end_x}_{end_y}.png')) + osp.join( + clip_save_dir, f"{area_idx}_{start_x}_{start_y}_{end_x}_{end_y}.png" + ), + ) def main(): splits = { - 'train': [ - 'area1', 'area11', 'area13', 'area15', 'area17', 'area21', - 'area23', 'area26', 'area28', 'area3', 'area30', 'area32', - 'area34', 'area37', 'area5', 'area7' + "train": [ + "area1", + "area11", + "area13", + "area15", + "area17", + "area21", + "area23", + "area26", + "area28", + "area3", + "area30", + "area32", + "area34", + "area37", + "area5", + "area7", ], - 'val': [ - 'area6', 'area24', 'area35', 'area16', 'area14', 'area22', - 'area10', 'area4', 'area2', 'area20', 'area8', 'area31', 'area33', - 'area27', 'area38', 'area12', 'area29' + "val": [ + "area6", + "area24", + "area35", + "area16", + "area14", + "area22", + "area10", + "area4", + "area2", + "area20", + "area8", + "area31", + "area33", + "area27", + "area38", + "area12", + "area29", ], } dataset_path = args.dataset_path if args.out_dir is None: - out_dir = osp.join('data', 'vaihingen') + out_dir = osp.join("data", "vaihingen") else: out_dir = args.out_dir - print('Making directories...') - mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) - mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) + print("Making directories...") + mmcv.mkdir_or_exist(osp.join(out_dir, "img_dir", "train")) + mmcv.mkdir_or_exist(osp.join(out_dir, "img_dir", "val")) + mmcv.mkdir_or_exist(osp.join(out_dir, "ann_dir", "train")) + mmcv.mkdir_or_exist(osp.join(out_dir, "ann_dir", "val")) - zipp_list = glob.glob(os.path.join(dataset_path, '*.zip')) - print('Find the data', zipp_list) + zipp_list = glob.glob(os.path.join(dataset_path, "*.zip")) + print("Find the data", zipp_list) with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: for zipp in zipp_list: zip_file = zipfile.ZipFile(zipp) zip_file.extractall(tmp_dir) - src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif')) - if 'ISPRS_semantic_labeling_Vaihingen' in zipp: + src_path_list = glob.glob(os.path.join(tmp_dir, "*.tif")) + if "ISPRS_semantic_labeling_Vaihingen" in zipp: src_path_list = glob.glob( - os.path.join(os.path.join(tmp_dir, 'top'), '*.tif')) - if 'ISPRS_semantic_labeling_Vaihingen_ground_truth_eroded_COMPLETE' in zipp: # noqa - src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif')) + os.path.join(os.path.join(tmp_dir, "top"), "*.tif") + ) + if ( + "ISPRS_semantic_labeling_Vaihingen_ground_truth_eroded_COMPLETE" in zipp + ): # noqa + src_path_list = glob.glob(os.path.join(tmp_dir, "*.tif")) # delete unused area9 ground truth for area_ann in src_path_list: - if 'area9' in area_ann: + if "area9" in area_ann: src_path_list.remove(area_ann) prog_bar = mmcv.ProgressBar(len(src_path_list)) for i, src_path in enumerate(src_path_list): - area_idx = osp.basename(src_path).split('_')[3].strip('.tif') - data_type = 'train' if area_idx in splits['train'] else 'val' - if 'noBoundary' in src_path: - dst_dir = osp.join(out_dir, 'ann_dir', data_type) + area_idx = osp.basename(src_path).split("_")[3].strip(".tif") + data_type = "train" if area_idx in splits["train"] else "val" + if "noBoundary" in src_path: + dst_dir = osp.join(out_dir, "ann_dir", data_type) clip_big_image(src_path, dst_dir, to_label=True) else: - dst_dir = osp.join(out_dir, 'img_dir', data_type) + dst_dir = osp.join(out_dir, "img_dir", data_type) clip_big_image(src_path, dst_dir, to_label=False) prog_bar.update() - print('Removing the temporary files...') + print("Removing the temporary files...") - print('Done!') + print("Done!") -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() main() diff --git a/mmsegmentation/tools/convert_datasets/voc_aug.py b/mmsegmentation/tools/convert_datasets/voc_aug.py index 1d42c27..05342e2 100644 --- a/mmsegmentation/tools/convert_datasets/voc_aug.py +++ b/mmsegmentation/tools/convert_datasets/voc_aug.py @@ -13,9 +13,9 @@ def convert_mat(mat_file, in_dir, out_dir): data = loadmat(osp.join(in_dir, mat_file)) - mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8) - seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png')) - Image.fromarray(mask).save(seg_filename, 'PNG') + mask = data["GTcls"][0]["Segmentation"][0].astype(np.uint8) + seg_filename = osp.join(out_dir, mat_file.replace(".mat", ".png")) + Image.fromarray(mask).save(seg_filename, "PNG") def generate_aug_list(merged_list, excluded_list): @@ -24,12 +24,12 @@ def generate_aug_list(merged_list, excluded_list): def parse_args(): parser = argparse.ArgumentParser( - description='Convert PASCAL VOC annotations to mmsegmentation format') - parser.add_argument('devkit_path', help='pascal voc devkit path') - parser.add_argument('aug_path', help='pascal voc aug path') - parser.add_argument('-o', '--out_dir', help='output path') - parser.add_argument( - '--nproc', default=1, type=int, help='number of process') + description="Convert PASCAL VOC annotations to mmsegmentation format" + ) + parser.add_argument("devkit_path", help="pascal voc devkit path") + parser.add_argument("aug_path", help="pascal voc aug path") + parser.add_argument("-o", "--out_dir", help="output path") + parser.add_argument("--nproc", default=1, type=int, help="number of process") args = parser.parse_args() return args @@ -40,53 +40,50 @@ def main(): aug_path = args.aug_path nproc = args.nproc if args.out_dir is None: - out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug') + out_dir = osp.join(devkit_path, "VOC2012", "SegmentationClassAug") else: out_dir = args.out_dir mmcv.mkdir_or_exist(out_dir) - in_dir = osp.join(aug_path, 'dataset', 'cls') + in_dir = osp.join(aug_path, "dataset", "cls") mmcv.track_parallel_progress( partial(convert_mat, in_dir=in_dir, out_dir=out_dir), - list(mmcv.scandir(in_dir, suffix='.mat')), - nproc=nproc) + list(mmcv.scandir(in_dir, suffix=".mat")), + nproc=nproc, + ) full_aug_list = [] - with open(osp.join(aug_path, 'dataset', 'train.txt')) as f: + with open(osp.join(aug_path, "dataset", "train.txt")) as f: full_aug_list += [line.strip() for line in f] - with open(osp.join(aug_path, 'dataset', 'val.txt')) as f: + with open(osp.join(aug_path, "dataset", "val.txt")) as f: full_aug_list += [line.strip() for line in f] with open( - osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', - 'train.txt')) as f: + osp.join(devkit_path, "VOC2012/ImageSets/Segmentation", "train.txt") + ) as f: ori_train_list = [line.strip() for line in f] - with open( - osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', - 'val.txt')) as f: + with open(osp.join(devkit_path, "VOC2012/ImageSets/Segmentation", "val.txt")) as f: val_list = [line.strip() for line in f] - aug_train_list = generate_aug_list(ori_train_list + full_aug_list, - val_list) - assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format( - AUG_LEN) + aug_train_list = generate_aug_list(ori_train_list + full_aug_list, val_list) + assert len(aug_train_list) == AUG_LEN, "len(aug_train_list) != {}".format(AUG_LEN) with open( - osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', - 'trainaug.txt'), 'w') as f: - f.writelines(line + '\n' for line in aug_train_list) + osp.join(devkit_path, "VOC2012/ImageSets/Segmentation", "trainaug.txt"), "w" + ) as f: + f.writelines(line + "\n" for line in aug_train_list) aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list) - assert len(aug_list) == AUG_LEN - len( - ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN - - len(ori_train_list)) + assert len(aug_list) == AUG_LEN - len(ori_train_list), "len(aug_list) != {}".format( + AUG_LEN - len(ori_train_list) + ) with open( - osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'), - 'w') as f: - f.writelines(line + '\n' for line in aug_list) + osp.join(devkit_path, "VOC2012/ImageSets/Segmentation", "aug.txt"), "w" + ) as f: + f.writelines(line + "\n" for line in aug_list) - print('Done!') + print("Done!") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/deploy_test.py b/mmsegmentation/tools/deploy_test.py index eca5430..bb102eb 100644 --- a/mmsegmentation/tools/deploy_test.py +++ b/mmsegmentation/tools/deploy_test.py @@ -20,30 +20,32 @@ class ONNXRuntimeSegmentor(BaseSegmentor): - def __init__(self, onnx_file: str, cfg: Any, device_id: int): - super(ONNXRuntimeSegmentor, self).__init__() + super().__init__() import onnxruntime as ort # get the custom op path - ort_custom_op_path = '' + ort_custom_op_path = "" try: from mmcv.ops import get_onnxruntime_op_path + ort_custom_op_path = get_onnxruntime_op_path() except (ImportError, ModuleNotFoundError): - warnings.warn('If input model has custom op from mmcv, \ - you may have to build mmcv with ONNXRuntime from source.') + warnings.warn( + "If input model has custom op from mmcv, \ + you may have to build mmcv with ONNXRuntime from source." + ) session_options = ort.SessionOptions() # register custom op for onnxruntime if osp.exists(ort_custom_op_path): session_options.register_custom_ops_library(ort_custom_op_path) sess = ort.InferenceSession(onnx_file, session_options) - providers = ['CPUExecutionProvider'] + providers = ["CPUExecutionProvider"] options = [{}] - is_cuda_available = ort.get_device() == 'GPU' + is_cuda_available = ort.get_device() == "GPU" if is_cuda_available: - providers.insert(0, 'CUDAExecutionProvider') - options.insert(0, {'device_id': device_id}) + providers.insert(0, "CUDAExecutionProvider") + options.insert(0, {"device_id": device_id}) sess.set_providers(providers, options) @@ -58,58 +60,59 @@ def __init__(self, onnx_file: str, cfg: Any, device_id: int): self.is_cuda_available = is_cuda_available def extract_feat(self, imgs): - raise NotImplementedError('This method is not implemented.') + raise NotImplementedError("This method is not implemented.") def encode_decode(self, img, img_metas): - raise NotImplementedError('This method is not implemented.') + raise NotImplementedError("This method is not implemented.") def forward_train(self, imgs, img_metas, **kwargs): - raise NotImplementedError('This method is not implemented.') + raise NotImplementedError("This method is not implemented.") - def simple_test(self, img: torch.Tensor, img_meta: Iterable, - **kwargs) -> list: + def simple_test(self, img: torch.Tensor, img_meta: Iterable, **kwargs) -> list: if not self.is_cuda_available: img = img.detach().cpu() elif self.device_id >= 0: img = img.cuda(self.device_id) device_type = img.device.type self.io_binding.bind_input( - name='input', + name="input", device_type=device_type, device_id=self.device_id, element_type=np.float32, shape=img.shape, - buffer_ptr=img.data_ptr()) + buffer_ptr=img.data_ptr(), + ) self.sess.run_with_iobinding(self.io_binding) seg_pred = self.io_binding.copy_outputs_to_cpu()[0] # whole might support dynamic reshape - ori_shape = img_meta[0]['ori_shape'] - if not (ori_shape[0] == seg_pred.shape[-2] - and ori_shape[1] == seg_pred.shape[-1]): + ori_shape = img_meta[0]["ori_shape"] + if not ( + ori_shape[0] == seg_pred.shape[-2] and ori_shape[1] == seg_pred.shape[-1] + ): seg_pred = torch.from_numpy(seg_pred).float() - seg_pred = resize( - seg_pred, size=tuple(ori_shape[:2]), mode='nearest') + seg_pred = resize(seg_pred, size=tuple(ori_shape[:2]), mode="nearest") seg_pred = seg_pred.long().detach().cpu().numpy() seg_pred = seg_pred[0] seg_pred = list(seg_pred) return seg_pred def aug_test(self, imgs, img_metas, **kwargs): - raise NotImplementedError('This method is not implemented.') + raise NotImplementedError("This method is not implemented.") class TensorRTSegmentor(BaseSegmentor): - def __init__(self, trt_file: str, cfg: Any, device_id: int): - super(TensorRTSegmentor, self).__init__() + super().__init__() from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin + try: load_tensorrt_plugin() except (ImportError, ModuleNotFoundError): - warnings.warn('If input model has custom op from mmcv, \ - you may have to build mmcv with TensorRT from source.') - model = TRTWraper( - trt_file, input_names=['input'], output_names=['output']) + warnings.warn( + "If input model has custom op from mmcv, \ + you may have to build mmcv with TensorRT from source." + ) + model = TRTWraper(trt_file, input_names=["input"], output_names=["output"]) self.model = model self.device_id = device_id @@ -117,104 +120,111 @@ def __init__(self, trt_file: str, cfg: Any, device_id: int): self.test_mode = cfg.model.test_cfg.mode def extract_feat(self, imgs): - raise NotImplementedError('This method is not implemented.') + raise NotImplementedError("This method is not implemented.") def encode_decode(self, img, img_metas): - raise NotImplementedError('This method is not implemented.') + raise NotImplementedError("This method is not implemented.") def forward_train(self, imgs, img_metas, **kwargs): - raise NotImplementedError('This method is not implemented.') + raise NotImplementedError("This method is not implemented.") - def simple_test(self, img: torch.Tensor, img_meta: Iterable, - **kwargs) -> list: + def simple_test(self, img: torch.Tensor, img_meta: Iterable, **kwargs) -> list: with torch.cuda.device(self.device_id), torch.no_grad(): - seg_pred = self.model({'input': img})['output'] + seg_pred = self.model({"input": img})["output"] seg_pred = seg_pred.detach().cpu().numpy() # whole might support dynamic reshape - ori_shape = img_meta[0]['ori_shape'] - if not (ori_shape[0] == seg_pred.shape[-2] - and ori_shape[1] == seg_pred.shape[-1]): + ori_shape = img_meta[0]["ori_shape"] + if not ( + ori_shape[0] == seg_pred.shape[-2] and ori_shape[1] == seg_pred.shape[-1] + ): seg_pred = torch.from_numpy(seg_pred).float() - seg_pred = resize( - seg_pred, size=tuple(ori_shape[:2]), mode='nearest') + seg_pred = resize(seg_pred, size=tuple(ori_shape[:2]), mode="nearest") seg_pred = seg_pred.long().detach().cpu().numpy() seg_pred = seg_pred[0] seg_pred = list(seg_pred) return seg_pred def aug_test(self, imgs, img_metas, **kwargs): - raise NotImplementedError('This method is not implemented.') + raise NotImplementedError("This method is not implemented.") def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description='mmseg backend test (and eval)') - parser.add_argument('config', help='test config file path') - parser.add_argument('model', help='Input model file') + parser = argparse.ArgumentParser(description="mmseg backend test (and eval)") + parser.add_argument("config", help="test config file path") + parser.add_argument("model", help="Input model file") parser.add_argument( - '--backend', - help='Backend of the model.', - choices=['onnxruntime', 'tensorrt']) - parser.add_argument('--out', help='output result file in pickle format') + "--backend", help="Backend of the model.", choices=["onnxruntime", "tensorrt"] + ) + parser.add_argument("--out", help="output result file in pickle format") parser.add_argument( - '--format-only', - action='store_true', - help='Format the output results without perform evaluation. It is' - 'useful when you want to format the result to a specific format and ' - 'submit it to the test server') + "--format-only", + action="store_true", + help="Format the output results without perform evaluation. It is" + "useful when you want to format the result to a specific format and " + "submit it to the test server", + ) parser.add_argument( - '--eval', + "--eval", type=str, - nargs='+', + nargs="+", help='evaluation metrics, which depends on the dataset, e.g., "mIoU"' - ' for generic datasets, and "cityscapes" for Cityscapes') - parser.add_argument('--show', action='store_true', help='show results') + ' for generic datasets, and "cityscapes" for Cityscapes', + ) + parser.add_argument("--show", action="store_true", help="show results") parser.add_argument( - '--show-dir', help='directory where painted images will be saved') + "--show-dir", help="directory where painted images will be saved" + ) parser.add_argument( - '--options', - nargs='+', + "--options", + nargs="+", action=DictAction, help="--options is deprecated in favor of --cfg_options' and it will " - 'not be supported in version v0.22.0. Override some settings in the ' - 'used config, the key-value pair in xxx=yyy format will be merged ' - 'into config file. If the value to be overwritten is a list, it ' + "not be supported in version v0.22.0. Override some settings in the " + "used config, the key-value pair in xxx=yyy format will be merged " + "into config file. If the value to be overwritten is a list, it " 'should be like key="[a,b]" or key=a,b It also allows nested ' 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' - 'marks are necessary and that no white space is allowed.') + "marks are necessary and that no white space is allowed.", + ) parser.add_argument( - '--cfg-options', - nargs='+', + "--cfg-options", + nargs="+", action=DictAction, - help='override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into config file. If the value to ' + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' - 'Note that the quotation marks are necessary and that no white space ' - 'is allowed.') + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) parser.add_argument( - '--eval-options', - nargs='+', + "--eval-options", + nargs="+", action=DictAction, - help='custom options for evaluation') + help="custom options for evaluation", + ) parser.add_argument( - '--opacity', + "--opacity", type=float, default=0.5, - help='Opacity of painted segmentation map. In (0, 1] range.') - parser.add_argument('--local_rank', type=int, default=0) + help="Opacity of painted segmentation map. In (0, 1] range.", + ) + parser.add_argument("--local_rank", type=int, default=0) args = parser.parse_args() - if 'LOCAL_RANK' not in os.environ: - os.environ['LOCAL_RANK'] = str(args.local_rank) + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(args.local_rank) if args.options and args.cfg_options: raise ValueError( - '--options and --cfg-options cannot be both ' - 'specified, --options is deprecated in favor of --cfg-options. ' - '--options will not be supported in version v0.22.0.') + "--options and --cfg-options cannot be both " + "specified, --options is deprecated in favor of --cfg-options. " + "--options will not be supported in version v0.22.0." + ) if args.options: - warnings.warn('--options is deprecated in favor of --cfg-options. ' - '--options will not be supported in version v0.22.0.') + warnings.warn( + "--options is deprecated in favor of --cfg-options. " + "--options will not be supported in version v0.22.0." + ) args.cfg_options = args.options return args @@ -223,17 +233,17 @@ def parse_args() -> argparse.Namespace: def main(): args = parse_args() - assert args.out or args.eval or args.format_only or args.show \ - or args.show_dir, \ - ('Please specify at least one operation (save/eval/format/show the ' - 'results / save the results) with the argument "--out", "--eval"' - ', "--format-only", "--show" or "--show-dir"') + assert args.out or args.eval or args.format_only or args.show or args.show_dir, ( + "Please specify at least one operation (save/eval/format/show the " + 'results / save the results) with the argument "--out", "--eval"' + ', "--format-only", "--show" or "--show-dir"' + ) if args.eval and args.format_only: - raise ValueError('--eval and --format_only cannot be both specified') + raise ValueError("--eval and --format_only cannot be both specified") - if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): - raise ValueError('The output file must be a pkl file.') + if args.out is not None and not args.out.endswith((".pkl", ".pickle")): + raise ValueError("The output file must be a pkl file.") cfg = mmcv.Config.fromfile(args.config) if args.cfg_options is not None: @@ -252,14 +262,15 @@ def main(): samples_per_gpu=1, workers_per_gpu=cfg.data.workers_per_gpu, dist=distributed, - shuffle=False) + shuffle=False, + ) # load onnx config and meta cfg.model.train_cfg = None - if args.backend == 'onnxruntime': + if args.backend == "onnxruntime": model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0) - elif args.backend == 'tensorrt': + elif args.backend == "tensorrt": model = TensorRTSegmentor(args.model, cfg=cfg, device_id=0) model.CLASSES = dataset.CLASSES @@ -270,25 +281,27 @@ def main(): eval_kwargs = {} if args.eval_options is None else args.eval_options # Deprecated - efficient_test = eval_kwargs.get('efficient_test', False) + efficient_test = eval_kwargs.get("efficient_test", False) if efficient_test: warnings.warn( - '``efficient_test=True`` does not have effect in tools/test.py, ' - 'the evaluation and format results are CPU memory efficient by ' - 'default') + "``efficient_test=True`` does not have effect in tools/test.py, " + "the evaluation and format results are CPU memory efficient by " + "default" + ) - eval_on_format_results = ( - args.eval is not None and 'cityscapes' in args.eval) + eval_on_format_results = args.eval is not None and "cityscapes" in args.eval if eval_on_format_results: - assert len(args.eval) == 1, 'eval on format results is not ' \ - 'applicable for metrics other than ' \ - 'cityscapes' + assert len(args.eval) == 1, ( + "eval on format results is not " + "applicable for metrics other than " + "cityscapes" + ) if args.format_only or eval_on_format_results: - if 'imgfile_prefix' in eval_kwargs: - tmpdir = eval_kwargs['imgfile_prefix'] + if "imgfile_prefix" in eval_kwargs: + tmpdir = eval_kwargs["imgfile_prefix"] else: - tmpdir = '.format_cityscapes' - eval_kwargs.setdefault('imgfile_prefix', tmpdir) + tmpdir = ".format_cityscapes" + eval_kwargs.setdefault("imgfile_prefix", tmpdir) mmcv.mkdir_or_exist(tmpdir) else: tmpdir = None @@ -303,17 +316,19 @@ def main(): args.opacity, pre_eval=args.eval is not None and not eval_on_format_results, format_only=args.format_only or eval_on_format_results, - format_args=eval_kwargs) + format_args=eval_kwargs, + ) rank, _ = get_dist_info() if rank == 0: if args.out: warnings.warn( - 'The behavior of ``args.out`` has been changed since MMSeg ' - 'v0.16, the pickled outputs could be seg map as type of ' - 'np.array, pre-eval results or file paths for ' - '``dataset.format_results()``.') - print(f'\nwriting results to {args.out}') + "The behavior of ``args.out`` has been changed since MMSeg " + "v0.16, the pickled outputs could be seg map as type of " + "np.array, pre-eval results or file paths for " + "``dataset.format_results()``." + ) + print(f"\nwriting results to {args.out}") mmcv.dump(results, args.out) if args.eval: dataset.evaluate(results, args.eval, **eval_kwargs) @@ -322,17 +337,17 @@ def main(): shutil.rmtree(tmpdir) -if __name__ == '__main__': +if __name__ == "__main__": main() # Following strings of text style are from colorama package - bright_style, reset_style = '\x1b[1m', '\x1b[0m' - red_text, blue_text = '\x1b[31m', '\x1b[34m' - white_background = '\x1b[107m' + bright_style, reset_style = "\x1b[1m", "\x1b[0m" + red_text, blue_text = "\x1b[31m", "\x1b[34m" + white_background = "\x1b[107m" msg = white_background + bright_style + red_text - msg += 'DeprecationWarning: This tool will be deprecated in future. ' - msg += blue_text + 'Welcome to use the unified model deployment toolbox ' - msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy' + msg += "DeprecationWarning: This tool will be deprecated in future. " + msg += blue_text + "Welcome to use the unified model deployment toolbox " + msg += "MMDeploy: https://github.com/open-mmlab/mmdeploy" msg += reset_style warnings.warn(msg) diff --git a/mmsegmentation/tools/get_flops.py b/mmsegmentation/tools/get_flops.py index e30c36f..175fd61 100644 --- a/mmsegmentation/tools/get_flops.py +++ b/mmsegmentation/tools/get_flops.py @@ -8,53 +8,54 @@ def parse_args(): - parser = argparse.ArgumentParser( - description='Get the FLOPs of a segmentor') - parser.add_argument('config', help='train config file path') + parser = argparse.ArgumentParser(description="Get the FLOPs of a segmentor") + parser.add_argument("config", help="train config file path") parser.add_argument( - '--shape', - type=int, - nargs='+', - default=[2048, 1024], - help='input image size') + "--shape", type=int, nargs="+", default=[2048, 1024], help="input image size" + ) args = parser.parse_args() return args def main(): - args = parse_args() if len(args.shape) == 1: input_shape = (3, args.shape[0], args.shape[0]) elif len(args.shape) == 2: - input_shape = (3, ) + tuple(args.shape) + input_shape = (3,) + tuple(args.shape) else: - raise ValueError('invalid input shape') + raise ValueError("invalid input shape") cfg = Config.fromfile(args.config) cfg.model.pretrained = None model = build_segmentor( - cfg.model, - train_cfg=cfg.get('train_cfg'), - test_cfg=cfg.get('test_cfg')).cuda() + cfg.model, train_cfg=cfg.get("train_cfg"), test_cfg=cfg.get("test_cfg") + ).cuda() model.eval() - if hasattr(model, 'forward_dummy'): + if hasattr(model, "forward_dummy"): model.forward = model.forward_dummy else: raise NotImplementedError( - 'FLOPs counter is currently not currently supported with {}'. - format(model.__class__.__name__)) + "FLOPs counter is currently not currently supported with {}".format( + model.__class__.__name__ + ) + ) flops, params = get_model_complexity_info(model, input_shape) - split_line = '=' * 30 - print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format( - split_line, input_shape, flops, params)) - print('!!!Please be cautious if you use the results in papers. ' - 'You may need to check if all ops are supported and verify that the ' - 'flops computation is correct.') - - -if __name__ == '__main__': + split_line = "=" * 30 + print( + "{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}".format( + split_line, input_shape, flops, params + ) + ) + print( + "!!!Please be cautious if you use the results in papers. " + "You may need to check if all ops are supported and verify that the " + "flops computation is correct." + ) + + +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/inference.py b/mmsegmentation/tools/inference.py index 53c55fd..93ab0d2 100644 --- a/mmsegmentation/tools/inference.py +++ b/mmsegmentation/tools/inference.py @@ -7,7 +7,7 @@ import pandas as pd from mmcv import Config from mmcv.parallel import MMDataParallel -from mmcv.runner import load_checkpoint + from mmseg.apis import single_gpu_test from mmseg.datasets import build_dataloader, build_dataset from mmseg.models import build_segmentor @@ -36,7 +36,7 @@ def get_latest(work_dir: Path) -> Union[str, None]: if not latest_file.exists(): return None - with open(latest_file, "r", encoding="utf8") as f: + with open(latest_file, encoding="utf8") as f: path = f.read() return path @@ -47,7 +47,7 @@ def get_last_checkpoint(work_dir: Path) -> Union[str, None]: if not latest_checkpoint_file.exists(): return None - with open(latest_checkpoint_file, "r", encoding="utf8") as f: + with open(latest_checkpoint_file, encoding="utf8") as f: checkpoint_path = f.read() return checkpoint_path @@ -107,7 +107,7 @@ def main(): submission = pd.read_csv(SAMPLE_PATH, index_col=None) - with open(TEST_JSON_PATH, "r", encoding="utf8") as outfile: + with open(TEST_JSON_PATH, encoding="utf8") as outfile: datas = json.load(outfile) # PredictionString 대입 diff --git a/mmsegmentation/tools/model_converters/beit2mmseg.py b/mmsegmentation/tools/model_converters/beit2mmseg.py index 91b91fa..681f07e 100644 --- a/mmsegmentation/tools/model_converters/beit2mmseg.py +++ b/mmsegmentation/tools/model_converters/beit2mmseg.py @@ -12,17 +12,17 @@ def convert_beit(ckpt): new_ckpt = OrderedDict() for k, v in ckpt.items(): - if k.startswith('blocks'): - new_key = k.replace('blocks', 'layers') - if 'norm' in new_key: - new_key = new_key.replace('norm', 'ln') - elif 'mlp.fc1' in new_key: - new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0') - elif 'mlp.fc2' in new_key: - new_key = new_key.replace('mlp.fc2', 'ffn.layers.1') + if k.startswith("blocks"): + new_key = k.replace("blocks", "layers") + if "norm" in new_key: + new_key = new_key.replace("norm", "ln") + elif "mlp.fc1" in new_key: + new_key = new_key.replace("mlp.fc1", "ffn.layers.0.0") + elif "mlp.fc2" in new_key: + new_key = new_key.replace("mlp.fc2", "ffn.layers.1") new_ckpt[new_key] = v - elif k.startswith('patch_embed'): - new_key = k.replace('patch_embed.proj', 'patch_embed.projection') + elif k.startswith("patch_embed"): + new_key = k.replace("patch_embed.proj", "patch_embed.projection") new_ckpt[new_key] = v else: new_key = k @@ -33,18 +33,19 @@ def convert_beit(ckpt): def main(): parser = argparse.ArgumentParser( - description='Convert keys in official pretrained beit models to' - 'MMSegmentation style.') - parser.add_argument('src', help='src model path or url') + description="Convert keys in official pretrained beit models to" + "MMSegmentation style." + ) + parser.add_argument("src", help="src model path or url") # The dst path must be a full path of the new checkpoint. - parser.add_argument('dst', help='save path') + parser.add_argument("dst", help="save path") args = parser.parse_args() - checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - elif 'model' in checkpoint: - state_dict = checkpoint['model'] + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif "model" in checkpoint: + state_dict = checkpoint["model"] else: state_dict = checkpoint weight = convert_beit(state_dict) @@ -52,5 +53,5 @@ def main(): torch.save(weight, args.dst) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/model_converters/mit2mmseg.py b/mmsegmentation/tools/model_converters/mit2mmseg.py index 2eff1f7..ec8ac4e 100644 --- a/mmsegmentation/tools/model_converters/mit2mmseg.py +++ b/mmsegmentation/tools/model_converters/mit2mmseg.py @@ -12,43 +12,43 @@ def convert_mit(ckpt): new_ckpt = OrderedDict() # Process the concat between q linear weights and kv linear weights for k, v in ckpt.items(): - if k.startswith('head'): + if k.startswith("head"): continue # patch embedding conversion - elif k.startswith('patch_embed'): - stage_i = int(k.split('.')[0].replace('patch_embed', '')) - new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0') + elif k.startswith("patch_embed"): + stage_i = int(k.split(".")[0].replace("patch_embed", "")) + new_k = k.replace(f"patch_embed{stage_i}", f"layers.{stage_i-1}.0") new_v = v - if 'proj.' in new_k: - new_k = new_k.replace('proj.', 'projection.') + if "proj." in new_k: + new_k = new_k.replace("proj.", "projection.") # transformer encoder layer conversion - elif k.startswith('block'): - stage_i = int(k.split('.')[0].replace('block', '')) - new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1') + elif k.startswith("block"): + stage_i = int(k.split(".")[0].replace("block", "")) + new_k = k.replace(f"block{stage_i}", f"layers.{stage_i-1}.1") new_v = v - if 'attn.q.' in new_k: - sub_item_k = k.replace('q.', 'kv.') - new_k = new_k.replace('q.', 'attn.in_proj_') + if "attn.q." in new_k: + sub_item_k = k.replace("q.", "kv.") + new_k = new_k.replace("q.", "attn.in_proj_") new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) - elif 'attn.kv.' in new_k: + elif "attn.kv." in new_k: continue - elif 'attn.proj.' in new_k: - new_k = new_k.replace('proj.', 'attn.out_proj.') - elif 'attn.sr.' in new_k: - new_k = new_k.replace('sr.', 'sr.') - elif 'mlp.' in new_k: - string = f'{new_k}-' - new_k = new_k.replace('mlp.', 'ffn.layers.') - if 'fc1.weight' in new_k or 'fc2.weight' in new_k: + elif "attn.proj." in new_k: + new_k = new_k.replace("proj.", "attn.out_proj.") + elif "attn.sr." in new_k: + new_k = new_k.replace("sr.", "sr.") + elif "mlp." in new_k: + string = f"{new_k}-" + new_k = new_k.replace("mlp.", "ffn.layers.") + if "fc1.weight" in new_k or "fc2.weight" in new_k: new_v = v.reshape((*v.shape, 1, 1)) - new_k = new_k.replace('fc1.', '0.') - new_k = new_k.replace('dwconv.dwconv.', '1.') - new_k = new_k.replace('fc2.', '4.') - string += f'{new_k} {v.shape}-{new_v.shape}' + new_k = new_k.replace("fc1.", "0.") + new_k = new_k.replace("dwconv.dwconv.", "1.") + new_k = new_k.replace("fc2.", "4.") + string += f"{new_k} {v.shape}-{new_v.shape}" # norm layer conversion - elif k.startswith('norm'): - stage_i = int(k.split('.')[0].replace('norm', '')) - new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2') + elif k.startswith("norm"): + stage_i = int(k.split(".")[0].replace("norm", "")) + new_k = k.replace(f"norm{stage_i}", f"layers.{stage_i-1}.2") new_v = v else: new_k = k @@ -59,18 +59,19 @@ def convert_mit(ckpt): def main(): parser = argparse.ArgumentParser( - description='Convert keys in official pretrained segformer to ' - 'MMSegmentation style.') - parser.add_argument('src', help='src model path or url') + description="Convert keys in official pretrained segformer to " + "MMSegmentation style." + ) + parser.add_argument("src", help="src model path or url") # The dst path must be a full path of the new checkpoint. - parser.add_argument('dst', help='save path') + parser.add_argument("dst", help="save path") args = parser.parse_args() - checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - elif 'model' in checkpoint: - state_dict = checkpoint['model'] + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif "model" in checkpoint: + state_dict = checkpoint["model"] else: state_dict = checkpoint weight = convert_mit(state_dict) @@ -78,5 +79,5 @@ def main(): torch.save(weight, args.dst) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/model_converters/stdc2mmseg.py b/mmsegmentation/tools/model_converters/stdc2mmseg.py index 9241f86..c2daa07 100644 --- a/mmsegmentation/tools/model_converters/stdc2mmseg.py +++ b/mmsegmentation/tools/model_converters/stdc2mmseg.py @@ -9,32 +9,44 @@ def convert_stdc(ckpt, stdc_type): new_state_dict = {} - if stdc_type == 'STDC1': - stage_lst = ['0', '1', '2.0', '2.1', '3.0', '3.1', '4.0', '4.1'] + if stdc_type == "STDC1": + stage_lst = ["0", "1", "2.0", "2.1", "3.0", "3.1", "4.0", "4.1"] else: stage_lst = [ - '0', '1', '2.0', '2.1', '2.2', '2.3', '3.0', '3.1', '3.2', '3.3', - '3.4', '4.0', '4.1', '4.2' + "0", + "1", + "2.0", + "2.1", + "2.2", + "2.3", + "3.0", + "3.1", + "3.2", + "3.3", + "3.4", + "4.0", + "4.1", + "4.2", ] for k, v in ckpt.items(): ori_k = k flag = False - if 'cp.' in k: - k = k.replace('cp.', '') - if 'features.' in k: - num_layer = int(k.split('.')[1]) - feature_key_lst = 'features.' + str(num_layer) + '.' - stages_key_lst = 'stages.' + stage_lst[num_layer] + '.' + if "cp." in k: + k = k.replace("cp.", "") + if "features." in k: + num_layer = int(k.split(".")[1]) + feature_key_lst = "features." + str(num_layer) + "." + stages_key_lst = "stages." + stage_lst[num_layer] + "." k = k.replace(feature_key_lst, stages_key_lst) flag = True - if 'conv_list' in k: - k = k.replace('conv_list', 'layers') + if "conv_list" in k: + k = k.replace("conv_list", "layers") flag = True - if 'avd_layer.' in k: - if 'avd_layer.0' in k: - k = k.replace('avd_layer.0', 'downsample.conv') - elif 'avd_layer.1' in k: - k = k.replace('avd_layer.1', 'downsample.bn') + if "avd_layer." in k: + if "avd_layer.0" in k: + k = k.replace("avd_layer.0", "downsample.conv") + elif "avd_layer.1" in k: + k = k.replace("avd_layer.1", "downsample.bn") flag = True if flag: new_state_dict[k] = ckpt[ori_k] @@ -44,28 +56,28 @@ def convert_stdc(ckpt, stdc_type): def main(): parser = argparse.ArgumentParser( - description='Convert keys in official pretrained STDC1/2 to ' - 'MMSegmentation style.') - parser.add_argument('src', help='src model path') + description="Convert keys in official pretrained STDC1/2 to " + "MMSegmentation style." + ) + parser.add_argument("src", help="src model path") # The dst path must be a full path of the new checkpoint. - parser.add_argument('dst', help='save path') - parser.add_argument('type', help='model type: STDC1 or STDC2') + parser.add_argument("dst", help="save path") + parser.add_argument("type", help="model type: STDC1 or STDC2") args = parser.parse_args() - checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - elif 'model' in checkpoint: - state_dict = checkpoint['model'] + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif "model" in checkpoint: + state_dict = checkpoint["model"] else: state_dict = checkpoint - assert args.type in ['STDC1', - 'STDC2'], 'STD type should be STDC1 or STDC2!' + assert args.type in ["STDC1", "STDC2"], "STD type should be STDC1 or STDC2!" weight = convert_stdc(state_dict, args.type) mmcv.mkdir_or_exist(osp.dirname(args.dst)) torch.save(weight, args.dst) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/model_converters/swin2mmseg.py b/mmsegmentation/tools/model_converters/swin2mmseg.py index 03b24ce..ebe98c2 100644 --- a/mmsegmentation/tools/model_converters/swin2mmseg.py +++ b/mmsegmentation/tools/model_converters/swin2mmseg.py @@ -14,8 +14,7 @@ def convert_swin(ckpt): def correct_unfold_reduction_order(x): out_channel, in_channel = x.shape x = x.reshape(out_channel, 4, in_channel // 4) - x = x[:, [0, 2, 1, 3], :].transpose(1, - 2).reshape(out_channel, in_channel) + x = x[:, [0, 2, 1, 3], :].transpose(1, 2).reshape(out_channel, in_channel) return x def correct_unfold_norm_order(x): @@ -25,32 +24,32 @@ def correct_unfold_norm_order(x): return x for k, v in ckpt.items(): - if k.startswith('head'): + if k.startswith("head"): continue - elif k.startswith('layers'): + elif k.startswith("layers"): new_v = v - if 'attn.' in k: - new_k = k.replace('attn.', 'attn.w_msa.') - elif 'mlp.' in k: - if 'mlp.fc1.' in k: - new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') - elif 'mlp.fc2.' in k: - new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') + if "attn." in k: + new_k = k.replace("attn.", "attn.w_msa.") + elif "mlp." in k: + if "mlp.fc1." in k: + new_k = k.replace("mlp.fc1.", "ffn.layers.0.0.") + elif "mlp.fc2." in k: + new_k = k.replace("mlp.fc2.", "ffn.layers.1.") else: - new_k = k.replace('mlp.', 'ffn.') - elif 'downsample' in k: + new_k = k.replace("mlp.", "ffn.") + elif "downsample" in k: new_k = k - if 'reduction.' in k: + if "reduction." in k: new_v = correct_unfold_reduction_order(v) - elif 'norm.' in k: + elif "norm." in k: new_v = correct_unfold_norm_order(v) else: new_k = k - new_k = new_k.replace('layers', 'stages', 1) - elif k.startswith('patch_embed'): + new_k = new_k.replace("layers", "stages", 1) + elif k.startswith("patch_embed"): new_v = v - if 'proj' in k: - new_k = k.replace('proj', 'projection') + if "proj" in k: + new_k = k.replace("proj", "projection") else: new_k = k else: @@ -64,18 +63,19 @@ def correct_unfold_norm_order(x): def main(): parser = argparse.ArgumentParser( - description='Convert keys in official pretrained swin models to' - 'MMSegmentation style.') - parser.add_argument('src', help='src model path or url') + description="Convert keys in official pretrained swin models to" + "MMSegmentation style." + ) + parser.add_argument("src", help="src model path or url") # The dst path must be a full path of the new checkpoint. - parser.add_argument('dst', help='save path') + parser.add_argument("dst", help="save path") args = parser.parse_args() - checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - elif 'model' in checkpoint: - state_dict = checkpoint['model'] + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif "model" in checkpoint: + state_dict = checkpoint["model"] else: state_dict = checkpoint weight = convert_swin(state_dict) @@ -83,5 +83,5 @@ def main(): torch.save(weight, args.dst) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/model_converters/twins2mmseg.py b/mmsegmentation/tools/model_converters/twins2mmseg.py index ab64aa5..701bb71 100644 --- a/mmsegmentation/tools/model_converters/twins2mmseg.py +++ b/mmsegmentation/tools/model_converters/twins2mmseg.py @@ -9,72 +9,71 @@ def convert_twins(args, ckpt): - new_ckpt = OrderedDict() for k, v in list(ckpt.items()): new_v = v - if k.startswith('head'): + if k.startswith("head"): continue - elif k.startswith('patch_embeds'): - if 'proj.' in k: - new_k = k.replace('proj.', 'projection.') + elif k.startswith("patch_embeds"): + if "proj." in k: + new_k = k.replace("proj.", "projection.") else: new_k = k - elif k.startswith('blocks'): + elif k.startswith("blocks"): # Union - if 'attn.q.' in k: - new_k = k.replace('q.', 'attn.in_proj_') - new_v = torch.cat([v, ckpt[k.replace('attn.q.', 'attn.kv.')]], - dim=0) - elif 'mlp.fc1' in k: - new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') - elif 'mlp.fc2' in k: - new_k = k.replace('mlp.fc2', 'ffn.layers.1') + if "attn.q." in k: + new_k = k.replace("q.", "attn.in_proj_") + new_v = torch.cat([v, ckpt[k.replace("attn.q.", "attn.kv.")]], dim=0) + elif "mlp.fc1" in k: + new_k = k.replace("mlp.fc1", "ffn.layers.0.0") + elif "mlp.fc2" in k: + new_k = k.replace("mlp.fc2", "ffn.layers.1") # Only pcpvt - elif args.model == 'pcpvt': - if 'attn.proj.' in k: - new_k = k.replace('proj.', 'attn.out_proj.') + elif args.model == "pcpvt": + if "attn.proj." in k: + new_k = k.replace("proj.", "attn.out_proj.") else: new_k = k # Only svt else: - if 'attn.proj.' in k: - k_lst = k.split('.') + if "attn.proj." in k: + k_lst = k.split(".") if int(k_lst[2]) % 2 == 1: - new_k = k.replace('proj.', 'attn.out_proj.') + new_k = k.replace("proj.", "attn.out_proj.") else: new_k = k else: new_k = k - new_k = new_k.replace('blocks.', 'layers.') - elif k.startswith('pos_block'): - new_k = k.replace('pos_block', 'position_encodings') - if 'proj.0.' in new_k: - new_k = new_k.replace('proj.0.', 'proj.') + new_k = new_k.replace("blocks.", "layers.") + elif k.startswith("pos_block"): + new_k = k.replace("pos_block", "position_encodings") + if "proj.0." in new_k: + new_k = new_k.replace("proj.0.", "proj.") else: new_k = k - if 'attn.kv.' not in k: + if "attn.kv." not in k: new_ckpt[new_k] = new_v return new_ckpt def main(): parser = argparse.ArgumentParser( - description='Convert keys in timm pretrained vit models to ' - 'MMSegmentation style.') - parser.add_argument('src', help='src model path or url') + description="Convert keys in timm pretrained vit models to " + "MMSegmentation style." + ) + parser.add_argument("src", help="src model path or url") # The dst path must be a full path of the new checkpoint. - parser.add_argument('dst', help='save path') - parser.add_argument('model', help='model: pcpvt or svt') + parser.add_argument("dst", help="save path") + parser.add_argument("model", help="model: pcpvt or svt") args = parser.parse_args() - checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location="cpu") - if 'state_dict' in checkpoint: + if "state_dict" in checkpoint: # timm checkpoint - state_dict = checkpoint['state_dict'] + state_dict = checkpoint["state_dict"] else: state_dict = checkpoint @@ -83,5 +82,5 @@ def main(): torch.save(weight, args.dst) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/model_converters/vit2mmseg.py b/mmsegmentation/tools/model_converters/vit2mmseg.py index bc18ebe..3ad0ee0 100644 --- a/mmsegmentation/tools/model_converters/vit2mmseg.py +++ b/mmsegmentation/tools/model_converters/vit2mmseg.py @@ -9,33 +9,32 @@ def convert_vit(ckpt): - new_ckpt = OrderedDict() for k, v in ckpt.items(): - if k.startswith('head'): + if k.startswith("head"): continue - if k.startswith('norm'): - new_k = k.replace('norm.', 'ln1.') - elif k.startswith('patch_embed'): - if 'proj' in k: - new_k = k.replace('proj', 'projection') + if k.startswith("norm"): + new_k = k.replace("norm.", "ln1.") + elif k.startswith("patch_embed"): + if "proj" in k: + new_k = k.replace("proj", "projection") else: new_k = k - elif k.startswith('blocks'): - if 'norm' in k: - new_k = k.replace('norm', 'ln') - elif 'mlp.fc1' in k: - new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') - elif 'mlp.fc2' in k: - new_k = k.replace('mlp.fc2', 'ffn.layers.1') - elif 'attn.qkv' in k: - new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_') - elif 'attn.proj' in k: - new_k = k.replace('attn.proj', 'attn.attn.out_proj') + elif k.startswith("blocks"): + if "norm" in k: + new_k = k.replace("norm", "ln") + elif "mlp.fc1" in k: + new_k = k.replace("mlp.fc1", "ffn.layers.0.0") + elif "mlp.fc2" in k: + new_k = k.replace("mlp.fc2", "ffn.layers.1") + elif "attn.qkv" in k: + new_k = k.replace("attn.qkv.", "attn.attn.in_proj_") + elif "attn.proj" in k: + new_k = k.replace("attn.proj", "attn.attn.out_proj") else: new_k = k - new_k = new_k.replace('blocks.', 'layers.') + new_k = new_k.replace("blocks.", "layers.") else: new_k = k new_ckpt[new_k] = v @@ -45,20 +44,21 @@ def convert_vit(ckpt): def main(): parser = argparse.ArgumentParser( - description='Convert keys in timm pretrained vit models to ' - 'MMSegmentation style.') - parser.add_argument('src', help='src model path or url') + description="Convert keys in timm pretrained vit models to " + "MMSegmentation style." + ) + parser.add_argument("src", help="src model path or url") # The dst path must be a full path of the new checkpoint. - parser.add_argument('dst', help='save path') + parser.add_argument("dst", help="save path") args = parser.parse_args() - checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') - if 'state_dict' in checkpoint: + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location="cpu") + if "state_dict" in checkpoint: # timm checkpoint - state_dict = checkpoint['state_dict'] - elif 'model' in checkpoint: + state_dict = checkpoint["state_dict"] + elif "model" in checkpoint: # deit checkpoint - state_dict = checkpoint['model'] + state_dict = checkpoint["model"] else: state_dict = checkpoint weight = convert_vit(state_dict) @@ -66,5 +66,5 @@ def main(): torch.save(weight, args.dst) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/model_converters/vitjax2mmseg.py b/mmsegmentation/tools/model_converters/vitjax2mmseg.py index 585f408..58c075b 100644 --- a/mmsegmentation/tools/model_converters/vitjax2mmseg.py +++ b/mmsegmentation/tools/model_converters/vitjax2mmseg.py @@ -11,84 +11,95 @@ def vit_jax_to_torch(jax_weights, num_layer=12): torch_weights = dict() # patch embedding - conv_filters = jax_weights['embedding/kernel'] + conv_filters = jax_weights["embedding/kernel"] conv_filters = conv_filters.permute(3, 2, 0, 1) - torch_weights['patch_embed.projection.weight'] = conv_filters - torch_weights['patch_embed.projection.bias'] = jax_weights[ - 'embedding/bias'] + torch_weights["patch_embed.projection.weight"] = conv_filters + torch_weights["patch_embed.projection.bias"] = jax_weights["embedding/bias"] # pos embedding - torch_weights['pos_embed'] = jax_weights[ - 'Transformer/posembed_input/pos_embedding'] + torch_weights["pos_embed"] = jax_weights["Transformer/posembed_input/pos_embedding"] # cls token - torch_weights['cls_token'] = jax_weights['cls'] + torch_weights["cls_token"] = jax_weights["cls"] # head - torch_weights['ln1.weight'] = jax_weights['Transformer/encoder_norm/scale'] - torch_weights['ln1.bias'] = jax_weights['Transformer/encoder_norm/bias'] + torch_weights["ln1.weight"] = jax_weights["Transformer/encoder_norm/scale"] + torch_weights["ln1.bias"] = jax_weights["Transformer/encoder_norm/bias"] # transformer blocks for i in range(num_layer): - jax_block = f'Transformer/encoderblock_{i}' - torch_block = f'layers.{i}' + jax_block = f"Transformer/encoderblock_{i}" + torch_block = f"layers.{i}" # attention norm - torch_weights[f'{torch_block}.ln1.weight'] = jax_weights[ - f'{jax_block}/LayerNorm_0/scale'] - torch_weights[f'{torch_block}.ln1.bias'] = jax_weights[ - f'{jax_block}/LayerNorm_0/bias'] + torch_weights[f"{torch_block}.ln1.weight"] = jax_weights[ + f"{jax_block}/LayerNorm_0/scale" + ] + torch_weights[f"{torch_block}.ln1.bias"] = jax_weights[ + f"{jax_block}/LayerNorm_0/bias" + ] # attention query_weight = jax_weights[ - f'{jax_block}/MultiHeadDotProductAttention_1/query/kernel'] + f"{jax_block}/MultiHeadDotProductAttention_1/query/kernel" + ] query_bias = jax_weights[ - f'{jax_block}/MultiHeadDotProductAttention_1/query/bias'] + f"{jax_block}/MultiHeadDotProductAttention_1/query/bias" + ] key_weight = jax_weights[ - f'{jax_block}/MultiHeadDotProductAttention_1/key/kernel'] - key_bias = jax_weights[ - f'{jax_block}/MultiHeadDotProductAttention_1/key/bias'] + f"{jax_block}/MultiHeadDotProductAttention_1/key/kernel" + ] + key_bias = jax_weights[f"{jax_block}/MultiHeadDotProductAttention_1/key/bias"] value_weight = jax_weights[ - f'{jax_block}/MultiHeadDotProductAttention_1/value/kernel'] + f"{jax_block}/MultiHeadDotProductAttention_1/value/kernel" + ] value_bias = jax_weights[ - f'{jax_block}/MultiHeadDotProductAttention_1/value/bias'] + f"{jax_block}/MultiHeadDotProductAttention_1/value/bias" + ] qkv_weight = torch.from_numpy( - np.stack((query_weight, key_weight, value_weight), 1)) + np.stack((query_weight, key_weight, value_weight), 1) + ) qkv_weight = torch.flatten(qkv_weight, start_dim=1) - qkv_bias = torch.from_numpy( - np.stack((query_bias, key_bias, value_bias), 0)) + qkv_bias = torch.from_numpy(np.stack((query_bias, key_bias, value_bias), 0)) qkv_bias = torch.flatten(qkv_bias, start_dim=0) - torch_weights[f'{torch_block}.attn.attn.in_proj_weight'] = qkv_weight - torch_weights[f'{torch_block}.attn.attn.in_proj_bias'] = qkv_bias + torch_weights[f"{torch_block}.attn.attn.in_proj_weight"] = qkv_weight + torch_weights[f"{torch_block}.attn.attn.in_proj_bias"] = qkv_bias to_out_weight = jax_weights[ - f'{jax_block}/MultiHeadDotProductAttention_1/out/kernel'] + f"{jax_block}/MultiHeadDotProductAttention_1/out/kernel" + ] to_out_weight = torch.flatten(to_out_weight, start_dim=0, end_dim=1) - torch_weights[ - f'{torch_block}.attn.attn.out_proj.weight'] = to_out_weight - torch_weights[f'{torch_block}.attn.attn.out_proj.bias'] = jax_weights[ - f'{jax_block}/MultiHeadDotProductAttention_1/out/bias'] + torch_weights[f"{torch_block}.attn.attn.out_proj.weight"] = to_out_weight + torch_weights[f"{torch_block}.attn.attn.out_proj.bias"] = jax_weights[ + f"{jax_block}/MultiHeadDotProductAttention_1/out/bias" + ] # mlp norm - torch_weights[f'{torch_block}.ln2.weight'] = jax_weights[ - f'{jax_block}/LayerNorm_2/scale'] - torch_weights[f'{torch_block}.ln2.bias'] = jax_weights[ - f'{jax_block}/LayerNorm_2/bias'] + torch_weights[f"{torch_block}.ln2.weight"] = jax_weights[ + f"{jax_block}/LayerNorm_2/scale" + ] + torch_weights[f"{torch_block}.ln2.bias"] = jax_weights[ + f"{jax_block}/LayerNorm_2/bias" + ] # mlp - torch_weights[f'{torch_block}.ffn.layers.0.0.weight'] = jax_weights[ - f'{jax_block}/MlpBlock_3/Dense_0/kernel'] - torch_weights[f'{torch_block}.ffn.layers.0.0.bias'] = jax_weights[ - f'{jax_block}/MlpBlock_3/Dense_0/bias'] - torch_weights[f'{torch_block}.ffn.layers.1.weight'] = jax_weights[ - f'{jax_block}/MlpBlock_3/Dense_1/kernel'] - torch_weights[f'{torch_block}.ffn.layers.1.bias'] = jax_weights[ - f'{jax_block}/MlpBlock_3/Dense_1/bias'] + torch_weights[f"{torch_block}.ffn.layers.0.0.weight"] = jax_weights[ + f"{jax_block}/MlpBlock_3/Dense_0/kernel" + ] + torch_weights[f"{torch_block}.ffn.layers.0.0.bias"] = jax_weights[ + f"{jax_block}/MlpBlock_3/Dense_0/bias" + ] + torch_weights[f"{torch_block}.ffn.layers.1.weight"] = jax_weights[ + f"{jax_block}/MlpBlock_3/Dense_1/kernel" + ] + torch_weights[f"{torch_block}.ffn.layers.1.bias"] = jax_weights[ + f"{jax_block}/MlpBlock_3/Dense_1/bias" + ] # transpose weights for k, v in torch_weights.items(): - if 'weight' in k and 'patch_embed' not in k and 'ln' not in k: + if "weight" in k and "patch_embed" not in k and "ln" not in k: v = v.permute(1, 0) torch_weights[k] = v @@ -98,11 +109,12 @@ def vit_jax_to_torch(jax_weights, num_layer=12): def main(): # stole refactoring code from Robin Strudel, thanks parser = argparse.ArgumentParser( - description='Convert keys from jax official pretrained vit models to ' - 'MMSegmentation style.') - parser.add_argument('src', help='src model path or url') + description="Convert keys from jax official pretrained vit models to " + "MMSegmentation style." + ) + parser.add_argument("src", help="src model path or url") # The dst path must be a full path of the new checkpoint. - parser.add_argument('dst', help='save path') + parser.add_argument("dst", help="save path") args = parser.parse_args() jax_weights = np.load(args.src) @@ -110,7 +122,7 @@ def main(): for key in jax_weights.files: value = torch.from_numpy(jax_weights[key]) jax_weights_tensor[key] = value - if 'L_16-i21k' in args.src: + if "L_16-i21k" in args.src: num_layer = 24 else: num_layer = 12 @@ -119,5 +131,5 @@ def main(): torch.save(torch_weights, args.dst) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/model_ensemble.py b/mmsegmentation/tools/model_ensemble.py index b526650..dcbfa71 100644 --- a/mmsegmentation/tools/model_ensemble.py +++ b/mmsegmentation/tools/model_ensemble.py @@ -16,7 +16,6 @@ @torch.no_grad() def main(args): - models = [] gpu_ids = args.gpus configs = args.config @@ -25,9 +24,7 @@ def main(args): cfg = mmcv.Config.fromfile(configs[0]) if args.aug_test: - cfg.data.test.pipeline[1].img_ratios = [ - 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0 - ] + cfg.data.test.pipeline[1].img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0] cfg.data.test.pipeline[1].flip = True else: cfg.data.test.pipeline[1].img_ratios = [1.0] @@ -50,10 +47,10 @@ def main(args): cfg.model.pretrained = None cfg.data.test.test_mode = True - model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) - if cfg.get('fp16', None): + model = build_segmentor(cfg.model, test_cfg=cfg.get("test_cfg")) + if cfg.get("fp16", None): wrap_fp16_model(model) - load_checkpoint(model, ckpt, map_location='cpu') + load_checkpoint(model, ckpt, map_location="cpu") torch.cuda.empty_cache() tmpdir = args.out mmcv.mkdir_or_exist(tmpdir) @@ -69,7 +66,8 @@ def main(args): for model in models: x, _ = scatter_kwargs( - inputs=data, kwargs=None, target_gpus=model.device_ids) + inputs=data, kwargs=None, target_gpus=model.device_ids + ) if args.aug_test: logits = model.module.aug_test_logits(**x[0]) else: @@ -83,39 +81,42 @@ def main(args): pred = result_logits.argmax(axis=1).squeeze() img_info = dataset.img_infos[batch_indices[0]] file_name = os.path.join( - tmpdir, img_info['ann']['seg_map'].split(os.path.sep)[-1]) + tmpdir, img_info["ann"]["seg_map"].split(os.path.sep)[-1] + ) Image.fromarray(pred.astype(np.uint8)).save(file_name) prog_bar.update() def parse_args(): - parser = argparse.ArgumentParser( - description='Model Ensemble with logits result') + parser = argparse.ArgumentParser(description="Model Ensemble with logits result") parser.add_argument( - '--config', type=str, nargs='+', help='ensemble config files path') + "--config", type=str, nargs="+", help="ensemble config files path" + ) parser.add_argument( - '--checkpoint', - type=str, - nargs='+', - help='ensemble checkpoint files path') + "--checkpoint", type=str, nargs="+", help="ensemble checkpoint files path" + ) parser.add_argument( - '--aug-test', - action='store_true', - help='control ensemble aug-result or single-result (default)') + "--aug-test", + action="store_true", + help="control ensemble aug-result or single-result (default)", + ) parser.add_argument( - '--out', type=str, default='results', help='the dir to save result') + "--out", type=str, default="results", help="the dir to save result" + ) parser.add_argument( - '--gpus', type=int, nargs='+', default=[0], help='id of gpu to use') + "--gpus", type=int, nargs="+", default=[0], help="id of gpu to use" + ) args = parser.parse_args() - assert len(args.config) == len(args.checkpoint), \ - f'len(config) must equal len(checkpoint), ' \ - f'but len(config) = {len(args.config)} and' \ - f'len(checkpoint) = {len(args.checkpoint)}' + assert len(args.config) == len(args.checkpoint), ( + f"len(config) must equal len(checkpoint), " + f"but len(config) = {len(args.config)} and" + f"len(checkpoint) = {len(args.checkpoint)}" + ) assert args.out, "ensemble result out-dir can't be None" return args -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() main(args) diff --git a/mmsegmentation/tools/onnx2tensorrt.py b/mmsegmentation/tools/onnx2tensorrt.py index 0f60dce..3a3f1f1 100644 --- a/mmsegmentation/tools/onnx2tensorrt.py +++ b/mmsegmentation/tools/onnx2tensorrt.py @@ -11,8 +11,12 @@ import onnxruntime as ort import torch from mmcv.ops import get_onnxruntime_op_path -from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt, - save_trt_engine) +from mmcv.tensorrt import ( + TRTWraper, + is_tensorrt_plugin_loaded, + onnx2trt, + save_trt_engine, +) from mmseg.apis.inference import LoadImage from mmseg.datasets import DATASETS @@ -24,27 +28,29 @@ def get_GiB(x: int): return x * (1 << 30) -def _prepare_input_img(img_path: str, - test_pipeline: Iterable[dict], - shape: Optional[Iterable] = None, - rescale_shape: Optional[Iterable] = None) -> dict: +def _prepare_input_img( + img_path: str, + test_pipeline: Iterable[dict], + shape: Optional[Iterable] = None, + rescale_shape: Optional[Iterable] = None, +) -> dict: # build the data pipeline if shape is not None: - test_pipeline[1]['img_scale'] = (shape[1], shape[0]) - test_pipeline[1]['transforms'][0]['keep_ratio'] = False + test_pipeline[1]["img_scale"] = (shape[1], shape[0]) + test_pipeline[1]["transforms"][0]["keep_ratio"] = False test_pipeline = [LoadImage()] + test_pipeline[1:] test_pipeline = Compose(test_pipeline) # prepare data data = dict(img=img_path) data = test_pipeline(data) - imgs = data['img'] - img_metas = [i.data for i in data['img_metas']] + imgs = data["img"] + img_metas = [i.data for i in data["img_metas"]] if rescale_shape is not None: for img_meta in img_metas: - img_meta['ori_shape'] = tuple(rescale_shape) + (3, ) + img_meta["ori_shape"] = tuple(rescale_shape) + (3,) - mm_inputs = {'imgs': imgs, 'img_metas': img_metas} + mm_inputs = {"imgs": imgs, "img_metas": img_metas} return mm_inputs @@ -53,34 +59,39 @@ def _update_input_img(img_list: Iterable, img_meta_list: Iterable): # update img and its meta list N = img_list[0].size(0) img_meta = img_meta_list[0][0] - img_shape = img_meta['img_shape'] - ori_shape = img_meta['ori_shape'] - pad_shape = img_meta['pad_shape'] - new_img_meta_list = [[{ - 'img_shape': - img_shape, - 'ori_shape': - ori_shape, - 'pad_shape': - pad_shape, - 'filename': - img_meta['filename'], - 'scale_factor': - (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2, - 'flip': - False, - } for _ in range(N)]] + img_shape = img_meta["img_shape"] + ori_shape = img_meta["ori_shape"] + pad_shape = img_meta["pad_shape"] + new_img_meta_list = [ + [ + { + "img_shape": img_shape, + "ori_shape": ori_shape, + "pad_shape": pad_shape, + "filename": img_meta["filename"], + "scale_factor": ( + img_shape[1] / ori_shape[1], + img_shape[0] / ori_shape[0], + ) + * 2, + "flip": False, + } + for _ in range(N) + ] + ] return img_list, new_img_meta_list -def show_result_pyplot(img: Union[str, np.ndarray], - result: np.ndarray, - palette: Optional[Iterable] = None, - fig_size: Iterable[int] = (15, 10), - opacity: float = 0.5, - title: str = '', - block: bool = True): +def show_result_pyplot( + img: Union[str, np.ndarray], + result: np.ndarray, + palette: Optional[Iterable] = None, + fig_size: Iterable[int] = (15, 10), + opacity: float = 0.5, + title: str = "", + block: bool = True, +): img = mmcv.imread(img) img = img.copy() seg = result[0] @@ -105,42 +116,45 @@ def show_result_pyplot(img: Union[str, np.ndarray], plt.show(block=block) -def onnx2tensorrt(onnx_file: str, - trt_file: str, - config: dict, - input_config: dict, - fp16: bool = False, - verify: bool = False, - show: bool = False, - dataset: str = 'CityscapesDataset', - workspace_size: int = 1, - verbose: bool = False): +def onnx2tensorrt( + onnx_file: str, + trt_file: str, + config: dict, + input_config: dict, + fp16: bool = False, + verify: bool = False, + show: bool = False, + dataset: str = "CityscapesDataset", + workspace_size: int = 1, + verbose: bool = False, +): import tensorrt as trt - min_shape = input_config['min_shape'] - max_shape = input_config['max_shape'] + + min_shape = input_config["min_shape"] + max_shape = input_config["max_shape"] # create trt engine and wrapper - opt_shape_dict = {'input': [min_shape, min_shape, max_shape]} + opt_shape_dict = {"input": [min_shape, min_shape, max_shape]} max_workspace_size = get_GiB(workspace_size) trt_engine = onnx2trt( onnx_file, opt_shape_dict, log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR, fp16_mode=fp16, - max_workspace_size=max_workspace_size) + max_workspace_size=max_workspace_size, + ) save_dir, _ = osp.split(trt_file) if save_dir: os.makedirs(save_dir, exist_ok=True) save_trt_engine(trt_engine, trt_file) - print(f'Successfully created TensorRT engine: {trt_file}') + print(f"Successfully created TensorRT engine: {trt_file}") if verify: inputs = _prepare_input_img( - input_config['input_path'], - config.data.test.pipeline, - shape=min_shape[2:]) + input_config["input_path"], config.data.test.pipeline, shape=min_shape[2:] + ) - imgs = inputs['imgs'] - img_metas = inputs['img_metas'] + imgs = inputs["imgs"] + img_metas = inputs["img_metas"] img_list = [img[None, :] for img in imgs] img_meta_list = [[img_meta] for img_meta in img_metas] # update img_meta @@ -160,15 +174,16 @@ def onnx2tensorrt(onnx_file: str, if osp.exists(ort_custom_op_path): session_options.register_custom_ops_library(ort_custom_op_path) sess = ort.InferenceSession(onnx_file, session_options) - sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode - onnx_output = sess.run(['output'], - {'input': img_list[0].detach().numpy()})[0][0] + sess.set_providers(["CPUExecutionProvider"], [{}]) # use cpu mode + onnx_output = sess.run(["output"], {"input": img_list[0].detach().numpy()})[0][ + 0 + ] # Get results from TensorRT - trt_model = TRTWraper(trt_file, ['input'], ['output']) + trt_model = TRTWraper(trt_file, ["input"], ["output"]) with torch.no_grad(): - trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()}) - trt_output = trt_outputs['output'][0].cpu().detach().numpy() + trt_outputs = trt_model({"input": img_list[0].contiguous().cuda()}) + trt_output = trt_outputs["output"][0].cpu().detach().numpy() if show: dataset = DATASETS.get(dataset) @@ -176,91 +191,93 @@ def onnx2tensorrt(onnx_file: str, palette = dataset.PALETTE show_result_pyplot( - input_config['input_path'], - (onnx_output[0].astype(np.uint8), ), + input_config["input_path"], + (onnx_output[0].astype(np.uint8),), palette=palette, - title='ONNXRuntime', - block=False) + title="ONNXRuntime", + block=False, + ) show_result_pyplot( - input_config['input_path'], (trt_output[0].astype(np.uint8), ), + input_config["input_path"], + (trt_output[0].astype(np.uint8),), palette=palette, - title='TensorRT') + title="TensorRT", + ) - np.testing.assert_allclose( - onnx_output, trt_output, rtol=1e-03, atol=1e-05) - print('TensorRT and ONNXRuntime output all close.') + np.testing.assert_allclose(onnx_output, trt_output, rtol=1e-03, atol=1e-05) + print("TensorRT and ONNXRuntime output all close.") def parse_args(): parser = argparse.ArgumentParser( - description='Convert MMSegmentation models from ONNX to TensorRT') - parser.add_argument('config', help='Config file of the model') - parser.add_argument('model', help='Path to the input ONNX model') + description="Convert MMSegmentation models from ONNX to TensorRT" + ) + parser.add_argument("config", help="Config file of the model") + parser.add_argument("model", help="Path to the input ONNX model") parser.add_argument( - '--trt-file', type=str, help='Path to the output TensorRT engine') + "--trt-file", type=str, help="Path to the output TensorRT engine" + ) parser.add_argument( - '--max-shape', + "--max-shape", type=int, nargs=4, default=[1, 3, 400, 600], - help='Maximum shape of model input.') + help="Maximum shape of model input.", + ) parser.add_argument( - '--min-shape', + "--min-shape", type=int, nargs=4, default=[1, 3, 400, 600], - help='Minimum shape of model input.') - parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode') - parser.add_argument( - '--workspace-size', - type=int, - default=1, - help='Max workspace size in GiB') + help="Minimum shape of model input.", + ) + parser.add_argument("--fp16", action="store_true", help="Enable fp16 mode") parser.add_argument( - '--input-img', type=str, default='', help='Image for test') + "--workspace-size", type=int, default=1, help="Max workspace size in GiB" + ) + parser.add_argument("--input-img", type=str, default="", help="Image for test") parser.add_argument( - '--show', action='store_true', help='Whether to show output results') + "--show", action="store_true", help="Whether to show output results" + ) parser.add_argument( - '--dataset', - type=str, - default='CityscapesDataset', - help='Dataset name') + "--dataset", type=str, default="CityscapesDataset", help="Dataset name" + ) parser.add_argument( - '--verify', - action='store_true', - help='Verify the outputs of ONNXRuntime and TensorRT') + "--verify", + action="store_true", + help="Verify the outputs of ONNXRuntime and TensorRT", + ) parser.add_argument( - '--verbose', - action='store_true', - help='Whether to verbose logging messages while creating \ - TensorRT engine.') + "--verbose", + action="store_true", + help="Whether to verbose logging messages while creating \ + TensorRT engine.", + ) args = parser.parse_args() return args -if __name__ == '__main__': - - assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.' +if __name__ == "__main__": + assert is_tensorrt_plugin_loaded(), "TensorRT plugin should be compiled." args = parse_args() if not args.input_img: - args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png') + args.input_img = osp.join(osp.dirname(__file__), "../demo/demo.png") # check arguments - assert osp.exists(args.config), 'Config {} not found.'.format(args.config) - assert osp.exists(args.model), \ - 'ONNX model {} not found.'.format(args.model) - assert args.workspace_size >= 0, 'Workspace size less than 0.' - assert DATASETS.get(args.dataset) is not None, \ - 'Dataset {} does not found.'.format(args.dataset) + assert osp.exists(args.config), f"Config {args.config} not found." + assert osp.exists(args.model), f"ONNX model {args.model} not found." + assert args.workspace_size >= 0, "Workspace size less than 0." + assert ( + DATASETS.get(args.dataset) is not None + ), f"Dataset {args.dataset} does not found." for max_value, min_value in zip(args.max_shape, args.min_shape): - assert max_value >= min_value, \ - 'max_shape should be larger than min shape' + assert max_value >= min_value, "max_shape should be larger than min shape" input_config = { - 'min_shape': args.min_shape, - 'max_shape': args.max_shape, - 'input_path': args.input_img + "min_shape": args.min_shape, + "max_shape": args.max_shape, + "input_path": args.input_img, } cfg = mmcv.Config.fromfile(args.config) @@ -274,16 +291,17 @@ def parse_args(): show=args.show, dataset=args.dataset, workspace_size=args.workspace_size, - verbose=args.verbose) + verbose=args.verbose, + ) # Following strings of text style are from colorama package - bright_style, reset_style = '\x1b[1m', '\x1b[0m' - red_text, blue_text = '\x1b[31m', '\x1b[34m' - white_background = '\x1b[107m' + bright_style, reset_style = "\x1b[1m", "\x1b[0m" + red_text, blue_text = "\x1b[31m", "\x1b[34m" + white_background = "\x1b[107m" msg = white_background + bright_style + red_text - msg += 'DeprecationWarning: This tool will be deprecated in future. ' - msg += blue_text + 'Welcome to use the unified model deployment toolbox ' - msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy' + msg += "DeprecationWarning: This tool will be deprecated in future. " + msg += blue_text + "Welcome to use the unified model deployment toolbox " + msg += "MMDeploy: https://github.com/open-mmlab/mmdeploy" msg += reset_style warnings.warn(msg) diff --git a/mmsegmentation/tools/print_config.py b/mmsegmentation/tools/print_config.py index 3f9c08d..77bb514 100644 --- a/mmsegmentation/tools/print_config.py +++ b/mmsegmentation/tools/print_config.py @@ -8,41 +8,45 @@ def parse_args(): - parser = argparse.ArgumentParser(description='Print the whole config') - parser.add_argument('config', help='config file path') + parser = argparse.ArgumentParser(description="Print the whole config") + parser.add_argument("config", help="config file path") + parser.add_argument("--graph", action="store_true", help="print the models graph") parser.add_argument( - '--graph', action='store_true', help='print the models graph') - parser.add_argument( - '--options', - nargs='+', + "--options", + nargs="+", action=DictAction, help="--options is deprecated in favor of --cfg_options' and it will " - 'not be supported in version v0.22.0. Override some settings in the ' - 'used config, the key-value pair in xxx=yyy format will be merged ' - 'into config file. If the value to be overwritten is a list, it ' + "not be supported in version v0.22.0. Override some settings in the " + "used config, the key-value pair in xxx=yyy format will be merged " + "into config file. If the value to be overwritten is a list, it " 'should be like key="[a,b]" or key=a,b It also allows nested ' 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' - 'marks are necessary and that no white space is allowed.') + "marks are necessary and that no white space is allowed.", + ) parser.add_argument( - '--cfg-options', - nargs='+', + "--cfg-options", + nargs="+", action=DictAction, - help='override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into config file. If the value to ' + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' - 'Note that the quotation marks are necessary and that no white space ' - 'is allowed.') + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) args = parser.parse_args() if args.options and args.cfg_options: raise ValueError( - '--options and --cfg-options cannot be both ' - 'specified, --options is deprecated in favor of --cfg-options. ' - '--options will not be supported in version v0.22.0.') + "--options and --cfg-options cannot be both " + "specified, --options is deprecated in favor of --cfg-options. " + "--options will not be supported in version v0.22.0." + ) if args.options: - warnings.warn('--options is deprecated in favor of --cfg-options, ' - '--options will not be supported in version v0.22.0.') + warnings.warn( + "--options is deprecated in favor of --cfg-options, " + "--options will not be supported in version v0.22.0." + ) args.cfg_options = args.options return args @@ -54,16 +58,16 @@ def main(): cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) - print(f'Config:\n{cfg.pretty_text}') + print(f"Config:\n{cfg.pretty_text}") # dump config - cfg.dump('example.py') + cfg.dump("example.py") # dump models graph if args.graph: - model = init_segmentor(args.config, device='cpu') - print(f'Model graph:\n{str(model)}') - with open('example-graph.txt', 'w') as f: + model = init_segmentor(args.config, device="cpu") + print(f"Model graph:\n{str(model)}") + with open("example-graph.txt", "w") as f: f.writelines(str(model)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/publish_model.py b/mmsegmentation/tools/publish_model.py index e266057..dbfc972 100644 --- a/mmsegmentation/tools/publish_model.py +++ b/mmsegmentation/tools/publish_model.py @@ -6,25 +6,24 @@ def parse_args(): - parser = argparse.ArgumentParser( - description='Process a checkpoint to be published') - parser.add_argument('in_file', help='input checkpoint filename') - parser.add_argument('out_file', help='output checkpoint filename') + parser = argparse.ArgumentParser(description="Process a checkpoint to be published") + parser.add_argument("in_file", help="input checkpoint filename") + parser.add_argument("out_file", help="output checkpoint filename") args = parser.parse_args() return args def process_checkpoint(in_file, out_file): - checkpoint = torch.load(in_file, map_location='cpu') + checkpoint = torch.load(in_file, map_location="cpu") # remove optimizer for smaller file size - if 'optimizer' in checkpoint: - del checkpoint['optimizer'] + if "optimizer" in checkpoint: + del checkpoint["optimizer"] # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. torch.save(checkpoint, out_file) - sha = subprocess.check_output(['sha256sum', out_file]).decode() - final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) - subprocess.Popen(['mv', out_file, final_file]) + sha = subprocess.check_output(["sha256sum", out_file]).decode() + final_file = out_file.rstrip(".pth") + f"-{sha[:8]}.pth" + subprocess.Popen(["mv", out_file, final_file]) def main(): @@ -32,5 +31,5 @@ def main(): process_checkpoint(args.in_file, args.out_file) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/pytorch2onnx.py b/mmsegmentation/tools/pytorch2onnx.py index 060d187..31e4410 100644 --- a/mmsegmentation/tools/pytorch2onnx.py +++ b/mmsegmentation/tools/pytorch2onnx.py @@ -26,9 +26,13 @@ def _convert_batchnorm(module): module_output = module if isinstance(module, torch.nn.SyncBatchNorm): - module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, - module.momentum, module.affine, - module.track_running_stats) + module_output = torch.nn.BatchNorm2d( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + ) if module.affine: module_output.weight.data = module.weight.data.clone().detach() module_output.bias.data = module.bias.data.clone().detach() @@ -56,45 +60,44 @@ def _demo_mm_inputs(input_shape, num_classes): (N, C, H, W) = input_shape rng = np.random.RandomState(0) imgs = rng.rand(*input_shape) - segs = rng.randint( - low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) - img_metas = [{ - 'img_shape': (H, W, C), - 'ori_shape': (H, W, C), - 'pad_shape': (H, W, C), - 'filename': '.png', - 'scale_factor': 1.0, - 'flip': False, - } for _ in range(N)] + segs = rng.randint(low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) + img_metas = [ + { + "img_shape": (H, W, C), + "ori_shape": (H, W, C), + "pad_shape": (H, W, C), + "filename": ".png", + "scale_factor": 1.0, + "flip": False, + } + for _ in range(N) + ] mm_inputs = { - 'imgs': torch.FloatTensor(imgs).requires_grad_(True), - 'img_metas': img_metas, - 'gt_semantic_seg': torch.LongTensor(segs) + "imgs": torch.FloatTensor(imgs).requires_grad_(True), + "img_metas": img_metas, + "gt_semantic_seg": torch.LongTensor(segs), } return mm_inputs -def _prepare_input_img(img_path, - test_pipeline, - shape=None, - rescale_shape=None): +def _prepare_input_img(img_path, test_pipeline, shape=None, rescale_shape=None): # build the data pipeline if shape is not None: - test_pipeline[1]['img_scale'] = (shape[1], shape[0]) - test_pipeline[1]['transforms'][0]['keep_ratio'] = False + test_pipeline[1]["img_scale"] = (shape[1], shape[0]) + test_pipeline[1]["transforms"][0]["keep_ratio"] = False test_pipeline = [LoadImage()] + test_pipeline[1:] test_pipeline = Compose(test_pipeline) # prepare data data = dict(img=img_path) data = test_pipeline(data) - imgs = data['img'] - img_metas = [i.data for i in data['img_metas']] + imgs = data["img"] + img_metas = [i.data for i in data["img_metas"]] if rescale_shape is not None: for img_meta in img_metas: - img_meta['ori_shape'] = tuple(rescale_shape) + (3, ) + img_meta["ori_shape"] = tuple(rescale_shape) + (3,) - mm_inputs = {'imgs': imgs, 'img_metas': img_metas} + mm_inputs = {"imgs": imgs, "img_metas": img_metas} return mm_inputs @@ -107,33 +110,38 @@ def _update_input_img(img_list, img_meta_list, update_ori_shape=False): if update_ori_shape: ori_shape = img_shape else: - ori_shape = img_meta['ori_shape'] + ori_shape = img_meta["ori_shape"] pad_shape = img_shape - new_img_meta_list = [[{ - 'img_shape': - img_shape, - 'ori_shape': - ori_shape, - 'pad_shape': - pad_shape, - 'filename': - img_meta['filename'], - 'scale_factor': - (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2, - 'flip': - False, - } for _ in range(N)]] + new_img_meta_list = [ + [ + { + "img_shape": img_shape, + "ori_shape": ori_shape, + "pad_shape": pad_shape, + "filename": img_meta["filename"], + "scale_factor": ( + img_shape[1] / ori_shape[1], + img_shape[0] / ori_shape[0], + ) + * 2, + "flip": False, + } + for _ in range(N) + ] + ] return img_list, new_img_meta_list -def pytorch2onnx(model, - mm_inputs, - opset_version=11, - show=False, - output_file='tmp.onnx', - verify=False, - dynamic_export=False): +def pytorch2onnx( + model, + mm_inputs, + opset_version=11, + show=False, + output_file="tmp.onnx", + verify=False, + dynamic_export=False, +): """Export Pytorch model to ONNX model and verify the outputs are same between Pytorch and ONNX. @@ -157,8 +165,8 @@ def pytorch2onnx(model, else: num_classes = model.decode_head.num_classes - imgs = mm_inputs.pop('imgs') - img_metas = mm_inputs.pop('img_metas') + imgs = mm_inputs.pop("imgs") + img_metas = mm_inputs.pop("img_metas") img_list = [img[None, :] for img in imgs] img_meta_list = [[img_meta] for img_meta in img_metas] @@ -168,50 +176,43 @@ def pytorch2onnx(model, # replace original forward function origin_forward = model.forward model.forward = partial( - model.forward, - img_metas=img_meta_list, - return_loss=False, - rescale=True) + model.forward, img_metas=img_meta_list, return_loss=False, rescale=True + ) dynamic_axes = None if dynamic_export: - if test_mode == 'slide': - dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}} + if test_mode == "slide": + dynamic_axes = {"input": {0: "batch"}, "output": {1: "batch"}} else: dynamic_axes = { - 'input': { - 0: 'batch', - 2: 'height', - 3: 'width' - }, - 'output': { - 1: 'batch', - 2: 'height', - 3: 'width' - } + "input": {0: "batch", 2: "height", 3: "width"}, + "output": {1: "batch", 2: "height", 3: "width"}, } register_extra_symbolics(opset_version) with torch.no_grad(): torch.onnx.export( - model, (img_list, ), + model, + (img_list,), output_file, - input_names=['input'], - output_names=['output'], + input_names=["input"], + output_names=["output"], export_params=True, keep_initializers_as_inputs=False, verbose=show, opset_version=opset_version, - dynamic_axes=dynamic_axes) - print(f'Successfully exported ONNX model: {output_file}') + dynamic_axes=dynamic_axes, + ) + print(f"Successfully exported ONNX model: {output_file}") model.forward = origin_forward if verify: # check by onnx import onnx + onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model) - if dynamic_export and test_mode == 'whole': + if dynamic_export and test_mode == "whole": # scale image for dynamic shape test img_list = [resize(_, scale_factor=1.5) for _ in img_list] # concate flip image for batch test @@ -223,7 +224,8 @@ def pytorch2onnx(model, # update img_meta img_list, img_meta_list = _update_input_img( - img_list, img_meta_list, test_mode == 'whole') + img_list, img_meta_list, test_mode == "whole" + ) # check the numerical value # get pytorch output @@ -233,102 +235,110 @@ def pytorch2onnx(model, # get onnx output input_all = [node.name for node in onnx_model.graph.input] - input_initializer = [ - node.name for node in onnx_model.graph.initializer - ] + input_initializer = [node.name for node in onnx_model.graph.initializer] net_feed_input = list(set(input_all) - set(input_initializer)) - assert (len(net_feed_input) == 1) + assert len(net_feed_input) == 1 sess = rt.InferenceSession(output_file) - onnx_result = sess.run( - None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0] + onnx_result = sess.run(None, {net_feed_input[0]: img_list[0].detach().numpy()})[ + 0 + ][0] # show segmentation results if show: import os.path as osp import cv2 - img = img_meta_list[0][0]['filename'] + + img = img_meta_list[0][0]["filename"] if not osp.exists(img): img = imgs[0][:3, ...].permute(1, 2, 0) * 255 img = img.detach().numpy().astype(np.uint8) ori_shape = img.shape[:2] else: - ori_shape = LoadImage()({'img': img})['ori_shape'] + ori_shape = LoadImage()({"img": img})["ori_shape"] # resize onnx_result to ori_shape - onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8), - (ori_shape[1], ori_shape[0])) + onnx_result_ = cv2.resize( + onnx_result[0].astype(np.uint8), (ori_shape[1], ori_shape[0]) + ) show_result_pyplot( model, - img, (onnx_result_, ), + img, + (onnx_result_,), palette=model.PALETTE, block=False, - title='ONNXRuntime', - opacity=0.5) + title="ONNXRuntime", + opacity=0.5, + ) # resize pytorch_result to ori_shape - pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8), - (ori_shape[1], ori_shape[0])) + pytorch_result_ = cv2.resize( + pytorch_result[0].astype(np.uint8), (ori_shape[1], ori_shape[0]) + ) show_result_pyplot( model, - img, (pytorch_result_, ), - title='PyTorch', + img, + (pytorch_result_,), + title="PyTorch", palette=model.PALETTE, - opacity=0.5) + opacity=0.5, + ) # compare results np.testing.assert_allclose( pytorch_result.astype(np.float32) / num_classes, onnx_result.astype(np.float32) / num_classes, rtol=1e-5, atol=1e-5, - err_msg='The outputs are different between Pytorch and ONNX') - print('The outputs are same between Pytorch and ONNX') + err_msg="The outputs are different between Pytorch and ONNX", + ) + print("The outputs are same between Pytorch and ONNX") def parse_args(): - parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX') - parser.add_argument('config', help='test config file path') - parser.add_argument('--checkpoint', help='checkpoint file', default=None) - parser.add_argument( - '--input-img', type=str, help='Images for input', default=None) - parser.add_argument( - '--show', - action='store_true', - help='show onnx graph and segmentation results') + parser = argparse.ArgumentParser(description="Convert MMSeg to ONNX") + parser.add_argument("config", help="test config file path") + parser.add_argument("--checkpoint", help="checkpoint file", default=None) + parser.add_argument("--input-img", type=str, help="Images for input", default=None) parser.add_argument( - '--verify', action='store_true', help='verify the onnx model') - parser.add_argument('--output-file', type=str, default='tmp.onnx') - parser.add_argument('--opset-version', type=int, default=11) + "--show", action="store_true", help="show onnx graph and segmentation results" + ) + parser.add_argument("--verify", action="store_true", help="verify the onnx model") + parser.add_argument("--output-file", type=str, default="tmp.onnx") + parser.add_argument("--opset-version", type=int, default=11) parser.add_argument( - '--shape', + "--shape", type=int, - nargs='+', + nargs="+", default=None, - help='input image height and width.') + help="input image height and width.", + ) parser.add_argument( - '--rescale_shape', + "--rescale_shape", type=int, - nargs='+', + nargs="+", default=None, - help='output image rescale height and width, work for slide mode.') + help="output image rescale height and width, work for slide mode.", + ) parser.add_argument( - '--cfg-options', - nargs='+', + "--cfg-options", + nargs="+", action=DictAction, - help='Override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into config file. If the value to ' + help="Override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' - 'Note that the quotation marks are necessary and that no white space ' - 'is allowed.') + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) parser.add_argument( - '--dynamic-export', - action='store_true', - help='Whether to export onnx with dynamic axis.') + "--dynamic-export", + action="store_true", + help="Whether to export onnx with dynamic axis.", + ) args = parser.parse_args() return args -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() cfg = mmcv.Config.fromfile(args.config) @@ -337,7 +347,7 @@ def parse_args(): cfg.model.pretrained = None if args.shape is None: - img_scale = cfg.test_pipeline[1]['img_scale'] + img_scale = cfg.test_pipeline[1]["img_scale"] input_shape = (1, 3, img_scale[1], img_scale[0]) elif len(args.shape) == 1: input_shape = (1, 3, args.shape[0], args.shape[0]) @@ -347,22 +357,20 @@ def parse_args(): 3, ) + tuple(args.shape) else: - raise ValueError('invalid input shape') + raise ValueError("invalid input shape") test_mode = cfg.model.test_cfg.mode # build the model and load checkpoint cfg.model.train_cfg = None - segmentor = build_segmentor( - cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) + segmentor = build_segmentor(cfg.model, train_cfg=None, test_cfg=cfg.get("test_cfg")) # convert SyncBN to BN segmentor = _convert_batchnorm(segmentor) if args.checkpoint: - checkpoint = load_checkpoint( - segmentor, args.checkpoint, map_location='cpu') - segmentor.CLASSES = checkpoint['meta']['CLASSES'] - segmentor.PALETTE = checkpoint['meta']['PALETTE'] + checkpoint = load_checkpoint(segmentor, args.checkpoint, map_location="cpu") + segmentor.CLASSES = checkpoint["meta"]["CLASSES"] + segmentor.PALETTE = checkpoint["meta"]["PALETTE"] # read input or create dummpy input if args.input_img is not None: @@ -374,7 +382,8 @@ def parse_args(): args.input_img, cfg.data.test.pipeline, shape=preprocess_shape, - rescale_shape=rescale_shape) + rescale_shape=rescale_shape, + ) else: if isinstance(segmentor.decode_head, nn.ModuleList): num_classes = segmentor.decode_head[-1].num_classes @@ -390,16 +399,17 @@ def parse_args(): show=args.show, output_file=args.output_file, verify=args.verify, - dynamic_export=args.dynamic_export) + dynamic_export=args.dynamic_export, + ) # Following strings of text style are from colorama package - bright_style, reset_style = '\x1b[1m', '\x1b[0m' - red_text, blue_text = '\x1b[31m', '\x1b[34m' - white_background = '\x1b[107m' + bright_style, reset_style = "\x1b[1m", "\x1b[0m" + red_text, blue_text = "\x1b[31m", "\x1b[34m" + white_background = "\x1b[107m" msg = white_background + bright_style + red_text - msg += 'DeprecationWarning: This tool will be deprecated in future. ' - msg += blue_text + 'Welcome to use the unified model deployment toolbox ' - msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy' + msg += "DeprecationWarning: This tool will be deprecated in future. " + msg += blue_text + "Welcome to use the unified model deployment toolbox " + msg += "MMDeploy: https://github.com/open-mmlab/mmdeploy" msg += reset_style warnings.warn(msg) diff --git a/mmsegmentation/tools/pytorch2torchscript.py b/mmsegmentation/tools/pytorch2torchscript.py index d76f5ec..5d40089 100644 --- a/mmsegmentation/tools/pytorch2torchscript.py +++ b/mmsegmentation/tools/pytorch2torchscript.py @@ -16,31 +16,36 @@ def digit_version(version_str): digit_version = [] - for x in version_str.split('.'): + for x in version_str.split("."): if x.isdigit(): digit_version.append(int(x)) - elif x.find('rc') != -1: - patch_version = x.split('rc') + elif x.find("rc") != -1: + patch_version = x.split("rc") digit_version.append(int(patch_version[0]) - 1) digit_version.append(int(patch_version[1])) return digit_version def check_torch_version(): - torch_minimum_version = '1.8.0' + torch_minimum_version = "1.8.0" torch_version = digit_version(torch.__version__) - assert (torch_version >= digit_version(torch_minimum_version)), \ - f'Torch=={torch.__version__} is not support for converting to ' \ - f'torchscript. Please install pytorch>={torch_minimum_version}.' + assert torch_version >= digit_version(torch_minimum_version), ( + f"Torch=={torch.__version__} is not support for converting to " + f"torchscript. Please install pytorch>={torch_minimum_version}." + ) def _convert_batchnorm(module): module_output = module if isinstance(module, torch.nn.SyncBatchNorm): - module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, - module.momentum, module.affine, - module.track_running_stats) + module_output = torch.nn.BatchNorm2d( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + ) if module.affine: module_output.weight.data = module.weight.data.clone().detach() module_output.bias.data = module.bias.data.clone().detach() @@ -68,29 +73,29 @@ def _demo_mm_inputs(input_shape, num_classes): (N, C, H, W) = input_shape rng = np.random.RandomState(0) imgs = rng.rand(*input_shape) - segs = rng.randint( - low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) - img_metas = [{ - 'img_shape': (H, W, C), - 'ori_shape': (H, W, C), - 'pad_shape': (H, W, C), - 'filename': '.png', - 'scale_factor': 1.0, - 'flip': False, - } for _ in range(N)] + segs = rng.randint(low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) + img_metas = [ + { + "img_shape": (H, W, C), + "ori_shape": (H, W, C), + "pad_shape": (H, W, C), + "filename": ".png", + "scale_factor": 1.0, + "flip": False, + } + for _ in range(N) + ] mm_inputs = { - 'imgs': torch.FloatTensor(imgs).requires_grad_(True), - 'img_metas': img_metas, - 'gt_semantic_seg': torch.LongTensor(segs) + "imgs": torch.FloatTensor(imgs).requires_grad_(True), + "img_metas": img_metas, + "gt_semantic_seg": torch.LongTensor(segs), } return mm_inputs -def pytorch2libtorch(model, - input_shape, - show=False, - output_file='tmp.pt', - verify=False): +def pytorch2libtorch( + model, input_shape, show=False, output_file="tmp.pt", verify=False +): """Export Pytorch model to TorchScript model and verify the outputs are same between Pytorch and TorchScript. @@ -111,7 +116,7 @@ def pytorch2libtorch(model, mm_inputs = _demo_mm_inputs(input_shape, num_classes) - imgs = mm_inputs.pop('imgs') + imgs = mm_inputs.pop("imgs") # replace the original forword with forward_dummy model.forward = model.forward_dummy @@ -126,30 +131,30 @@ def pytorch2libtorch(model, print(traced_model.graph) traced_model.save(output_file) - print('Successfully exported TorchScript model: {}'.format(output_file)) + print(f"Successfully exported TorchScript model: {output_file}") def parse_args(): - parser = argparse.ArgumentParser( - description='Convert MMSeg to TorchScript') - parser.add_argument('config', help='test config file path') - parser.add_argument('--checkpoint', help='checkpoint file', default=None) - parser.add_argument( - '--show', action='store_true', help='show TorchScript graph') + parser = argparse.ArgumentParser(description="Convert MMSeg to TorchScript") + parser.add_argument("config", help="test config file path") + parser.add_argument("--checkpoint", help="checkpoint file", default=None) + parser.add_argument("--show", action="store_true", help="show TorchScript graph") parser.add_argument( - '--verify', action='store_true', help='verify the TorchScript model') - parser.add_argument('--output-file', type=str, default='tmp.pt') + "--verify", action="store_true", help="verify the TorchScript model" + ) + parser.add_argument("--output-file", type=str, default="tmp.pt") parser.add_argument( - '--shape', + "--shape", type=int, - nargs='+', + nargs="+", default=[512, 512], - help='input image size (height, width)') + help="input image size (height, width)", + ) args = parser.parse_args() return args -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() check_torch_version() @@ -161,20 +166,19 @@ def parse_args(): 3, ) + tuple(args.shape) else: - raise ValueError('invalid input shape') + raise ValueError("invalid input shape") cfg = mmcv.Config.fromfile(args.config) cfg.model.pretrained = None # build the model and load checkpoint cfg.model.train_cfg = None - segmentor = build_segmentor( - cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) + segmentor = build_segmentor(cfg.model, train_cfg=None, test_cfg=cfg.get("test_cfg")) # convert SyncBN to BN segmentor = _convert_batchnorm(segmentor) if args.checkpoint: - load_checkpoint(segmentor, args.checkpoint, map_location='cpu') + load_checkpoint(segmentor, args.checkpoint, map_location="cpu") # convert the PyTorch model to LibTorch model pytorch2libtorch( @@ -182,4 +186,5 @@ def parse_args(): input_shape, show=args.show, output_file=args.output_file, - verify=args.verify) + verify=args.verify, + ) diff --git a/mmsegmentation/tools/test.py b/mmsegmentation/tools/test.py index a643b08..ded55cb 100644 --- a/mmsegmentation/tools/test.py +++ b/mmsegmentation/tools/test.py @@ -9,8 +9,7 @@ import mmcv import torch from mmcv.cnn.utils import revert_sync_batchnorm -from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, - wrap_fp16_model) +from mmcv.runner import get_dist_info, init_dist, load_checkpoint, wrap_fp16_model from mmcv.utils import DictAction from mmseg import digit_version @@ -21,95 +20,111 @@ def parse_args(): - parser = argparse.ArgumentParser( - description='mmseg test (and eval) a model') - parser.add_argument('config', help='test config file path') - parser.add_argument('checkpoint', help='checkpoint file') + parser = argparse.ArgumentParser(description="mmseg test (and eval) a model") + parser.add_argument("config", help="test config file path") + parser.add_argument("checkpoint", help="checkpoint file") parser.add_argument( - '--work-dir', - help=('if specified, the evaluation metric results will be dumped' - 'into the directory as json')) + "--work-dir", + help=( + "if specified, the evaluation metric results will be dumped" + "into the directory as json" + ), + ) parser.add_argument( - '--aug-test', action='store_true', help='Use Flip and Multi scale aug') - parser.add_argument('--out', help='output result file in pickle format') + "--aug-test", action="store_true", help="Use Flip and Multi scale aug" + ) + parser.add_argument("--out", help="output result file in pickle format") parser.add_argument( - '--format-only', - action='store_true', - help='Format the output results without perform evaluation. It is' - 'useful when you want to format the result to a specific format and ' - 'submit it to the test server') + "--format-only", + action="store_true", + help="Format the output results without perform evaluation. It is" + "useful when you want to format the result to a specific format and " + "submit it to the test server", + ) parser.add_argument( - '--eval', + "--eval", type=str, - nargs='+', + nargs="+", help='evaluation metrics, which depends on the dataset, e.g., "mIoU"' - ' for generic datasets, and "cityscapes" for Cityscapes') - parser.add_argument('--show', action='store_true', help='show results') + ' for generic datasets, and "cityscapes" for Cityscapes', + ) + parser.add_argument("--show", action="store_true", help="show results") parser.add_argument( - '--show-dir', help='directory where painted images will be saved') + "--show-dir", help="directory where painted images will be saved" + ) parser.add_argument( - '--gpu-collect', - action='store_true', - help='whether to use gpu to collect results.') + "--gpu-collect", + action="store_true", + help="whether to use gpu to collect results.", + ) parser.add_argument( - '--gpu-id', + "--gpu-id", type=int, default=0, - help='id of gpu to use ' - '(only applicable to non-distributed testing)') + help="id of gpu to use " "(only applicable to non-distributed testing)", + ) parser.add_argument( - '--tmpdir', - help='tmp directory used for collecting results from multiple ' - 'workers, available when gpu_collect is not specified') + "--tmpdir", + help="tmp directory used for collecting results from multiple " + "workers, available when gpu_collect is not specified", + ) parser.add_argument( - '--options', - nargs='+', + "--options", + nargs="+", action=DictAction, help="--options is deprecated in favor of --cfg_options' and it will " - 'not be supported in version v0.22.0. Override some settings in the ' - 'used config, the key-value pair in xxx=yyy format will be merged ' - 'into config file. If the value to be overwritten is a list, it ' + "not be supported in version v0.22.0. Override some settings in the " + "used config, the key-value pair in xxx=yyy format will be merged " + "into config file. If the value to be overwritten is a list, it " 'should be like key="[a,b]" or key=a,b It also allows nested ' 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' - 'marks are necessary and that no white space is allowed.') + "marks are necessary and that no white space is allowed.", + ) parser.add_argument( - '--cfg-options', - nargs='+', + "--cfg-options", + nargs="+", action=DictAction, - help='override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into config file. If the value to ' + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' - 'Note that the quotation marks are necessary and that no white space ' - 'is allowed.') + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) parser.add_argument( - '--eval-options', - nargs='+', + "--eval-options", + nargs="+", action=DictAction, - help='custom options for evaluation') + help="custom options for evaluation", + ) parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') + "--launcher", + choices=["none", "pytorch", "slurm", "mpi"], + default="none", + help="job launcher", + ) parser.add_argument( - '--opacity', + "--opacity", type=float, default=0.5, - help='Opacity of painted segmentation map. In (0, 1] range.') - parser.add_argument('--local_rank', type=int, default=0) + help="Opacity of painted segmentation map. In (0, 1] range.", + ) + parser.add_argument("--local_rank", type=int, default=0) args = parser.parse_args() - if 'LOCAL_RANK' not in os.environ: - os.environ['LOCAL_RANK'] = str(args.local_rank) + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(args.local_rank) if args.options and args.cfg_options: raise ValueError( - '--options and --cfg-options cannot be both ' - 'specified, --options is deprecated in favor of --cfg-options. ' - '--options will not be supported in version v0.22.0.') + "--options and --cfg-options cannot be both " + "specified, --options is deprecated in favor of --cfg-options. " + "--options will not be supported in version v0.22.0." + ) if args.options: - warnings.warn('--options is deprecated in favor of --cfg-options. ' - '--options will not be supported in version v0.22.0.') + warnings.warn( + "--options is deprecated in favor of --cfg-options. " + "--options will not be supported in version v0.22.0." + ) args.cfg_options = args.options return args @@ -117,17 +132,17 @@ def parse_args(): def main(): args = parse_args() - assert args.out or args.eval or args.format_only or args.show \ - or args.show_dir, \ - ('Please specify at least one operation (save/eval/format/show the ' - 'results / save the results) with the argument "--out", "--eval"' - ', "--format-only", "--show" or "--show-dir"') + assert args.out or args.eval or args.format_only or args.show or args.show_dir, ( + "Please specify at least one operation (save/eval/format/show the " + 'results / save the results) with the argument "--out", "--eval"' + ', "--format-only", "--show" or "--show-dir"' + ) if args.eval and args.format_only: - raise ValueError('--eval and --format_only cannot be both specified') + raise ValueError("--eval and --format_only cannot be both specified") - if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): - raise ValueError('The output file must be a pkl file.') + if args.out is not None and not args.out.endswith((".pkl", ".pickle")): + raise ValueError("The output file must be a pkl file.") cfg = mmcv.Config.fromfile(args.config) if args.cfg_options is not None: @@ -137,13 +152,11 @@ def main(): setup_multi_processes(cfg) # set cudnn_benchmark - if cfg.get('cudnn_benchmark', False): + if cfg.get("cudnn_benchmark", False): torch.backends.cudnn.benchmark = True if args.aug_test: # hard code index - cfg.data.test.pipeline[1].img_ratios = [ - 0.5, 0.75, 1.0, 1.25, 1.5, 1.75 - ] + cfg.data.test.pipeline[1].img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] cfg.data.test.pipeline[1].flip = True cfg.model.pretrained = None cfg.data.test.test_mode = True @@ -152,13 +165,15 @@ def main(): cfg.gpu_ids = [args.gpu_id] # init distributed env first, since logger depends on the dist info. - if args.launcher == 'none': + if args.launcher == "none": cfg.gpu_ids = [args.gpu_id] distributed = False if len(cfg.gpu_ids) > 1: - warnings.warn(f'The gpu-ids is reset from {cfg.gpu_ids} to ' - f'{cfg.gpu_ids[0:1]} to avoid potential error in ' - 'non-distribute testing time.') + warnings.warn( + f"The gpu-ids is reset from {cfg.gpu_ids} to " + f"{cfg.gpu_ids[0:1]} to avoid potential error in " + "non-distribute testing time." + ) cfg.gpu_ids = cfg.gpu_ids[0:1] else: distributed = True @@ -168,24 +183,19 @@ def main(): # allows not to create if args.work_dir is not None and rank == 0: mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) if args.aug_test: - json_file = osp.join(args.work_dir, - f'eval_multi_scale_{timestamp}.json') + json_file = osp.join(args.work_dir, f"eval_multi_scale_{timestamp}.json") else: - json_file = osp.join(args.work_dir, - f'eval_single_scale_{timestamp}.json') + json_file = osp.join(args.work_dir, f"eval_single_scale_{timestamp}.json") elif rank == 0: - work_dir = osp.join('./work_dirs', - osp.splitext(osp.basename(args.config))[0]) + work_dir = osp.join("./work_dirs", osp.splitext(osp.basename(args.config))[0]) mmcv.mkdir_or_exist(osp.abspath(work_dir)) - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) if args.aug_test: - json_file = osp.join(work_dir, - f'eval_multi_scale_{timestamp}.json') + json_file = osp.join(work_dir, f"eval_multi_scale_{timestamp}.json") else: - json_file = osp.join(work_dir, - f'eval_single_scale_{timestamp}.json') + json_file = osp.join(work_dir, f"eval_single_scale_{timestamp}.json") # build the dataloader # TODO: support multiple images per gpu (only minor changes are needed) @@ -195,38 +205,47 @@ def main(): # cfg.gpus will be ignored if distributed num_gpus=len(cfg.gpu_ids), dist=distributed, - shuffle=False) + shuffle=False, + ) # The overall dataloader settings - loader_cfg.update({ - k: v - for k, v in cfg.data.items() if k not in [ - 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', - 'test_dataloader' - ] - }) + loader_cfg.update( + { + k: v + for k, v in cfg.data.items() + if k + not in [ + "train", + "val", + "test", + "train_dataloader", + "val_dataloader", + "test_dataloader", + ] + } + ) test_loader_cfg = { **loader_cfg, - 'samples_per_gpu': 1, - 'shuffle': False, # Not shuffle by default - **cfg.data.get('test_dataloader', {}) + "samples_per_gpu": 1, + "shuffle": False, # Not shuffle by default + **cfg.data.get("test_dataloader", {}), } # build the dataloader data_loader = build_dataloader(dataset, **test_loader_cfg) # build the model and load checkpoint cfg.model.train_cfg = None - model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) - fp16_cfg = cfg.get('fp16', None) + model = build_segmentor(cfg.model, test_cfg=cfg.get("test_cfg")) + fp16_cfg = cfg.get("fp16", None) if fp16_cfg is not None: wrap_fp16_model(model) - checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') - if 'CLASSES' in checkpoint.get('meta', {}): - model.CLASSES = checkpoint['meta']['CLASSES'] + checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu") + if "CLASSES" in checkpoint.get("meta", {}): + model.CLASSES = checkpoint["meta"]["CLASSES"] else: print('"CLASSES" not found in meta, use dataset.CLASSES instead') model.CLASSES = dataset.CLASSES - if 'PALETTE' in checkpoint.get('meta', {}): - model.PALETTE = checkpoint['meta']['PALETTE'] + if "PALETTE" in checkpoint.get("meta", {}): + model.PALETTE = checkpoint["meta"]["PALETTE"] else: print('"PALETTE" not found in meta, use dataset.PALETTE instead') model.PALETTE = dataset.PALETTE @@ -236,25 +255,27 @@ def main(): eval_kwargs = {} if args.eval_options is None else args.eval_options # Deprecated - efficient_test = eval_kwargs.get('efficient_test', False) + efficient_test = eval_kwargs.get("efficient_test", False) if efficient_test: warnings.warn( - '``efficient_test=True`` does not have effect in tools/test.py, ' - 'the evaluation and format results are CPU memory efficient by ' - 'default') + "``efficient_test=True`` does not have effect in tools/test.py, " + "the evaluation and format results are CPU memory efficient by " + "default" + ) - eval_on_format_results = ( - args.eval is not None and 'cityscapes' in args.eval) + eval_on_format_results = args.eval is not None and "cityscapes" in args.eval if eval_on_format_results: - assert len(args.eval) == 1, 'eval on format results is not ' \ - 'applicable for metrics other than ' \ - 'cityscapes' + assert len(args.eval) == 1, ( + "eval on format results is not " + "applicable for metrics other than " + "cityscapes" + ) if args.format_only or eval_on_format_results: - if 'imgfile_prefix' in eval_kwargs: - tmpdir = eval_kwargs['imgfile_prefix'] + if "imgfile_prefix" in eval_kwargs: + tmpdir = eval_kwargs["imgfile_prefix"] else: - tmpdir = '.format_cityscapes' - eval_kwargs.setdefault('imgfile_prefix', tmpdir) + tmpdir = ".format_cityscapes" + eval_kwargs.setdefault("imgfile_prefix", tmpdir) mmcv.mkdir_or_exist(tmpdir) else: tmpdir = None @@ -262,12 +283,14 @@ def main(): cfg.device = get_device() if not distributed: warnings.warn( - 'SyncBN is only supported with DDP. To be compatible with DP, ' - 'we convert SyncBN to BN. Please use dist_train.sh which can ' - 'avoid this error.') + "SyncBN is only supported with DDP. To be compatible with DP, " + "we convert SyncBN to BN. Please use dist_train.sh which can " + "avoid this error." + ) if not torch.cuda.is_available(): - assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \ - 'Please use MMCV >= 1.4.4 for CPU training!' + assert digit_version(mmcv.__version__) >= digit_version( + "1.4.4" + ), "Please use MMCV >= 1.4.4 for CPU training!" model = revert_sync_batchnorm(model) model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids) results = single_gpu_test( @@ -279,13 +302,15 @@ def main(): args.opacity, pre_eval=args.eval is not None and not eval_on_format_results, format_only=args.format_only or eval_on_format_results, - format_args=eval_kwargs) + format_args=eval_kwargs, + ) else: model = build_ddp( model, cfg.device, - device_ids=[int(os.environ['LOCAL_RANK'])], - broadcast_buffers=False) + device_ids=[int(os.environ["LOCAL_RANK"])], + broadcast_buffers=False, + ) results = multi_gpu_test( model, data_loader, @@ -294,17 +319,19 @@ def main(): False, pre_eval=args.eval is not None and not eval_on_format_results, format_only=args.format_only or eval_on_format_results, - format_args=eval_kwargs) + format_args=eval_kwargs, + ) rank, _ = get_dist_info() if rank == 0: if args.out: warnings.warn( - 'The behavior of ``args.out`` has been changed since MMSeg ' - 'v0.16, the pickled outputs could be seg map as type of ' - 'np.array, pre-eval results or file paths for ' - '``dataset.format_results()``.') - print(f'\nwriting results to {args.out}') + "The behavior of ``args.out`` has been changed since MMSeg " + "v0.16, the pickled outputs could be seg map as type of " + "np.array, pre-eval results or file paths for " + "``dataset.format_results()``." + ) + print(f"\nwriting results to {args.out}") mmcv.dump(results, args.out) if args.eval: eval_kwargs.update(metric=args.eval) @@ -316,5 +343,5 @@ def main(): shutil.rmtree(tmpdir) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/test_jwseo.py b/mmsegmentation/tools/test_jwseo.py index a643b08..ded55cb 100644 --- a/mmsegmentation/tools/test_jwseo.py +++ b/mmsegmentation/tools/test_jwseo.py @@ -9,8 +9,7 @@ import mmcv import torch from mmcv.cnn.utils import revert_sync_batchnorm -from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, - wrap_fp16_model) +from mmcv.runner import get_dist_info, init_dist, load_checkpoint, wrap_fp16_model from mmcv.utils import DictAction from mmseg import digit_version @@ -21,95 +20,111 @@ def parse_args(): - parser = argparse.ArgumentParser( - description='mmseg test (and eval) a model') - parser.add_argument('config', help='test config file path') - parser.add_argument('checkpoint', help='checkpoint file') + parser = argparse.ArgumentParser(description="mmseg test (and eval) a model") + parser.add_argument("config", help="test config file path") + parser.add_argument("checkpoint", help="checkpoint file") parser.add_argument( - '--work-dir', - help=('if specified, the evaluation metric results will be dumped' - 'into the directory as json')) + "--work-dir", + help=( + "if specified, the evaluation metric results will be dumped" + "into the directory as json" + ), + ) parser.add_argument( - '--aug-test', action='store_true', help='Use Flip and Multi scale aug') - parser.add_argument('--out', help='output result file in pickle format') + "--aug-test", action="store_true", help="Use Flip and Multi scale aug" + ) + parser.add_argument("--out", help="output result file in pickle format") parser.add_argument( - '--format-only', - action='store_true', - help='Format the output results without perform evaluation. It is' - 'useful when you want to format the result to a specific format and ' - 'submit it to the test server') + "--format-only", + action="store_true", + help="Format the output results without perform evaluation. It is" + "useful when you want to format the result to a specific format and " + "submit it to the test server", + ) parser.add_argument( - '--eval', + "--eval", type=str, - nargs='+', + nargs="+", help='evaluation metrics, which depends on the dataset, e.g., "mIoU"' - ' for generic datasets, and "cityscapes" for Cityscapes') - parser.add_argument('--show', action='store_true', help='show results') + ' for generic datasets, and "cityscapes" for Cityscapes', + ) + parser.add_argument("--show", action="store_true", help="show results") parser.add_argument( - '--show-dir', help='directory where painted images will be saved') + "--show-dir", help="directory where painted images will be saved" + ) parser.add_argument( - '--gpu-collect', - action='store_true', - help='whether to use gpu to collect results.') + "--gpu-collect", + action="store_true", + help="whether to use gpu to collect results.", + ) parser.add_argument( - '--gpu-id', + "--gpu-id", type=int, default=0, - help='id of gpu to use ' - '(only applicable to non-distributed testing)') + help="id of gpu to use " "(only applicable to non-distributed testing)", + ) parser.add_argument( - '--tmpdir', - help='tmp directory used for collecting results from multiple ' - 'workers, available when gpu_collect is not specified') + "--tmpdir", + help="tmp directory used for collecting results from multiple " + "workers, available when gpu_collect is not specified", + ) parser.add_argument( - '--options', - nargs='+', + "--options", + nargs="+", action=DictAction, help="--options is deprecated in favor of --cfg_options' and it will " - 'not be supported in version v0.22.0. Override some settings in the ' - 'used config, the key-value pair in xxx=yyy format will be merged ' - 'into config file. If the value to be overwritten is a list, it ' + "not be supported in version v0.22.0. Override some settings in the " + "used config, the key-value pair in xxx=yyy format will be merged " + "into config file. If the value to be overwritten is a list, it " 'should be like key="[a,b]" or key=a,b It also allows nested ' 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' - 'marks are necessary and that no white space is allowed.') + "marks are necessary and that no white space is allowed.", + ) parser.add_argument( - '--cfg-options', - nargs='+', + "--cfg-options", + nargs="+", action=DictAction, - help='override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into config file. If the value to ' + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' - 'Note that the quotation marks are necessary and that no white space ' - 'is allowed.') + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) parser.add_argument( - '--eval-options', - nargs='+', + "--eval-options", + nargs="+", action=DictAction, - help='custom options for evaluation') + help="custom options for evaluation", + ) parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') + "--launcher", + choices=["none", "pytorch", "slurm", "mpi"], + default="none", + help="job launcher", + ) parser.add_argument( - '--opacity', + "--opacity", type=float, default=0.5, - help='Opacity of painted segmentation map. In (0, 1] range.') - parser.add_argument('--local_rank', type=int, default=0) + help="Opacity of painted segmentation map. In (0, 1] range.", + ) + parser.add_argument("--local_rank", type=int, default=0) args = parser.parse_args() - if 'LOCAL_RANK' not in os.environ: - os.environ['LOCAL_RANK'] = str(args.local_rank) + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(args.local_rank) if args.options and args.cfg_options: raise ValueError( - '--options and --cfg-options cannot be both ' - 'specified, --options is deprecated in favor of --cfg-options. ' - '--options will not be supported in version v0.22.0.') + "--options and --cfg-options cannot be both " + "specified, --options is deprecated in favor of --cfg-options. " + "--options will not be supported in version v0.22.0." + ) if args.options: - warnings.warn('--options is deprecated in favor of --cfg-options. ' - '--options will not be supported in version v0.22.0.') + warnings.warn( + "--options is deprecated in favor of --cfg-options. " + "--options will not be supported in version v0.22.0." + ) args.cfg_options = args.options return args @@ -117,17 +132,17 @@ def parse_args(): def main(): args = parse_args() - assert args.out or args.eval or args.format_only or args.show \ - or args.show_dir, \ - ('Please specify at least one operation (save/eval/format/show the ' - 'results / save the results) with the argument "--out", "--eval"' - ', "--format-only", "--show" or "--show-dir"') + assert args.out or args.eval or args.format_only or args.show or args.show_dir, ( + "Please specify at least one operation (save/eval/format/show the " + 'results / save the results) with the argument "--out", "--eval"' + ', "--format-only", "--show" or "--show-dir"' + ) if args.eval and args.format_only: - raise ValueError('--eval and --format_only cannot be both specified') + raise ValueError("--eval and --format_only cannot be both specified") - if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): - raise ValueError('The output file must be a pkl file.') + if args.out is not None and not args.out.endswith((".pkl", ".pickle")): + raise ValueError("The output file must be a pkl file.") cfg = mmcv.Config.fromfile(args.config) if args.cfg_options is not None: @@ -137,13 +152,11 @@ def main(): setup_multi_processes(cfg) # set cudnn_benchmark - if cfg.get('cudnn_benchmark', False): + if cfg.get("cudnn_benchmark", False): torch.backends.cudnn.benchmark = True if args.aug_test: # hard code index - cfg.data.test.pipeline[1].img_ratios = [ - 0.5, 0.75, 1.0, 1.25, 1.5, 1.75 - ] + cfg.data.test.pipeline[1].img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] cfg.data.test.pipeline[1].flip = True cfg.model.pretrained = None cfg.data.test.test_mode = True @@ -152,13 +165,15 @@ def main(): cfg.gpu_ids = [args.gpu_id] # init distributed env first, since logger depends on the dist info. - if args.launcher == 'none': + if args.launcher == "none": cfg.gpu_ids = [args.gpu_id] distributed = False if len(cfg.gpu_ids) > 1: - warnings.warn(f'The gpu-ids is reset from {cfg.gpu_ids} to ' - f'{cfg.gpu_ids[0:1]} to avoid potential error in ' - 'non-distribute testing time.') + warnings.warn( + f"The gpu-ids is reset from {cfg.gpu_ids} to " + f"{cfg.gpu_ids[0:1]} to avoid potential error in " + "non-distribute testing time." + ) cfg.gpu_ids = cfg.gpu_ids[0:1] else: distributed = True @@ -168,24 +183,19 @@ def main(): # allows not to create if args.work_dir is not None and rank == 0: mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) if args.aug_test: - json_file = osp.join(args.work_dir, - f'eval_multi_scale_{timestamp}.json') + json_file = osp.join(args.work_dir, f"eval_multi_scale_{timestamp}.json") else: - json_file = osp.join(args.work_dir, - f'eval_single_scale_{timestamp}.json') + json_file = osp.join(args.work_dir, f"eval_single_scale_{timestamp}.json") elif rank == 0: - work_dir = osp.join('./work_dirs', - osp.splitext(osp.basename(args.config))[0]) + work_dir = osp.join("./work_dirs", osp.splitext(osp.basename(args.config))[0]) mmcv.mkdir_or_exist(osp.abspath(work_dir)) - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) if args.aug_test: - json_file = osp.join(work_dir, - f'eval_multi_scale_{timestamp}.json') + json_file = osp.join(work_dir, f"eval_multi_scale_{timestamp}.json") else: - json_file = osp.join(work_dir, - f'eval_single_scale_{timestamp}.json') + json_file = osp.join(work_dir, f"eval_single_scale_{timestamp}.json") # build the dataloader # TODO: support multiple images per gpu (only minor changes are needed) @@ -195,38 +205,47 @@ def main(): # cfg.gpus will be ignored if distributed num_gpus=len(cfg.gpu_ids), dist=distributed, - shuffle=False) + shuffle=False, + ) # The overall dataloader settings - loader_cfg.update({ - k: v - for k, v in cfg.data.items() if k not in [ - 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', - 'test_dataloader' - ] - }) + loader_cfg.update( + { + k: v + for k, v in cfg.data.items() + if k + not in [ + "train", + "val", + "test", + "train_dataloader", + "val_dataloader", + "test_dataloader", + ] + } + ) test_loader_cfg = { **loader_cfg, - 'samples_per_gpu': 1, - 'shuffle': False, # Not shuffle by default - **cfg.data.get('test_dataloader', {}) + "samples_per_gpu": 1, + "shuffle": False, # Not shuffle by default + **cfg.data.get("test_dataloader", {}), } # build the dataloader data_loader = build_dataloader(dataset, **test_loader_cfg) # build the model and load checkpoint cfg.model.train_cfg = None - model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) - fp16_cfg = cfg.get('fp16', None) + model = build_segmentor(cfg.model, test_cfg=cfg.get("test_cfg")) + fp16_cfg = cfg.get("fp16", None) if fp16_cfg is not None: wrap_fp16_model(model) - checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') - if 'CLASSES' in checkpoint.get('meta', {}): - model.CLASSES = checkpoint['meta']['CLASSES'] + checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu") + if "CLASSES" in checkpoint.get("meta", {}): + model.CLASSES = checkpoint["meta"]["CLASSES"] else: print('"CLASSES" not found in meta, use dataset.CLASSES instead') model.CLASSES = dataset.CLASSES - if 'PALETTE' in checkpoint.get('meta', {}): - model.PALETTE = checkpoint['meta']['PALETTE'] + if "PALETTE" in checkpoint.get("meta", {}): + model.PALETTE = checkpoint["meta"]["PALETTE"] else: print('"PALETTE" not found in meta, use dataset.PALETTE instead') model.PALETTE = dataset.PALETTE @@ -236,25 +255,27 @@ def main(): eval_kwargs = {} if args.eval_options is None else args.eval_options # Deprecated - efficient_test = eval_kwargs.get('efficient_test', False) + efficient_test = eval_kwargs.get("efficient_test", False) if efficient_test: warnings.warn( - '``efficient_test=True`` does not have effect in tools/test.py, ' - 'the evaluation and format results are CPU memory efficient by ' - 'default') + "``efficient_test=True`` does not have effect in tools/test.py, " + "the evaluation and format results are CPU memory efficient by " + "default" + ) - eval_on_format_results = ( - args.eval is not None and 'cityscapes' in args.eval) + eval_on_format_results = args.eval is not None and "cityscapes" in args.eval if eval_on_format_results: - assert len(args.eval) == 1, 'eval on format results is not ' \ - 'applicable for metrics other than ' \ - 'cityscapes' + assert len(args.eval) == 1, ( + "eval on format results is not " + "applicable for metrics other than " + "cityscapes" + ) if args.format_only or eval_on_format_results: - if 'imgfile_prefix' in eval_kwargs: - tmpdir = eval_kwargs['imgfile_prefix'] + if "imgfile_prefix" in eval_kwargs: + tmpdir = eval_kwargs["imgfile_prefix"] else: - tmpdir = '.format_cityscapes' - eval_kwargs.setdefault('imgfile_prefix', tmpdir) + tmpdir = ".format_cityscapes" + eval_kwargs.setdefault("imgfile_prefix", tmpdir) mmcv.mkdir_or_exist(tmpdir) else: tmpdir = None @@ -262,12 +283,14 @@ def main(): cfg.device = get_device() if not distributed: warnings.warn( - 'SyncBN is only supported with DDP. To be compatible with DP, ' - 'we convert SyncBN to BN. Please use dist_train.sh which can ' - 'avoid this error.') + "SyncBN is only supported with DDP. To be compatible with DP, " + "we convert SyncBN to BN. Please use dist_train.sh which can " + "avoid this error." + ) if not torch.cuda.is_available(): - assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \ - 'Please use MMCV >= 1.4.4 for CPU training!' + assert digit_version(mmcv.__version__) >= digit_version( + "1.4.4" + ), "Please use MMCV >= 1.4.4 for CPU training!" model = revert_sync_batchnorm(model) model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids) results = single_gpu_test( @@ -279,13 +302,15 @@ def main(): args.opacity, pre_eval=args.eval is not None and not eval_on_format_results, format_only=args.format_only or eval_on_format_results, - format_args=eval_kwargs) + format_args=eval_kwargs, + ) else: model = build_ddp( model, cfg.device, - device_ids=[int(os.environ['LOCAL_RANK'])], - broadcast_buffers=False) + device_ids=[int(os.environ["LOCAL_RANK"])], + broadcast_buffers=False, + ) results = multi_gpu_test( model, data_loader, @@ -294,17 +319,19 @@ def main(): False, pre_eval=args.eval is not None and not eval_on_format_results, format_only=args.format_only or eval_on_format_results, - format_args=eval_kwargs) + format_args=eval_kwargs, + ) rank, _ = get_dist_info() if rank == 0: if args.out: warnings.warn( - 'The behavior of ``args.out`` has been changed since MMSeg ' - 'v0.16, the pickled outputs could be seg map as type of ' - 'np.array, pre-eval results or file paths for ' - '``dataset.format_results()``.') - print(f'\nwriting results to {args.out}') + "The behavior of ``args.out`` has been changed since MMSeg " + "v0.16, the pickled outputs could be seg map as type of " + "np.array, pre-eval results or file paths for " + "``dataset.format_results()``." + ) + print(f"\nwriting results to {args.out}") mmcv.dump(results, args.out) if args.eval: eval_kwargs.update(metric=args.eval) @@ -316,5 +343,5 @@ def main(): shutil.rmtree(tmpdir) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/torchserve/mmseg2torchserve.py b/mmsegmentation/tools/torchserve/mmseg2torchserve.py index 9063634..8d9eb02 100644 --- a/mmsegmentation/tools/torchserve/mmseg2torchserve.py +++ b/mmsegmentation/tools/torchserve/mmseg2torchserve.py @@ -17,7 +17,7 @@ def mmseg2torchserve( checkpoint_file: str, output_folder: str, model_name: str, - model_version: str = '1.0', + model_version: str = "1.0", force: bool = False, ): """Converts mmsegmentation model (config + checkpoint) to TorchServe @@ -48,64 +48,75 @@ def mmseg2torchserve( config = mmcv.Config.fromfile(config_file) with TemporaryDirectory() as tmpdir: - config.dump(f'{tmpdir}/config.py') + config.dump(f"{tmpdir}/config.py") args = Namespace( **{ - 'model_file': f'{tmpdir}/config.py', - 'serialized_file': checkpoint_file, - 'handler': f'{Path(__file__).parent}/mmseg_handler.py', - 'model_name': model_name or Path(checkpoint_file).stem, - 'version': model_version, - 'export_path': output_folder, - 'force': force, - 'requirements_file': None, - 'extra_files': None, - 'runtime': 'python', - 'archive_format': 'default' - }) + "model_file": f"{tmpdir}/config.py", + "serialized_file": checkpoint_file, + "handler": f"{Path(__file__).parent}/mmseg_handler.py", + "model_name": model_name or Path(checkpoint_file).stem, + "version": model_version, + "export_path": output_folder, + "force": force, + "requirements_file": None, + "extra_files": None, + "runtime": "python", + "archive_format": "default", + } + ) manifest = ModelExportUtils.generate_manifest_json(args) package_model(args, manifest) def parse_args(): parser = ArgumentParser( - description='Convert mmseg models to TorchServe `.mar` format.') - parser.add_argument('config', type=str, help='config file path') - parser.add_argument('checkpoint', type=str, help='checkpoint file path') + description="Convert mmseg models to TorchServe `.mar` format." + ) + parser.add_argument("config", type=str, help="config file path") + parser.add_argument("checkpoint", type=str, help="checkpoint file path") parser.add_argument( - '--output-folder', + "--output-folder", type=str, required=True, - help='Folder where `{model_name}.mar` will be created.') + help="Folder where `{model_name}.mar` will be created.", + ) parser.add_argument( - '--model-name', + "--model-name", type=str, default=None, - help='If not None, used for naming the `{model_name}.mar`' - 'file that will be created under `output_folder`.' - 'If None, `{Path(checkpoint_file).stem}` will be used.') + help="If not None, used for naming the `{model_name}.mar`" + "file that will be created under `output_folder`." + "If None, `{Path(checkpoint_file).stem}` will be used.", + ) parser.add_argument( - '--model-version', - type=str, - default='1.0', - help='Number used for versioning.') + "--model-version", type=str, default="1.0", help="Number used for versioning." + ) parser.add_argument( - '-f', - '--force', - action='store_true', - help='overwrite the existing `{model_name}.mar`') + "-f", + "--force", + action="store_true", + help="overwrite the existing `{model_name}.mar`", + ) args = parser.parse_args() return args -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() if package_model is None: - raise ImportError('`torch-model-archiver` is required.' - 'Try: pip install torch-model-archiver') + raise ImportError( + "`torch-model-archiver` is required." + "Try: pip install torch-model-archiver" + ) - mmseg2torchserve(args.config, args.checkpoint, args.output_folder, - args.model_name, args.model_version, args.force) + mmseg2torchserve( + args.config, + args.checkpoint, + args.output_folder, + args.model_name, + args.model_version, + args.force, + ) diff --git a/mmsegmentation/tools/torchserve/mmseg_handler.py b/mmsegmentation/tools/torchserve/mmseg_handler.py index 28fe501..7195951 100644 --- a/mmsegmentation/tools/torchserve/mmseg_handler.py +++ b/mmsegmentation/tools/torchserve/mmseg_handler.py @@ -12,19 +12,20 @@ class MMsegHandler(BaseHandler): - def initialize(self, context): properties = context.system_properties - self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' - self.device = torch.device(self.map_location + ':' + - str(properties.get('gpu_id')) if torch.cuda. - is_available() else self.map_location) + self.map_location = "cuda" if torch.cuda.is_available() else "cpu" + self.device = torch.device( + self.map_location + ":" + str(properties.get("gpu_id")) + if torch.cuda.is_available() + else self.map_location + ) self.manifest = context.manifest - model_dir = properties.get('model_dir') - serialized_file = self.manifest['model']['serializedFile'] + model_dir = properties.get("model_dir") + serialized_file = self.manifest["model"]["serializedFile"] checkpoint = os.path.join(model_dir, serialized_file) - self.config_file = os.path.join(model_dir, 'config.py') + self.config_file = os.path.join(model_dir, "config.py") self.model = init_segmentor(self.config_file, checkpoint, self.device) self.model = revert_sync_batchnorm(self.model) @@ -34,7 +35,7 @@ def preprocess(self, data): images = [] for row in data: - image = row.get('data') or row.get('body') + image = row.get("data") or row.get("body") if isinstance(image, str): image = base64.b64decode(image) image = mmcv.imfrombytes(image) @@ -50,7 +51,7 @@ def postprocess(self, data): output = [] for image_result in data: - _, buffer = cv2.imencode('.png', image_result[0].astype('uint8')) + _, buffer = cv2.imencode(".png", image_result[0].astype("uint8")) content = buffer.tobytes() output.append(content) return output diff --git a/mmsegmentation/tools/torchserve/test_torchserve.py b/mmsegmentation/tools/torchserve/test_torchserve.py index 432834a..dd92756 100644 --- a/mmsegmentation/tools/torchserve/test_torchserve.py +++ b/mmsegmentation/tools/torchserve/test_torchserve.py @@ -11,37 +11,38 @@ def parse_args(): parser = ArgumentParser( - description='Compare result of torchserve and pytorch,' - 'and visualize them.') - parser.add_argument('img', help='Image file') - parser.add_argument('config', help='Config file') - parser.add_argument('checkpoint', help='Checkpoint file') - parser.add_argument('model_name', help='The model name in the server') + description="Compare result of torchserve and pytorch," "and visualize them." + ) + parser.add_argument("img", help="Image file") + parser.add_argument("config", help="Config file") + parser.add_argument("checkpoint", help="Checkpoint file") + parser.add_argument("model_name", help="The model name in the server") parser.add_argument( - '--inference-addr', - default='127.0.0.1:8080', - help='Address and port of the inference server') + "--inference-addr", + default="127.0.0.1:8080", + help="Address and port of the inference server", + ) parser.add_argument( - '--result-image', + "--result-image", type=str, default=None, - help='save server output in result-image') - parser.add_argument( - '--device', default='cuda:0', help='Device used for inference') + help="save server output in result-image", + ) + parser.add_argument("--device", default="cuda:0", help="Device used for inference") args = parser.parse_args() return args def main(args): - url = 'http://' + args.inference_addr + '/predictions/' + args.model_name - with open(args.img, 'rb') as image: + url = "http://" + args.inference_addr + "/predictions/" + args.model_name + with open(args.img, "rb") as image: tmp_res = requests.post(url, image) content = tmp_res.content if args.result_image: - with open(args.result_image, 'wb') as out_image: + with open(args.result_image, "wb") as out_image: out_image.write(content) - plt.imshow(mmcv.imread(args.result_image, 'grayscale')) + plt.imshow(mmcv.imread(args.result_image, "grayscale")) plt.show() else: plt.imshow(plt.imread(BytesIO(content))) @@ -53,6 +54,6 @@ def main(args): plt.show() -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() main(args) diff --git a/mmsegmentation/tools/train.py b/mmsegmentation/tools/train.py index c4219b0..cefc482 100644 --- a/mmsegmentation/tools/train.py +++ b/mmsegmentation/tools/train.py @@ -17,92 +17,101 @@ from mmseg.apis import init_random_seed, set_random_seed, train_segmentor from mmseg.datasets import build_dataset from mmseg.models import build_segmentor -from mmseg.utils import (collect_env, get_device, get_root_logger, - setup_multi_processes) +from mmseg.utils import collect_env, get_device, get_root_logger, setup_multi_processes def parse_args(): - parser = argparse.ArgumentParser(description='Train a segmentor') - parser.add_argument('config', help='train config file path') - parser.add_argument('--work-dir', help='the dir to save logs and models') + parser = argparse.ArgumentParser(description="Train a segmentor") + parser.add_argument("config", help="train config file path") + parser.add_argument("--work-dir", help="the dir to save logs and models") + parser.add_argument("--load-from", help="the checkpoint file to load weights from") + parser.add_argument("--resume-from", help="the checkpoint file to resume from") parser.add_argument( - '--load-from', help='the checkpoint file to load weights from') - parser.add_argument( - '--resume-from', help='the checkpoint file to resume from') - parser.add_argument( - '--no-validate', - action='store_true', - help='whether not to evaluate the checkpoint during training') + "--no-validate", + action="store_true", + help="whether not to evaluate the checkpoint during training", + ) group_gpus = parser.add_mutually_exclusive_group() group_gpus.add_argument( - '--gpus', + "--gpus", type=int, - help='(Deprecated, please use --gpu-id) number of gpus to use ' - '(only applicable to non-distributed training)') + help="(Deprecated, please use --gpu-id) number of gpus to use " + "(only applicable to non-distributed training)", + ) group_gpus.add_argument( - '--gpu-ids', + "--gpu-ids", type=int, - nargs='+', - help='(Deprecated, please use --gpu-id) ids of gpus to use ' - '(only applicable to non-distributed training)') + nargs="+", + help="(Deprecated, please use --gpu-id) ids of gpus to use " + "(only applicable to non-distributed training)", + ) group_gpus.add_argument( - '--gpu-id', + "--gpu-id", type=int, default=0, - help='id of gpu to use ' - '(only applicable to non-distributed training)') - parser.add_argument('--seed', type=int, default=None, help='random seed') + help="id of gpu to use " "(only applicable to non-distributed training)", + ) + parser.add_argument("--seed", type=int, default=None, help="random seed") parser.add_argument( - '--diff_seed', - action='store_true', - help='Whether or not set different seeds for different ranks') + "--diff_seed", + action="store_true", + help="Whether or not set different seeds for different ranks", + ) parser.add_argument( - '--deterministic', - action='store_true', - help='whether to set deterministic options for CUDNN backend.') + "--deterministic", + action="store_true", + help="whether to set deterministic options for CUDNN backend.", + ) parser.add_argument( - '--options', - nargs='+', + "--options", + nargs="+", action=DictAction, help="--options is deprecated in favor of --cfg_options' and it will " - 'not be supported in version v0.22.0. Override some settings in the ' - 'used config, the key-value pair in xxx=yyy format will be merged ' - 'into config file. If the value to be overwritten is a list, it ' + "not be supported in version v0.22.0. Override some settings in the " + "used config, the key-value pair in xxx=yyy format will be merged " + "into config file. If the value to be overwritten is a list, it " 'should be like key="[a,b]" or key=a,b It also allows nested ' 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' - 'marks are necessary and that no white space is allowed.') + "marks are necessary and that no white space is allowed.", + ) parser.add_argument( - '--cfg-options', - nargs='+', + "--cfg-options", + nargs="+", action=DictAction, - help='override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into config file. If the value to ' + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' - 'Note that the quotation marks are necessary and that no white space ' - 'is allowed.') + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) + "--launcher", + choices=["none", "pytorch", "slurm", "mpi"], + default="none", + help="job launcher", + ) + parser.add_argument("--local_rank", type=int, default=0) parser.add_argument( - '--auto-resume', - action='store_true', - help='resume from the latest checkpoint automatically.') + "--auto-resume", + action="store_true", + help="resume from the latest checkpoint automatically.", + ) args = parser.parse_args() - if 'LOCAL_RANK' not in os.environ: - os.environ['LOCAL_RANK'] = str(args.local_rank) + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(args.local_rank) if args.options and args.cfg_options: raise ValueError( - '--options and --cfg-options cannot be both ' - 'specified, --options is deprecated in favor of --cfg-options. ' - '--options will not be supported in version v0.22.0.') + "--options and --cfg-options cannot be both " + "specified, --options is deprecated in favor of --cfg-options. " + "--options will not be supported in version v0.22.0." + ) if args.options: - warnings.warn('--options is deprecated in favor of --cfg-options. ' - '--options will not be supported in version v0.22.0.') + warnings.warn( + "--options is deprecated in favor of --cfg-options. " + "--options will not be supported in version v0.22.0." + ) args.cfg_options = args.options return args @@ -116,39 +125,44 @@ def main(): cfg.merge_from_dict(args.cfg_options) # set cudnn_benchmark - if cfg.get('cudnn_benchmark', False): + if cfg.get("cudnn_benchmark", False): torch.backends.cudnn.benchmark = True # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None cfg.work_dir = args.work_dir - elif cfg.get('work_dir', None) is None: + elif cfg.get("work_dir", None) is None: # use config filename as default work_dir if cfg.work_dir is None - cfg.work_dir = osp.join('./work_dirs', - osp.splitext(osp.basename(args.config))[0]) + cfg.work_dir = osp.join( + "./work_dirs", osp.splitext(osp.basename(args.config))[0] + ) if args.load_from is not None: cfg.load_from = args.load_from if args.resume_from is not None: cfg.resume_from = args.resume_from if args.gpus is not None: cfg.gpu_ids = range(1) - warnings.warn('`--gpus` is deprecated because we only support ' - 'single GPU mode in non-distributed training. ' - 'Use `gpus=1` now.') + warnings.warn( + "`--gpus` is deprecated because we only support " + "single GPU mode in non-distributed training. " + "Use `gpus=1` now." + ) if args.gpu_ids is not None: cfg.gpu_ids = args.gpu_ids[0:1] - warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' - 'Because we only support single GPU mode in ' - 'non-distributed training. Use the first GPU ' - 'in `gpu_ids` now.') + warnings.warn( + "`--gpu-ids` is deprecated, please use `--gpu-id`. " + "Because we only support single GPU mode in " + "non-distributed training. Use the first GPU " + "in `gpu_ids` now." + ) if args.gpus is None and args.gpu_ids is None: cfg.gpu_ids = [args.gpu_id] cfg.auto_resume = args.auto_resume # init distributed env first, since logger depends on the dist info. - if args.launcher == 'none': + if args.launcher == "none": distributed = False else: distributed = True @@ -162,8 +176,8 @@ def main(): # dump config cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) # init the logger before other steps - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) - log_file = osp.join(cfg.work_dir, f'{timestamp}.log') + timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) + log_file = osp.join(cfg.work_dir, f"{timestamp}.log") logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) # set multi-process settings @@ -174,39 +188,37 @@ def main(): meta = dict() # log env info env_info_dict = collect_env() - env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) - dash_line = '-' * 60 + '\n' - logger.info('Environment info:\n' + dash_line + env_info + '\n' + - dash_line) - meta['env_info'] = env_info + env_info = "\n".join([f"{k}: {v}" for k, v in env_info_dict.items()]) + dash_line = "-" * 60 + "\n" + logger.info("Environment info:\n" + dash_line + env_info + "\n" + dash_line) + meta["env_info"] = env_info # log some basic info - logger.info(f'Distributed training: {distributed}') - logger.info(f'Config:\n{cfg.pretty_text}') + logger.info(f"Distributed training: {distributed}") + logger.info(f"Config:\n{cfg.pretty_text}") # set random seeds cfg.device = get_device() seed = init_random_seed(args.seed, device=cfg.device) seed = seed + dist.get_rank() if args.diff_seed else seed - logger.info(f'Set random seed to {seed}, ' - f'deterministic: {args.deterministic}') + logger.info(f"Set random seed to {seed}, " f"deterministic: {args.deterministic}") set_random_seed(seed, deterministic=args.deterministic) cfg.seed = seed - meta['seed'] = seed - meta['exp_name'] = osp.basename(args.config) + meta["seed"] = seed + meta["exp_name"] = osp.basename(args.config) model = build_segmentor( - cfg.model, - train_cfg=cfg.get('train_cfg'), - test_cfg=cfg.get('test_cfg')) + cfg.model, train_cfg=cfg.get("train_cfg"), test_cfg=cfg.get("test_cfg") + ) model.init_weights() # SyncBN is not support for DP if not distributed: warnings.warn( - 'SyncBN is only supported with DDP. To be compatible with DP, ' - 'we convert SyncBN to BN. Please use dist_train.sh which can ' - 'avoid this error.') + "SyncBN is only supported with DDP. To be compatible with DP, " + "we convert SyncBN to BN. Please use dist_train.sh which can " + "avoid this error." + ) model = revert_sync_batchnorm(model) logger.info(model) @@ -220,10 +232,11 @@ def main(): # save mmseg version, config file content and class names in # checkpoints as meta data cfg.checkpoint_config.meta = dict( - mmseg_version=f'{__version__}+{get_git_hash()[:7]}', + mmseg_version=f"{__version__}+{get_git_hash()[:7]}", config=cfg.pretty_text, CLASSES=datasets[0].CLASSES, - PALETTE=datasets[0].PALETTE) + PALETTE=datasets[0].PALETTE, + ) # add an attribute for visualization convenience model.CLASSES = datasets[0].CLASSES # passing checkpoint meta for saving best checkpoint @@ -235,8 +248,9 @@ def main(): distributed=distributed, validate=(not args.no_validate), timestamp=timestamp, - meta=meta) + meta=meta, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mmsegmentation/tools/train_jwseo.py b/mmsegmentation/tools/train_jwseo.py index 5264b3c..8f538c7 100644 --- a/mmsegmentation/tools/train_jwseo.py +++ b/mmsegmentation/tools/train_jwseo.py @@ -18,12 +18,13 @@ from mmcv.runner import get_dist_info, init_dist from mmcv.utils import Config, DictAction, get_git_hash from mmcv.utils.config import ConfigDict +from rich.console import Console + from mmseg import __version__ from mmseg.apis import init_random_seed, set_random_seed, train_segmentor from mmseg.datasets import build_dataset from mmseg.models import build_segmentor from mmseg.utils import collect_env, get_device, get_root_logger, setup_multi_processes -from rich.console import Console KST_TZ = pytz.timezone("Asia/Seoul") GPU_ID = 0 @@ -48,7 +49,7 @@ def get_latest_checkpoint(work_dir: Path) -> Union[str, None]: if not latest_checkpoint_file.exists(): return None - with open(latest_checkpoint_file, "r", encoding="utf8") as f: + with open(latest_checkpoint_file, encoding="utf8") as f: checkpoint_path = f.read() return checkpoint_path diff --git a/src/inference.py b/src/inference.py index 6b6fc5c..fd5647b 100644 --- a/src/inference.py +++ b/src/inference.py @@ -5,10 +5,9 @@ import numpy as np import pandas as pd import torch -from tqdm import tqdm - from dataset import make_dataloader from network import define_network +from tqdm import tqdm warnings.filterwarnings("ignore") @@ -47,7 +46,6 @@ def test(model, test_loader, device): with torch.no_grad(): for step, (imgs, image_infos) in enumerate(tqdm(test_loader)): - # inference (512 x 512) outs = model(torch.stack(imgs).to(device))["out"] oms = torch.argmax(outs.squeeze(), dim=1).detach().cpu().numpy() diff --git a/src/train.py b/src/train.py index e66340e..d06311f 100644 --- a/src/train.py +++ b/src/train.py @@ -4,10 +4,9 @@ import numpy as np import torch -from tqdm import tqdm - from dataset import make_dataloader from network import define_network +from tqdm import tqdm from utils import add_hist, label_accuracy_score warnings.filterwarnings("ignore") @@ -70,7 +69,6 @@ def validation(epoch, model, data_loader, criterion, device): hist = np.zeros((n_class, n_class)) for step, (images, masks, _) in enumerate(data_loader): - images = torch.stack(images) masks = torch.stack(masks).long() diff --git a/src/utils/stratified_kfold.py b/src/utils/stratified_kfold.py index 15e81af..d3b38fd 100644 --- a/src/utils/stratified_kfold.py +++ b/src/utils/stratified_kfold.py @@ -95,7 +95,6 @@ def main(args): def update_dataset(index, mode, input_json, output_dir): - with open(input_json) as file: data = json.load(file)