-
Notifications
You must be signed in to change notification settings - Fork 0
/
Trainer.py
149 lines (124 loc) · 4.39 KB
/
Trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import copy
import time
import numpy as np
from GAME.GameKuhn import Kuhn
from GAME.GameLeduc import Leduc
from GAME.GameLeduc5Pot import Leduc5Pot
from GAME.GameLeduc3Pot import Leduc3Pot
from GAME.GameGoofspiel import Goofspiel
from GAME.GameKuhnNPot import KuhnNPot
from GAME.GamePrincessAndMonster import PrincessAndMonster as PAM
from Solver.CFVFP import CFVFPSolver
from Solver.CFR import CFRSolver
from CONFIG import test_sampling_train_config
import draw.convergence_rate
from draw.convergence_rate import plt_perfect_game_convergence_inline
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
import gc
def train_sec(tmp_train_config):
op_env = tmp_train_config.get('op_env', 'CFR')
if op_env == 'CFVFP':
tmp = CFVFPSolver(tmp_train_config)
elif op_env == 'CFR':
tmp = CFRSolver(tmp_train_config)
else:
return
tmp.train()
del tmp
gc.collect()
return
if __name__ == '__main__':
np.set_printoptions(precision=6, suppress=True)
logdir = 'logCFRSampling'
game_name = 'Leduc'
is_show_policy = False
prior_state_num = 15
y_pot = 1
z_len = 1
game_config = {
'game_name' : game_name,
'prior_state_num': prior_state_num,
'y_pot' : y_pot,
'z_len' : z_len,
'player_num' : 2
}
if game_name == 'Leduc':
game_class = Leduc(game_config)
elif game_name == 'Kuhn':
game_class = Kuhn(game_config)
elif game_name == 'Goofspiel':
game_class = Goofspiel(game_config)
elif game_name == 'Leduc3Pot':
game_class = Leduc3Pot(game_config)
elif game_name == 'Leduc5Pot':
game_class = Leduc5Pot(game_config)
elif game_name == 'KuhnNPot':
game_class = KuhnNPot(game_config)
elif game_name == 'PAM':
game_class = PAM(game_config)
# train_mode = 'fix_itr'
train_mode = 'fix_node_touched'
# train_mode = 'fix_train_time'
# log_interval_mode = 'itr'
log_interval_mode = 'node_touched'
# log_interval_mode = 'train_time'
# log_mode = 'normal'
log_mode = 'exponential'
total_train_constraint = 10000000
log_interval = 1.5
nun_of_train_repetitions = 9
n_jobs = 9
total_exp_name = str(prior_state_num) + '_' + game_name + '_' + time.strftime('%Y_%m_%d_%H_%M_%S',
time.localtime(time.time()))
for key in test_sampling_train_config.keys():
start = time.time()
print(key)
parallel_train_config_list = []
for i_train_repetition in range(nun_of_train_repetitions):
train_config = copy.deepcopy(test_sampling_train_config[key])
train_config['game'] = copy.deepcopy(game_class)
train_config['game_info'] = key
train_config['train_mode'] = train_mode
train_config['log_interval_mode'] = log_interval_mode
train_config['log_mode'] = log_mode
train_config['is_show_policy'] = is_show_policy
train_config['total_exp_name'] = total_exp_name
train_config['total_train_constraint'] = total_train_constraint
train_config['log_interval'] = log_interval
train_config['No.'] = i_train_repetition
parallel_train_config_list.append(train_config)
ans_list = Parallel(n_jobs=n_jobs)(
delayed(train_sec)(i_train_config) for i_train_config in parallel_train_config_list
)
end = time.time()
print(end - start)
plt.figure(figsize=(32, 10), dpi=60)
if game_name == 'KuhnNPot':
fig_title = str(prior_state_num) + 'C' + str(y_pot) + 'P' + str(z_len) + 'L_Kuhn'
else:
fig_title = str(prior_state_num) + '_' + game_name
plt.subplot(1, 2, 1)
plt_perfect_game_convergence_inline(
fig_title,
logdir + '/' + total_exp_name,
is_x_log=True,
is_y_log=False,
x_label_index=4,
y_label_index=2,
x_label_name='node touched',
y_label_name='epsilon'
)
plt.subplot(1, 2, 2)
plt_perfect_game_convergence_inline(
fig_title,
logdir + '/' + total_exp_name,
is_x_log=True,
is_y_log=False,
x_label_index=1,
y_label_index=2,
x_label_name='time(ms)',
y_label_name='epsilon'
)
plt.savefig(logdir + '/' + total_exp_name + '/pic.png')
# plt.show()