Skip to content

Commit

Permalink
Add mlp experiment scripts.
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Nov 26, 2024
1 parent 0159ed2 commit 75d4bb9
Show file tree
Hide file tree
Showing 2 changed files with 590 additions and 0 deletions.
291 changes: 291 additions & 0 deletions experiments/plot_mlp_results.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Plot MLP results"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"!pip install -U kaleido\n",
"!pip install plotly==5.24.1"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import numpy as np\n",
"import plotly.graph_objs as go\n",
"import plotly.colors as pc"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def plot_accuracies_over_ts_and_lrs(accuracies, max_ts, save_path):\n",
" n_lrs = len(accuracies.keys())\n",
" n_ts = len(accuracies[0.5])\n",
" xtickvals = [t for t in range(n_ts+1)]\n",
"\n",
" fig = go.Figure()\n",
" colors = pc.sample_colorscale(\"Oranges\", n_lrs+2)[::-1]\n",
" for i, (lr, max_avg_acc) in enumerate(accuracies.items()):\n",
" fig.add_traces(\n",
" go.Scatter(\n",
" y=max_avg_acc,\n",
" name=f\"$lr = {{{lr}}}$\",\n",
" mode=\"lines+markers\",\n",
" line=dict(width=2, color=colors[i])\n",
" )\n",
" )\n",
"\n",
" fig.update_layout(\n",
" height=350,\n",
" width=550,\n",
" xaxis=dict(\n",
" title=\"Max T\",\n",
" tickvals=xtickvals,\n",
" ticktext=max_ts,\n",
" ),\n",
" yaxis=dict(\n",
" title=\"Max mean accuracy (%)\",\n",
" nticks=5\n",
" ),\n",
" font=dict(size=16),\n",
" margin=dict(r=120)\n",
" )\n",
" fig.write_image(save_path)\n",
"\n",
"\n",
"def plot_accuracies_over_optims(accuracies, save_path, test_every=100):\n",
" n_train_iters = len(accuracies[\"Euler\"])\n",
" train_iters = [t+1 for t in range(n_train_iters)]\n",
"\n",
" colors = [\"#636EAF\", \"#EF553B\", \"#00CC96\"]\n",
" fig = go.Figure()\n",
" for i, (optim_id, accuracy) in enumerate(accuracies.items()):\n",
" means = accuracy.mean(axis=-1)\n",
" stds = accuracy.std(axis=-1)\n",
" y_upper, y_lower = means + stds, means - stds\n",
" \n",
" fig.add_traces(\n",
" go.Scatter(\n",
" x=list(train_iters) + list(train_iters[::-1]),\n",
" y=list(y_upper) + list(y_lower[::-1]),\n",
" fill=\"toself\",\n",
" fillcolor=colors[i],\n",
" line=dict(color=\"rgba(255,255,255,0)\"),\n",
" hoverinfo=\"skip\",\n",
" showlegend=False,\n",
" opacity=0.3\n",
" )\n",
" )\n",
" fig.add_trace(\n",
" go.Scatter(\n",
" x=train_iters,\n",
" y=means,\n",
" mode=\"lines+markers\",\n",
" name=optim_id if optim_id != \"SGD\" else \"GD\",\n",
" line=dict(width=2, color=colors[i])\n",
" )\n",
" )\n",
"\n",
" fig.update_layout(\n",
" height=300,\n",
" width=400,\n",
" xaxis=dict(\n",
" title=\"Training iteration\",\n",
" tickvals=[1, int(train_iters[-1]/2)+1, train_iters[-1]],\n",
" ticktext=[1, (int(train_iters[-1]/2)+1)*test_every, train_iters[-1]*test_every]\n",
" ),\n",
" yaxis=dict(title=\"Test accuracy (%)\"),\n",
" font=dict(size=16)\n",
" )\n",
" fig.write_image(save_path)\n",
"\n",
"\n",
"def plot_runtimes_over_optims(runtimes, save_path):\n",
" n_train_iters = len(runtimes[\"Euler\"])\n",
" train_iters = [t+1 for t in range(n_train_iters)]\n",
"\n",
" colors = [\"#636EAF\", \"#EF553B\", \"#00CC96\"]\n",
" fig = go.Figure()\n",
" for i, (optim_id, runtime) in enumerate(runtimes.items()):\n",
" means = runtime.mean(axis=-1)\n",
" stds = runtime.std(axis=-1)\n",
" y_upper, y_lower = means + stds, means - stds\n",
" \n",
" fig.add_traces(\n",
" go.Scatter(\n",
" x=list(train_iters) + list(train_iters[::-1]),\n",
" y=list(y_upper) + list(y_lower[::-1]),\n",
" fill=\"toself\",\n",
" fillcolor=colors[i],\n",
" line=dict(color=\"rgba(255,255,255,0)\"),\n",
" hoverinfo=\"skip\",\n",
" showlegend=False,\n",
" opacity=0.3\n",
" )\n",
" )\n",
" fig.add_trace(\n",
" go.Scatter(\n",
" x=train_iters,\n",
" y=means,\n",
" mode=\"lines\",\n",
" name=optim_id if optim_id != \"SGD\" else \"GD\",\n",
" line=dict(width=2, color=colors[i])\n",
" )\n",
" )\n",
"\n",
" fig.update_layout(\n",
" height=300,\n",
" width=400,\n",
" xaxis=dict(\n",
" title=\"Training iteration\",\n",
" tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],\n",
" ticktext=[1, int(train_iters[-1]/2), train_iters[-1]]\n",
" ),\n",
" yaxis=dict(title=\"Runtime (ms)\"),\n",
" font=dict(size=16)\n",
" )\n",
" fig.write_image(save_path)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"### plot max avg test acc as a function of T and lr, for each optim ###\n",
"DATASETS = [\"MNIST\", \"Fashion-MNIST\"]\n",
"N_HIDDENS = [3, 5]\n",
"ACTIVITY_OPTIMS_ID = [\"Euler\", \"Heun\", \"SGD\"]\n",
"\n",
"MAX_T1S = [5, 10, 20, 50, 100, 200, 500]\n",
"ACTIVITY_LRS = [5e-1, 1e-1, 5e-2]\n",
"\n",
"N_SEEDS = 3\n",
"\n",
"for dataset in DATASETS:\n",
" for n_hidden in N_HIDDENS:\n",
" for activity_optim_id in ACTIVITY_OPTIMS_ID:\n",
" max_test_accs = {}\n",
" max_t1s = MAX_T1S[:-1] if activity_optim_id == \"Euler\" else MAX_T1S\n",
" \n",
" for activity_lr in ACTIVITY_LRS:\n",
" max_test_accs[activity_lr] = []\n",
" for max_t1 in max_t1s:\n",
" avg_test_acc = 0.\n",
" for seed in range(N_SEEDS):\n",
" test_acc = np.load(\n",
" f\"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/max_t1_{max_t1}/activity_lr_{activity_lr}/param_lr_0.001/{activity_optim_id}/{seed}/test_accs.npy\"\n",
" )\n",
" avg_test_acc += test_acc\n",
" \n",
" avg_test_acc /= N_SEEDS\n",
" max_test_accs[activity_lr].append(max(avg_test_acc))\n",
" \n",
" plot_accuracies_over_ts_and_lrs(\n",
" accuracies=max_test_accs,\n",
" max_ts=max_t1s,\n",
" save_path=f\"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/max_avg_test_acc_{activity_optim_id}.pdf\"\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"### plot inference runtimes for best accuracies of different optimisers ###\n",
"BATCH_SIZE = 64\n",
"TEST_EVERY = 100\n",
"\n",
"n_train_iters = math.floor(60000/BATCH_SIZE)-1\n",
"n_tests = math.floor(n_train_iters/TEST_EVERY)\n",
"\n",
"for dataset in DATASETS:\n",
" for n_hidden in N_HIDDENS:\n",
" test_accs, inference_runtimes = {}, {}\n",
" for activity_optim_id in ACTIVITY_OPTIMS_ID:\n",
" \n",
" if n_hidden == 3:\n",
" best_t = 10 if activity_optim_id == \"SGD\" else 20\n",
" if dataset == \"MNIST\":\n",
" best_lr = 0.05 if activity_optim_id == \"Heun\" else 0.5\n",
" else:\n",
" best_lr = 0.1 if activity_optim_id == \"Heun\" else 0.5\n",
" \n",
" if n_hidden == 5:\n",
" if dataset == \"MNIST\":\n",
" best_t = 50\n",
" best_lr = 0.05 if activity_optim_id == \"Heun\" else 0.5\n",
" else:\n",
" best_t = 200\n",
" best_lr = 0.1 if activity_optim_id == \"SGD\" else 0.5\n",
" \n",
" test_accs[activity_optim_id] = np.zeros((n_tests, N_SEEDS))\n",
" inference_runtimes[activity_optim_id] = np.zeros((n_train_iters, N_SEEDS))\n",
" for seed in range(N_SEEDS):\n",
" test_acc = np.load(\n",
" f\"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/max_t1_{best_t}/activity_lr_{best_lr}/param_lr_0.001/{activity_optim_id}/{seed}/test_accs.npy\"\n",
" )\n",
" inference_runtime = np.load(\n",
" f\"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/max_t1_{best_t}/activity_lr_{best_lr}/param_lr_0.001/{activity_optim_id}/{seed}/inference_runtimes.npy\"\n",
" )\n",
" test_accs[activity_optim_id][:, seed] = test_acc\n",
" # skip first point for jit compilation\n",
" inference_runtimes[activity_optim_id][:, seed] = inference_runtime[1:]\n",
" \n",
" plot_accuracies_over_optims(test_accs, f\"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/best_mean_test_accs.pdf\")\n",
" plot_runtimes_over_optims(inference_runtimes, f\"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/best_mean_infer_runtimes.pdf\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading

0 comments on commit 75d4bb9

Please sign in to comment.