forked from phizaz/diffae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lmdb_writer.py
executable file
·131 lines (105 loc) · 3.53 KB
/
lmdb_writer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from io import BytesIO
import lmdb
from PIL import Image
import torch
from contextlib import contextmanager
from torch.utils.data import Dataset
from multiprocessing import Process, Queue
import os
import shutil
def convert(x, format, quality=100):
# to prevent locking!
torch.set_num_threads(1)
buffer = BytesIO()
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0)
x = x.to(torch.uint8)
x = x.numpy()
img = Image.fromarray(x)
img.save(buffer, format=format, quality=quality)
val = buffer.getvalue()
return val
@contextmanager
def nullcontext():
yield
class _WriterWroker(Process):
def __init__(self, path, format, quality, zfill, q):
super().__init__()
if os.path.exists(path):
shutil.rmtree(path)
self.path = path
self.format = format
self.quality = quality
self.zfill = zfill
self.q = q
self.i = 0
def run(self):
if not os.path.exists(self.path):
os.makedirs(self.path)
with lmdb.open(self.path, map_size=1024**4, readahead=False) as env:
while True:
job = self.q.get()
if job is None:
break
with env.begin(write=True) as txn:
for x in job:
key = f"{str(self.i).zfill(self.zfill)}".encode(
"utf-8")
x = convert(x, self.format, self.quality)
txn.put(key, x)
self.i += 1
with env.begin(write=True) as txn:
txn.put("length".encode("utf-8"), str(self.i).encode("utf-8"))
class LMDBImageWriter:
def __init__(self, path, format='webp', quality=100, zfill=7) -> None:
self.path = path
self.format = format
self.quality = quality
self.zfill = zfill
self.queue = None
self.worker = None
def __enter__(self):
self.queue = Queue(maxsize=3)
self.worker = _WriterWroker(self.path, self.format, self.quality,
self.zfill, self.queue)
self.worker.start()
def put_images(self, tensor):
"""
Args:
tensor: (n, c, h, w) [0-1] tensor
"""
self.queue.put(tensor.cpu())
# with self.env.begin(write=True) as txn:
# for x in tensor:
# key = f"{str(self.i).zfill(self.zfill)}".encode("utf-8")
# x = convert(x, self.format, self.quality)
# txn.put(key, x)
# self.i += 1
def __exit__(self, *args, **kwargs):
self.queue.put(None)
self.queue.close()
self.worker.join()
class LMDBImageReader(Dataset):
def __init__(self, path, zfill: int = 7):
self.zfill = zfill
self.env = lmdb.open(
path,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
if not self.env:
raise IOError('Cannot open lmdb dataset', path)
with self.env.begin(write=False) as txn:
self.length = int(
txn.get('length'.encode('utf-8')).decode('utf-8'))
def __len__(self):
return self.length
def __getitem__(self, index):
with self.env.begin(write=False) as txn:
key = f'{str(index).zfill(self.zfill)}'.encode('utf-8')
img_bytes = txn.get(key)
buffer = BytesIO(img_bytes)
img = Image.open(buffer)
return img