From 9bbfdf12aad0aaa54cf7015c49431eb546e87a0a Mon Sep 17 00:00:00 2001 From: lukearcus Date: Tue, 6 Sep 2022 13:16:35 +0100 Subject: [PATCH] fixed a bug, added a inset to exploit --- UI/plot_funcs.py | 21 +++++++++++++-------- functions.py | 7 ++++--- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/UI/plot_funcs.py b/UI/plot_funcs.py index ed878e6..8af1b63 100644 --- a/UI/plot_funcs.py +++ b/UI/plot_funcs.py @@ -25,9 +25,7 @@ def plot_everything(pols, bels, game, reward, exploitability): fig2 = plt.figure() ax = fig2.subplots() reward_smoothed(reward, ax) - fig3 = plt.figure() - ax = fig3.subplots() - ax.plot(exploitability) + exploitability_plot(exploitability, 1) else: multiple_heatmaps(pols, fig, game + "_policy") fig.suptitle('Policies', fontsize=32) @@ -86,11 +84,18 @@ def plot_heatmap(im, ax, label_name, overlay_vals=False): for (j,i),label in np.ndenumerate(im): ax.text(i,j,np.round(label,2),ha='center',va='center') -def exploitability_plot(exploitability, exploit_freq): - plt.plot(np.arange(0, len(exploitability)*exploit_freq, exploit_freq), exploitability) - plt.xlabel("Iteration") - plt.ylabel("Exploitability") - plt.title("Exploitability vs Iteration") +def exploitability_plot(exploitability, exploit_freq, ax=None): + if ax is None: + fig, ax = plt.subplots() + iters = np.arange(0, len(exploitability)*exploit_freq, exploit_freq) + ax.plot(iters, exploitability) + ax.set_xlabel("Iteration") + ax.set_ylabel("Exploitability") + ax.set_title("Exploitability vs Iteration") + inset = ax.inset_axes([0.5,0.5,0.45,0.4]) + inset.loglog(iters, exploitability) + + def FSP_plots(exploitability, exploit_freq, pols, game): exploitability_plot(exploitability, exploit_freq) diff --git a/functions.py b/functions.py index 8725e01..c2d4009 100644 --- a/functions.py +++ b/functions.py @@ -23,12 +23,13 @@ def play_to_convergence(players, game, max_iters=1000000, tol=1e-5): converged_itt = 0 for i in range(max_iters): converged_itt += 1 - for i, p in enumerate(players): - old_pol[i] = np.copy(p.opt_pol) + for k, p in enumerate(players): + old_pol[k] = np.copy(p.opt_pol) play_game(players, game) converged = True for j, p in enumerate(players): - pol_diff = np.linalg.norm(p.opt_pol-old_pol[j]) + pol_diff = np.linalg.norm(p.opt_pol-old_pol[j],ord=np.inf) + logging.debug("Iteration "+str(i) + " diff " + str(pol_diff)) converged = converged and pol_diff <= tol if not converged: break