Skip to content

Commit

Permalink
added Q actor critic
Browse files Browse the repository at this point in the history
  • Loading branch information
lukearcus committed Aug 8, 2022
1 parent fe6e762 commit 81f1ece
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
60 changes: 58 additions & 2 deletions Kuhn_poker/learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ def learn(self):
advantage = self.advantage_func.eval(elem["s"], elem["a"], elem["r"], elem["s'"])
theta_update += grad_log_theta*advantage

delta = self.advantage_func.eval(elem["s"], elem["a"], elem["r"], elem["s'"])
a_prime = np.argmax(np.random.multinomial(1, pvals=self.opt_pol[elem["s'"]]))
delta = self.advantage_func.calc_delta(elem["s"], elem["a"], elem["r"], elem["s'"], a_prime)
self.advantage_func.update(lr*delta, elem["s"], elem["a"])
self.pol_func.thetas += lr*theta_update
self.opt_pol = self.pol_func.update()

return self.opt_pol

class pol_func_base:
Expand All @@ -145,6 +145,9 @@ def __init__(self, init_adv, num_states, num_actions):

def eval(self, s, a, r, s_prime):
raise NotImplementedError

def calc_delta(self, s, a, r, s_prime, a_prime):
return self.eval(s, a, r, s_prime)

def update(self, update, s, a):
raise NotImplementedError
Expand Down Expand Up @@ -175,6 +178,59 @@ def reset(self):
self.thetas = np.ones_like(self.thetas)
self.update()

class Q_advantage(advantage_func_base):

def __init__(self, init_adv, num_states, num_actions, df):
self.Q = np.ones((num_states, num_actions))*init_adv
self.V = np.ones(num_states)*init_adv
self.init_adv = init_adv
self.gamma = df

def eval(self, s, a, r, s_prime):
return self.Q[s,a]-self.V[s]

def calc_delta(self, s, a, r, s_prime, a_prime):
delta = np.zeros(2)
if s_prime != -1:
delta[0] = r + self.gamma*np.max(self.Q[s_prime, :]) - self.Q[s, a]
delta[1] = r + self.gamma*self.V[s_prime] - self.V[s]
else:
delta[0] = r - self.Q[s,a]
delta[1] = r - self.V[s]
return delta

def update(self, update, s, a):
self.Q[s, a] += update[0]
self.V[s] += update[1]

def reset(self):
self.Q = np.ones_like(self.Q) * self.init_adv
self.V = np.ones_like(self.V) * self.init_adv

class Q_actor_critic(advantage_func_base):

def __init__(self, init_adv, num_states, num_actions, df):
self.Q = np.ones((num_states, num_actions))*init_adv
self.init_adv = init_adv
self.gamma = df

def eval(self, s, a, r, s_prime):
return self.Q[s,a]

def calc_delta(self, s, a, r, s_prime, a_prime):
if s_prime != -1:
delta = r + self.gamma*self.Q[s_prime, a_prime] - self.Q[s, a]
else:
delta = r - self.Q[s,a]
return delta

def update(self, update, s, a):
self.Q[s, a] += update

def reset(self):
self.Q = np.ones_like(self.Q) * self.init_adv


class value_advantage(advantage_func_base):

def __init__(self, init_adv, num_states, _, df):
Expand Down
2 changes: 1 addition & 1 deletion Kuhn_poker/main_FSP.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

KP_game = game.Kuhn_Poker_int_io()

RL_learners = [learners.actor_critic(learners.softmax, learners.value_advantage, 2, 6, extra_samples = extras)\
RL_learners = [learners.actor_critic(learners.softmax, learners.value_advantage, 2, 6, init_adv=-2, extra_samples = extras)\
for p in range(2)]
SL_learners = [learners.count_based_SL((6,2)) for p in range(2)]

Expand Down

0 comments on commit 81f1ece

Please sign in to comment.