-
Notifications
You must be signed in to change notification settings - Fork 0
/
detect.py
103 lines (79 loc) · 3.79 KB
/
detect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
from SSD_VGG16D.networks import AuxiliaryNetwork, PredictionNetwork, VGG16DBaseNetwork, DetectionNetwork
from SSD_VGG16D import SSD256
import torch
from torchvision import transforms
from SSD_VGG16D.utils import *
from PIL import Image, ImageDraw, ImageFont
import sys
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ap = argparse.ArgumentParser()
ap.add_argument("--trained_model", default="models/ssd256-24b-6w.pth.tar", type=str,
help="Trained state_dict file path to open")
ap.add_argument("--input", type=str, help="Input path for detect")
ap.add_argument("--output", type=str, help="Output path to save")
ap.add_argument("--min_score", default=0.4, type=float,
help="Min score for detect")
ap.add_argument("--max_overlap", default=0.45,
type=float, help="Max overlap for NMS")
ap.add_argument("--top_k", default=200, type=int, help="Top k for NMS")
args = ap.parse_args()
img_path = args.input
trained_model = torch.load(args.trained_model)
ouput_path = args.output
start_epoch = trained_model["epoch"] + 1
print('\nLoaded model trained with epoch %d.\n' % start_epoch)
model = trained_model['model']
model = model.to(device)
model.eval()
resize = transforms.Resize((256, 256))
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def detect(model, device, original_image, min_score, max_overlap, top_k, suppress=None):
image = normalize(to_tensor(resize(original_image)))
image = image.to(device)
locs_pred, cls_pred = model(image.unsqueeze(0))
# print("locs_pred: ", locs_pred)
detect_boxes, detect_labels, detect_scores = model.detect(locs_pred, cls_pred,
min_score, max_overlap, top_k)
detect_boxes = detect_boxes[0].to('cpu')
original_dims = torch.FloatTensor(
[original_image.width, original_image.height, original_image.width,
original_image.height]).unsqueeze(0)
detect_boxes = detect_boxes * original_dims
detect_labels = [rev_label_map[l]
for l in detect_labels[0].to('cpu').tolist()]
if detect_labels == ["background"]:
return original_image
annotated_image = original_image
draw = ImageDraw.Draw(annotated_image)
font = ImageFont.truetype("arial.ttf", 15)
for i in range(detect_boxes.size(0)):
if suppress is not None:
if detect_labels[i] in suppress:
continue
box_location = detect_boxes[i].tolist()
draw.rectangle(xy=box_location,
outline=label_color_map[detect_labels[i]])
draw.rectangle(xy=[l + 1. for l in box_location],
outline=label_color_map[detect_labels[i]])
text = ""
if detect_scores is not None:
text = f"{detect_labels[i].upper()} {detect_scores[i][0]:.2f}"
text_size = font.getsize(text)
text_location = [box_location[0] + 2., box_location[1] - text_size[1]]
textbox_location = [box_location[0], box_location[1] - text_size[1], box_location[0] + text_size[0] + 4.,
box_location[1]]
draw.rectangle(xy=textbox_location,
fill=label_color_map[detect_labels[i]])
draw.text(xy=text_location,
text=text, fill='white', font=font)
return annotated_image
if __name__ == '__main__':
original_image = Image.open(img_path, mode='r')
original_image = original_image.convert('RGB')
annotated_image = detect(model, device, original_image, min_score=args.min_score,
max_overlap=args.max_overlap, top_k=args.top_k)
annotated_image.save(ouput_path)
annotated_image.show()