-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
147 lines (127 loc) · 5.88 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import torch
import logging
import cv2
from PIL import Image
import imageio
import numpy as np
import torch.utils.data as data
from os.path import join, exists
import math
import random
import sys
import json
import random
from subnet.basics import *
from subnet.ms_ssim_torch import ms_ssim
from augmentation import random_flip, random_crop_and_pad_image_and_labels
class UVGDataSet(data.Dataset):
def __init__(self, root="autodl-tmp/data/UVG/images/", filelist="autodl-tmp/data/UVG/originalv.txt", refdir='H265L23', testfull=False):
with open(filelist) as f:
folders = f.readlines()
self.ref = []
self.refbpp = []
self.input = []
self.hevcclass = []
AllIbpp = self.getbpp(refdir)
ii = 0
print(folders)
for folder in folders:
seq = folder.rstrip()
seqIbpp = AllIbpp[ii]
imlist = os.listdir(os.path.join(root, seq))
cnt = 0
for im in imlist:
if im[-4:] == '.png':
cnt += 1
if testfull:
framerange = cnt // 12
else:
framerange = 1
for i in range(framerange):
if i*12+1+12>=600:
break
refpath = os.path.join(root, seq, refdir, 'im'+str(i * 12 + 1).zfill(4)+'.png')
inputpath = []
for j in range(12):
inputpath.append(os.path.join(root, seq, 'im' + str(i * 12 + j + 1).zfill(3)+'.png'))
self.ref.append(refpath)
self.refbpp.append(seqIbpp)
self.input.append(inputpath)
ii += 1
def getbpp(self, ref_i_folder):
Ibpp = None
if ref_i_folder == 'H265L20':
print('use H265L20')
Ibpp = [1.5809944058641974]# you need to fill bpps after generating crf=20
elif ref_i_folder == 'H265L23':
print('use H265L23')
Ibpp = [1.1798924575617284]# you need to fill bpps after generating crf=23
elif ref_i_folder == 'H265L26':
print('use H265L26')
Ibpp = [0.8653549382716049]# you need to fill bpps after generating crf=26
elif ref_i_folder == 'H265L29':
print('use H265L29')
Ibpp = [0.6269290123456791]# you need to fill bpps after generating crf=29
else:
print('cannot find ref : ', ref_i_folder)
exit()
if len(Ibpp) == 0:
print('You need to generate I frames and fill the bpps above!')
exit()
return Ibpp
def __len__(self):
return len(self.ref)
def __getitem__(self, index):
ref_image = imageio.imread(self.ref[index]).transpose(2, 0, 1).astype(np.float32) / 255.0
h = (ref_image.shape[1] // 64) * 64
w = (ref_image.shape[2] // 64) * 64
ref_image = np.array(ref_image[:, :h, :w])
input_images = []
refpsnr = None
refmsssim = None
for filename in self.input[index]:
input_image = (imageio.imread(filename).transpose(2, 0, 1)[:, :h, :w]).astype(np.float32) / 255.0
if refpsnr is None:
refpsnr = CalcuPSNR(input_image, ref_image)
refmsssim = ms_ssim(torch.from_numpy(input_image[np.newaxis, :]), torch.from_numpy(ref_image[np.newaxis, :]), data_range=1.0).numpy()
else:
input_images.append(input_image[:, :h, :w])
input_images = np.array(input_images)
return input_images, ref_image, self.refbpp[index], refpsnr, refmsssim
class DataSet(data.Dataset):
def __init__(self, path="autodl-tmp/data/vimeo_septuplet/test.txt", im_height=256, im_width=256):
self.image_input_list, self.image_ref_list = self.get_vimeo(filefolderlist=path)
self.im_height = im_height
self.im_width = im_width
self.featurenoise = torch.zeros([out_channel_M, self.im_height // 16, self.im_width // 16])
self.znoise = torch.zeros([out_channel_N, self.im_height // 64, self.im_width // 64])
self.mvnois = torch.zeros([out_channel_mv, self.im_height // 16, self.im_width // 16])
print("dataset find image: ", len(self.image_input_list))
def get_vimeo(self, rootdir="autodl-tmp/data/vimeo_septuplet/sequences/", filefolderlist="autodl-tmp/data/vimeo_septuplet/test.txt"):
with open(filefolderlist) as f:
data = f.readlines()
fns_train_input = []
fns_train_ref = []
for n, line in enumerate(data, 1):
y = os.path.join(rootdir, line.rstrip())
fns_train_input += [y]
refnumber = int(y[-5:-4]) - 2
refname = y[0:-5] + str(refnumber) + '.png'
fns_train_ref += [refname]
return fns_train_input, fns_train_ref
def __len__(self):
return len(self.image_input_list)
def __getitem__(self, index):
input_image = imageio.imread(self.image_input_list[index])
ref_image = imageio.imread(self.image_ref_list[index])
input_image = input_image.astype(np.float32) / 255.0
ref_image = ref_image.astype(np.float32) / 255.0
input_image = input_image.transpose(2, 0, 1)
ref_image = ref_image.transpose(2, 0, 1)
input_image = torch.from_numpy(input_image).float()
ref_image = torch.from_numpy(ref_image).float()
input_image, ref_image = random_crop_and_pad_image_and_labels(input_image, ref_image, [self.im_height, self.im_width])
input_image, ref_image = random_flip(input_image, ref_image)
quant_noise_feature, quant_noise_z, quant_noise_mv = torch.nn.init.uniform_(torch.zeros_like(self.featurenoise), -0.5, 0.5), torch.nn.init.uniform_(torch.zeros_like(self.znoise), -0.5, 0.5), torch.nn.init.uniform_(torch.zeros_like(self.mvnois), -0.5, 0.5)
return input_image, ref_image, quant_noise_feature, quant_noise_z, quant_noise_mv