Skip to content

Commit

Permalink
Added parallelisation
Browse files Browse the repository at this point in the history
  • Loading branch information
lukearcus committed Sep 19, 2022
1 parent 8231da3 commit cdbcdb2
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 172 deletions.
13 changes: 7 additions & 6 deletions FSP.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def run_algo(self):
if isinstance(self.game, kuhn):
exploit_learner = learners.kuhn_exact_solver()
else:
exploit_learner = learners.fitted_Q_iteration(0, (self.game.num_states, self.game.num_actions))
#exploit_learner = learners.actor_critic(learners.softmax, learners.value_advantage, \
# self.game.num_actions[0], self.game.num_states[0])

#exploit_learner = learners.fitted_Q_iteration(0, (self.game.num_states[0], self.game.num_actions[0]))
exploit_learner = learners.actor_critic(learners.softmax, learners.value_advantage, \
self.game.num_actions[0], self.game.num_states[0])
times = []
for j in range(1,self.max_iters): # start from 1 or 2?
log.info("Iteration " + str(j))
eta_j = 1/j
Expand Down Expand Up @@ -86,13 +86,14 @@ def run_algo(self):
results = {'true' : [], 'est':[], 'beta': []}

exploit, br_pols, _, values = calc_exploitability(new_pi, self.game, exploit_learner,\
num_iters = -1, num_exploit_iters=-1)
num_iters = 10**4, num_exploit_iters=10**4)
log.info("exploitability: " + str(exploit))
exploitability.append(exploit)
toc = time.perf_counter()
times.append(toc-tic)
if toc-tic > self.max_time:
break
return pi[-1], exploitability, {'pi': pi, 'beta':beta, 'D': D}
return pi[-1], exploitability, {'pi': pi, 'beta':beta, 'D': D, 'times':times}

def play_game(self, strat):
buffer = [[] for i in range(self.num_players)]
Expand Down
11 changes: 8 additions & 3 deletions agents/learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ def wipe_memory(self):
class RL_base(learner_base):
opt_pol = None
memory = []

eval = False

def __init__(self, extra_samples, init_lr, df):
self.extra_samples = extra_samples
self.gamma = df
self.init_lr = init_lr
self.max_mem = 100000
self.eval = False

def update_memory(self, data):
self.last_round = len(data)
Expand Down Expand Up @@ -113,12 +113,13 @@ def reset(self):

class actor_critic(RL_base):

def __init__(self, pol_func, advantage_func, num_actions, num_states, init_adv = 0, extra_samples=10, init_lr=0.05, df=1.0, tol=999):
def __init__(self, pol_func, advantage_func, num_actions, num_states, init_adv = 0, extra_samples=10, init_lr=0.05, df=1.0, tol=999, max_iters = 10**0):
self.pol_func = pol_func(num_states, num_actions)
self.advantage_func = advantage_func(init_adv, num_states, num_actions, df, self.entropy)
self.opt_pol = self.pol_func.policy
self.memory = []
self.tol=tol
self.max_iters = max_iters
super().__init__(extra_samples, init_lr, df)

def reset(self):
Expand All @@ -131,6 +132,7 @@ def learn(self):
self.iteration += 1
lr = self.init_lr/(1+0.003*np.sqrt(self.iteration))
prev_pol = np.copy(self.opt_pol) - self.tol - 1
itt = 0
while np.linalg.norm(prev_pol - self.opt_pol) > self.tol:
prev_pol = np.copy(self.opt_pol)
RL_buff = random.sample(self.memory, min(self.extra_samples, len(self.memory)))
Expand All @@ -145,6 +147,9 @@ def learn(self):
self.advantage_func.update(lr*delta, elem["s"], elem["a"])
self.pol_func.thetas += lr*theta_update
self.opt_pol = self.pol_func.update()
itt += 1
if itt > self.max_iters:
break
return self.opt_pol

class pol_func_base:
Expand Down
8 changes: 5 additions & 3 deletions agents/players.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def __init__(self, learner, player_id, fict_game, belief_iters = 10000, averagin
self.avg_pol = np.ones_like(self.opt_pol)/self.opt_pol.shape[1]
self.curr_opp_lvl = 0
self.learn_avg = averaging == "FSP_style"
self.ot_lvls = 10

def set_other_players(self, other_players):
self.other_players = other_players.copy()
Expand All @@ -242,7 +243,7 @@ def observe(self, observation, fict=False):
self.r = observation[1]
if not fict:
if self.state != -1:
for lvl in range(self.curr_lvl):
for lvl in range(max(self.curr_lvl - self.ot_lvls, 0),self.curr_lvl):
belief_probs = self.beliefs[lvl][self.state, :]
#Here we do OBL
res = -1
Expand Down Expand Up @@ -273,7 +274,7 @@ def observe(self, observation, fict=False):
r = next_obs[1]
self.buffer[lvl].append({"s":self.state, "a": act, "r":r, "s'":s_prime})
else:
for lvl in range(self.curr_lvl):
for lvl in range(max(self.curr_lvl - self.ot_lvls, 0), self.curr_lvl):
self.learner[lvl].update_memory([(self.buffer[lvl], None)])
self.pols[lvl+1] = self.learner[lvl].learn()
self.opt_pol = self.pols[self.curr_lvl]
Expand All @@ -295,14 +296,15 @@ def action(self, lvl=-1):
def add_to_mem(self):
for i in range(self.belief_iters):
self.fict_game.start_game()
lvl = np.random.randint(self.curr_lvl)
while not self.fict_game.ended:
p_id = self.fict_game.curr_player
if p_id == self.id:
player = self
else:
player = self.other_players[p_id]
player.observe(self.fict_game.observe(), fict=True)
act = player.action()
act = player.action(lvl)
self.fict_game.action(act)
if p_id == self.id:
hidden_state = self.fict_game.get_hidden(self.id)
Expand Down
Loading

0 comments on commit cdbcdb2

Please sign in to comment.