-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
164 lines (148 loc) · 5.02 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from argparse import ArgumentParser
from pathlib import Path
import matplotlib.pyplot as plt
import torch
from vap.modules.VAP import load_model_from_state_dict, step_extraction
from vap.modules.lightning_module import VAPModule
from vap.utils.audio import load_waveform
from vap.utils.plot import plot_stereo
from vap.utils.utils import (
batch_to_device,
everything_deterministic,
tensor_dict_to_json,
write_json,
)
everything_deterministic()
torch.manual_seed(0)
def get_args():
parser = ArgumentParser()
parser.add_argument(
"-a",
"--audio",
type=str,
help="Path to waveform",
required=True,
)
parser.add_argument(
"--output",
type=str,
default=None,
help="Path to waveform",
)
parser.add_argument(
"-sd",
"--state_dict",
type=str,
default=None, # "example/checkpoints/VAP_state_dict.pt",
help="Path to state_dict",
)
parser.add_argument(
"-c",
"--checkpoint",
type=str,
default=None, # "example/checkpoints/VAP_state_dict.pt",
help="Path to trained model",
)
parser.add_argument(
"--chunk_time",
type=float,
default=20,
help="Duration of each chunk processed by model (total duration of clips to evaluate)",
)
parser.add_argument(
"--step_time",
type=float,
default=5,
help="Increment to process in a step. (Uses the last `chunk_time - step_time` as context to predict `step_time` for each relevant clip/chunk)",
)
parser.add_argument(
"--force_no_chunk",
action="store_true",
help="Don't use chunking but process the entire audio in one pass.",
)
parser.add_argument(
"--plot", action="store_true", help="Visualize output (matplotlib)"
)
args = parser.parse_args()
assert Path(args.audio).exists(), f"Audio {args.audio} does not exist"
assert (
args.state_dict is not None or args.checkpoint is not None
), "Must provide state_dict or checkpoint"
if args.state_dict:
assert Path(
args.state_dict
).exists(), f"State-dict {args.state_dict} does not exist"
elif args.checkpoint:
assert Path(
args.checkpoint
).exists(), f"Checkpoint {args.checkpoint} does not exist"
return args
def load_vap_model(args):
if args.state_dict:
model = load_model_from_state_dict(args.state_dict)
elif args.checkpoint:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = VAPModule.load_model(args.checkpoint, map_location=device)
else:
raise ValueError("Must provide state_dict or checkpoint")
return model
if __name__ == "__main__":
args = get_args()
for k, v in vars(args).items():
print(f"{k}: {v}")
###########################################################
# Load the Model
###########################################################
print("Load Model...")
model = load_vap_model(args)
model = model.eval()
###########################################################
# Load the Audio
###########################################################
print("Loading Audio...")
waveform = load_waveform(args.audio, sample_rate=model.sample_rate, mono=False)[
0
].unsqueeze(0)
duration = round(waveform.shape[-1] / model.sample_rate)
###########################################################
# Model Forward
###########################################################
# For consistency with training, we need to ensure that we use the
# normal context length (default: 20s)
print("Model Forward...")
if duration > 20:
print("Duration > 20: ", duration)
if args.force_no_chunk:
out = model.probs(waveform.to(model.device))
else:
out = step_extraction(
waveform, model, chunk_time=args.chunk_time, step_time=args.step_time
)
else:
out = model.probs(waveform.to(model.device))
out = batch_to_device(out, "cpu") # to cpu for plot/save
###########################################################
# Save Output
###########################################################
if args.output is None:
args.output = "vap_output.json"
data = tensor_dict_to_json(out)
write_json(data, args.output)
print("wavefile: ", args.audio)
print("Saved output -> ", args.output)
###########################################################
# Plot
###########################################################
if args.plot:
fig, ax = plot_stereo(
waveform[0].cpu(),
p_now=out["p_now"][0].cpu(),
p_fut=out["p_future"][0].cpu(),
vad=out["vad"][0].cpu(),
)
# Save figure
figpath = args.output.replace(".json", ".png")
fig.savefig(figpath)
print(f"Saved figure as {figpath}.png")
print("Close figure to exit")
plt.show()