Skip to content

Commit

Permalink
add lvis support (open-mmlab#2088)
Browse files Browse the repository at this point in the history
* add lvis dataset

* fixed eval

* fixed test cfg

* add resnext config

* update md

* fixed name

* update model urls

* minor fix

* fixed typo

* use open-mmlab lvis

* update travis

* fixed install

* make class balance as default
  • Loading branch information
xvjiarui authored Jun 5, 2020
1 parent 206107e commit 8fc0542
Show file tree
Hide file tree
Showing 12 changed files with 556 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ install:
- pip install Pillow==6.2.2 # remove this line when torchvision>=0.5
- pip install torch==${TORCH} torchvision==${TORCHVISION}
- pip install mmcv-nightly
- pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI"
- pip install "git+https://github.com/open-mmlab/cocoapi.git#subdirectory=PythonAPI"
- pip install -r requirements.txt

before_script:
Expand Down
22 changes: 22 additions & 0 deletions configs/_base_/datasets/lvis_instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
_base_ = 'coco_instance.py'
dataset_type = 'LVISDataset'
data_root = 'data/lvis/'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type='ClassBalancedDataset',
oversample_thr=1e-3,
dataset=dict(
type=dataset_type,
ann_file=data_root + 'annotations/lvis_v0.5_train.json',
img_prefix=data_root + 'train2017/')),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/lvis_v0.5_val.json',
img_prefix=data_root + 'val2017/'),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/lvis_v0.5_val.json',
img_prefix=data_root + 'val2017/'))
evaluation = dict(metric=['bbox', 'segm'])
24 changes: 24 additions & 0 deletions configs/lvis/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# LVIS dataset

## Introduction
```
@inproceedings{gupta2019lvis,
title={{LVIS}: A Dataset for Large Vocabulary Instance Segmentation},
author={Gupta, Agrim and Dollar, Piotr and Girshick, Ross},
booktitle={Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition},
year={2019}
}
```

