forked from phizaz/diffae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_resize_bedroom.py
101 lines (77 loc) · 2.77 KB
/
data_resize_bedroom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
import multiprocessing
import os
from os.path import join, exists
from functools import partial
from io import BytesIO
import shutil
import lmdb
from PIL import Image
from torchvision.datasets import LSUNClass
from torchvision.transforms import functional as trans_fn
from tqdm import tqdm
from multiprocessing import Process, Queue
def resize_and_convert(img, size, resample, quality=100):
img = trans_fn.resize(img, size, resample)
img = trans_fn.center_crop(img, size)
buffer = BytesIO()
img.save(buffer, format="webp", quality=quality)
val = buffer.getvalue()
return val
def resize_multiple(img,
sizes=(128, 256, 512, 1024),
resample=Image.LANCZOS,
quality=100):
imgs = []
for size in sizes:
imgs.append(resize_and_convert(img, size, resample, quality))
return imgs
def resize_worker(idx, img, sizes, resample):
img = img.convert("RGB")
out = resize_multiple(img, sizes=sizes, resample=resample)
return idx, out
from torch.utils.data import Dataset, DataLoader
class ConvertDataset(Dataset):
def __init__(self, data) -> None:
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
img, _ = self.data[index]
bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90)
return bytes
if __name__ == "__main__":
"""
converting lsun' original lmdb to our lmdb, which is somehow more performant.
"""
from tqdm import tqdm
# path to the original lsun's lmdb
src_path = 'datasets/bedroom_train_lmdb'
out_path = 'datasets/bedroom256.lmdb'
dataset = LSUNClass(root=os.path.expanduser(src_path))
dataset = ConvertDataset(dataset)
loader = DataLoader(dataset,
batch_size=50,
num_workers=12,
collate_fn=lambda x: x,
shuffle=False)
target = os.path.expanduser(out_path)
if os.path.exists(target):
shutil.rmtree(target)
with lmdb.open(target, map_size=1024**4, readahead=False) as env:
with tqdm(total=len(dataset)) as progress:
i = 0
for batch in loader:
with env.begin(write=True) as txn:
for img in batch:
key = f"{256}-{str(i).zfill(7)}".encode("utf-8")
# print(key)
txn.put(key, img)
i += 1
progress.update()
# if i == 1000:
# break
# if total == len(imgset):
# break
with env.begin(write=True) as txn:
txn.put("length".encode("utf-8"), str(i).encode("utf-8"))