-
Notifications
You must be signed in to change notification settings - Fork 0
/
base.py
84 lines (74 loc) · 2.46 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import time
import numpy as np
class AgentBase:
def __init__(self):
# sub module
self._env = None
self._replay = None
# epsilon greedy
self._epsilon = 0.0
self._epsilon_min = 0.0
self._epsilon_max = 0.8
# global_step
self._global_episode = 0
self._global_step = 0
def is_not_used(self):
pass
def _epsilon_greedy(self):
# (0, 0) (10, 0.2) (30, 0.55) (50, 0.75) (60, 0.8)
self._epsilon = np.tanh(0.02 * self._global_episode)
self._epsilon = np.maximum(self._epsilon, self._epsilon_min)
self._epsilon = np.minimum(self._epsilon, self._epsilon_max)
return self._epsilon
def _train_impl(self, max_step):
# prepare for epsilon greedy
self._global_episode += 1
# train
state = self._env.reset()
for step in range(max_step):
self._global_step += 1
# 1. predict
action = self._explore(state)
# 2. action
state_, reward, done, info = self._env.step(action)
# 3. perceive
self._perceive(state, action, state_, reward, done)
# 4. update
state = state_
if done:
break
def _test_impl(self, max_step, delay, gui):
# test
total_reward = 0
state = self._env.reset()
for step in range(max_step):
if gui:
self._env.render()
time.sleep(delay)
# 1. predict
action = self._optimal_action(state)
# 2. action
state_, reward, done, info = self._env.step(action)
# 3. perceive
total_reward += reward
# 4. update
state = state_
if done:
break
return total_reward
def _explore(self, state):
epsilon = self._epsilon_greedy()
if np.random.uniform() > epsilon:
action = self._random_action()
else:
action = self._optimal_action(state)
return action
def _random_action(self):
self.is_not_used()
raise NotImplementedError('function must overridden.')
def _optimal_action(self, state):
self.is_not_used()
raise NotImplementedError('function must overridden.')
def _perceive(self, state, action, state_, reward, done):
self.is_not_used()
raise NotImplementedError('function must overridden.')