-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathrun_benchmark.py
352 lines (297 loc) · 14.7 KB
/
run_benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
#
# run_benchmark.py
#
# Benchmark queries for AI assistant. Used for testing and assessing the quality of assistant
# responses. This script talks to a production endpoint and not the Python assistant server
# directly. Simply run:
#
# python run_benchmark.py tests/tests.json
#
# Use --help for more instructions.
#
import argparse
from datetime import datetime
from enum import Enum
import json
import os
import requests
from typing import List, Optional
import numpy as np
from pydantic import BaseModel, RootModel
from models import Capability, MultimodalResponse
####################################################################################################
# Test Case JSON and Evaluation
####################################################################################################
class UserMessage(BaseModel):
text: str
image: Optional[str] = None
capabilities: Optional[List[Capability]] = None # capabilities that are required to have been used
capabilities_any: Optional[List[Capability]] = None # must use at least one of the capabilities listed here
class TestCase(BaseModel):
name: str
active: bool
default_image: Optional[str] = None
conversations: List[List[UserMessage | str]]
class TestCaseFile(RootModel):
root: List[TestCase]
class TestResult(str, Enum):
FAILED = "FAILED"
IGNORED = "IGNORED"
PASSED = "PASSED"
def load_tests(filepath: str) -> List[TestCase]:
with open(file=filepath, mode="r") as fp:
text = fp.read()
return TestCaseFile.model_validate_json(json_data=text).root
def evaluate_capabilities_used(input: UserMessage, output: MultimodalResponse) -> TestResult:
# Do we have anything to evaluate against?
has_required_capabilities = input.capabilities is not None and len(input.capabilities) > 0
has_any_capabilities = input.capabilities_any is not None and len(input.capabilities_any) > 0
if (not has_required_capabilities) and (not has_any_capabilities):
# Ignore if desired test results are not specified
return TestResult.IGNORED
capabilities_used = output.capabilities_used
# Evaluate result against required capabilities
if has_required_capabilities:
for required_capability in input.capabilities:
if required_capability not in capabilities_used:
return TestResult.FAILED
# Evaluate result against "any capabilities" (an OR function)
if has_any_capabilities:
any_present = False
for interchangeable_capability in input.capabilities_any:
if interchangeable_capability in capabilities_used:
any_present = True
if not any_present:
return TestResult.FAILED
return TestResult.PASSED
####################################################################################################
# Helper Functions
####################################################################################################
def load_binary_file(filepath: str) -> bytes:
with open(file=filepath, mode="rb") as fp:
return fp.read()
####################################################################################################
# Markdown Report Generation
####################################################################################################
class ReportGenerator:
def __init__(self, test_filepath: str, generate_markdown: bool):
self._generate_markdown = generate_markdown
if not generate_markdown:
return
base = os.path.splitext(os.path.basename(test_filepath))[0]
filename = f"{base}.md"
self._fp = open(file=filename, mode="w")
self._fp.write(f"# {test_filepath}\n\n")
self._total_times = []
def __del__(self):
if not self._generate_markdown:
return
self._fp.close()
def begin_test(self, name: str):
self._total_times = []
if not self._generate_markdown:
return
self._fp.write(f"## {name}\n\n")
self._fp.write(f"|Passed?|User|Assistant|Image|Debug|\n")
self._fp.write(f"|-------|----|---------|-----|-----|\n")
def begin_conversation(self):
if not self._generate_markdown:
return
self._fp.write("|\\-\\-\\-\\-\\-\\-\\-\\-|\\-\\-\\-\\-\\-\\-\\-\\-|\\-\\-\\-\\-\\-\\-\\-\\-|\\-\\-\\-\\-\\-\\-\\-\\-|\\-\\-\\-\\-\\-\\-\\-\\-|\n")
def end_conversation(self):
pass
def add_result(self, user_message: UserMessage, response: MultimodalResponse, assistant_response: str, test_result: TestResult):
if not self._generate_markdown:
return
passed_column = f"{test_result.value}"
user_column = self._escape(user_message.text)
assistant_column = self._escape(assistant_response)
image_column = f"<img src=\"{user_message.image}\" alt=\"image\" style=\"width:200px;\"/>" if user_message.image is not None else ""
debug_column = f"```{response.debug_tools}```"
self._fp.write(f"|{passed_column}|{user_column}|{assistant_column}|{image_column}|{debug_column}|\n")
# Timings
try:
timings = json.loads(response.timings)
self._total_times.append(float(timings["total_time"]))
except:
pass
def end_test(self, num_passed: int, num_evaluated: int):
if not self._generate_markdown:
return
mean_time = np.mean(self._total_times)
median_time = np.median(self._total_times)
min_time = np.min(self._total_times)
max_time = np.max(self._total_times)
pct90_time = np.quantile(self._total_times, q=0.9)
pct95_time = np.quantile(self._total_times, q=0.95)
pct99_time = np.quantile(self._total_times, q=0.99)
if num_evaluated == 0:
self._fp.write(f"**Score: N/A**\n\n")
else:
self._fp.write(f"**Score: {100.0 * num_passed / num_evaluated : .1f}%**\n\n")
self._fp.write(f"**Timings**\n")
self._fp.write(f"|Mean|Median|Min|Max|90%|95%|99%|\n")
self._fp.write(f"|----|------|---|---|---|---|---|\n")
self._fp.write(f"|{mean_time:.1f}|{median_time:.1f}|{min_time:.1f}|{max_time:.1f}|{pct90_time:.1f}|{pct95_time:.1f}|{pct99_time:.1f}|\n\n")
@staticmethod
def _escape(text: str) -> str:
special_chars = "\\`'\"*_{}[]()#+-.!"
escaped_text = ''.join(['\\' + char if char in special_chars else char for char in text])
return escaped_text.replace("\n", " ")
####################################################################################################
# Main Program
####################################################################################################
if __name__ == "__main__":
parser = argparse.ArgumentParser("run_benchmark")
parser.add_argument("file", nargs=1)
parser.add_argument("--endpoint", action="store", default="https://api.brilliant.xyz/dev/noa/mm", help="Address to send request to (Noa server)")
parser.add_argument("--token", action="store", help="Noa API token")
parser.add_argument("--test", metavar="name", help="Run specific test")
parser.add_argument("--markdown", action="store_true", help="Produce report in markdown file")
parser.add_argument("--vision", action="store", help="Vision model to use (gpt-4o, gpt-4-vision-preview, claude-3-haiku-20240307, claude-3-sonnet-20240229, claude-3-opus-20240229)", default="gpt-4o")
parser.add_argument("--address", action="store", default="San Francisco, CA 94115", help="Simulated location")
options = parser.parse_args()
# Load tests
tests = load_tests(filepath=options.file[0])
# Markdown report generator
report = ReportGenerator(test_filepath=options.file[0], generate_markdown=options.markdown)
# Authorization header
headers = {
"Authorization": options.token if options.token is not None else os.getenv("BRILLIANT_API_KEY")
}
# Metrics
total_user_prompts = 0
total_tokens_in = 0
total_tokens_out = 0
total_times = []
localhost = options.endpoint == "localhost"
# Run all active tests
for test in tests:
if not options.test:
# No specific test, run all that are active
if not test.active:
continue
else:
if test.name.lower().strip() != options.test.lower().strip():
continue
print(f"Test: {test.name}")
report.begin_test(name=test.name)
num_evaluated = 0
num_passed = 0
for conversation in test.conversations:
report.begin_conversation()
# Create new message history for each conversation
history = []
for user_message in conversation:
# Each user message can be either a string or a UserMessage object
assert isinstance(user_message, str) or isinstance(user_message, UserMessage)
if isinstance(user_message, str):
user_message = UserMessage(text=user_message)
# If there is no image associated with this message, use the default image, if it
# exists
if user_message.image is None and test.default_image is not None:
user_message = user_message.model_copy()
user_message.image = test.default_image
# Construct API call data
if localhost:
options.endpoint = "http://localhost:8000/mm"
data = {
"mm": json.dumps({
"prompt": user_message.text,
"messages": history,
"address": options.address,
"local_time": datetime.now().strftime("%A, %B %d, %Y, %I:%M %p"),
"search_api": "perplexity",
"config": { "engine": "google_lens" },
"experiment": "1",
"vision": options.vision
}
),
}
else:
data = {
"prompt": user_message.text,
"messages": json.dumps(history),
"address": options.address,
"local_time": datetime.now().strftime("%A, %B %d, %Y, %I:%M %p"),
"search_api": "perplexity",
"config": json.dumps({ "engine": "google_lens" }),
"experiment": "1", # this activates the passthrough to the Python ai-experiments code
"vision": options.vision
}
files = {}
if user_message.image is not None:
files["image"] = (os.path.basename(user_message.image), load_binary_file(filepath=user_message.image))
# Make the call and evaluate
response = requests.post(url=options.endpoint, files=files, data=data, headers=headers)
error = False
try:
if response.status_code != 200:
print(f"Error: {response.status_code}")
print(response.content)
response.raise_for_status()
#print(response.content)
mm_response = MultimodalResponse.model_validate_json(json_data=response.content)
#print("Sent:")
#print(json.dumps(history))
test_result = evaluate_capabilities_used(input=user_message, output=mm_response)
if test_result != TestResult.IGNORED:
num_evaluated += 1
num_passed += (1 if test_result == TestResult.PASSED else 0)
history.append({ "role": "user", "content": user_message.text })
assistant_response = ""
if len(mm_response.response) > 0:
assistant_response = mm_response.response
elif len(mm_response.image) > 0:
assistant_response = "<generated image>"
if len(assistant_response) > 0:
history.append({ "role": "assistant", "content": assistant_response })
timings = json.loads(mm_response.timings)
print(f"User: {user_message.text}" + (f" ({user_message.image})" if user_message.image else ""))
print(f"Response: {assistant_response}")
print(f"Tools: {mm_response.debug_tools}")
print(f"Timings: {timings}")
#pct_out = float(content["output_tokens"]) / float(content["total_tokens"]) * 100.0
#print(f"Tokens: in={content['input_tokens']}, out={content['output_tokens']} %out={pct_out:.0f}%")
print(f"Test: {test_result}")
print("")
report.add_result(user_message=user_message, response=mm_response, assistant_response=assistant_response, test_result=test_result)
total_user_prompts += 1
total_tokens_in += mm_response.input_tokens
total_tokens_out += mm_response.output_tokens
total_times.append(float(timings["total_time"]))
except Exception as e:
print(f"Error: {e}")
report.end_conversation()
# Print test results
print("")
print(f"TEST RESULTS: {test.name}")
if num_evaluated == 0:
print(f" Score: N/A")
else:
print(f" Score: {num_passed}/{num_evaluated} = {100.0 * num_passed / num_evaluated : .1f}%")
report.end_test(num_passed=num_passed, num_evaluated=num_evaluated)
# Summary
print(f"User messages: {total_user_prompts}")
print(f"Total input tokens: {total_tokens_in}")
print(f"Total output tokens: {total_tokens_out}")
print(f"Average input tokens: {total_tokens_in / total_user_prompts}")
print(f"Average output tokens: {total_tokens_out / total_user_prompts}")
# Timings
mean_time = np.mean(total_times)
median_time = np.median(total_times)
min_time = np.min(total_times)
max_time = np.max(total_times)
pct90_time = np.quantile(total_times, q=0.9)
pct95_time = np.quantile(total_times, q=0.95)
pct99_time = np.quantile(total_times, q=0.99)
print("")
print("Timing")
print("------")
print(f"Mean : {mean_time:.1f}")
print(f"Median: {median_time:.1f}")
print(f"Min : {min_time:.1f}")
print(f"Max : {max_time:.1f}")
print(f"90% : {pct90_time:.1f}")
print(f"95% : {pct95_time:.1f}")
print(f"99% : {pct99_time:.1f}")