-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdataset.py
110 lines (81 loc) · 3.08 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import glob
import h5py
import random
import numpy as np
from PIL import Image
import torch.utils.data as data
import torchvision.transforms as transforms
def random_crop(hr, lr, size, scale):
h, w = lr.shape[:-1]
x = random.randint(0, w-size)
y = random.randint(0, h-size)
hsize = size*scale
hx, hy = x*scale, y*scale
crop_lr = lr[y:y+size, x:x+size].copy()
crop_hr = hr[hy:hy+hsize, hx:hx+hsize].copy()
return crop_hr, crop_lr
def random_flip_and_rotate(im1, im2):
if random.random() < 0.5:
im1 = np.flipud(im1)
im2 = np.flipud(im2)
if random.random() < 0.5:
im1 = np.fliplr(im1)
im2 = np.fliplr(im2)
angle = random.choice([0, 1, 2, 3])
im1 = np.rot90(im1, angle)
im2 = np.rot90(im2, angle)
# have to copy before be called by transform function
return im1.copy(), im2.copy()
class TrainDataset(data.Dataset):
def __init__(self, path, size, scale):
super(TrainDataset, self).__init__()
self.size = size
h5f = h5py.File(path, "r")
self.hr = [v[:] for v in h5f["HR"].values()]
# perform multi-scale training
if scale == 0:
self.scale = [2, 3, 4]
self.lr = [[v[:] for v in h5f["X{}".format(i)].values()] for i in self.scale]
else:
self.scale = [scale]
self.lr = [[v[:] for v in h5f["X{}".format(scale)].values()]]
h5f.close()
self.transform = transforms.Compose([
transforms.ToTensor()
])
def __getitem__(self, index):
size = self.size
item = [(self.hr[index], self.lr[i][index]) for i, _ in enumerate(self.lr)]
item = [random_crop(hr, lr, size, self.scale[i]) for i, (hr, lr) in enumerate(item)]
item = [random_flip_and_rotate(hr, lr) for hr, lr in item]
return [(self.transform(hr), self.transform(lr)) for hr, lr in item]
def __len__(self):
return len(self.hr)
class TestDataset(data.Dataset):
def __init__(self, dirname, scale):
super(TestDataset, self).__init__()
self.name = dirname.split("/")[-1]
self.scale = scale
if "DIV" in self.name:
self.hr = glob.glob(os.path.join("{}_HR".format(dirname), "*.png"))
self.lr = glob.glob(os.path.join("{}_LR_bicubic".format(dirname),
"X{}/*.png".format(scale)))
else:
all_files = glob.glob(os.path.join(dirname, "x{}/*.png".format(scale)))
self.hr = [name for name in all_files if "HR" in name]
self.lr = [name for name in all_files if "LR" in name]
self.hr.sort()
self.lr.sort()
self.transform = transforms.Compose([
transforms.ToTensor()
])
def __getitem__(self, index):
hr = Image.open(self.hr[index])
lr = Image.open(self.lr[index])
hr = hr.convert("RGB")
lr = lr.convert("RGB")
filename = self.hr[index].split("/")[-1]
return self.transform(hr), self.transform(lr), filename
def __len__(self):
return len(self.hr)