-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
43 lines (35 loc) · 853 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
INST_DICT = [
"clarinet",
"distorted electric guitar",
"female singer",
"flute",
"piano",
"tenor saxophone",
"trumpet",
"violin"
]
IRMAS_INST_DICT = [
"cel",
"cla",
"flu",
"gac",
"gel",
"org",
"pia",
"sax",
"tru",
"vio",
"voi",
]
def stereo_to_mono(audio: np.array) -> np.array:
return np.mean(audio, axis=0)
def save_model(model, filename: str):
from torch import save
from os import path
return save(model.state_dict(), path.join(path.dirname(path.abspath(__file__)), 'trained_models', filename))
def load_model(model, filename: str):
from torch import load
from os import path
model.load_state_dict(load(path.join(path.dirname(path.abspath(__file__)), 'trained_models', filename), map_location='cpu'))
return model