Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gpt math solver #991

Draft
wants to merge 136 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
7dc4035
handle format error in message in _construct_params
yiranwu0 Apr 11, 2023
83ff983
fix typo
yiranwu0 Apr 12, 2023
ab2cada
Add math solver with automatic tool queries.
yiranwu0 Apr 16, 2023
2d70c99
add imports in QueryHandler
yiranwu0 Apr 16, 2023
c823cbf
update math solver
yiranwu0 Apr 23, 2023
766b022
require wolfram id in readme
yiranwu0 Apr 23, 2023
84ba0be
Merge branch 'main' into gpt_math_solver
yiranwu0 Apr 23, 2023
8f67ed7
fix bug in running python code
yiranwu0 Apr 23, 2023
a511f0a
Update flaml/autogen/math_solver/MathSolver.py
yiranwu0 Apr 23, 2023
87ad79d
Update flaml/autogen/math_solver/README.md
yiranwu0 Apr 23, 2023
a16fa5f
revise according to comments
yiranwu0 Apr 23, 2023
e21fd76
Merge branch 'gpt_math_solver' of github.com:kevin666aa/FLAML into gp…
yiranwu0 Apr 23, 2023
45dcb7f
fix code format
yiranwu0 Apr 23, 2023
435e7a4
Add prompt to system message
yiranwu0 Apr 23, 2023
d1747cf
refrtor file names
yiranwu0 Apr 24, 2023
56627a7
refine prompts
yiranwu0 Apr 24, 2023
9821820
add baseline PoT
yiranwu0 Apr 24, 2023
e37ee3e
fix bugs in query_handler
yiranwu0 Apr 24, 2023
5d44e5e
refine prompts
yiranwu0 Apr 24, 2023
bab2878
refine prompt to output fractions
yiranwu0 Apr 24, 2023
d0b0d4b
change prompt
yiranwu0 Apr 24, 2023
3e171a3
add temperature as args
yiranwu0 Apr 24, 2023
2261c5c
fix concat float to str
yiranwu0 Apr 24, 2023
8c5a86c
change prompt back to use fractions instead of decimal
yiranwu0 Apr 24, 2023
2b8b717
rewind prompt back to e37ee3
yiranwu0 Apr 25, 2023
8b68ff7
pass args.samples_per_category in PoT
yiranwu0 Apr 25, 2023
54407a7
fix counting bug in PoT and print in mth_solver
yiranwu0 Apr 25, 2023
4806631
fix error: convet exception to str
yiranwu0 Apr 25, 2023
80a7063
add logger to log stdouts and compress files
yiranwu0 Apr 25, 2023
d737644
refine logging
yiranwu0 Apr 25, 2023
d146e35
add option to put prompt in either system or user message, add option…
yiranwu0 Apr 26, 2023
26c0caa
clean up main.py
yiranwu0 Apr 26, 2023
2a1a47e
create pseudo_main.py
yiranwu0 Apr 26, 2023
edfc679
fix category loading bug
yiranwu0 Apr 26, 2023
6a15761
handle timeout
yiranwu0 Apr 26, 2023
ab64723
two new prompts
yiranwu0 Apr 26, 2023
f723a8f
add bash
yiranwu0 Apr 27, 2023
1a5c93c
more prompts
yiranwu0 Apr 27, 2023
955edca
change run sequence
yiranwu0 Apr 27, 2023
8519967
add more prompts
yiranwu0 Apr 28, 2023
912193e
catch wolfram error
yiranwu0 Apr 28, 2023
c8f90b4
more runs on v2.1 select, v1.2 select, add new v3select
yiranwu0 Apr 28, 2023
7a8c2ac
compress when all finished
yiranwu0 Apr 28, 2023
b9a7e04
py exec output fix
yiranwu0 Apr 28, 2023
65f1580
v3.1 select
yiranwu0 Apr 29, 2023
73088ce
new both prompt, v3.2select
yiranwu0 Apr 29, 2023
144c148
change execute to run
yiranwu0 Apr 29, 2023
812477a
refine query handling and v3.3select
yiranwu0 Apr 30, 2023
25e2708
catch wolfram errors
yiranwu0 Apr 30, 2023
1c00283
ablation on only using python and zeroshot baseline
yiranwu0 May 1, 2023
1330a00
change run sequence
yiranwu0 May 1, 2023
e61212f
new run
yiranwu0 May 1, 2023
2b5dd52
new run
yiranwu0 May 1, 2023
ac11d2a
consitent ouput folder in PoT
yiranwu0 May 1, 2023
9d291b9
1erun pot , refined prompt v1.3 v1.4 and v3.4
yiranwu0 May 2, 2023
ce7144a
resume 22 if not finished
yiranwu0 May 2, 2023
6fefde3
handle wolfram exception
yiranwu0 May 2, 2023
eaae6ce
one run for v1.5
yiranwu0 May 2, 2023
8fdf74f
one run for v1.5 corrections
yiranwu0 May 2, 2023
ca75c91
two more prompts v3.5select and v3.1python based on v3python
yiranwu0 May 3, 2023
47179ce
remove error string clipping
yiranwu0 May 3, 2023
a8c3758
handle UnicodeDecodeError
yiranwu0 May 3, 2023
132638a
handle UnicodeDecodeError
yiranwu0 May 3, 2023
280f9de
quick test on adding wolfram to v3.1python
yiranwu0 May 3, 2023
45a4abd
rerun v3.1 with refine, add v3.7select to further test wolfram
yiranwu0 May 4, 2023
b0efcbf
switch run seq v3.7select then v3.1python
yiranwu0 May 4, 2023
10c28ae
add v3.2python, slightly refine from v3.1. try out v3.3python
yiranwu0 May 4, 2023
bfe61aa
more args for PoT and refine load_leve5 func
yiranwu0 May 5, 2023
39ea367
trial 38-42: validate our methods on all level of problems, run large…
yiranwu0 May 5, 2023
0cebecb
update run.sh
yiranwu0 May 5, 2023
bddb610
move
sonichi May 6, 2023
326da82
add v4
yiranwu0 May 7, 2023
bd040b5
Merge branch 'gpt_math_solver' of github.com:kevin666aa/FLAML into gp…
yiranwu0 May 7, 2023
c8ba447
test with new system message
yiranwu0 May 7, 2023
62b5259
add baseline pnas, run v4.2 on level5 problems, test new sys message …
yiranwu0 May 8, 2023
ef509d4
fix trial 49
yiranwu0 May 8, 2023
e60850f
remove print
yiranwu0 May 8, 2023
5fe0b0b
run v3 with specified sentence removed, 4.2 with original sys message…
yiranwu0 May 9, 2023
d92b559
remove trial 52
yiranwu0 May 9, 2023
ede98a5
endpoint
sonichi May 9, 2023
7082355
Merge branch 'gpt_math_solver' of https://github.com/kevin666aa/FLAML…
sonichi May 9, 2023
7d34485
fix bug in queryhandler
yiranwu0 May 9, 2023
8e34218
Merge branch 'gpt_math_solver' of github.com:kevin666aa/FLAML into gp…
yiranwu0 May 9, 2023
c592837
fix queryhandler 2
yiranwu0 May 9, 2023
6345e0b
v3.3python
yiranwu0 May 10, 2023
40fd299
remove print
yiranwu0 May 10, 2023
dac1551
test final prompts
yiranwu0 May 11, 2023
fff4e4b
change run sequence
yiranwu0 May 11, 2023
da0f7d9
run exact v3.1 as before
yiranwu0 May 11, 2023
2775e08
keep runing v3.1python and add general_5
yiranwu0 May 11, 2023
ad10b71
add general_5
yiranwu0 May 11, 2023
7800a46
continue run 55 and 56
yiranwu0 May 12, 2023
4f78539
switch seq
yiranwu0 May 12, 2023
a76113f
trial 63 v3.5python, then run large-scale with v3.3python
yiranwu0 May 12, 2023
7d22c07
add v3.3, 3.7, 3.8
yiranwu0 May 13, 2023
908d283
revise 3.6-3.8
yiranwu0 May 13, 2023
079b4e2
v3.9
yiranwu0 May 13, 2023
f071214
test interalge and precal on v3.9
yiranwu0 May 13, 2023
6444f91
test v3.9 on 50 problems, then zero shot
yiranwu0 May 13, 2023
1c4a278
fix prompt
yiranwu0 May 13, 2023
c744613
endpoint
sonichi May 13, 2023
2ad469f
Merge branch 'gpt_math_solver' of https://github.com/kevin666aa/FLAML…
sonichi May 13, 2023
b806dfb
run all problems on v3.9, and pnas
yiranwu0 May 13, 2023
3733028
endpoint
sonichi May 13, 2023
cbd0be0
Merge remote-tracking branch 'upstream/main' into gpt_math_solver
May 15, 2023
89d7512
run v1python
yiranwu0 May 15, 2023
3791326
Merge branch 'gpt_math_solver' of github.com:kevin666aa/FLAML into gp…
yiranwu0 May 15, 2023
2a6ffa1
run v1python+wolfram
yiranwu0 May 16, 2023
d833938
run pot with sys message
yiranwu0 May 19, 2023
bf73756
endpoint
sonichi May 19, 2023
f1b3873
Merge branch 'gpt_math_solver' of https://github.com/kevin666aa/FLAML…
sonichi May 19, 2023
e3d8de1
run pot with system message
yiranwu0 May 19, 2023
d84213d
Merge branch 'gpt_math_solver' of https://github.com/kevin666aa/FLAML…
sonichi May 19, 2023
f8c68ff
Merge branch 'gpt_math_solver' of github.com:kevin666aa/FLAML into gp…
May 19, 2023
9bc17db
fewshot+zeroshot prompt
May 19, 2023
769803e
add assert
May 19, 2023
85d9b59
refine fewshot
yiranwu0 May 20, 2023
bce7f4f
run pre-commit
yiranwu0 May 20, 2023
59bc9f9
rerun v3.9 with cache and get token info
yiranwu0 May 21, 2023
32de58f
run PoT on all problems
yiranwu0 May 22, 2023
9dabf61
Merge remote-tracking branch 'upstream/main' into gpt_math_solver
yiranwu0 May 22, 2023
9c3efd4
merge new changes and update pot
yiranwu0 May 22, 2023
fc8bcdc
endpoint
sonichi May 22, 2023
c711143
Merge branch 'gpt_math_solver' of https://github.com/kevin666aa/FLAML…
sonichi May 22, 2023
f535e50
fix decode in PoT
yiranwu0 May 22, 2023
841ff2a
Merge branch 'gpt_math_solver' of github.com:kevin666aa/FLAML into gp…
yiranwu0 May 22, 2023
43d8277
clean up and rename
yiranwu0 May 27, 2023
01f7712
resolve conflict in setup
yiranwu0 May 27, 2023
d4d8242
Merge branch 'microsoft:main' into gpt_math_solver
yiranwu0 Jun 7, 2023
1cfce5f
clean up
yiranwu0 Jun 7, 2023
be7bb3d
update readme
yiranwu0 Jun 7, 2023
d3e8719
add mathchat flow hart
yiranwu0 Jun 7, 2023
2c8823f
Update README.md
yiranwu0 Jun 7, 2023
7808b4f
Merge branch 'microsoft:main' into gpt_math_solver
yiranwu0 Jun 10, 2023
348446b
add missing files
yiranwu0 Jul 10, 2023
c49ab9c
Merge branch 'gpt_math_solver' of github.com:kevin666aa/FLAML into gp…
yiranwu0 Jul 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions flaml/autogen/math_solver/MathSolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from QueryHandler import QueryHandler
from flaml.autogen.math_utils import eval_math_responses, remove_boxed, last_boxed_only_string, nestmkdir, write_json, remove_asy_sections, math_type_mapping
from flaml import oai
import os
import json
import re
import copy