## Common Setting
* All experiments use oversample strategy [here](../../docs/tutorials/new_dataset.md#class-balanced-dataset) with oversample threshold `1e-3`.
* The size of LVIS v0.5 is half of COCO, so schedule `2x` in LVIS is roughly the same iterations as `1x` in COCO.

## Results and models

| Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | box AP | mask AP | Download |
| :-------------: | :-----: | :-----: | :------: | :------------: | :----: | :-----: | :------: |
| R-50-FPN | pytorch | 2x | - | - | 26.1 | 25.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis-dbd06831.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis_20200531_160435.log.json) |
| R-101-FPN | pytorch | 2x | - | - | 27.1 | 27.0 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_lvis-54582ee2.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_lvis_20200601_134748.log.json) |
| X-101-32x4d-FPN | pytorch | 2x | - | - | 26.7 | 26.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_lvis-3cf55ea2.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_x101_32x4d_fpn_sample1e-3_mstrain_2x_lvis_20200531_221749.log.json) |
| X-101-64x4d-FPN | pytorch | 2x | - | - | 26.4 | 26.0 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_lvis-1c99a5ad.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_lvis/mask_rcnn_x101_64x4d_fpn_sample1e-3_mstrain_2x_lvis_20200601_194651.log.json) |
2 changes: 2 additions & 0 deletions configs/lvis/mask_rcnn_r101_fpn_sample1e-3_mstrain_2x_lvis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = './mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py'
model = dict(pretrained='torchvision://resnet101', backbone=dict(depth=101))
31 changes: 31 additions & 0 deletions configs/lvis/mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
_base_ = [
'../_base_/models/mask_rcnn_r50_fpn.py',
'../_base_/datasets/lvis_instance.py',
'../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py'
]
model = dict(
roi_head=dict(
bbox_head=dict(num_classes=1230), mask_head=dict(num_classes=1230)))
test_cfg = dict(
rcnn=dict(
score_thr=0.0001,
# LVIS allows up to 300
max_per_img=300))
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='Resize',
img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736),
(1333, 768), (1333, 800)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
data = dict(train=dict(dataset=dict(pipeline=train_pipeline)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_base_ = './mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py'
model = dict(
pretrained='open-mmlab://resnext101_32x4d',
backbone=dict(
type='ResNeXt',
depth=101,
groups=32,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
style='pytorch'))
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_base_ = './mask_rcnn_r50_fpn_sample1e-3_mstrain_2x_lvis.py'
model = dict(
pretrained='open-mmlab://resnext101_64x4d',
backbone=dict(
type='ResNeXt',
depth=101,
groups=64,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
style='pytorch'))
7 changes: 4 additions & 3 deletions docs/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,12 @@ cd mmdetection
```

d. Install build requirements and then install mmdetection.
(We install pycocotools via the github repo instead of pypi because the pypi version is old and not compatible with the latest numpy.)
(We install our forked version of pycocotools via the github repo instead of pypi
for better compatibility with our repo.)

```shell
pip install -r requirements/build.txt
pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI"
pip install "git+https://github.com/open-mmlab/cocoapi.git#subdirectory=PythonAPI"
pip install -v -e . # or "python setup.py develop"
```

Expand Down Expand Up @@ -130,7 +131,7 @@ conda install -c pytorch pytorch torchvision -y
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
pip install -r requirements/build.txt
pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI"
pip install "git+https://github.com/open-mmlab/cocoapi.git#subdirectory=PythonAPI"
pip install -v -e .
```

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/new_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ dataset_A_train = dict(
)
```

### Repeat factor dataset
### Class balanced dataset

We use `ClassBalancedDataset` as wrapper to repeat the dataset based on category
frequency. The dataset to repeat needs to instantiate function `self.get_cat_ids(idx)`
Expand Down
9 changes: 5 additions & 4 deletions mmdet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
from .custom import CustomDataset
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
RepeatDataset)
from .lvis import LVISDataset
from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler
from .voc import VOCDataset
from .wider_face import WIDERFaceDataset
from .xml_style import XMLDataset

__all__ = [
'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset',
'CityscapesDataset', 'GroupSampler', 'DistributedGroupSampler',
'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
'ClassBalancedDataset', 'WIDERFaceDataset', 'DATASETS', 'PIPELINES',
'build_dataset'
'CityscapesDataset', 'LVISDataset', 'GroupSampler',
'DistributedGroupSampler', 'DistributedSampler', 'build_dataloader',
'ConcatDataset', 'RepeatDataset', 'ClassBalancedDataset',
'WIDERFaceDataset', 'DATASETS', 'PIPELINES', 'build_dataset'
]
22 changes: 11 additions & 11 deletions mmdet/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,26 @@ class CocoDataset(CustomDataset):

def load_annotations(self, ann_file):
self.coco = COCO(ann_file)
self.cat_ids = self.coco.getCatIds(catNms=self.CLASSES)
self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.img_ids = self.coco.getImgIds()
self.img_ids = self.coco.get_img_ids()
data_infos = []
for i in self.img_ids:
info = self.coco.loadImgs([i])[0]
info = self.coco.load_imgs([i])[0]
info['filename'] = info['file_name']
data_infos.append(info)
return data_infos

def get_ann_info(self, idx):
img_id = self.data_infos[idx]['id']
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
ann_info = self.coco.loadAnns(ann_ids)
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
ann_info = self.coco.load_anns(ann_ids)
return self._parse_ann_info(self.data_infos[idx], ann_info)

def get_cat_ids(self, idx):
img_id = self.data_infos[idx]['id']
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
ann_info = self.coco.loadAnns(ann_ids)
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
ann_info = self.coco.load_anns(ann_ids)
return [ann['category_id'] for ann in ann_info]

def _filter_imgs(self, min_size=32):
Expand Down Expand Up @@ -83,12 +83,12 @@ def get_subset_by_classes(self):

ids = set()
for i, class_id in enumerate(self.cat_ids):
ids |= set(self.coco.catToImgs[class_id])
ids |= set(self.coco.cat_img_map[class_id])
self.img_ids = list(ids)

data_infos = []
for i in self.img_ids:
info = self.coco.loadImgs([i])[0]
info = self.coco.load_imgs([i])[0]
info['filename'] = info['file_name']
data_infos.append(info)
return data_infos
Expand Down Expand Up @@ -268,8 +268,8 @@ def results2json(self, results, outfile_prefix):
def fast_eval_recall(self, results, proposal_nums, iou_thrs, logger=None):
gt_bboxes = []
for i in range(len(self.img_ids)):
ann_ids = self.coco.getAnnIds(imgIds=self.img_ids[i])
ann_info = self.coco.loadAnns(ann_ids)
ann_ids = self.coco.get_ann_ids(img_ids=self.img_ids[i])
ann_info = self.coco.load_anns(ann_ids)
if len(ann_info) == 0:
gt_bboxes.append(np.zeros((0, 4)))
continue
Expand Down
Loading

0 comments on commit 8fc0542

Please sign in to comment.