-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdisplay_for_ddqn.py
281 lines (247 loc) · 10.4 KB
/
display_for_ddqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import random
from fightingice_env import FightingiceEnv
import matplotlib.pyplot as plt
# 格斗游戏是一个典型的实时动作游戏,玩家在游戏中选择一定的动作,在规
# 定的时间内击败对方角色,赢得胜利。本任务基于FightingICE 格斗游戏平台,
# 以已知的固定bot “MctsAi” 作为游戏对手,利用课堂上讲述的强化学习方法设计
# 游戏AI,通过训练学习得到具有一定智能水平的格斗AI。
# TODO:将最终学到的强化学习AI 与MctsAi 对抗,统计100 局中AI 的胜率,以及
# TODO:每局结束时双方血量差的平均值,以此作为评判强化学习系统的性能优劣。
# 其中.\FighingICE.jar 是格斗游戏的java 程序;
# .\fightingice_env.py 包含了强化学习系统启动格斗游戏的接口程序;
# .\gym_ai.py 包含了强化学习系统控制游戏角色的代码;
# .\data\ai\MctsAi.jar 是基于java 开发的对手bot;
# .\train.py 包含强化学习系统的主要设计框架.
import torch
import numpy as np
from torch import nn, optim
import torch.nn.functional as F
import collections
import os
def ensure_dir(file_path):
directory = os.path.dirname(file_path)
if not os.path.exists(directory):
os.makedirs(directory)
class DQN(nn.Module):
def __init__(self, input_size, output_size, mem_len):
super(DQN, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.memory = collections.deque(maxlen = mem_len)
self.net = nn.Sequential(
nn.Linear(self.input_size, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU()
)
# Dueling 架构
self.V = nn.Linear(128, 1)
self.A = nn.Linear(128, self.output_size)
def forward(self, input):
net_output = self.net(input)
v = self.V(net_output)
advantage = self.A(net_output)
advantage = advantage - torch.mean(advantage)
q_value = v + advantage
return q_value
def sample_action(self, inputs, epsilon):
inputs = torch.tensor(inputs, dtype = torch.float32)
inputs = inputs.unsqueeze(0)
q_value = self(inputs)
seed = np.random.rand()
if seed > epsilon:
action_choice = int(torch.argmax(q_value))
else:
action_choice = random.choice(range(self.output_size))
return action_choice
def save_trans(self, transition):
self.memory.append(transition)
def sample_memory(self, batch_size):
s_ls, a_ls, r_ls, s_next_ls, done_flag_ls = [], [], [], [], []
trans_batch = random.sample(self.memory, batch_size)
for trans in trans_batch:
s, a, r, s_next, done_flag = trans
s_ls.append(s)
a_ls.append([a])
r_ls.append([r])
s_next_ls.append(s_next)
done_flag_ls.append([done_flag])
return torch.tensor(s_ls,dtype=torch.float32),\
torch.tensor(a_ls,dtype=torch.int64),\
torch.tensor(r_ls,dtype=torch.float32),\
torch.tensor(s_next_ls,dtype=torch.float32),\
torch.tensor(done_flag_ls,dtype=torch.float32)
def train_net(Q_net, Q_target, optimizer, losses, loss_list, replay_time, gamma, batch_size):
s, a, r, s_next, done_flag = Q_net.sample_memory(batch_size)
# for i in range(replay_time):
q_value = Q_net(s)
a = torch.LongTensor(a)
q_value = torch.gather(q_value, 1, a)
q_t = Q_net(s_next)
a_index = torch.argmax(q_t, 1)
a_index = a_index.reshape((a_index.shape[0], 1))
# print(a.size())
# print(a_index.shape)
q_target = Q_target(s_next)
q_target = torch.gather(q_target, 1, a_index)
q_target = r + gamma * q_target * done_flag
loss = losses(q_target, q_value)
loss_list.append(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def plot_curse(target_list, loss_list):
figure1 = plt.figure()
plt.grid()
X = []
for i in range(len(target_list)):
X.append(i)
plt.plot(X,target_list,'-r')
plt.xlabel('epoch')
plt.ylabel('score')
figure2 = plt.figure()
plt.grid()
X = []
for i in range(len(loss_list)):
X.append(i)
plt.plot(X,loss_list,'-b')
plt.xlabel('train step')
plt.ylabel('loss')
plt.show()
# Hyperparam
max_epoch = 10000
path = "param\DDQN"
model_path = "model\DDQN"
R_info = "info\Reward_DDQN"
L_info = "info\Loss_DDQN"
hp_info = "info\Hp_diff"
# 创建模型和参数的保存目录
ensure_dir(path)
ensure_dir(model_path)
ensure_dir(R_info)
LOAD_KEY = True
learning_rate = 1e-3
max_steps = 300
replay_time = 1
epsilon = 0.01
gamma = 0.95
step_count = 0
train_flag = False
mem_len = 30000
train_begin = 20000 #4000
batch_size = 32
Hp_diff = []
win_count = 0
if __name__ == '__main__':
env = FightingiceEnv(port=4242)
# for windows user, port parameter is necessary because port_for library does not work in windows
# for linux user, you can omit port parameter, just let env = FightingiceEnv()
env_args = ["--fastmode", "--grey-bg", "--inverted-player", "1", "--mute"]
# this mode let two players have infinite hp, their hp in round can be negative
# you can close the window display functional by using the following mode
# env_args = ["--fastmode", "--disable-window", "--grey-bg", "--inverted-player", "4", "--mute"]
Q_net = DQN(input_size = 144, output_size = 40, mem_len = mem_len)
Q_target = DQN(input_size = 144, output_size = 40, mem_len = mem_len)
Q_target.load_state_dict(Q_net.state_dict())
if LOAD_KEY:
epoch_model_param = 960
Q_net = torch.load(model_path + str(epoch_model_param) + ".pth")
Q_target.load_state_dict(Q_net.state_dict())
# score_m = np.load(R_info + str(690) + ".npy")
# score_list = score_m.tolist()
# loss_m = np.load(L_info + str(690) + ".npy")
# loss_list = loss_m.tolist()
print("Load Weights and param/info!")
optimizer = optim.Adam(Q_net.parameters(), lr = learning_rate)
losses = nn.MSELoss()
loss_list = []
score_list = []
for epo_i in range(960, max_epoch):
obs = env.reset(env_args=env_args)
# reward, done, info = 0, False, None
score = 0.
epsilon = max(0.01, epsilon*0.999)
for step in range(max_steps):
step_count += 1
action_choice = Q_net.sample_action(obs, epsilon)
new_obs, reward, done, info = env.step(action_choice)
if done:
done_flag = 0
new_obs = [0. for i in range(len(obs))]
else:
done_flag = 1
Q_net.save_trans((obs, action_choice, reward, new_obs, done_flag))
score += reward
# train
if step_count > train_begin:
train_flag = True
train_net(Q_net, Q_target, optimizer, losses, loss_list, 1, gamma, batch_size)
# target copy online net
if step_count % 3000 == 0 and train_flag == True:
Q_target.load_state_dict(Q_net.state_dict())
obs = new_obs
# infomation print
if done or step + 1 == max_steps:
score_list.append(score)
# if (epo_i+1) % 30 == 0 and train_flag == True:
# print("Log information and save weights/models...")
# torch.save(Q_net.state_dict(), path + str(epo_i+1) + ".ckpt")
# torch.save(Q_net, model_path + str(epo_i+1) + ".pth")
# score_np = np.array(score_list)
# loss_np = np.array(loss_list)
# np.save(R_info + str(epo_i + 1) + ".npy", score_np)
# np.save(L_info + str(epo_i + 1) + ".npy", loss_np)
if done:
if info is not None:
if info[0] > info[1]:
win_count += 1 # 胜率计算
f = open(hp_info + ".txt", "a")
f.write("Epoch: " + str(epo_i + 1) + " Hp_diff: " + str(info[0] - info[1]) + "\r")
f.close()
if info is not None:
# f = open(hp_info + ".txt", "a")
# f.write("Epoch: " + str(epo_i + 1) + " Hp_diff: " + str(info[0] - info[1]) + "\r\n")
# f.close()
# Hp_diff.append(info[0] - info[1])
print("Epoch: {} round result: own hp {} vs opp hp {}, you {} training: {} epsilon: {} done: {} step_count: {}".format(epo_i, info[0], info[1],
'win' if info[0]>info[1] else 'lose', train_flag, epsilon, done, step_count))
else:
# java terminates unexpectedly
pass
break
plot_curse(score_list, loss_list)
print("finish training")
# 示例程序
# if __name__ == '__main__':
# env = FightingiceEnv(port=4242)
# # for windows user, port parameter is necessary because port_for library does not work in windows
# # for linux user, you can omit port parameter, just let env = FightingiceEnv()
#
# env_args = ["--fastmode", "--grey-bg", "--inverted-player", "1", "--mute"]
# # this mode let two players have infinite hp, their hp in round can be negative
# # you can close the window display functional by using the following mode
# #env_args = ["--fastmode", "--disable-window", "--grey-bg", "--inverted-player", "1", "--mute"]
#
# while True:
# obs = env.reset(env_args=env_args)
# reward, done, info = 0, False, None
#
# while not done:
# act = random.randint(0, 10)
# # TODO: or you can design with your RL algorithm to choose action [act] according to game state [obs]
# new_obs, reward, done, info = env.step(act)
#
# if not done:
# # TODO: (main part) learn with data (obs, act, reward, new_obs)
# # suggested discount factor value: gamma in [0.9, 0.95]
# pass
# elif info is not None:
# print("round result: own hp {} vs opp hp {}, you {}".format(info[0], info[1],
# 'win' if info[0]>info[1] else 'lose'))
# else:
# # java terminates unexpectedly
# pass
#
# print("finish training")