PROMPT = """
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
Let's use two tools (python code and Wolfram Alpha) to solve this problem step by step. You should always follow your own reasoning and only query when necessary.

First state the key idea to solve the problem. Then follow the process:
1. Output one step.
2. Take out any queries that can be asked through python or Wolfram Alpha (for example, any calculations or equations that can be calculated) and choose the best tool to be used. When you are querying python, you should: 1.use tab('\\t') for indentation. 2. use 'print' function for the output. 3. always output fractions instead of decimal.
Please format the query in json:
{ "tool" : "", # "python" or "wolfram"
"query": "", # your query here, either python code or Wolfram query.
}
4. Wait for me to give the results.
5. Give a new query if the results are invalid or unexpected.
6. When you get the answer, put the answer in \\box{}.
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved

Problem:
"""


class MathSolver:
def __init__(self, model, max_tokens, max_round=10, n=1, use_cache=True, cache_folder='.cache'):
self.max_round = max_round

self.deafult_config = {
'model': model,
"max_tokens": max_tokens,
'messages' : [
{"role": "system", "content": "You are a helpful assistant."},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question could be: Does changing the system prompt make any difference?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about making the prompt before the "Problem:" a system message?

],
'n' : n, # n should be 1 for now
}
# set oai cache
self.cache_folder = cache_folder
self.use_cache = use_cache
oai.ChatCompletion.set_cache(seed=41, cache_path=self.cache_folder)


