forked from rosinality/fcos-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
executable file
·119 lines (77 loc) · 3.1 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import torch
from torchvision import datasets
from boxlist import BoxList
def has_only_empty_bbox(annot):
return all(any(o <= 1 for o in obj['bbox'][2:]) for obj in annot)
def has_valid_annotation(annot):
if len(annot) == 0:
return False
if has_only_empty_bbox(annot):
return False
return True
class COCODataset(datasets.CocoDetection):
def __init__(self, path, split, transform=None):
root = os.path.join(path, f'{split}2017')
annot = os.path.join(path, 'annotations', f'instances_{split}2017.json')
super().__init__(root, annot)
self.ids = sorted(self.ids)
if split == 'train':
ids = []
for id in self.ids:
ann_ids = self.coco.getAnnIds(imgIds=id, iscrowd=None)
annot = self.coco.loadAnns(ann_ids)
if has_valid_annotation(annot):
ids.append(id)
self.ids = ids
self.category2id = {v: i + 1 for i, v in enumerate(self.coco.getCatIds())}
self.id2category = {v: k for k, v in self.category2id.items()}
self.id2img = {k: v for k, v in enumerate(self.ids)}
self.transform = transform
def __getitem__(self, index):
img, annot = super().__getitem__(index)
annot = [o for o in annot if o['iscrowd'] == 0]
boxes = [o['bbox'] for o in annot]
boxes = torch.as_tensor(boxes).reshape(-1, 4)
target = BoxList(boxes, img.size, mode='xywh').convert('xyxy')
classes = [o['category_id'] for o in annot]
classes = [self.category2id[c] for c in classes]
classes = torch.tensor(classes)
target.fields['labels'] = classes
target.clip(remove_empty=True)
if self.transform is not None:
img, target = self.transform(img, target)
return img, target, index
def get_image_meta(self, index):
id = self.id2img[index]
img_data = self.coco.imgs[id]
return img_data
class ImageList:
def __init__(self, tensors, sizes):
self.tensors = tensors
self.sizes = sizes
def to(self, *args, **kwargs):
tensor = self.tensors.to(*args, **kwargs)
return ImageList(tensor, self.sizes)
def image_list(tensors, size_divisible=0):
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
if size_divisible > 0:
stride = size_divisible
max_size = list(max_size)
max_size[1] = (max_size[1] | (stride - 1)) + 1
max_size[2] = (max_size[2] | (stride - 1)) + 1
max_size = tuple(max_size)
shape = (len(tensors),) + max_size
batch = tensors[0].new(*shape).zero_()
for img, pad_img in zip(tensors, batch):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
sizes = [img.shape[-2:] for img in tensors]
return ImageList(batch, sizes)
def collate_fn(config):
def collate_data(batch):
batch = list(zip(*batch))
imgs = image_list(batch[0], config.size_divisible)
targets = batch[1]
ids = batch[2]
return imgs, targets, ids
return collate_data