-
Notifications
You must be signed in to change notification settings - Fork 336
/
Copy pathtransformer-cnn-emotion-recognition.py
167 lines (123 loc) · 4.19 KB
/
transformer-cnn-emotion-recognition.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import time
import sys
import numpy as np
import librosa
import ailia
# import original modules
sys.path.append('../../util')
from arg_utils import get_base_parser, update_parser # noqa: E402
from model_utils import check_and_download_models # noqa: E402
# logger
from logging import getLogger # noqa: E402
logger = getLogger(__name__)
# ======================
# PARAMETERS
# ======================
# https://smartlaboratory.org/ravdess/
WAVE_PATH = "03-01-01-01-01-01-01.wav"
WEIGHT_PATH = "parallel_is_all_you_want_ep428.onnx"
MODEL_PATH = "parallel_is_all_you_want_ep428.onnx.prototxt"
REMOTE_PATH = "https://storage.googleapis.com/ailia-models/parallel_is_all_you_want/"
LABELS = [
"surprised", "neutral", "calm", "happy",
"sad", "angry", "fearful", "disgust",
]
# ======================
# Arguemnt Parser Config
# ======================
parser = get_base_parser(
'Parallel_is_All_You_Want.', WAVE_PATH, None, input_ftype='audio')
args = update_parser(parser)
# ======================
# Utils
# ======================
def get_waveforms(file, sample_rate=48000):
# load an individual sample audio file
# read the full 3 seconds of the file, cut off the first 0.5s of silence; native sample rate = 48k
# don't need to store the sample rate that librosa.load returns
waveform, _ = librosa.load(file, duration=3, offset=0.5, sr=sample_rate)
# make sure waveform vectors are homogenous by defining explicitly
waveform_homo = np.zeros((int(sample_rate * 3, )))
waveform_homo[:len(waveform)] = waveform
# return a single file's waveform
return waveform_homo
def feature_mfcc(
waveform,
sample_rate,
n_mfcc=40,
fft=1024,
winlen=512,
window='hamming',
# hop=256, # increases # of time steps; was not helpful
mels=128):
# Compute the MFCCs for all STFT frames
# 40 mel filterbanks (n_mfcc) = 40 coefficients
mfc_coefficients = librosa.feature.mfcc(
y=waveform,
sr=sample_rate,
n_mfcc=n_mfcc,
n_fft=fft,
win_length=winlen,
window=window,
# hop_length=hop,
n_mels=mels,
fmax=sample_rate / 2
)
return mfc_coefficients
def preprocess(x):
c, h, w = x.shape
x = x.reshape(-1)
mean = np.load('std_mean.npy')
var = np.load('std_var.npy')
x = (x - mean) / np.sqrt(var)
x = x.reshape((c, h, w))
return x
# ======================
# Main function
# ======================
def predict(net, waveform):
sample_rate = 48000
features = feature_mfcc(waveform, sample_rate)
features = np.expand_dims(features, axis=0)
x = preprocess(features)
# feedforward
x = np.expand_dims(x, axis=0)
output = net.predict([x])
output_logits, output_softmax = output
output_softmax = output_softmax[0]
label = np.argmax(output_softmax)
conf = output_softmax[label]
return label, conf
def recognize_from_audio(net):
for input_path in args.input:
logger.info(f'input: {input_path}')
# load audio
waveform = get_waveforms(input_path)
# inference
logger.info('Start inference...')
if args.benchmark:
logger.info('BENCHMARK mode')
total_time_estimation = 0
for i in range(args.benchmark_count):
start = int(round(time.time() * 1000))
label, conf = predict(net, waveform)
end = int(round(time.time() * 1000))
estimation_time = (end - start)
# Loggin
logger.info(f'\tailia processing estimation time {estimation_time} ms')
if i != 0:
total_time_estimation = total_time_estimation + estimation_time
else:
label, conf = predict(net, waveform)
label = LABELS[label]
logger.info("Emotion: %s" % label)
logger.info("Confidence: %s" % conf)
logger.info('Script finished successfully.')
def main():
# model files check and download
check_and_download_models(WEIGHT_PATH, MODEL_PATH, REMOTE_PATH)
# initialize
net = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=args.env_id)
recognize_from_audio(net)
if __name__ == "__main__":
main()