def make_conversation(self, problem, saving_folder):
query_hanlder = QueryHandler()

# initialize the conversation
config = copy.deepcopy(self.deafult_config)
config['messages'].append({"role": "user", "content": PROMPT + remove_asy_sections(problem['problem'])})

# save a readable conversation in txt file
convesation_saver = open(os.path.join(saving_folder, problem['problem_id'] + '.txt'), 'a')
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
seperate_line = '\n'+ '-'* 40 + '\n'
convesation_saver.write(f'Problem: {self.str_splitter(problem["problem"])}\n\n {seperate_line}')

# init parameters
is_valid_reply = False # only valid when detect \box
consecutive_fail = False # for query
token_used, total_cost = 0, 0
response_with_ans = "" # save the response with \box to get the answer
for _ in range(self.max_round):
# 1. get the response from the assistant
raw_responses = oai.ChatCompletion.create(None, **config, use_cache=self.use_cache)
if raw_responses == -1:
break # catch the error when no valid reply
responses = [r["message"]["content"].rstrip() for r in raw_responses["choices"]]
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
convesation_saver.write(f'assistant: {self.str_splitter(responses[0])}{seperate_line}')
token_used = raw_responses['usage']['total_tokens']
total_cost += oai.ChatCompletion.cost(self.deafult_config['model'], raw_responses)
config['messages'].append({"role": "assistant", "content": responses[0]}) # append the response to the conversation
if '\\box' in responses[0]:
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
# if the assistant gives a valid reply, stop the conversation
is_valid_reply = True
response_with_ans = responses[0]
break
elif token_used > 8192 - config['max_tokens']:
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
# if the assistant uses too many tokens, stop the conversation. max prompt token + max response token allowed = 8192
break
assert len(responses) == 1, 'More than one response' # right now we only use one response

