-
Notifications
You must be signed in to change notification settings - Fork 5
/
run_benchmarks.py
106 lines (97 loc) · 3.1 KB
/
run_benchmarks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import glob
import os
from tabulate import tabulate
from tqdm import tqdm
from video_sampler.sampler import SamplerConfig, VideoSampler
clip_gate = dict(
type="clip",
pos_samples=["a cat"],
neg_samples=[
"an empty background",
"text on screen",
"blurry image",
"a forest with no animals",
],
model_name="ViT-B-32",
batch_size=32,
pos_margin=0.2,
neg_margin=0.3,
)
pass_gate = dict(type="pass")
blur_gate_laplacian = dict(type="blur", method="laplacian", threshold=120)
blur_gate_fft = dict(type="blur", method="fft", threshold=20)
def run_benchmarks(gate_config: dict, target_size: int = 256, debug: bool = False):
configs = [
SamplerConfig(
buffer_config=dict(
type="grid",
hash_size=4,
size=30,
debug=debug,
grid_x=4,
grid_y=4,
max_hits=1,
),
),
SamplerConfig(
buffer_config=dict(type="hash", hash_size=4, size=30, debug=debug)
),
SamplerConfig(
buffer_config=dict(type="hash", hash_size=4, size=30, debug=debug),
gate_config=gate_config,
),
]
table = []
for cfg in tqdm(configs, desc="Creating gifs..."):
sampler = VideoSampler(cfg)
gate_type = cfg.gate_config["type"]
model_type = (
f"{cfg.buffer_config['type']}_{cfg.buffer_config['hash_size']}_{gate_type}"
)
for video_fn in glob.glob("videos/*.mp4"):
frames = []
timestamps = []
for res in sampler.sample(video_path=video_fn):
for frame_obj in res:
if frame_obj.frame is None:
continue
frames.append(frame_obj.frame)
timestamps.append(float(frame_obj.metadata["frame_time"]))
# sort by the timestamps
frames = [
x.resize((target_size, target_size))
for _, x in sorted(zip(timestamps, frames))
]
first_frame = frames[0]
bsn = os.path.basename(video_fn)
savename = f"assets/{bsn}_{model_type}.gif"
first_frame.save(
savename,
format="GIF",
append_images=frames,
save_all=True,
duration=250,
loop=0,
)
stats = sampler.stats
table.append(
[
bsn,
cfg.buffer_config["type"],
gate_type,
stats["decoded"],
stats["produced"],
stats["gated"],
]
)
print(
tabulate(
table,
headers=["video", "buffer", "gate", "decoded", "produced", "gated"],
tablefmt="github",
)
)
if __name__ == "__main__":
run_benchmarks(gate_config=clip_gate, debug=False)
run_benchmarks(gate_config=pass_gate, debug=False)
run_benchmarks(gate_config=blur_gate_fft, debug=False)