Skip to content

Commit

Permalink
analysis tool refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
TheJDen committed Oct 18, 2023
1 parent 2d1d835 commit 5d02b3d
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 41 deletions.
27 changes: 27 additions & 0 deletions decryptoai/decryptoai/analysis/guesser_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from dataclasses import dataclass
import pandas
import numpy as np
import players.unsupervised.numpy_guesser as nguesser
import word2vec_loader.loader as wv_loader

@dataclass
class Suite:
name: str
clue_df: pandas.DataFrame
correct_code_index: pandas.Series

@dataclass
class Strat:
name: str
strat_func: callable


def get_guess(word_index, strat_func, input_row):
keyword_card = (input_row.keyword1, input_row.keyword2, input_row.keyword3, input_row.keyword4)
clues = (input_row.clue1, input_row.clue2, input_row.clue3)
wv_kw_card = map(wv_loader.official_keyword_to_word, keyword_card)
random_vars = nguesser.guesser_random_variables(wv_kw_card, word_index)
clue_indices = nguesser.np_clues(clues, word_index)
code_log_probabilities = nguesser.log_expected_probabilities_codes(strat_func, random_vars, clue_indices)
code_index_guess = np.argmax(code_log_probabilities)
return pandas.Series([code_index_guess, code_log_probabilities[code_index_guess]], index=["code_index_guess", "log_expected_prob"])
95 changes: 54 additions & 41 deletions notebooks/word2vec/word2vec_unsupervised_guesser.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,26 +65,26 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "477a36be",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"def cosine_similarity(clue_index, keyword_index):\n",
"def cosine_similarity(embedding, clue_index, keyword_index):\n",
" clue_embedding = google_news_wv[clue_index.squeeze()].transpose()\n",
" keyword_embedding = google_news_wv[keyword_index.squeeze()]\n",
" return np.expand_dims(keyword_embedding.dot(clue_embedding), axis=-1)\n",
"\n",
"# simple heuristics\n",
"\n",
"def log_square_cosine_similarity(clue_index, keyword_index):\n",
" similarity = cosine_similarity(clue_index, keyword_index)\n",
"def log_square_cosine_similarity(embedding, clue_index, keyword_index):\n",
" similarity = cosine_similarity(embedding, clue_index, keyword_index)\n",
" return 2 * np.log(np.abs(similarity))\n",
"\n",
"def log_normalized_cosine_similarity(clue_index, keyword_index):\n",
" similarity = cosine_similarity(clue_index, keyword_index)\n",
"def log_normalized_cosine_similarity(embedding, clue_index, keyword_index):\n",
" similarity = cosine_similarity(embedding, clue_index, keyword_index)\n",
" normalized_similiarity = (1 + similarity) / 2\n",
" return np.log(normalized_similiarity)"
]
Expand All @@ -99,7 +99,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "331f9d33",
"metadata": {},
"outputs": [],
Expand All @@ -113,7 +113,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 14,
"id": "f91f30e9",
"metadata": {},
"outputs": [
Expand All @@ -128,7 +128,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Pandas Apply: 100%|██████████| 36000/36000 [00:34<00:00, 1046.36it/s]\n"
"Pandas Apply: 100%|██████████| 36000/36000 [00:34<00:00, 1040.22it/s]\n"
]
},
{
Expand All @@ -142,7 +142,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Pandas Apply: 100%|██████████| 36000/36000 [00:37<00:00, 964.41it/s] \n"
"Pandas Apply: 100%|██████████| 36000/36000 [00:51<00:00, 700.76it/s] \n"
]
},
{
Expand All @@ -157,7 +157,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Pandas Apply: 100%|██████████| 36000/36000 [00:26<00:00, 1340.60it/s]\n"
"Pandas Apply: 100%|██████████| 36000/36000 [00:39<00:00, 912.38it/s] \n"
]
},
{
Expand All @@ -171,7 +171,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Pandas Apply: 100%|██████████| 36000/36000 [00:27<00:00, 1294.73it/s]"
"Pandas Apply: 100%|██████████| 36000/36000 [00:24<00:00, 1466.94it/s]"
]
},
{
Expand Down Expand Up @@ -204,32 +204,45 @@
" clue_df: pandas.DataFrame\n",
" correct_code_index: pandas.Series\n",
"\n",
"def get_guess(strat_func, input_row):\n",
"@dataclass\n",
"class Strat:\n",
" name: str\n",
" strat_func: callable\n",
"\n",
"\n",
"def get_guess(word_index, strat_func, input_row):\n",
" keyword_card = (input_row.keyword1, input_row.keyword2, input_row.keyword3, input_row.keyword4)\n",
" clues = (input_row.clue1, input_row.clue2, input_row.clue3)\n",
" wv_kw_card = map(wv_loader.official_keyword_to_word, keyword_card)\n",
" random_vars = nguesser.guesser_random_variables(wv_kw_card, google_news_wv.key_to_index)\n",
" code_log_probabilities = nguesser.log_expected_probabilities_codes(strat_func, random_vars, nguesser.np_clues(clues, google_news_wv.key_to_index))\n",
" random_vars = nguesser.guesser_random_variables(wv_kw_card, word_index)\n",
" clue_indices = nguesser.np_clues(clues, word_index)\n",
" code_log_probabilities = nguesser.log_expected_probabilities_codes(strat_func, random_vars, clue_indices)\n",
" code_index_guess = np.argmax(code_log_probabilities)\n",
" return pandas.Series([code_index_guess, code_log_probabilities[code_index_guess]], index=[\"code_index_guess\", \"log_expected_prob\"])\n",
"\n",
"meaning_clue_df, meaning_correct_code_index = meaning_df.drop('code_index', axis=1), meaning_df['code_index']\n",
"triggerword_clue_df, triggerword_correct_code_index = triggerword_df.drop('code_index', axis=1), triggerword_df['code_index']\n",
"\n",
"suites = [Suite(\"meaning\", meaning_clue_df, meaning_correct_code_index), Suite(\"triggerword\", triggerword_clue_df, triggerword_correct_code_index)]\n",
"suites = [\n",
" Suite(\"meaning\", meaning_clue_df, meaning_correct_code_index),\n",
" Suite(\"triggerword\", triggerword_clue_df, triggerword_correct_code_index)\n",
" ]\n",
"\n",
"strat_funcs = [log_square_cosine_similarity, log_normalized_cosine_similarity]\n",
"guesses_by_suitename_and_strat = defaultdict(dict)\n",
"for strat_func in strat_funcs:\n",
" print(strat_func.__name__)\n",
" get_guess_with_strat = partial(get_guess, strat_func)\n",
"strats = [\n",
" Strat(\"log_square_cosine_similarity\", partial(log_square_cosine_similarity, google_news_wv)),\n",
" Strat(\"log_normalized_cosine_similarity\", partial(log_normalized_cosine_similarity, google_news_wv))\n",
" ]\n",
"guesses_by_suitename_and_stratname = defaultdict(dict)\n",
"for strat in strats:\n",
" print(strat.name)\n",
" get_guess_with_strat = partial(get_guess, google_news_wv.key_to_index, strat.strat_func)\n",
"\n",
" for suite in suites:\n",
" guesses = suite.clue_df.swifter.apply(get_guess_with_strat, axis=1)\n",
" guesses[\"correct\"] = guesses[\"code_index_guess\"] == suite.correct_code_index\n",
" percent_correct = 100 * np.sum(guesses[\"correct\"]) / len(guesses)\n",
" print(f\"percent {suite.name} clues correct: {percent_correct:.2f}%\")\n",
" guesses_by_suitename_and_strat[suite.name][strat_func] = guesses"
" guesses_by_suitename_and_stratname[suite.name][strat.name] = guesses"
]
},
{
Expand All @@ -250,7 +263,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 15,
"id": "90c185e2",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -300,10 +313,10 @@
"import matplotlib.pyplot as plt\n",
"\n",
"for suite in suites:\n",
" for strat_func, meaning_guesses_df in guesses_by_suitename_and_strat[suite.name].items():\n",
" for strat_name, meaning_guesses_df in guesses_by_suitename_and_stratname[suite.name].items():\n",
" _, ax = plt.subplots()\n",
" plot = sns.histplot(ax=ax, data=meaning_guesses_df, x=\"log_expected_prob\", hue=\"correct\", kde=True)\n",
" plot.set(title=f'{strat_func.__name__} {suite.name.capitalize()} Clue Confidence')\n"
" plot.set(title=f'{strat_name} {suite.name.capitalize()} Clue Confidence')\n"
]
},
{
Expand All @@ -318,7 +331,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 16,
"id": "80fbf193",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -435,24 +448,24 @@
"4 17.0 4 "
]
},
"execution_count": 13,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%capture --no-display\n",
"\n",
"guessed_wrong_by_suitename_and_strat = defaultdict(dict)\n",
"guessed_wrong_by_suitename_and_stratname = defaultdict(dict)\n",
"for suite in suites:\n",
" for strat_func, guesses_df in guesses_by_suitename_and_strat[suite.name].items():\n",
" for strat_name, guesses_df in guesses_by_suitename_and_stratname[suite.name].items():\n",
" wrong_guesses = guesses_df[\"correct\"] == False\n",
" guessed_wrong = suite.clue_df[wrong_guesses]\n",
" guessed_wrong[\"guessed_code_index\"] = guesses_df[\"code_index_guess\"][wrong_guesses]\n",
" guessed_wrong[\"correct_code_index\"] = suite.correct_code_index[wrong_guesses]\n",
" guessed_wrong_by_suitename_and_strat[suite.name][strat_func] = guessed_wrong\n",
" guessed_wrong_by_suitename_and_stratname[suite.name][strat_name] = guessed_wrong\n",
"\n",
"guessed_wrong_by_suitename_and_strat[\"meaning\"][log_square_cosine_similarity].head()\n"
"guessed_wrong_by_suitename_and_stratname[\"meaning\"][\"log_square_cosine_similarity\"].head()\n"
]
},
{
Expand Down Expand Up @@ -481,29 +494,29 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 17,
"id": "759cfa55",
"metadata": {},
"outputs": [],
"source": [
"def cosine_distance(clue_index, keyword_index):\n",
" return 1 - cosine_similarity(clue_index, keyword_index)\n",
"def cosine_distance(embedding, clue_index, keyword_index):\n",
" return 1 - cosine_similarity(embedding, clue_index, keyword_index)\n",
"\n",
"def log_zipf(clue_index, keyword_index):\n",
" return np.log(keyword_index) - np.log(clue_index)\n",
"\n",
"def log_outer_radius_proportion(clue_index, keyword_index):\n",
" clue_distance = cosine_distance(clue_index, keyword_index)\n",
"def log_outer_radius_proportion(embedding, clue_index, keyword_index):\n",
" clue_distance = cosine_distance(embedding, clue_index, keyword_index)\n",
" all_word_indices = np.arange(len(google_news_wv))\n",
" all_distances = cosine_distance(np.expand_dims(all_word_indices, axis=-1), keyword_index).swapaxes(-1, -2)\n",
" all_distances = cosine_distance(embedding, np.expand_dims(all_word_indices, axis=-1), keyword_index).swapaxes(-1, -2)\n",
" num_outside = (all_distances > clue_distance).sum(axis=-1)\n",
" return np.log(num_outside) - np.log(len(google_news_wv))\n",
"\n",
"def log_zipf_scaled(clue_index, keyword_index):\n",
" return log_zipf(clue_index, keyword_index) + log_square_cosine_similarity(clue_index, keyword_index)\n",
"def log_zipf_scaled(embedding, clue_index, keyword_index):\n",
" return log_zipf(clue_index, keyword_index) + log_square_cosine_similarity(embedding, clue_index, keyword_index)\n",
"\n",
"def log_outer_radius_proportion_scaled(clue_index, keyword_index):\n",
" return log_outer_radius_proportion(clue_index, keyword_index) + log_square_cosine_similarity(clue_index, keyword_index)"
"def log_outer_radius_proportion_scaled(embedding, clue_index, keyword_index):\n",
" return log_outer_radius_proportion(embedding, clue_index, keyword_index) + log_square_cosine_similarity(embedding, clue_index, keyword_index)"
]
}
],
Expand Down

0 comments on commit 5d02b3d

Please sign in to comment.