forked from GuoxiaWang/insightface
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathofrecord_util.py
150 lines (126 loc) · 4.42 KB
/
ofrecord_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import oneflow as flow
from config import config
def train_dataset_reader(
data_dir, batch_size, data_part_num, part_name_suffix_length=1
):
if os.path.exists(data_dir):
print("Loading train data from {}".format(data_dir))
else:
raise Exception("Invalid train dataset dir", data_dir)
image_blob_conf = flow.data.BlobConf(
"encoded",
shape=(112, 112, 3),
dtype=flow.float,
codec=flow.data.ImageCodec(
image_preprocessors=[
flow.data.ImagePreprocessor("bgr2rgb"),
flow.data.ImagePreprocessor("mirror"),
]
),
preprocessors=[
flow.data.NormByChannelPreprocessor(
mean_values=(127.5, 127.5, 127.5), std_values=(128, 128, 128)
),
],
)
label_blob_conf = flow.data.BlobConf(
"label", shape=(), dtype=flow.int32, codec=flow.data.RawCodec()
)
return flow.data.decode_ofrecord(
data_dir,
(label_blob_conf, image_blob_conf),
batch_size=batch_size,
data_part_num=data_part_num,
part_name_prefix=config.part_name_prefix,
part_name_suffix_length=config.part_name_suffix_length,
shuffle=config.shuffle,
buffer_size=16384,
)
def validation_dataset_reader(val_dataset_dir, val_batch_size=1, val_data_part_num=1):
# lfw: (12000L, 3L, 112L, 112L)
# cfp_fp: (14000L, 3L, 112L, 112L)
# agedb_30: (12000L, 3L, 112L, 112L)
if os.path.exists(val_dataset_dir):
print("Loading validation data from {}".format(val_dataset_dir))
else:
raise Exception("Invalid validation dataset dir", val_dataset_dir)
color_space = "RGB"
ofrecord = flow.data.ofrecord_reader(
val_dataset_dir,
batch_size=val_batch_size,
data_part_num=val_data_part_num,
part_name_suffix_length=1,
shuffle_after_epoch=False,
)
image = flow.data.OFRecordImageDecoder(ofrecord, "encoded", color_space=color_space)
issame = flow.data.OFRecordRawDecoder(
ofrecord, "issame", shape=(), dtype=flow.int32
)
rsz, scale, new_size = flow.image.Resize(image, target_size=(112,112), channels=3)
normal = flow.image.CropMirrorNormalize(
rsz,
color_space=color_space,
crop_h=0,
crop_w=0,
crop_pos_y=0.5,
crop_pos_x=0.5,
mean=[127.5, 127.5, 127.5],
std=[128.0, 128.0, 128.0],
output_dtype=flow.float,
)
normal = flow.transpose(normal, name="transpose_val", perm=[0, 2, 3, 1])
return issame, normal
def load_synthetic(config):
batch_size = config.train_batch_size
image_size = 112
label = flow.data.decode_random(
shape=(),
dtype=flow.int32,
batch_size=batch_size,
initializer=flow.zeros_initializer(flow.int32),
)
image = flow.data.decode_random(
shape=(image_size, image_size, 3), dtype=flow.float, batch_size=batch_size,
)
return label, image
def load_train_dataset(args):
data_dir = config.dataset_dir
batch_size = args.train_batch_size
data_part_num = config.train_data_part_num
part_name_suffix_length = config.part_name_suffix_length
print("train batch size in load train dataset: ", batch_size)
labels, images = train_dataset_reader(
data_dir, batch_size, data_part_num, part_name_suffix_length
)
return labels, images
def load_lfw_dataset(args):
data_dir = args.lfw_dataset_dir
batch_size = args.val_batch_size_per_device
data_part_num = args.val_data_part_num
(issame, images) = validation_dataset_reader(
val_dataset_dir=data_dir,
val_batch_size=batch_size,
val_data_part_num=data_part_num,
)
return issame, images
def load_cfp_fp_dataset(args):
data_dir = args.cfp_fp_dataset_dir
batch_size = args.val_batch_size_per_device
data_part_num = args.val_data_part_num
(issame, images) = validation_dataset_reader(
val_dataset_dir=data_dir,
val_batch_size=batch_size,
val_data_part_num=data_part_num,
)
return issame, images
def load_agedb_30_dataset(args):
data_dir = args.agedb_30_dataset_dir
batch_size = args.val_batch_size_per_device
data_part_num = args.val_data_part_num
(issame, images) = validation_dataset_reader(
val_dataset_dir=data_dir,
val_batch_size=batch_size,
val_data_part_num=data_part_num,
)
return issame, images