-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathdata_loader.py
99 lines (84 loc) · 3.76 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import matplotlib.pyplot as plt
class DataLoader:
"""Data Loader class. As a simple case, the model is tried on TinyImageNet. For larger datasets,
you may need to adapt this class to use the Tensorflow Dataset API"""
def __init__(self, batch_size, shuffle=False):
self.X_train = None
self.y_train = None
self.img_mean = None
self.train_data_len = 0
self.X_val = None
self.y_val = None
self.val_data_len = 0
self.X_test = None
self.y_test = None
self.test_data_len = 0
self.shuffle = shuffle
self.batch_size = batch_size
def load_data(self):
# Please make sure to change this function to load your train/validation/test data.
train_data = np.array([plt.imread('./data/test_images/0.jpg'), plt.imread('./data/test_images/1.jpg'),
plt.imread('./data/test_images/2.jpg'), plt.imread('./data/test_images/3.jpg')])
self.X_train = train_data
self.y_train = np.array([284, 264, 682, 2])
val_data = np.array([plt.imread('./data/test_images/0.jpg'), plt.imread('./data/test_images/1.jpg'),
plt.imread('./data/test_images/2.jpg'), plt.imread('./data/test_images/3.jpg')])
self.X_val = val_data
self.y_val = np.array([284, 264, 682, 2])
self.train_data_len = self.X_train.shape[0]
self.val_data_len = self.X_val.shape[0]
img_height = 224
img_width = 224
num_channels = 3
return img_height, img_width, num_channels, self.train_data_len, self.val_data_len
def generate_batch(self, type='train'):
"""Generate batch from X_train/X_test and y_train/y_test using a python DataGenerator"""
if type == 'train':
# Training time!
new_epoch = True
start_idx = 0
mask = None
while True:
if new_epoch:
start_idx = 0
if self.shuffle:
mask = np.random.choice(self.train_data_len, self.train_data_len, replace=False)
else:
mask = np.arange(self.train_data_len)
new_epoch = False
# Batch mask selection
X_batch = self.X_train[mask[start_idx:start_idx + self.batch_size]]
y_batch = self.y_train[mask[start_idx:start_idx + self.batch_size]]
start_idx += self.batch_size
# Reset everything after the end of an epoch
if start_idx >= self.train_data_len:
new_epoch = True
mask = None
yield X_batch, y_batch
elif type == 'test':
# Testing time!
start_idx = 0
while True:
# Batch mask selection
X_batch = self.X_test[start_idx:start_idx + self.batch_size]
y_batch = self.y_test[start_idx:start_idx + self.batch_size]
start_idx += self.batch_size
# Reset everything
if start_idx >= self.test_data_len:
start_idx = 0
yield X_batch, y_batch
elif type == 'val':
# Testing time!
start_idx = 0
while True:
# Batch mask selection
X_batch = self.X_val[start_idx:start_idx + self.batch_size]
y_batch = self.y_val[start_idx:start_idx + self.batch_size]
start_idx += self.batch_size
# Reset everything
if start_idx >= self.val_data_len:
start_idx = 0
yield X_batch, y_batch
else:
raise ValueError("Please select a type from \'train\', \'val\', or \'test\'")