-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
56 lines (50 loc) · 2.23 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import tensorflow as tf
from model import MultiModal
from solver import Solver
flags = tf.app.flags
flags.DEFINE_string('mode', 'None', "'train_rgb', 'train_depth', 'train_double_stream','train_double_stream_moddrop', 'train_hallucination', , 'train_hallucination_p2', 'finetune_hallucination', 'test_hallucination', 'train_eccv'")
flags.DEFINE_string('bs', '32', "batch size")
flags.DEFINE_string('lr', '0.0001', "learning rate")
flags.DEFINE_string('it', '2000', "training iterations")
flags.DEFINE_string('noise', '0', "test-time noise")
FLAGS = flags.FLAGS
def main(_):
model = MultiModal(mode=FLAGS.mode, learning_rate=float(FLAGS.lr))
solver = Solver(model, batch_size=int(FLAGS.bs), train_iter=int(
FLAGS.it), train_iter_adv=int(FLAGS.it))
if FLAGS.mode == 'train_rgb':
solver.train_single_stream(modality='rgb')
elif 'test_rgb' in FLAGS.mode: # test also rgb1
solver.test_single_stream(modality=FLAGS.mode.split('_')[-1])
elif FLAGS.mode == 'test_ensemble_baseline': # rgb+rgb1
solver.test_ensemble_baseline()
elif FLAGS.mode == 'train_depth':
solver.train_single_stream(modality='depth')
elif FLAGS.mode == 'test_depth':
solver.test_single_stream(modality='depth')
elif 'train_double_stream' in FLAGS.mode:
solver.train_double_stream()
elif 'test_moddrop' in FLAGS.mode:
solver.test_moddrop(noise=float(FLAGS.noise))
elif 'train_hallucination' in FLAGS.mode:
solver.train_hallucination()
elif FLAGS.mode == 'finetune_hallucination':
solver.finetune_hallucination()
elif FLAGS.mode == 'test_hallucination':
solver.test_hallucination()
elif FLAGS.mode == 'train_autoencoder':
solver.train_autoencoder()
elif FLAGS.mode == 'test_autoencoder':
solver.test_autoencoder()
elif FLAGS.mode == 'test_double_stream_with_ae':
solver.test_double_stream_with_ae()
elif FLAGS.mode == 'test_double_stream':
solver.test_double_stream(noise=float(FLAGS.noise))
elif FLAGS.mode == 'test_disc':
solver.test_disc(noise=float(FLAGS.noise))
elif FLAGS.mode == 'train_eccv':
solver.train_eccv()
else:
print 'Unrecognized mode.'
if __name__ == '__main__':
tf.app.run()