Skip to content

Commit

Permalink
Updated plot of IMDB
Browse files Browse the repository at this point in the history
  • Loading branch information
viictorjimenezzz committed Sep 18, 2024
1 parent 3165981 commit 5430d0d
Show file tree
Hide file tree
Showing 4 changed files with 312 additions and 134 deletions.
48 changes: 17 additions & 31 deletions src/plot/adv/adv_eval.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -19,7 +19,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -34,7 +34,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -44,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -137,7 +137,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -184,32 +184,17 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:32<00:00, 2.95s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"dataframe stored in /cluster/home/vjimenez/adv_pa_new/results/adv/Wong2020Fast/def=Wong2020Fast_att=FMN.pkl.\n"
]
}
],
"outputs": [],
"source": [
"df_pgd = _get_attack_df(\"PGD\")\n",
"df_fmn = _get_attack_df(\"FMN\")"
]
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -579,7 +564,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -594,7 +579,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -709,7 +694,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -812,7 +797,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -986,6 +971,7 @@
"\n",
" ax.grid(linestyle=\"--\")\n",
"\n",
" ax.set_title(r\"$\\operatorname{PA}(\\mathbf{x}^\\prime, \\mathbf{x}^{\\prime \\prime}; \\beta)$\", fontname=fontname)\n",
" \n",
" # Legend\n",
" if legend == True:\n",
Expand Down Expand Up @@ -1022,7 +1008,7 @@
" pngname = f\"{attack_name}_{value}_{metric}\"\n",
" fname = osp.join(\n",
" savedir, \n",
" pngname + version_appendix + \".pdf\"\n",
" pngname + version_appendix + \".png\"\n",
" )\n",
" plt.savefig(fname)\n",
" plt.clf()\n",
Expand Down Expand Up @@ -1066,7 +1052,7 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -1076,7 +1062,7 @@
},
{
"cell_type": "code",
"execution_count": 53,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -1088,7 +1074,7 @@
" save = True,\n",
" savedir = r\"/cluster/home/vjimenez/adv_pa_new/results/adv\",\n",
" legend = False,\n",
" version_appendix=\"\"\n",
" version_appendix=\"FOO\"\n",
")"
]
},
Expand Down
143 changes: 110 additions & 33 deletions src/plot/dg/dg_datashift.ipynb

Large diffs are not rendered by default.

