-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
71 lines (64 loc) · 2.09 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gym
from baselines import deepq
from baselines.common.atari_wrappers_deprecated import wrap_dqn, ScaledFloatFrame
from collections import deque
import numpy as np
from Advantage_Actor_Critic import Advantage_Actor_Critic
import pickle
import time
GAMMA = 0.99
EPISODE = 350
TEST = 100
def main():
env = gym.make("CartPole-v0")
agent = Advantage_Actor_Critic(env)
episodes_rewards = []
avg_rewards = []
skip_rewards = []
step_num = 0
for episode in range(EPISODE):
goal = 0
I = 1.0
state = env.reset()
while True:
action = agent.select_action(state)
next_state, reward, done, _ = env.step(action)
I = GAMMA * I
# env.render()
agent.perceive(state, action, reward, next_state, done, I, step_num)
goal += reward
step_num += 1
state = next_state
if done:
if len(episodes_rewards) == 100:
episodes_rewards.pop(0)
episodes_rewards.append(goal)
break
print("Episode: ", episode, " Last 100 episode average reward: ", np.average(episodes_rewards), " Toal step number: ", step_num)
avg_rewards.append(np.average(episodes_rewards))
if episode % 100 == 0:
out_file = open("avg_rewards.pkl",'wb')
pickle.dump(avg_rewards, out_file)
out_file.close()
agent.saver.save(agent.session, 'saved_networks/' + 'network' + '-dqn', global_step=episode)
env.close()
def play():
env = gym.make("CartPole-v0")
agent = Advantage_Actor_Critic(env)
for episode in range(TEST):
goal = 0
step_num = 0
state = env.reset()
while True:
action = agent.action(state)
next_state, reward, done, _ = env.step(action)
step_num += 1
env.render()
goal += reward
state = next_state
if done:
print("Episode: ", episode, " Total reward: ", goal)
break
if __name__ == '__main__':
main()
play()