-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset_wvu_new.py
107 lines (85 loc) · 3.88 KB
/
dataset_wvu_new.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import random
from matplotlib.pyplot import fill
import numpy as np
import os
from torch.utils.data import Dataset
from torchvision import transforms
import config
import utils_wvu_new
import torch
class WVUNewVerifierOne(Dataset):
def __init__(self, train = True):
super().__init__()
self.train = train
if self.train == True:
print("trainning phase ...")
if config.num_join_fingers == 1:
self.dict_photo, self.dict_print = utils_wvu_new.get_one_img_dict(
config.train_photo_dir, config.train_print_dir)
elif config.num_join_fingers == 2:
self.dict_photo, self.dict_print = utils_wvu_new.get_two_img_dict(
config.train_photo_dir, config.train_print_dir, config.fnums)
elif self.train == False:
print("testing phase ...")
if config.num_join_fingers == 1:
self.dict_photo, self.dict_print = utils_wvu_new.get_one_img_dict(
config.test_photo_dir, config.test_print_dir)
elif config.num_join_fingers == 2:
self.dict_photo, self.dict_print = utils_wvu_new.get_two_img_dict(
config.test_photo_dir, config.test_print_dir, config.fnums)
self.num_photo_samples = len(self.dict_photo)
mean = [0.5]
std = [0.5]
fill_white = (255, )
self.train_trans = transforms.Compose([
transforms.Resize((config.img_size, config.img_size)),
#transforms.RandomAffine(3),
#transforms.Pad(16),
#transforms.RandomCrop(256),
#transforms.ColorJitter(brightness=0.2),
transforms.RandomRotation(20, fill=fill_white),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
self.test_trans = transforms.Compose([
transforms.Resize((config.img_size, config.img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
self.trans = (self.train_trans if train else self.test_trans)
def __len__(self):
return self.num_photo_samples * config.num_imposter
def __getitem__(self, index):
if index % config.num_imposter == 0: same_class = True
else: same_class = False
finger_id, photo_image = self.dict_photo[index // config.num_imposter]
# genuine pair
if same_class:
class_id = finger_id
# imposter pair
else:
class_id = list(self.dict_print.keys())[random.randint(0,
len(self.dict_print) - 1)]
while finger_id == class_id:
class_id = list(self.dict_print.keys())[random.randint(0,
len(self.dict_print) - 1)]
if config.num_join_fingers == 1:
img1 = self.trans(photo_image)
img2 = self.trans(self.dict_print[class_id])
elif config.num_join_fingers == 2:
print_image = self.dict_print[class_id]
ph_f1, ph_f2 = self.trans(photo_image[0]), self.trans(photo_image[1])
pr_f1, pr_f2 = self.trans(print_image[0]), self.trans(print_image[1])
if config.join_type == "concat":
img1 = torch.cat([ph_f1, ph_f2], dim=2)
img2 = torch.cat([pr_f1, pr_f2], dim=2)
elif config.join_type == "channel":
img1 = torch.cat([ph_f1, ph_f2], dim=0)
img2 = torch.cat([pr_f1, pr_f2], dim=0)
return img1, img2, same_class
if __name__ == "__main__":
data = WVUNewVerifierOne(train = config.is_train)
img1, img2, same_class = data.__getitem__(90)
print(img1.shape)
#title = ("genuine pair" if same_class else "imposter pair")
#utils_wvu_new.plot_tensors([img1, img2], title)