-
Notifications
You must be signed in to change notification settings - Fork 220
/
Copy pathpredict.py
39 lines (31 loc) · 1.16 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import numpy as np
import extract_feats.opensmile as of
import extract_feats.librosa as lf
import models
import utils
def predict(config, audio_path: str, model) -> None:
"""
预测音频情感
Args:
config: 配置项
audio_path (str): 要预测的音频路径
model: 加载的模型
"""
# utils.play_audio(audio_path)
if config.feature_method == 'o':
# 一个玄学 bug 的暂时性解决方案
of.get_data(config, audio_path, train=False)
test_feature = of.load_feature(config, train=False)
elif config.feature_method == 'l':
test_feature = lf.get_data(config, audio_path, train=False)
result = model.predict(test_feature)
result_prob = model.predict_proba(test_feature)
print('Recogntion: ', config.class_labels[int(result)])
print('Probability: ', result_prob)
utils.radar(result_prob, config.class_labels)
if __name__ == '__main__':
audio_path = '/Users/zou/Renovamen/Developing/Speech-Emotion-Recognition/datasets/CASIA/angry/201-angry-liuchanhg.wav'
config = utils.parse_opt()
model = models.load(config)
predict(config, audio_path, model)