-
Notifications
You must be signed in to change notification settings - Fork 8
/
abstract_model.py
112 lines (91 loc) · 3.25 KB
/
abstract_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import abc
import tensorflow as tf
from prelude import *
class AbstractModel(metaclass=abc.ABCMeta):
def __init__(self, name):
self.name = name
gpu_options = tf.GPUOptions()
gpu_options.allow_growth = True
self._graph = tf.Graph()
with self._graph.as_default():
self._build_graph()
self.__summary_op = tf.summary.merge_all()
self.__init_op = tf.global_variables_initializer()
self._sess = tf.Session(graph=self._graph, config=tf.ConfigProto(gpu_options=gpu_options))
@abc.abstractmethod
def _build_graph(self):
pass
@abc.abstractmethod
def _train_func(self, batch_idx, batch_data):
pass
@abc.abstractmethod
def _test_func(self, batch_idx, batch_data):
pass
def __create_writer(self):
path = "SavedModel/Board/{}/".format(self.name)
if not os.path.exists(path):
os.makedirs(path)
return tf.summary.FileWriter(
logdir=path,
graph=self._graph,
session=self._sess
)
def load(self):
with self._graph.as_default():
saver = tf.train.Saver(tf.global_variables())
path = "savedModel/{}/".format(self.name)
saver.restore(self._sess, path)
print("Model \"{}\" loaded".format(self.name))
def try_load(self):
try:
self.load()
except:
pass
def close(self):
self._sess.close()
def save(self):
with self._graph.as_default():
saver = tf.train.Saver(tf.global_variables())
path = "savedModel/{}/".format(self.name)
folder = os.path.dirname(path)
if not os.path.exists(folder):
os.makedirs(folder)
prefix = saver.save(self._sess, path)
print("Model saved at \"{}\"".format(prefix))
def __init_global_variables(self):
print("Initializing global variables")
self._sess.run(self.__init_op)
def train(self, train_set, test_set, max_flip, max_epoch):
self.__init_global_variables()
flip_count = 0
best_loss = None
epoch = 0
while epoch < max_epoch:
batch_idx = 0
for batch_data in train_set.fetch():
self._train_func(batch_idx, batch_data)
batch_idx += 1
print()
new_loss = self.__valid(test_set)
if best_loss is None or new_loss < best_loss:
best_loss = new_loss
print("{} tested, new_loss={} best_loss={}".format(self.name, new_loss, best_loss))
self.save()
else:
print("{} tested, new_loss={} best_loss={}".format(self.name, new_loss, best_loss))
flip_count += 1
if flip_count >= max_flip:
break
epoch += 1
print("Model '{}' train over".format(self.name))
def __valid(self, test_set):
total_loss = 0.0
count = 0
avg_loss = 0.0
for batch_data in test_set.fetch():
loss = self._test_func(count, batch_data)
total_loss += loss
count += 1
avg_loss = total_loss / count
print()
return avg_loss