-
Notifications
You must be signed in to change notification settings - Fork 185
/
DOTA.py
121 lines (113 loc) · 4.09 KB
/
DOTA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#The code is used for visulization, inspired from cocoapi
# Licensed under the Simplified BSD License [see bsd.txt]
import os
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon, Circle
import numpy as np
import dota_utils as util
from collections import defaultdict
import cv2
def _isArrayLike(obj):
if type(obj) == str:
return False
return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
class DOTA:
def __init__(self, basepath):
self.basepath = basepath
self.labelpath = os.path.join(basepath, 'labelTxt')
self.imagepath = os.path.join(basepath, 'images')
self.imgpaths = util.GetFileFromThisRootDir(self.labelpath)
self.imglist = [util.custombasename(x) for x in self.imgpaths]
self.catToImgs = defaultdict(list)
self.ImgToAnns = defaultdict(list)
self.createIndex()
def createIndex(self):
for filename in self.imgpaths:
objects = util.parse_dota_poly(filename)
imgid = util.custombasename(filename)
self.ImgToAnns[imgid] = objects
for obj in objects:
cat = obj['name']
self.catToImgs[cat].append(imgid)
def getImgIds(self, catNms=[]):
"""
:param catNms: category names
:return: all the image ids contain the categories
"""
catNms = catNms if _isArrayLike(catNms) else [catNms]
if len(catNms) == 0:
return self.imglist
else:
imgids = []
for i, cat in enumerate(catNms):
if i == 0:
imgids = set(self.catToImgs[cat])
else:
imgids &= set(self.catToImgs[cat])
return list(imgids)
def loadAnns(self, catNms=[], imgId = None, difficult=None):
"""
:param catNms: category names
:param imgId: the img to load anns
:return: objects
"""
catNms = catNms if _isArrayLike(catNms) else [catNms]
objects = self.ImgToAnns[imgId]
if len(catNms) == 0:
return objects
outobjects = [obj for obj in objects if (obj['name'] in catNms)]
return outobjects
def showAnns(self, objects, imgId, range):
"""
:param catNms: category names
:param objects: objects to show
:param imgId: img to show
:param range: display range in the img
:return:
"""
img = self.loadImgs(imgId)[0]
plt.imshow(img)
plt.axis('off')
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
circles = []
r = 5
for obj in objects:
c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
poly = obj['poly']
polygons.append(Polygon(poly))
color.append(c)
point = poly[0]
circle = Circle((point[0], point[1]), r)
circles.append(circle)
p = PatchCollection(polygons, facecolors=color, linewidths=0, alpha=0.4)
ax.add_collection(p)
p = PatchCollection(polygons, facecolors='none', edgecolors=color, linewidths=2)
ax.add_collection(p)
p = PatchCollection(circles, facecolors='red')
ax.add_collection(p)
def loadImgs(self, imgids=[]):
"""
:param imgids: integer ids specifying img
:return: loaded img objects
"""
print('isarralike:', _isArrayLike(imgids))
imgids = imgids if _isArrayLike(imgids) else [imgids]
print('imgids:', imgids)
imgs = []
for imgid in imgids:
filename = os.path.join(self.imagepath, imgid + '.png')
print('filename:', filename)
img = cv2.imread(filename)
imgs.append(img)
return imgs
# if __name__ == '__main__':
# examplesplit = DOTA('examplesplit')
# imgids = examplesplit.getImgIds(catNms=['plane'])
# img = examplesplit.loadImgs(imgids)
# for imgid in imgids:
# anns = examplesplit.loadAnns(imgId=imgid)
# examplesplit.showAnns(anns, imgid, 2)