forked from PaddlePaddle/PARL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
executable file
·111 lines (91 loc) · 3.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gym
import numpy as np
import parl
from parl.utils import logger
from cartpole_model import CartpoleModel
from cartpole_agent import CartpoleAgent
from replay_memory import ReplayMemory
LEARN_FREQ = 5 # update parameters every 5 steps
MEMORY_SIZE = 20000 # replay memory size
MEMORY_WARMUP_SIZE = 200 # store some experiences in the replay memory in advance
BATCH_SIZE = 32
LEARNING_RATE = 0.0005
GAMMA = 0.99 # discount factor of reward
def run_episode(agent, env, rpm):
total_reward = 0
obs = env.reset()
step = 0
while True:
step += 1
action = agent.sample(obs)
next_obs, reward, isOver, _ = env.step(action)
rpm.append((obs, action, reward, next_obs, isOver))
# train model
if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):
(batch_obs, batch_action, batch_reward, batch_next_obs,
batch_isOver) = rpm.sample(BATCH_SIZE)
train_loss = agent.learn(batch_obs, batch_action, batch_reward,
batch_next_obs, batch_isOver)
total_reward += reward
obs = next_obs
if isOver:
break
return total_reward
def evaluate(agent, env, render=False):
# test part, run 5 episodes and average
eval_reward = []
for i in range(5):
obs = env.reset()
episode_reward = 0
isOver = False
while not isOver:
action = agent.predict(obs)
if render:
env.render()
obs, reward, isOver, _ = env.step(action)
episode_reward += reward
eval_reward.append(episode_reward)
return np.mean(eval_reward)
def main():
env = gym.make('CartPole-v0')
action_dim = env.action_space.n
obs_shape = env.observation_space.shape
rpm = ReplayMemory(MEMORY_SIZE)
model = CartpoleModel(act_dim=action_dim)
algorithm = parl.algorithms.DQN(
model, act_dim=action_dim, gamma=GAMMA, lr=LEARNING_RATE)
agent = CartpoleAgent(
algorithm,
obs_dim=obs_shape[0],
act_dim=action_dim,
e_greed=0.1, # explore
e_greed_decrement=1e-6
) # probability of exploring is decreasing during training
while len(rpm) < MEMORY_WARMUP_SIZE: # warm up replay memory
run_episode(agent, env, rpm)
max_episode = 2000
# start train
episode = 0
while episode < max_episode:
# train part
for i in range(0, 50):
total_reward = run_episode(agent, env, rpm)
episode += 1
eval_reward = evaluate(agent, env)
logger.info('episode:{} test_reward:{}'.format(
episode, eval_reward))
if __name__ == '__main__':
main()