Skip to content

Commit

Permalink
Fixed averaging steps
Browse files Browse the repository at this point in the history
  • Loading branch information
lukearcus committed Sep 7, 2022
1 parent 9bbfdf1 commit 071cba3
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 49 deletions.
4 changes: 1 addition & 3 deletions UI/plot_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,10 @@ def exploitability_plot(exploitability, exploit_freq, ax=None):
if ax is None:
fig, ax = plt.subplots()
iters = np.arange(0, len(exploitability)*exploit_freq, exploit_freq)
ax.plot(iters, exploitability)
ax.loglog(iters, exploitability)
ax.set_xlabel("Iteration")
ax.set_ylabel("Exploitability")
ax.set_title("Exploitability vs Iteration")
inset = ax.inset_axes([0.5,0.5,0.45,0.4])
inset.loglog(iters, exploitability)



Expand Down
41 changes: 33 additions & 8 deletions agents/players.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,12 @@ class OBL(RL):

def __init__(self, learner, player_id, fict_game, belief_iters = 10000):
self.belief_iters = belief_iters
self.belief_buff = []
self.pol_buff = []
super().__init__(learner, player_id)
self.fict_game = fict_game

self.avg_pol = np.ones_like(self.opt_pol)/self.opt_pol.shape[1]

def set_other_players(self, other_players):
self.other_players = other_players.copy()
self.other_players.insert(self.id, "me")
Expand Down Expand Up @@ -159,11 +162,8 @@ def action(self):
probs = self.opt_pol[self.state, :]
act = np.argmax(np.random.multinomial(1, pvals=probs))
return act

def update_belief(self):
num_hidden = len(self.fict_game.poss_hidden)
num_states = self.opt_pol.shape[0]
new_belief = np.ones((num_states, num_hidden))

def add_to_mem(self):
for i in range(self.belief_iters):
self.fict_game.start_game()
while not self.fict_game.ended:
Expand All @@ -173,9 +173,34 @@ def update_belief(self):
else:
player = self.other_players[p_id]
player.observe(self.fict_game.observe(), fict=True)
self.fict_game.action(player.action())
act = player.action()
self.fict_game.action(act)
if p_id == self.id:
hidden_state = self.fict_game.get_hidden(self.id)
new_belief[self.state, hidden_state] += 1
self.belief_buff.append({'s':self.state, 'hidden':hidden_state})
self.pol_buff.append({'s':self.state, 'a':act, 'probs':self.opt_pol[self.state, :]})

def update_belief(self):
num_hidden = len(self.fict_game.poss_hidden)
num_states = self.opt_pol.shape[0]
new_belief = np.ones((num_states, num_hidden))
for elem in self.belief_buff:
new_belief[elem['s'], elem['hidden']] += 1
new_belief /= np.sum(new_belief,1,keepdims=True)
self.belief = new_belief

def learn_avg_pol(self):
N = np.zeros(self.opt_pol.shape)
for elem in self.pol_buff:
state = elem['s']
N[state,:] += elem['probs']
pi = N/np.sum(N,axis=1)[:,np.newaxis]
for i, s in enumerate(pi):
if np.all(np.isnan(s)):
pi[i] = np.ones(s.shape)/s.size
self.avg_pol = pi

def update_mem_and_bel(self):
self.add_to_mem()
self.update_belief()
self.learn_avg_pol()
49 changes: 11 additions & 38 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,33 +53,20 @@ def main():
for p_id, other_p in enumerate(p.other_players):
if other_p != "me":
other_p.opt_pol = players[p_id].opt_pol
p.update_belief()
if averaged_bel:
p.belief_buff = []
p.update_mem_and_bel()
bels.append(np.copy(p.belief))
else:
bels.append(np.zeros((1,1)))
pol_hist.append(pols)
log.debug("Policies at lvl "+str(lvl) + ": " + str(pols))
belief_hist.append(bels)
log.debug("Beliefs at lvl "+str(lvl) + ": " + str(bels))
if averaged_bel:
new_avg_bels = []
for p_id, p in enumerate(players):
total_bel = np.zeros_like(belief_hist[0][p_id])
for i in range(lvl+1):
total_bel += belief_hist[i][p_id]
avg_bel = total_bel / (lvl+1)
p.belief = np.copy(avg_bel)
new_avg_bels.append(avg_bel)
avg_bels.append(new_avg_bels)
log.debug("Average beliefs at lvl "+str(lvl) + ": " + str(new_avg_bels))
if averaged_pol or learn_with_avg:
new_avg_pols = []
for p_id, p in enumerate(players):
total_pol = np.zeros_like(pol_hist[0][p_id])
for i in range(lvl+1):
total_pol += pol_hist[i][p_id]
avg_pol = total_pol / (lvl+1)
new_avg_pols.append(avg_pol)
for p in players:
new_avg_pols.append(p.avg_pol)
avg_pols.append(new_avg_pols)
log.debug("Average polices at lvl "+str(lvl) + ": " + str(new_avg_pols))
if lvl % exploit_freq == 0:
Expand All @@ -105,30 +92,19 @@ def main():
for p in players:
pols.append(p.opt_pol)
if p.belief is not None:
p.update_belief()
if not averaged_bel:
p.belief_buff = []
p.update_mem_and_bel()
bels.append(p.belief)
else:
bels.append(np.zeros((1,1)))
pol_hist.append(pols)
belief_hist.append(bels)

if averaged_bel:
new_avg_bels = []
for p_id, p in enumerate(players):
total_bel = np.zeros_like(belief_hist[0][p_id])
for i in range(lvl+1):
total_bel += belief_hist[i][p_id]
avg_bel = total_bel / (lvl+1)
new_avg_bels.append(avg_bel)
avg_bels.append(new_avg_bels)
if averaged_pol:
new_avg_pols = []
for p_id, p in enumerate(players):
total_pol = np.zeros_like(pol_hist[0][p_id])
for i in range(lvl+1):
total_pol += pol_hist[i][p_id]
avg_pol = total_pol / (lvl+1)
new_avg_pols.append(avg_pol)
for p in players:
new_avg_pols.append(p.avg_pol)
avg_pols.append(new_avg_pols)
exploit, _, _, _ = calc_exploitability(new_avg_pols, game, exploit_learner)
else:
Expand All @@ -141,10 +117,7 @@ def main():
pol_plot = avg_pols
else:
pol_plot = pol_hist
if averaged_bel:
bel_plot = avg_bels
else:
bel_plot = belief_hist
bel_plot = belief_hist
plot_everything(pol_plot, bel_plot, "kuhn", reward_hist[-1], exploitability)
filename="results/OBL_all_average"
np.savez(filename, pols=pol_plot, bels=bel_plot, explot=exploitability, rewards=reward_hist)
Expand Down

0 comments on commit 071cba3

Please sign in to comment.