forked from mit-han-lab/gan-compression
-
Notifications
You must be signed in to change notification settings - Fork 0
/
coco_generate_instance_map.py
56 lines (48 loc) · 2.42 KB
/
coco_generate_instance_map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import argparse
import os
import numpy as np
import skimage.io as io
import tqdm
from pycocotools.coco import COCO
from skimage.draw import polygon
parser = argparse.ArgumentParser()
parser.add_argument('--annotation_file', type=str, default="./annotations/instances_train2017.json",
help="Path to the annocation file. It can be downloaded at http://images.cocodataset.org/annotations/annotations_trainval2017.zip. Should be either instances_train2017.json or instances_val2017.json")
parser.add_argument('--input_label_dir', type=str, default="./train_label/",
help="Path to the directory containing label maps. It can be downloaded at http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip")
parser.add_argument('--output_instance_dir', type=str, default="./train_inst/",
help="Path to the output directory of instance maps")
opt = parser.parse_args()
os.makedirs(opt.output_instance_dir, exist_ok=True)
print("annotation file at {}".format(opt.annotation_file))
print("input label maps at {}".format(opt.input_label_dir))
print("output dir at {}".format(opt.output_instance_dir))
# initialize COCO api for instance annotations
coco = COCO(opt.annotation_file)
# display COCO categories and supercategories
cats = coco.loadCats(coco.getCatIds())
imgIds = coco.getImgIds(catIds=coco.getCatIds(cats))
for ix, id in enumerate(tqdm.tqdm(imgIds)):
# if ix % 50 == 0:
# print("{} / {}".format(ix, len(imgIds)))
img_dict = coco.loadImgs(id)[0]
filename = img_dict["file_name"].replace("jpg", "png")
label_name = os.path.join(opt.input_label_dir, filename)
inst_name = os.path.join(opt.output_instance_dir, filename)
img = io.imread(label_name, as_gray=True)
annIds = coco.getAnnIds(imgIds=id, catIds=[], iscrowd=None)
anns = coco.loadAnns(annIds)
count = 0
for ann in anns:
if type(ann["segmentation"]) == list:
if "segmentation" in ann:
for seg in ann["segmentation"]:
poly = np.array(seg).reshape((int(len(seg) / 2), 2))
rr, cc = polygon(poly[:, 1] - 1, poly[:, 0] - 1)
img[rr, cc] = count
count += 1
io.imsave(inst_name, img)