diff --git a/contents/11_Dyna_Q/RL_brain.py b/contents/11_Dyna_Q/RL_brain.py index 2fb39aa..23fa6fc 100644 --- a/contents/11_Dyna_Q/RL_brain.py +++ b/contents/11_Dyna_Q/RL_brain.py @@ -15,16 +15,23 @@ def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): self.lr = learning_rate self.gamma = reward_decay self.epsilon = e_greedy - self.q_table = pd.DataFrame(columns=self.actions) + + ## argmax type error + self.q_table = pd.DataFrame(columns=self.actions).astype('float32') def choose_action(self, observation): self.check_state_exist(observation) # action selection if np.random.uniform() < self.epsilon: # choose best action - state_action = self.q_table.ix[observation, :] + + + # state_action = self.q_table.ix[observation, :] + state_action = self.q_table.loc[observation, :] # for label indexing state_action = state_action.reindex(np.random.permutation(state_action.index)) # some actions have same value action = state_action.argmax() + + else: # choose random action action = np.random.choice(self.actions) @@ -32,12 +39,13 @@ def choose_action(self, observation): def learn(self, s, a, r, s_): self.check_state_exist(s_) - q_predict = self.q_table.ix[s, a] + + q_predict = self.q_table.loc[s, a] if s_ != 'terminal': - q_target = r + self.gamma * self.q_table.ix[s_, :].max() # next state is not terminal + q_target = r + self.gamma * self.q_table.loc[s_, :].max() # next state is not terminal else: q_target = r # next state is terminal - self.q_table.ix[s, a] += self.lr * (q_target - q_predict) # update + self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update def check_state_exist(self, state): if state not in self.q_table.index: @@ -71,9 +79,9 @@ def store_transition(self, s, a, r, s_): def sample_s_a(self): s = np.random.choice(self.database.index) - a = np.random.choice(self.database.ix[s].dropna().index) # filter out the None value + a = np.random.choice(self.database.loc[s].dropna().index) # filter out the None value return s, a def get_r_s_(self, s, a): - r, s_ = self.database.ix[s, a] + r, s_ = self.database.loc[s, a] return r, s_