-
Notifications
You must be signed in to change notification settings - Fork 763
/
main.py
70 lines (52 loc) · 2.07 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import print_function
import random
import tensorflow as tf
from dqn.agent import Agent
from dqn.environment import GymEnvironment, SimpleGymEnvironment
from config import get_config
flags = tf.app.flags
# Model
flags.DEFINE_string('model', 'm1', 'Type of model')
flags.DEFINE_boolean('dueling', False, 'Whether to use dueling deep q-network')
flags.DEFINE_boolean('double_q', False, 'Whether to use double q-learning')
# Environment
flags.DEFINE_string('env_name', 'Breakout-v0', 'The name of gym environment to use')
flags.DEFINE_integer('action_repeat', 4, 'The number of action to be repeated')
# Etc
flags.DEFINE_boolean('use_gpu', True, 'Whether to use gpu or not')
flags.DEFINE_string('gpu_fraction', '1/1', 'idx / # of gpu fraction e.g. 1/3, 2/3, 3/3')
flags.DEFINE_boolean('display', False, 'Whether to do display the game screen or not')
flags.DEFINE_boolean('is_train', True, 'Whether to do training or testing')
flags.DEFINE_integer('random_seed', 123, 'Value of random seed')
FLAGS = flags.FLAGS
# Set random seed
tf.set_random_seed(FLAGS.random_seed)
random.seed(FLAGS.random_seed)
if FLAGS.gpu_fraction == '':
raise ValueError("--gpu_fraction should be defined")
def calc_gpu_fraction(fraction_string):
idx, num = fraction_string.split('/')
idx, num = float(idx), float(num)
fraction = 1 / (num - idx + 1)
print(" [*] GPU : %.4f" % fraction)
return fraction
def main(_):
gpu_options = tf.GPUOptions(
per_process_gpu_memory_fraction=calc_gpu_fraction(FLAGS.gpu_fraction))
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
config = get_config(FLAGS) or FLAGS
if config.env_type == 'simple':
env = SimpleGymEnvironment(config)
else:
env = GymEnvironment(config)
if not tf.test.is_gpu_available() and FLAGS.use_gpu:
raise Exception("use_gpu flag is true when no GPUs are available")
if not FLAGS.use_gpu:
config.cnn_format = 'NHWC'
agent = Agent(config, env, sess)
if FLAGS.is_train:
agent.train()
else:
agent.play()
if __name__ == '__main__':
tf.app.run()