Skip to content

Commit

Permalink
FSP might be converging to nash now (very slowly)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukearcus committed Aug 24, 2022
1 parent 76241d4 commit 37735ef
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 19 deletions.
17 changes: 9 additions & 8 deletions FSP.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class FSP:

def __init__(self, _game, _agents, max_iters=100, max_time=300, m=50, n=50, exploit_iters=100, exploit_freq=10):
def __init__(self, _game, _agents, max_iters=100, max_time=300, m=50, n=50, exploit_iters=100000, exploit_freq=10):
self.game = _game
self.agents = _agents
self.num_players = self.game.num_players
Expand Down Expand Up @@ -71,23 +71,23 @@ def run_algo(self):
log.debug("p" + str(p+1) + " new_beta: " + str(new_beta[p]))
#import pdb; pdb.set_trace()
diff += np.linalg.norm(new_pi[p]-sigma[p])
log.info("norm difference between new_pi and sigma: " +str(diff))
log.debug("norm difference between new_pi and sigma: " +str(diff))
pi.append(new_pi)
beta.append(new_beta)
#import pdb; pdb.set_trace()
if j%self.est_exploit_freq == 0:

exploit, br_pols, _ = calc_exploitability(new_pi, self.game, exploit_learner)
#exploit = self.est_exploitability(new_pi, new_beta)
import pdb; pdb.set_trace()
exploit_calced_br, br_pols, _, values = calc_exploitability(new_pi, self.game, exploit_learner)
exploit = self.est_exploitability(new_pi, new_beta)
#import pdb; pdb.set_trace()
# compare br_pols with beta
log.info("exploitability: " + str(exploit))
#log.info("exploitability: " + str(exploit_calced_br))
log.info("exploitability using beta: " + str(exploit))
exploitability.append(exploit)
toc = time.perf_counter()
if toc-tic > self.max_time:
break
#import pdb; pdb.set_trace()
return pi[-1], exploitability, (pi, beta, D)
return pi[-1], exploitability, {'pi': pi, 'beta':beta, 'D': D}

def play_game(self, strat):
buffer = [[] for i in range(self.num_players)]
Expand Down Expand Up @@ -125,6 +125,7 @@ def play_game(self, strat):


def est_exploitability(self, pol, br):
#import pdb; pdb.set_trace()
#BRs = self.calc_BRs(pi)
R = [0 for i in range(self.num_players)]
for p in range(self.num_players):
Expand Down
32 changes: 30 additions & 2 deletions agents/learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ def __init__(self, extra_samples, init_lr, df):
self.extra_samples = extra_samples
self.gamma = df
self.init_lr = init_lr
self.max_mem = 100000

def update_memory(self, data):
self.last_round = len(data)
for elem in data:
self.memory += elem[0]
if len(self.memory) > self.max_mem:
self.memory.pop(0)

class SL_base(learner_base):
learned_pol = None
Expand Down Expand Up @@ -59,7 +62,7 @@ def learn(self):
self.learned_pol = pi
return pi

class Q_learn(RL_base):
class fitted_Q_iteration(RL_base):

def __init__(self, init_q, Q_shape, extra_samples=0, init_lr = 0.05, df=1.0):
self.Q = np.ones(Q_shape)*init_q
Expand Down Expand Up @@ -95,7 +98,8 @@ def learn(self):
return self.opt_pol

def reset(self):
self.Q = self.init_q_mat
pass
#self.Q = self.init_q_mat


class actor_critic(RL_base):
Expand Down Expand Up @@ -185,6 +189,30 @@ def reset(self):
self.thetas = np.ones_like(self.thetas)
self.update()

class linpol(pol_func_base):

def __init__(self, num_states, num_actions):
self.thetas = np.ones((num_states, 1))/2
self.update()

def grad_log(self, s, a):
grad = np.zeros_like(self.thetas)
for act in range(self.thetas.shape[1]):
if a == 0:
grad[s] = min(1/(self.thetas[s]),100)
else:
grad[s] = max((-1)/(1-self.thetas[s]),-100)
return grad

def update(self):
self.thetas = np.minimum(1,np.maximum(0,self.thetas))
self.policy = np.hstack((self.thetas, 1-self.thetas))
return self.policy

def reset(self):
self.thetas = np.ones_like(self.thetas)
self.update()

class Q_advantage(advantage_func_base):

def __init__(self, init_adv, num_states, num_actions, df):
Expand Down
6 changes: 3 additions & 3 deletions functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def calc_exploitability(pol, game, learner, num_iters=100000, num_exploit_iters
exploit_rewards[0].append(float(play_game(players, game)))

p_avg_exploitability[0] = sum(exploit_rewards[0])/len(exploit_rewards[0])
#V_1 = players[0].learner.advantage_func.V
V_1 = learner.advantage_func.V

learner.reset()
learner.wipe_memory()
Expand All @@ -59,11 +59,11 @@ def calc_exploitability(pol, game, learner, num_iters=100000, num_exploit_iters

p_avg_exploitability[1] = sum(exploit_rewards[1])/len(exploit_rewards[1])

#V_2 = players[1].learner.advantage_func.V
V_2 = learner.advantage_func.V

avg_exploitability = sum(p_avg_exploitability)
learner.reset()
learner.wipe_memory()

#import pdb; pdb.set_trace()
return avg_exploitability, new_pols, reward_hist
return avg_exploitability, new_pols, reward_hist, (V_1, V_2)
13 changes: 7 additions & 6 deletions main_FSP.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

#test
extras = 0
num_BR = 300
num_mixed = 50
num_BR = 30000
num_mixed = 20000
iters = 1000000
time = 300
RL_iters = 10
time = 600
RL_iters = 1
check_freq = 10

#new test
#extras = 0
Expand All @@ -35,12 +36,12 @@
RL_learners = [learners.actor_critic(learners.softmax, learners.value_advantage, game_obj.num_actions[p],\
game_obj.num_states[p], init_adv=0, extra_samples = extras)\
for p in range(2)]
RL_learners = [learners.Q_learn(0, (game_obj.num_states[p], game_obj.num_actions[p])) for p in range(2)]
#RL_learners = [learners.fitted_Q_iteration(0, (game_obj.num_states[p], game_obj.num_actions[p])) for p in range(2)]
SL_learners = [learners.count_based_SL((game_obj.num_states[p], game_obj.num_actions[p])) for p in range(2)]

agents = [learners.complete_learner(RL_learners[p], SL_learners[p], num_loops = RL_iters) for p in range(2)]

worker = FSP.FSP(game_obj, agents, max_iters=iters, max_time=time, m=num_BR, n=num_mixed, exploit_freq=1)
worker = FSP.FSP(game_obj, agents, max_iters=iters, max_time=time, m=num_BR, n=num_mixed, exploit_freq=check_freq)
pi, exploitability, data = worker.run_algo()

FSP_plots(exploitability, worker.est_exploit_freq, [pi], 'kuhn')
Expand Down

0 comments on commit 37735ef

Please sign in to comment.