Skip to content

Commit

Permalink
Edited saving of results
Browse files Browse the repository at this point in the history
  • Loading branch information
lukearcus committed Sep 12, 2022
1 parent eeca3c7 commit f2298b8
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
7 changes: 5 additions & 2 deletions UI/get_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ def run():
if '--game' in sys.argv:
game_ind = sys.argv.index('--game')
if len(sys.argv) > game_ind:
if sys.argv[game_ind+1] == "kuhn":
game_name = sys.argv[game_ind+1]
if game_name == "kuhn":
game = Kuhn_Poker_int_io()
fict_game = Fict_Kuhn_int()
exploit_learner = learners.kuhn_exact_solver()
elif sys.argv[game_ind+1] == "leduc":
elif game_name == "leduc":
game = leduc_int()
fict_game = leduc_fict()
exploit_learner = learners.actor_critic(learners.softmax, learners.value_advantage, \
Expand All @@ -41,6 +42,7 @@ def run():
print("Please select a game")
return(-1)
else:
game_name = "kuhn"
game = Kuhn_Poker_int_io()
fict_game = Fict_Kuhn_int()
exploit_learner = learners.kuhn_exact_solver()
Expand Down Expand Up @@ -71,6 +73,7 @@ def run():
else:
learner_type = "rl"
opts = {"num_lvls":num_lvls,
"game_name":game_name,
"game":game,
"fict_game":fict_game,
"exploit_learner":exploit_learner,
Expand Down
2 changes: 1 addition & 1 deletion UI/plot_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def plot_heatmap(im, ax, label_name, overlay_vals=False):
def exploitability_plot(exploitability, exploit_freq, ax=None):
if ax is None:
fig, ax = plt.subplots()
iters = np.arange(0, len(exploitability)*exploit_freq, exploit_freq)
iters = np.arange(1, (len(exploitability)+1)*exploit_freq, exploit_freq)
ax.loglog(iters, exploitability)
ax.set_xlabel("Iteration")
ax.set_ylabel("Exploitability")
Expand Down
9 changes: 5 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def main():

options = get_args.run()
num_lvls = options["num_lvls"]
game_name = options["game_name"]
game = options["game"]
fict_game = options["fict_game"]
exploit_learner = options["exploit_learner"]
Expand Down Expand Up @@ -73,7 +74,7 @@ def main():
for p_id, other_p in enumerate(p.other_players):
if other_p != "me":
other_p.opt_pol = players[p_id].opt_pol
if averaged_bel:
if not averaged_bel:
p.belief_buff = []
p.update_mem_and_bel()
bels.append(np.copy(p.belief))
Expand Down Expand Up @@ -137,7 +138,7 @@ def main():
else:
pol_hist.append(pols)
exploit, _, _, _ = calc_exploitability(pols, game, exploit_learner)
exploitability.append(exploit)
exploitability.append(exploit)
else:
if averaged_pol:
new_avg_pols = []
Expand All @@ -155,8 +156,8 @@ def main():
pol_plot = pol_hist
bel_plot = belief_hist
plot_everything(pol_plot, bel_plot, "kuhn", reward_hist[-1], exploitability)
filename="results/OBL_all_average"
np.savez(filename, pols=pol_plot, bels=bel_plot, explot=exploitability, rewards=reward_hist)
filename="results/" + game_name + "_" + learner_type + "_" + str(num_lvls) + "lvls"
np.savez(filename, pols=pol_plot, bels=bel_plot, exploit=exploitability, rewards=reward_hist)
return 0

if __name__=="__main__":
Expand Down

0 comments on commit f2298b8

Please sign in to comment.