-
Notifications
You must be signed in to change notification settings - Fork 6
/
memory.py
61 lines (42 loc) · 2.07 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#TODO TF & restoring
import numpy as np
import random
class ExperienceMemory(object):
def __init__(self, state_shape ,buffer_size = 50000): # buff size = no. of exp tuples
buffer_size_init = buffer_size
self.stateMem = np.empty([buffer_size_init,state_shape[0],state_shape[1],state_shape[2]])
self.actionMem = np.empty([buffer_size_init],dtype = int)
self.rewardMem = np.empty([buffer_size_init], dtype= float)
self.doneMem = np.empty([buffer_size_init],dtype= bool)
self.nxt_stateMem = np.empty([buffer_size_init,state_shape[0],state_shape[1],state_shape[2]])
self.buffer_size = buffer_size
self.isFull = False
self.indexer = 0
def add(self,experience):
if self.indexer == self.buffer_size:
self.indexer = 0
self.isFull = True
print("memory refill")
self.stateMem[self.indexer] = experience[0]
self.actionMem[self.indexer] = experience[1]
self.rewardMem[self.indexer] = experience[2]
self.doneMem[self.indexer] = experience[3]
self.nxt_stateMem[self.indexer] = experience[4]
self.indexer += 1
def sample(self,size):
if self.isFull == True:
assert self.buffer_size >= size, "batch size can't be larger than memory size!"
indexes = random.sample(range(self.buffer_size), size)
else:
assert self.indexer >= size, "batch size can't be larger than currently filled memory!"
indexes = random.sample(range(self.indexer), size)
return self.getSamples(indexes)
#return self.stateMem[indexes], self.actionMem[indexes], self.rewardMem[indexes], self.doneMem[indexes], self.nxt_stateMem[indexes]
def getSamples(self,indexes):
return self.stateMem[indexes], self.actionMem[indexes], self.rewardMem[indexes], self.doneMem[indexes], self.nxt_stateMem[indexes]
def update(self,deltas):
raise ("Not implemented!")
def getISW(self):
raise("Not implemented!")
def betaAnneal(self,s):
raise("Not implemented!")