-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathevaluate.py
63 lines (48 loc) · 2 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/usr/bin/python3
import tensorflow as tf
import time, os
import argparse
# User-defined
from network import Network
from diagnostics import Diagnostics
from data import Data
from model import Model
from config import config_test, directories
tf.logging.set_verbosity(tf.logging.ERROR)
def evaluate(config, directories, ckpt):
pin_cpu = tf.ConfigProto(allow_soft_placement=True, device_count = {'GPU':0})
start = time.time()
# Build graph
cnn = Model(config, directories)
# Restore the moving average version of the learned variables for eval.
variables_to_restore = cnn.ema.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
with tf.Session(config=pin_cpu) as sess:
# Initialize variables
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
assert (ckpt.model_checkpoint_path), 'Missing checkpoint file!'
saver.restore(sess, ckpt.model_checkpoint_path)
print('{} restored.'.format(ckpt.model_checkpoint_path))
val_handle = sess.run(cnn.val_iterator.string_handle())
sess.run(cnn.val_iterator.initializer)
eval_dict = {cnn.training_phase: False, cnn.handle: val_handle}
while True:
try:
_ = sess.run([cnn.update_accuracy, cnn.merge_op], feed_dict=eval_dict)
v_acc = sess.run(cnn.str_accuracy, feed_dict=eval_dict)
except tf.errors.OutOfRangeError:
break
print("Validation accuracy: {:.3f}".format(v_acc))
print("Inference complete. Duration: %g s" %(time.time()-start))
return v_acc
def main(**kwargs):
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", help="path to test dataset in tfrecords format")
args = parser.parse_args()
# Load training, test data
ckpt = tf.train.get_checkpoint_state(directories.checkpoints)
# Evaluate
val_accuracy = evaluate(config_test, directories, ckpt)
if __name__ == '__main__':
main()