-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
58 lines (45 loc) · 1.69 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from ultralytics import YOLO
import utils
yolo_model = 'model/yolov8n.pt'
custom_yaml = 'dataset/dataset.yaml'
custom_model = 'model/blue_focus.pt'
# model.save我用起来一直有bug,因此训练好模型后需要手动将
# runs\detect\trainN\weights\best.pt复制到上述路径
@utils.timer
def train() -> YOLO:
model = YOLO(yolo_model)
model.train(data=custom_yaml, epochs=100)
model.val(split='test')
return model
@utils.timer
def load(path) -> YOLO:
model = YOLO(path)
return model
if __name__ == '__main__':
#model = train()
model = load(custom_model)
@utils.timer
def predict(model, image, target_class, confidence_threshold):
results = model.predict(image) # , save=True)
reverse_dict = {v: k for k, v in model.names.items()}
target_class_index = reverse_dict.get(target_class)
detected = False
for result in results:
for box in result.boxes:
cls_index = int(box.cls.item())
confidence = float(box.conf.item())
if cls_index == target_class_index and confidence >= confidence_threshold:
detected = True
break
if detected:
break
return detected
target_class = 'blue_focus'
confidence_threshold = 0.25
# 好像第一次预测会慢,后面就很快了
image = r'D:\图片\steam\永劫无间\1203220_20240522144157_1.png'
result = predict(model, image, target_class, confidence_threshold)
print(result)
image = r'D:\图片\steam\永劫无间\1203220_20240522144156_2.png'
result = predict(model, image, target_class, confidence_threshold)
print(result)