168 changes: 144 additions & 24 deletions src/plot/nlp/plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,177 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import os.path as osp\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.lines as mlines\n",
"import matplotlib.font_manager as fm"
"import matplotlib.font_manager as fm\n",
"import seaborn as sns"
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"with open(r\"/cluster/home/vjimenez/adv_pa_new/results/nlp/sentiment_analysis/levenshtein.pkl\", 'rb') as f:\n",
" data_lev = pickle.load(f)\n",
"df = pd.DataFrame(data_lev)\n",
"df[\"method\"] = \"levenshtein\"\n",
"\n",
"with open(r\"/cluster/home/vjimenez/adv_pa_new/results/nlp/sentiment_analysis/adversarial.pkl\", 'rb') as f:\n",
" data_adv = pickle.load(f)\n",
"df2 = pd.DataFrame(data_adv)\n",
"df2[\"method\"] = \"adversarial\"\n",
"\n",
"with open(r\"/cluster/home/vjimenez/adv_pa_new/results/nlp/sentiment_analysis/adversarial-inverted.pkl\", 'rb') as f:\n",
" data_advinv = pickle.load(f)\n",
"df3 = pd.DataFrame(data_advinv)\n",
"df3[\"method\"] = \"adversarial_inverted\"\n",
"\n",
"df = pd.concat([df, df2, df3])"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"COLORS_DICT = {\n",
" \"levenshtein\": \"tab:blue\",\n",
" \"adversarial\": \"tab:red\",\n",
" \"adversarial_inverted\": \"limegreen\"\n",
"}\n",
"\n",
"METRICS_DICT = {\n",
" \"AFR_true\": \"Accuracy\",\n",
" \"logPA\": \"PA\"\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"df[\"intensity_plot\"] = np.log2(df[\"intensity\"])\n",
"\n",
"dfplot = df.loc[(df[\"intensity_plot\"] < 9)]"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'fm' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m fontname \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDejaVu Serif\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m _ \u001b[38;5;241m=\u001b[39m \u001b[43mfm\u001b[49m\u001b[38;5;241m.\u001b[39mfindfont(fm\u001b[38;5;241m.\u001b[39mFontProperties(family\u001b[38;5;241m=\u001b[39mfontname))\n\u001b[1;32m 4\u001b[0m fig, ax \u001b[38;5;241m=\u001b[39m plt\u001b[38;5;241m.\u001b[39msubplots()\n\u001b[1;32m 6\u001b[0m \u001b[38;5;66;03m# Set the title with the desired font name and font size\u001b[39;00m\n",
"\u001b[0;31mNameError\u001b[0m: name 'fm' is not defined"
"name": "stderr",
"output_type": "stream",
"text": [
"/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib/python3.9/site-packages/matplotlib/ticker.py:2937: RuntimeWarning: invalid value encountered in log10\n",
" majorstep_no_exponent = 10 ** (np.log10(majorstep) % 1)\n",
"/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib/python3.9/site-packages/matplotlib/ticker.py:2937: RuntimeWarning: invalid value encountered in log10\n",
" majorstep_no_exponent = 10 ** (np.log10(majorstep) % 1)\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 772.2x545.82 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib.ticker import ScalarFormatter\n",
"metric = \"logPA\"\n",
"\n",
"fontname = \"DejaVu Serif\"\n",
"_ = fm.findfont(fm.FontProperties(family=fontname))\n",
"\n",
"fig, ax = plt.subplots()\n",
"# Create a line plot for PGD attack type with Seaborn\n",
"_, ax = plt.subplots(figsize=(2 * 3.861, 2 * 2.7291))\n",
"fontsize = 18 \n",
"sns.set(font_scale=1.9)\n",
"plt.rcParams[\"font.family\"] = \"serif\"\n",
"plt.rcParams[\"font.serif\"] = fontname\n",
"sns.set_style(\"ticks\")\n",
"\n",
"sns.lineplot(\n",
" data=dfplot,\n",
" ax=ax,\n",
" x=\"intensity_plot\",\n",
" y=metric,\n",
" hue=\"method\",\n",
" style=\"method\",\n",
" palette=COLORS_DICT,\n",
" dashes=False,\n",
" marker=\"o\",\n",
" linewidth=3,\n",
" legend=False\n",
")\n",
"\n",
"# Set the title with the desired font name and font size\n",
"title_text = r\"AFR$_\\text{T}$\"\n",
"fontsize = 18 # Replace with your desired font size\n",
"ax.axis('off')\n",
"\n",
"# Set the title\n",
"ax.set_title(title_text, fontname=fontname, fontsize=fontsize, pad=20)\n",
"ax.minorticks_on()\n",
"ax.set_xticks([i for i in range(9)])\n",
"ax.set_xlabel(\"Attack Power\", fontname=fontname)\n",
"ax.set_title(\"IMDB classification\", fontname=fontname)\n",
"\n",
"# Adjust the layout to center the title\n",
"plt.tight_layout()\n",
"ax.tick_params(axis=\"both\", which=\"both\", direction=\"in\")\n",
"xticks_font = fm.FontProperties(family=fontname)\n",
"for tick in ax.get_xticklabels():\n",
" tick.set_fontproperties(xticks_font)\n",
"ax.grid(linestyle=\"--\")\n",
"\n",
"# Save the figure with the title\n",
"plt.savefig(\"title2.png\", bbox_inches='tight', dpi=500)\n",
"if metric == \"logPA\":\n",
" ax.set_yticks([0, -5000, -10000, -15000])\n",
" ax.set_yticklabels([\"0\", r\"$-0.5$\", r\"$-1.0$\", r\"$-1.5$\"])\n",
" ax.text(0, 1.03, r'$\\times 10^4$', transform=ax.transAxes, fontname=fontname, fontsize=18, verticalalignment='center', horizontalalignment='left')\n",
" # ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=True))\n",
" # ax.yaxis.get_offset_text().set_visible(True) # Show the offset text (scale factor)\n",
" # ax.yaxis.offsetText.set_text(r\"$\\times 10^{-4}$\") # Set custom scale factor label\n",
"\n",
"# Display the figure (optional)\n",
"plt.show()\n",
"\n"
"ax.set_ylabel(METRICS_DICT[metric], fontname=fontname)\n",
"# ax.set_yscale(\"linear\") \n",
"\n",
"# handles = [mlines.Line2D([], [], color='tab:blue', linewidth=3), mlines.Line2D([], [], color='limegreen', linewidth=3), mlines.Line2D([], [], color='tab:red', linewidth=3)]\n",
"# labels = [\"Levenshtein\", \"Amplification\", \"Contradiction\"]\n",
"# ax.legend(\n",
"# handles,\n",
"# labels,\n",
"# # loc=\"upper right\",\n",
"# loc=\"lower left\",\n",
"# # fontsize=12,\n",
"# handlelength=0.5,\n",
"# prop={\n",
"# \"family\": fontname,\n",
"# 'size': 18\n",
"# } \n",
"# )\n",
"\n",
"\n",
"plt.tight_layout()\n",
"fname = osp.join(\n",
" r\"/cluster/home/vjimenez/adv_pa_new/results/nlp\", \n",
" f\"presentation_pa.pdf\"\n",
")\n",
"plt.savefig(fname, dpi=300)\n",
"plt.clf()\n",
"plt.show()\n"
]
},
{
Expand Down
Loading

0 comments on commit 5430d0d

Please sign in to comment.