forked from 7788boy/CS_T0828_HW2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Test.py
47 lines (39 loc) · 1.43 KB
/
Test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import cv2
from tqdm import tqdm
import json
import torch
from torchvision.transforms import ToTensor
from detectron2.modeling import build_model
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
# Path
input_path = 'dataset/test/'
checkpoint_path = 'checkpoints/model_0244999.pth'
if __name__ == "__main__":
cfg = get_cfg()
model = build_model(cfg)
DetectionCheckpointer(model).load(checkpoint_path)
model.eval()
transform = ToTensor()
result = []
for index in tqdm(range(13068)):
img_name = str(index + 1) + '.png'
img = cv2.imread(input_path + img_name)
img = transform(img)
with torch.no_grad():
predict = model([{'image':img[(2, 1, 0)]}])
instance = predict[0]['instances']
bboxes = instance.get_fields()['pred_boxes'].tensor
scores = [int(s) for s in instance.get_fields()['scores']]
labels = [int(s) for s in instance.get_fields()['pred_classes']]
box_list = []
for index, box in enumerate(bboxes):
x1 = int(box[0])
y1 = int(box[1])
x2 = int(box[2])
y2 = int(box[3])
bbox = (y1, x1, y2, x2)
box_list.append(bbox)
result.append({'bbox': box_list, 'score': scores, 'label': labels})
with open('result.json', 'w') as output_file:
json.dump(result, output_file)