# 2. handle the response and get the query
query_response, is_query_sucess = query_hanlder.handle_query(responses[0])
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
if len(query_response) > 2000:
# prevent long response by string length, 2000 chars -> around 500-1000 tokens
query_response = 'Your requested query response is too long. You might have made a mistake. Please revise your reasoning and query.'
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
is_query_sucess = False
config['messages'].append({"role": "user", "content": query_response})
if not is_query_sucess:
if consecutive_fail:
# if the query is not valid and last query is also failed, replace the last message with a skip query message
assert config['messages'][-1]['role'] == 'user', 'The last message should be from user'
skip_query_str = 'Please solve this step yourself and do not use the tools. Then start the next step and give new queries if needed.'
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
config['messages'][-1]['content'] = skip_query_str
convesation_saver.write(f'****: Replacing {query_response}****\n')
consecutive_fail = False
else:
consecutive_fail = True
convesation_saver.write('user: {a}{s}'.format(a=config['messages'][-1]['content'], s=seperate_line))
convesation_saver.flush()

convesation_saver.write('Solution: ' + problem['solution'])
convesation_saver.close()
return {
'valid_q_count' : query_hanlder.valid_q_count, # number of valid queries
'total_q_count' : query_hanlder.total_q_count,
'is_valid_reply': is_valid_reply, # whether the assistant can give a valid reply
'response_with_ans': response_with_ans,
'messages': config['messages'],
'round' : len(config['messages'])//2 + 1,
'cost' : total_cost,
}


