-
Notifications
You must be signed in to change notification settings - Fork 0
/
cant_count.py
129 lines (103 loc) · 3.16 KB
/
cant_count.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import itertools
import operator
import os
import re
import sys
import argparse
from loguru import logger
import openai
logger.remove()
logger.add(sys.stderr, level="INFO")
_MODEL = "gpt-4o"
_API_KEY = os.getenv("OPENAI_API_KEY")
_CLIENT = openai.OpenAI(api_key=_API_KEY)
def _send_request(prompt):
return _CLIENT.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model=_MODEL,
)
def _str_to_int(s):
try:
return int(s.replace(",", ""))
except:
logger.warning(f"Model returned non-parsable response: '{s}'")
return None
def _extract_ints_from_string(s) -> list[int]:
return list(filter(None, map(_str_to_int, re.findall(r"[\d,]+", s))))
def _get_response_answers(response) -> tuple[int, ...]:
ints = []
for choice in response.choices:
ints += _extract_ints_from_string(choice.message.content)
return tuple(ints)
def _send_and_verify(a, b, op, operator_func):
prompt = f"{a}{op}{b}="
response = _send_request(prompt)
answers = _get_response_answers(response)
target = operator_func(a, b)
return any(answer == target for answer in answers), response
def run(min_n, max_n, max_attempts, op):
# Up-only binary search of the range.
operator_func = {
"+": operator.add,
"*": operator.mul,
"-": operator.sub,
}[op]
all_combinations = tuple(itertools.product(range(min_n, max_n + 1), repeat=2))
curr_idx = 0
max_idx = len(all_combinations) - 1
next_idx = max_idx // 2
logger.info(f"Testing {_MODEL} on operator '{op}' from {min_n:,} to {max_n:,}...")
num_attempts = 0
while num_attempts < max_attempts:
num_attempts += 1
a, b = all_combinations[curr_idx]
logger.debug(f"Testing {a:,} {op} {b:,}")
is_correct, response = _send_and_verify(a, b, op, operator_func)
if is_correct:
logger.info(f"{a:,} {op} {b:,} correct.")
logger.debug(
f"Model answered correctly: '{response.choices[0].message.content}'."
)
curr_idx = next_idx
next_idx = ((max_idx - next_idx) // 2) + next_idx
logger.debug(f"Next idx: {next_idx}")
else:
logger.info(f"{a:,} {op} {b:,} INCORRECT!!!")
logger.info(f"Model answered: '{response.choices[0].message.content}'.")
logger.info(f"Correct answer was: {operator_func(a,b):,}.")
break
logger.info(f"Done after {num_attempts} attempts.")
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
"--min",
type=int,
required=True,
)
arg_parser.add_argument(
"--max",
type=int,
required=True,
)
arg_parser.add_argument(
"--max_attempts",
type=int,
required=True,
)
arg_parser.add_argument(
"--op",
type=str,
default="+",
)
options = arg_parser.parse_args()
run(
min_n=options.min,
max_n=options.max,
max_attempts=options.max_attempts,
op=options.op,
)