-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathcalculate_fvd.py
98 lines (69 loc) · 2.91 KB
/
calculate_fvd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
import torch
from tqdm import tqdm
def trans(x):
# if greyscale images add channel
if x.shape[-3] == 1:
x = x.repeat(1, 1, 3, 1, 1)
# permute BTCHW -> BCTHW
x = x.permute(0, 2, 1, 3, 4)
return x
def calculate_fvd(videos1, videos2, device, method='styleganv', only_final=False):
if method == 'styleganv':
from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained
elif method == 'videogpt':
from fvd.videogpt.fvd import load_i3d_pretrained, frechet_distance
from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats
print("calculate_fvd...")
# videos [batch_size, timestamps, channel, h, w]
assert videos1.shape == videos2.shape
i3d = load_i3d_pretrained(device=device)
fvd_results = []
# support grayscale input, if grayscale -> channel*3
# BTCHW -> BCTHW
# videos -> [batch_size, channel, timestamps, h, w]
videos1 = trans(videos1)
videos2 = trans(videos2)
fvd_results = []
if only_final:
assert videos1.shape[2] >= 10, "for calculate FVD, each clip_timestamp must >= 10"
# videos_clip [batch_size, channel, timestamps, h, w]
videos_clip1 = videos1
videos_clip2 = videos2
# get FVD features
feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)
feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)
# calculate FVD
fvd_results.append(frechet_distance(feats1, feats2))
else:
# for calculate FVD, each clip_timestamp must >= 10
for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)):
# get a video clip
# videos_clip [batch_size, channel, timestamps[:clip], h, w]
videos_clip1 = videos1[:, :, : clip_timestamp]
videos_clip2 = videos2[:, :, : clip_timestamp]
# get FVD features
feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)
feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)
# calculate FVD when timestamps[:clip]
fvd_results.append(frechet_distance(feats1, feats2))
result = {
"value": fvd_results,
}
return result
# test code / using example
def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 30
CHANNEL = 3
SIZE = 64
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
device = torch.device("cuda")
# device = torch.device("cpu")
result = calculate_fvd(videos1, videos2, device, method='videogpt', only_final=False)
print("[fvd-videogpt ]", result["value"])
result = calculate_fvd(videos1, videos2, device, method='styleganv', only_final=False)
print("[fvd-styleganv]", result["value"])
if __name__ == "__main__":
main()