-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
118 lines (88 loc) · 2.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import cv2
import serial.tools.list_ports
from PIL import Image
import torch
from torchvision import transforms
from model import resnet18
import time
from com import send
from pynput import keyboard
from pynput.keyboard import Key, Controller
import json
import threading
cap = cv2.VideoCapture(0)
cap.set(3, 640)
cap.set(4, 480)
global frame
def cap0():
while 1:
global frame
_, frame = cap.read()
'''cv2.imshow("capture", frame)'''
key = cv2.waitKey(1) & 0xFF
if key == ord('q'):
break
capture = threading.Thread(target=cap0)
capture.daemon = 1
keyboard0 = Controller()
isEnd = False
with open('./class_indices.json', "r") as f:
class_indict = json.load(f)
def keyboard_on_release(key):
global isEnd
if key == keyboard.Key.esc:
isEnd = True
return False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("{} is in use".format(device))
data_transform = transforms.Compose(
[transforms.CenterCrop([200, 200]),
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
net = resnet18(num_classes=7).to(device)
net.load_state_dict(torch.load("weight/mms.pth", map_location=device))
net.eval()
device_exist = 0
capture.start()
stopper = keyboard.Listener(on_release=keyboard_on_release)
while 1:
with keyboard.Listener(
on_release=keyboard_on_release) as starter:
starter.join()
stopper = keyboard.Listener(on_release=keyboard_on_release)
stopper.start()
isEnd = False
start_time = time.time()
x = 1 # displays the frame rate every 1 second
counter = 0
while 1:
if device_exist == 0:
import serial.tools.list_ports
port_list = list(serial.tools.list_ports.comports())
port_name = "COM5"
for i in range(0, len(port_list)):
if port_name in port_list[i].description:
try:
serial = serial.Serial(port_list[i].device, write_timeout=1)
device_exist = 1
except Exception:
pass
frame_PIL = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
image = torch.unsqueeze(data_transform(frame_PIL), dim=0)
with torch.no_grad():
output = torch.squeeze(net(image.to(device))).cpu()
predict = torch.argmax(output).numpy()
device_exist = send(serial, 0, predict)
counter += 1
if (time.time() - start_time) >= x:
print(device_exist)
print("FPS: %.3f" % (counter / (time.time() - start_time)))
print(class_indict[str(predict)])
print()
counter = 0
start_time = time.time()
if isEnd:
if device_exist:
send(serial, 1, predict)
break