-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Martha
authored and
Martha
committed
Nov 20, 2024
1 parent
95f2808
commit c91fe40
Showing
50 changed files
with
225,742 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Digit matrix analogies | ||
|
||
## Generating problems | ||
This directory contains code to generate digit matrix analogy problems with symbol alphabets or alternative blanks, in `gen_problems_coords.py`. The original problems used by Webb, Holyoak, Lu (2023) can be converted to a symbol alphabet using `convert_to_symb.py`. | ||
|
||
The problems used in the paper are all available in `problems`. | ||
|
||
## Testing GPT | ||
GPT can be tested on the problems by running `eval_GPT_matprob.py` with command line arguments `--id` to choose GPT model, `--user_prompt_num` and `--sys_prompt_num` for alternative user and system prompts -- best results obtained with 2 and 1 respectively `--prob_format` specifies the problem format out of 'digits', 'symb', or 'coords'. GPT model setup needs to be defined by the user. | ||
|
||
#### Example usage. | ||
To evaluate GPT 4 (0613) on symbol problems with system prompt 1 and user prompt 2, you would call | ||
|
||
`python eval_GPT_digitmat.py --gpt 0613 --sys_prompt_num 1 --user_prompt_num 2 --prob_format symb` | ||
|
||
|
||
## Results | ||
Results are stored in `results` directory as `.npz` files. Results are processed and saved as csv in `all_data.csv`. | ||
|
||
## Human data | ||
Human data is available in `all_data.csv` for numeric subj_id | ||
|
||
## Data analysis and plotting | ||
A notebook in `plotting` gives code to generate all plots in the paper. | ||
|
||
|
||
|
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# building npz of problems just by replacing numbers with symbols | ||
import numpy as np | ||
|
||
all_prob = np.load(f'./problems/all_problems.npz', allow_pickle=True) | ||
|
||
all_problems = all_prob['all_problems'].item() | ||
all_prob_types = all_problems.keys() | ||
|
||
symbs = ['%', '&', '*', '!', '(', '<', '~', '>', ':', '$'] #NB: This replaces ? with ( | ||
all_prob_types = list(all_prob['all_problems'].item().keys()) | ||
print(all_prob_types) | ||
all_problems = all_prob['all_problems'].item() | ||
print(list(all_problems['row_constant'].keys())) | ||
all_symb_problems_np={} | ||
for prob_type in all_prob_types: | ||
symb_probs = [] | ||
for prob in all_problems[prob_type]['prob']: | ||
symb_prob = [] | ||
for row in prob: | ||
symb_row = [] | ||
for column in row: | ||
symb_col = [symbs[i] for i in column] | ||
symb_row.append(symb_col) | ||
symb_prob.append(symb_row) | ||
symb_probs.append(symb_prob) | ||
|
||
all_symb_choices = [] | ||
for choices in all_problems[prob_type]['answer_choices']: | ||
symb_choices = [] | ||
for choice in choices: | ||
symb_choice = [symbs[i] for i in choice] | ||
symb_choices.append(symb_choice) | ||
all_symb_choices.append(symb_choices) | ||
all_symb_problems_np[prob_type]={'prob':symb_probs,\ | ||
'answer_choices':all_symb_choices,\ | ||
'correct_ind':all_problems[prob_type]['correct_ind'],\ | ||
'perm_invariant':all_problems[prob_type]['perm_invariant']} | ||
|
||
|
||
np.savez('./all_problems_symb.npz', all_problems=all_symb_problems_np) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
|
||
import numpy as np | ||
import builtins | ||
import os | ||
from openai import AzureOpenAI | ||
import sys | ||
import argparse | ||
import time | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--id", help="The gpt version, can be 0613, 1106 or 0125 or 350613") | ||
parser.add_argument("--user_prompt_num", help="user prompt", type=int) | ||
parser.add_argument("--sys_prompt_num", help="sys prompt", type=int) | ||
parser.add_argument("--prob_format", help="use webb data") | ||
parser.add_argument("--letters", help="use letters", action="store_true") | ||
args = parser.parse_args() | ||
|
||
if args.prob_format not in ['digits', 'symb', 'coords']: | ||
print('prob_format must be one of digits, symb, or coords') | ||
|
||
id = args.id | ||
|
||
|
||
|
||
if args.prob_format=='digits': | ||
prob_format = '_digits' | ||
checkset = list(range(10)) | ||
elif args.prob_format=='coords': | ||
prob_format = '_coords' | ||
checkset = list(range(10)) | ||
elif args.prob_format=='symb': | ||
prob_format = '_symb' | ||
checkset =set(['%', '&', '*', '!', '(', '<', '~', '>', ':', '$']) #changed ? to ( | ||
else: | ||
print('prob_format must be one of digits, symb, or coords') | ||
sys.exit() | ||
|
||
|
||
if (args.sys_prompt_num == 0): | ||
sys_content = 'You are a helpful assistant.' | ||
elif args.sys_prompt_num == 1: | ||
sys_content = 'You are a genius at solving analogy problems.' | ||
|
||
versions = {'0125':{'resource_name':'0125-Preview', 'deployment_name':'0125-Preview'}, | ||
'1106':{'resource_name':'MMResearch', 'deployment_name':'gpt-4-1106-Preview'}, | ||
'0613':{'resource_name':'0613', 'deployment_name':'0613'}, | ||
'350613':{'resource_name':'0613', 'deployment_name':'0613'}} | ||
|
||
if id not in versions.keys(): | ||
print(f'id should be 0125, 1106, or 0613 or 350613') | ||
|
||
client = AzureOpenAI( | ||
azure_endpoint = os.getenv(f"AZURE_OPENAI_ENDPOINT_{id}"), | ||
api_key=os.getenv(f"AZURE_OPENAI_API_KEY_{id}"), | ||
api_version="2024-02-01" | ||
) | ||
|
||
# Split word into characters | ||
def split(word): | ||
return [char for char in word] | ||
|
||
# Load all problems | ||
if args.prob_format == 'digits': | ||
all_prob = np.load('./problems/all_problems.npz', allow_pickle=True) | ||
elif args.prob_format == 'coords': | ||
all_prob = np.load('./problems/all_problems_coords.npz', allow_pickle=True) | ||
else: | ||
all_prob = np.load('./problems/all_problems_symb.npz', allow_pickle=True) | ||
|
||
# GPT settings | ||
kwargs = {"temperature":0, "max_tokens":10, "stop":"\n", } | ||
|
||
|
||
# Loop through all problem types | ||
all_prob_types = builtins.list(all_prob['all_problems'].item().keys()) | ||
# Load data if it already exists | ||
all_data_fname = f'./results/gpt_matprob_results_{args.prob_format}_{id}_prompt_{args.sys_prompt_num}_{args.user_prompt_num}{prob_format}.npz' | ||
if os.path.exists(all_data_fname): | ||
data_exists = True | ||
all_data = np.load(all_data_fname, allow_pickle=True) | ||
else: | ||
data_exists = False | ||
# Create data structure for storing results | ||
all_gen_pred = {} | ||
all_gen_correct_pred = {} | ||
all_MC_pred = {} | ||
all_MC_correct_pred = {} | ||
all_alt_MC_correct_pred = {} | ||
for p in range(len(all_prob_types)): | ||
# Problem type | ||
prob_type = all_prob_types[p] | ||
# Load data | ||
if data_exists: | ||
all_gen_pred[prob_type] = all_data['all_gen_pred'].item()[prob_type] | ||
all_gen_correct_pred[prob_type] = all_data['all_gen_correct_pred'].item()[prob_type] | ||
all_MC_pred[prob_type] = all_data['all_MC_pred'].item()[prob_type] | ||
all_MC_correct_pred[prob_type] = all_data['all_MC_correct_pred'].item()[prob_type] | ||
all_alt_MC_correct_pred[prob_type] = all_data['all_alt_MC_correct_pred'].item()[prob_type] | ||
# Create data structure | ||
else: | ||
all_gen_pred[prob_type] = [] | ||
all_gen_correct_pred[prob_type] = [] | ||
all_MC_pred[prob_type] = [] | ||
all_MC_correct_pred[prob_type] = [] | ||
all_alt_MC_correct_pred[prob_type] = [] | ||
# Loop over all problem indices | ||
N_prob = 50 | ||
none_count = 0 | ||
for prob_ind in range(N_prob): | ||
print(str(prob_ind + 1) + ' of ' + str(N_prob) + '...') | ||
# Loop over all problem types | ||
for p in range(len(all_prob_types)): | ||
# Problem type | ||
prob_type = all_prob_types[p] | ||
print('Problem type: ' + prob_type + '...') | ||
perm_invariant = all_prob['all_problems'].item()[prob_type]['perm_invariant'] | ||
prob_type_N_prob = len(all_prob['all_problems'].item()[prob_type]['prob']) #changed this from shape[0] | ||
if prob_ind < prob_type_N_prob and len(all_gen_correct_pred[prob_type]) <= prob_ind: | ||
|
||
# Problem | ||
prob = all_prob['all_problems'].item()[prob_type]['prob'][prob_ind] | ||
size_prob =len(prob) | ||
answer_choices = all_prob['all_problems'].item()[prob_type]['answer_choices'][prob_ind] | ||
correct_ind = all_prob['all_problems'].item()[prob_type]['correct_ind'][prob_ind] | ||
correct_answer = answer_choices[correct_ind] | ||
|
||
if args.prob_format == 'coords': | ||
cx, cy = all_prob['all_problems'].item()[prob_type]['correct_coords'][prob_ind] | ||
|
||
# Generate prompt | ||
if args.user_prompt_num == 0: | ||
prompt = '' | ||
elif args.user_prompt_num ==1: | ||
prompt = "Let's try to complete the pattern:\n" #This prompt does not work Prompt 1 | ||
elif args.user_prompt_num == 2: | ||
prompt = "Try to complete the pattern below. Give ONLY the answer as briefly as possible.\n" | ||
elif args.user_prompt_num == 3: | ||
prompt = "Try to fill the gap in the pattern below. Give ONLY the answer as briefly as possible.\n" | ||
else: | ||
print('You must choose prompt 1, 2, or 3') | ||
sys.exit() | ||
for r in range(size_prob): | ||
for c in range(size_prob): | ||
prompt += '[' | ||
if args.prob_format != 'coords': | ||
if not (r == size_prob-1 and c == size_prob-1): | ||
for i in range(len(prob[r][c])): | ||
if prob[r][c][i] == -1: | ||
prompt += ' ' | ||
else: | ||
prompt += str(prob[r][c][i]) | ||
if i < len(prob[r][c]) - 1: | ||
prompt += ' ' | ||
prompt += ']' | ||
if c < size_prob-1: | ||
prompt += ' ' | ||
else: | ||
prompt += '\n' | ||
else: | ||
if (r == cx and c == cy): | ||
for i in range(len(prob[r][c])): | ||
prompt+= ' ' | ||
if i < (len(prob[r][c]) - 1): | ||
prompt += ' ' | ||
else: | ||
for i in range(len(prob[r][c])): | ||
if prob[r][c][i] == -1: | ||
prompt += ' ' | ||
else: | ||
prompt += str(prob[r][c][i]) | ||
if i < len(prob[r][c]) - 1: | ||
prompt += ' ' | ||
prompt += ']' | ||
if c < size_prob-1: | ||
prompt += ' ' | ||
else: | ||
prompt += '\n' | ||
print(prompt) | ||
# sys.exit() | ||
# Get response | ||
response = client.chat.completions.create( | ||
model= versions[id]['deployment_name'], | ||
messages=[ | ||
{"role": "system", "content": sys_content}, | ||
{"role": "user", "content": prompt}, | ||
], | ||
**kwargs | ||
) | ||
|
||
response_text = response.choices[0].message.content | ||
if response_text is None: | ||
response_text = 'None' | ||
none_count+=1 | ||
print(f'Nonecount is {none_count}') | ||
# print(response_text) | ||
# sys.exit() | ||
# Find portion of response corresponding to prediction | ||
prediction = response_text.lstrip('[') | ||
|
||
all_gen_pred[prob_type].append(prediction) | ||
# Get prediction set | ||
pred_set = [] | ||
invalid_char = False | ||
closing_bracket = False | ||
for p in split(prediction): | ||
if p != ' ': | ||
if args.prob_format != 'symb' and p.isdigit(): | ||
p = int(p) | ||
if p in checkset: | ||
pred_set.append(p) | ||
elif p == ']': | ||
break | ||
else: | ||
invalid_char = True | ||
break | ||
# Sort answer if problem type is permutation invariant | ||
if perm_invariant: | ||
correct_answer = np.sort(correct_answer) | ||
pred_set = np.sort(pred_set) | ||
# Determine whether prediction is correct | ||
correct_pred = False | ||
if not invalid_char and len(pred_set) == len(correct_answer): | ||
if np.all(pred_set == correct_answer): | ||
correct_pred = True | ||
all_gen_correct_pred[prob_type].append(correct_pred) | ||
|
||
# Save data | ||
eval_fname = f'./results/gpt_matprob_results_{args.prob_format}_{id}_prompt_{args.sys_prompt_num}_{args.user_prompt_num}{prob_format}.npz' | ||
np.savez(eval_fname, | ||
all_gen_pred=all_gen_pred, all_gen_correct_pred=all_gen_correct_pred, all_MC_pred=all_MC_pred, all_MC_correct_pred=all_MC_correct_pred, all_alt_MC_correct_pred=all_alt_MC_correct_pred, | ||
allow_pickle=True) | ||
|
||
print(f'Nonecount is {none_count}') |
Oops, something went wrong.