-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
70 lines (54 loc) · 2.38 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import cv2
import config
import numpy as np
def load_checkpoint(checkpoint, model, optimizer, lr):
print("=> Loading checkpoint")
model.load_state_dict(checkpoint["state_dict"])
# optimizer.load_state_dict(checkpoint["optimizer"])
# If we don't do this then it will just have learning rate of old checkpoint
# and it will lead to many hours of debugging \:
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def save_checkpoint(state, filename="mask_rcnn.pth.tar"):
print("=> Saving checkpoint")
torch.save(state, filename)
def predict_single_frame(frame, model):
images = cv2.resize(frame, config.IMAGE_SIZE, cv2.INTER_LINEAR)/255.
images = torch.as_tensor(images, dtype=torch.float32).unsqueeze(0)
images = images.swapaxes(1, 3).swapaxes(2, 3)
images = list(image.to(config.DEVICE) for image in images)
with torch.no_grad():
pred = model(images)
im = images[0].swapaxes(0, 2).swapaxes(
0, 1).detach().cpu().numpy().astype(np.float32)
im2 = np.zeros_like(im).astype(np.float32)
for i in range(len(pred[0]['masks'])):
msk = pred[0]['masks'][i, 0].detach().cpu().numpy()
scr = pred[0]['scores'][i].detach().cpu().numpy()
box = pred[0]['boxes'][i].detach().cpu().numpy()
if scr > 0.9:
cv2.rectangle(im, (int(box[0]), int(box[1])),
(int(box[2]), int(box[3])), (0, 0, 1), 2)
cv2.putText(im, "{0:.2f}%".format(scr*100), (int(box[0]+5), int(box[1])+15), cv2.FONT_HERSHEY_SIMPLEX,
0.5, (0, 0, 1), 2, cv2.LINE_AA)
im2[:, :, 0][msk > 0.8] = np.random.uniform(0,1)
im2[:, :, 1][msk > 0.8] = np.random.uniform(0,1)
im2[:, :, 2][msk > 0.8] = np.random.uniform(0,1)
return (cv2.addWeighted(im, 0.8, im2, 0.2,0)*255).astype(np.uint8)
def predict_video(input, output, model):
cap = cv2.VideoCapture(input)
out = cv2.VideoWriter(output, cv2.VideoWriter_fourcc(
'M', 'P', '4', 'V'), 60, (1152, 648))
model.train(False)
if (cap.isOpened() == False):
print("Error opening video stream or file")
while (cap.isOpened()):
ret, frame = cap.read()
if ret == True:
result_frame = predict_single_frame(frame, model)
out.write(result_frame)
else:
break
cap.release()
out.release()