forked from Hadisalman/smoothing-adversarial
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathzipdata.py
95 lines (85 loc) · 3.4 KB
/
zipdata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import multiprocessing
import os.path as op
from threading import local
from zipfile import ZipFile, BadZipFile
from PIL import Image
from io import BytesIO
import torch.utils.data as data
_VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png']
class ZipData(data.Dataset):
_IGNORE_ATTRS = {'_zip_file'}
def __init__(self, path, map_file,
transform=None, target_transform=None,
extensions=None):
self._path = path
if not extensions:
extensions = _VALID_IMAGE_TYPES
self._zip_file = ZipFile(path)
self.zip_dict = {}
self.samples = []
self.transform = transform
self.target_transform = target_transform
self.class_to_idx = {}
with open(map_file, 'r') as f:
for line in iter(f.readline, ""):
line = line.strip()
if not line:
continue
cls_idx = [l for l in line.split('\t') if l]
if not cls_idx:
continue
assert len(cls_idx) >= 2, "invalid line: {}".format(line)
idx = int(cls_idx[1])
cls = cls_idx[0]
del cls_idx
at_idx = cls.find('@')
assert at_idx >= 0, "invalid class: {}".format(cls)
cls = cls[at_idx + 1:]
if cls.startswith('/'):
# Python ZipFile expects no root
cls = cls[1:]
assert cls, "invalid class in line {}".format(line)
prev_idx = self.class_to_idx.get(cls)
assert prev_idx is None or prev_idx == idx, "class: {} idx: {} previously had idx: {}".format(
cls, idx, prev_idx
)
self.class_to_idx[cls] = idx
for fst in self._zip_file.infolist():
fname = fst.filename
target = self.class_to_idx.get(fname)
if target is None:
continue
if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0:
continue
ext = op.splitext(fname)[1].lower()
if ext in extensions:
self.samples.append((fname, target))
assert len(self), "No images found in: {} with map: {}".format(self._path, map_file)
def __repr__(self):
return 'ZipData({}, size={})'.format(self._path, len(self))
def __getstate__(self):
return {
key: val if key not in self._IGNORE_ATTRS else None
for key, val in self.__dict__.iteritems()
}
def __getitem__(self, index):
proc = multiprocessing.current_process()
pid = proc.pid # get pid of this process.
if pid not in self.zip_dict:
self.zip_dict[pid] = ZipFile(self._path)
zip_file = self.zip_dict[pid]
if index >= len(self) or index < 0:
raise KeyError("{} is invalid".format(index))
path, target = self.samples[index]
try:
sample = Image.open(BytesIO(zip_file.read(path))).convert('RGB')
except BadZipFile:
print("bad zip file")
return None, None
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)