-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_generator.py
129 lines (113 loc) · 5.06 KB
/
data_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
import keras
from conf import conf
import h5py
from sklearn.model_selection import train_test_split
import os
import shutil
def get_prev_self_play_model_dir(best_model_name=None):
if best_model_name is None: # get the latest best model
max_index = np.Inf
else: # get the latest best model until this best model
try:
best_model_name = best_model_name.split("/")[-1] # get the file name only if input is long path
best_model_name = best_model_name.split('.')[0]
max_index = int(best_model_name.split("_")[-1])
except:
max_index = np.Inf
index = 0
best_model_name_result = None
for filename in os.listdir(conf['SELF_PLAY_DIR']):
try:
name = filename.split('.')[0] # remove .h5
i = int(name.split('_')[-1]) # may throw exception here
if index < i < max_index:
best_model_name_result = filename
index = i
except:
print("EXCEPTION IN get_prev_self_play_model_dir")
continue
return os.path.join(conf['SELF_PLAY_DIR'], best_model_name_result) if best_model_name_result else None
def clean_unused_self_play_data(latest_trained_dir):
while latest_trained_dir is not None:
latest_trained_dir = get_prev_self_play_model_dir(latest_trained_dir)
if latest_trained_dir is not None:
shutil.rmtree(latest_trained_dir)
def get_training_desc():
# TODO: add log for which data folder it get for training
# a sliding window implementation to get most recent 500,000 self-play games
all_files = []
n_game = 0
self_play_best_model_dir = None
while n_game < conf['N_MOST_RECENT_GAMES']:
self_play_best_model_dir = get_prev_self_play_model_dir(self_play_best_model_dir)
if self_play_best_model_dir is None:
if n_game == 0: # Found no game data at all
raise FileNotFoundError("Can not find self-play directory")
print("Total games for training per epoch", n_game)
break
n_game += len(os.listdir(self_play_best_model_dir))
print("Reading training data from ", self_play_best_model_dir)
for root, dirs, files in os.walk(self_play_best_model_dir):
for f in files:
full_filename = os.path.join(root, f)
all_files.append(full_filename)
# look for data in extra training folder
extra_dirs = conf['EXTRA_TRAINING_DATA_DIR']
for extra_dir in extra_dirs:
extra_full_dir = os.path.join(extra_dir, self_play_best_model_dir.split("/")[-1])
if os.path.isdir(extra_full_dir):
print("Found extra training data dir:", extra_full_dir)
n_game += len(os.listdir(extra_full_dir))
for root, dirs, files in os.walk(extra_full_dir):
for f in files:
full_filename = os.path.join(root, f)
all_files.append(full_filename)
x_train, x_test = train_test_split(all_files, test_size=0.1, random_state=2)
# clean up old data that not longer needed for training
clean_unused_self_play_data(self_play_best_model_dir)
return {'train': x_train, 'validation': x_test}
class DataGenerator(keras.utils.Sequence):
'Generates data for Keras'
def __init__(self, list_IDs, labels, batch_size=32, dim=(32,32,32), n_channels=1,
n_classes=10, shuffle=True):
self.dim = dim
self.batch_size = batch_size
self.labels = labels
self.list_IDs = list_IDs
self.n_channels = n_channels
self.n_classes = n_classes
self.shuffle = shuffle
self.on_epoch_end()
def __len__(self):
'Denotes the number of batches per epoch'
return int(np.floor(len(self.list_IDs) / self.batch_size))
def __getitem__(self, index):
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
list_IDs_temp = [self.list_IDs[k] for k in indexes]
X, y = self.__data_generation(list_IDs_temp)
return X, y
def on_epoch_end(self):
self.indexes = np.arange(len(self.list_IDs))
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __data_generation(self, list_IDs_batch):
SIZE = conf['SIZE']
X = np.zeros((self.batch_size, *self.dim))
policy_y = np.zeros((self.batch_size, 1))
value_y = np.zeros((self.batch_size, SIZE * SIZE + 1))
for j, filename in enumerate(list_IDs_batch):
try:
with h5py.File(filename) as f:
board = f['board'][:]
policy = f['policy_target'][:]
value_target = f['value_target'][()]
X[j] = board
policy_y[j] = value_target
value_y[j] = policy
f.close()
except:
print("Exception while reading", filename, " skipping it")
continue
return X, [value_y, policy_y]