forked from phizaz/diffae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_resize_celeba.py
116 lines (90 loc) · 3.13 KB
/
data_resize_celeba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import multiprocessing
import os
import shutil
from functools import partial
from io import BytesIO
from multiprocessing import Process, Queue
from os.path import exists, join
from pathlib import Path
import lmdb
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import LSUNClass
from torchvision.transforms import functional as trans_fn
from tqdm import tqdm
def resize_and_convert(img, size, resample, quality=100):
if size is not None:
img = trans_fn.resize(img, size, resample)
img = trans_fn.center_crop(img, size)
buffer = BytesIO()
img.save(buffer, format="webp", quality=quality)
val = buffer.getvalue()
return val
def resize_multiple(img,
sizes=(128, 256, 512, 1024),
resample=Image.LANCZOS,
quality=100):
imgs = []
for size in sizes:
imgs.append(resize_and_convert(img, size, resample, quality))
return imgs
def resize_worker(idx, img, sizes, resample):
img = img.convert("RGB")
out = resize_multiple(img, sizes=sizes, resample=resample)
return idx, out
class ConvertDataset(Dataset):
def __init__(self, data, size) -> None:
self.data = data
self.size = size
def __len__(self):
return len(self.data)
def __getitem__(self, index):
img = self.data[index]
bytes = resize_and_convert(img, self.size, Image.LANCZOS, quality=100)
return bytes
class ImageFolder(Dataset):
def __init__(self, folder, ext='jpg'):
super().__init__()
paths = sorted([p for p in Path(f'{folder}').glob(f'*.{ext}')])
self.paths = paths
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = os.path.join(self.paths[index])
img = Image.open(path)
return img
if __name__ == "__main__":
from tqdm import tqdm
out_path = 'datasets/celeba.lmdb'
in_path = 'datasets/celeba'
ext = 'jpg'
size = None
dataset = ImageFolder(in_path, ext)
print('len:', len(dataset))
dataset = ConvertDataset(dataset, size)
loader = DataLoader(dataset,
batch_size=50,
num_workers=12,
collate_fn=lambda x: x,
shuffle=False)
target = os.path.expanduser(out_path)
if os.path.exists(target):
shutil.rmtree(target)
with lmdb.open(target, map_size=1024**4, readahead=False) as env:
with tqdm(total=len(dataset)) as progress:
i = 0
for batch in loader:
with env.begin(write=True) as txn:
for img in batch:
key = f"{size}-{str(i).zfill(7)}".encode("utf-8")
# print(key)
txn.put(key, img)
i += 1
progress.update()
# if i == 1000:
# break
# if total == len(imgset):
# break
with env.begin(write=True) as txn:
txn.put("length".encode("utf-8"), str(i).encode("utf-8"))