def str_splitter(self, string, length=130):
"""
Add '\n' every 'length' characters to make the output more readable.
If at 'length' there is a word, add '\n' before the word.

Args:
string (str): The input string to be processed.
length (int): The maximum number of characters in a line before adding a newline.

Returns:
str: The processed string with newlines added.
"""

words = string.split(' ')
current_line = []
current_length = 0
result = []

for word in words:
if current_length + len(word) + len(current_line) > length:
result.append(' '.join(current_line))
current_line = []
current_length = 0

current_line.append(word)
current_length += len(word)

if current_line:
result.append(' '.join(current_line))

return '\n'.join(result)


def solve_one_category(self, problem_set, saving_folder):
"""
Solve all problems in a category.
Assumption 1: all problems are of the same type
Assumption 2: if resume from a previous run, the sequence of problems are the same as the previous run, using same shuffling seed

Args:
problem_set (list): a list of problems
saving_folder (str): the result folder to save the solved problems, the category folder will be created inside

Returns:
None
"""

# assume all problems are of the same type: TODO: ensure this assumption
saving_folder = os.path.join(saving_folder, math_type_mapping[problem_set[0]['type']])
# assign temporary problem_id
for i in range(len(problem_set)):
problem_set[i]['problem_id'] = str(i)
# mkdir if not exist
nestmkdir(saving_folder, verbose=True)

# from the saving folder load solved problems
done_problems = set([int(f.split('.')[0]) for f in os.listdir(saving_folder) if 'json' in f])

correct_counts = 0
for count, problem in enumerate(problem_set):
problem_path = os.path.join(saving_folder, problem['problem_id'] + '.json')

# 1. if problem already solved, continue
if int(problem['problem_id']) in done_problems:
problem = json.load(open(problem_path, 'r'))
correct_counts += problem['is_correct']
print(f'{count}: {correct_counts}/{count+1} successes. valid response: {problem["is_valid_reply"]}, Correct: {problem["is_correct"]}, {problem["round"]} rounds. (This problem is loaded from previous run)')
continue

# 2. solve the problem
result = self.make_conversation(problem, saving_folder)
metrics = eval_math_responses([result['response_with_ans']], problem['solution'])

# 3. save the result
correct_ans = remove_boxed(last_boxed_only_string(problem['solution']))
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
problem.update({
'is_valid_reply': result['is_valid_reply'],
'is_correct': bool(metrics['success_vote']),
'correct_ans': correct_ans,
'voted_answer': remove_boxed(last_boxed_only_string(metrics['voted_answer'])) ,
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
'round': result['round'],
'valid_q_count': result['valid_q_count'], # total number of valid queries
'total_q_count': result['total_q_count'], # total number of queries
'cost': result['cost'], # total cost of the conversation
'messages': result['messages'], # the conversation
})
write_json(problem, problem_path)

# 4. continue to next problem
correct_counts += problem['is_correct']
print(f'{problem["problem_id"]} Is Valid: {problem["is_valid_reply"]}, Is Correct: {bool(problem["is_correct"])}, Conversation Round: {problem["round"]}, Accum Sucesses: {correct_counts}/{count+1}')

tp = problem_set[0]['type']
print(f'{tp} correct rate: {correct_counts}/{len(problem_set)} = {correct_counts/len(problem_set)}')


Loading