-
Notifications
You must be signed in to change notification settings - Fork 18
/
test_api_model.py
executable file
·206 lines (168 loc) · 7.86 KB
/
test_api_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#!/usr/bin/env python
try:
import dotenv
load_dotenv(override=True)
except:
pass
import time
import json
import sys
import os
import requests
import argparse
import subprocess
import traceback
from datauri import DataURI
from openai import OpenAI
import torch
# tests are configured with model_conf_tests.json
all_results = []
client = OpenAI(
base_url=os.environ.get("OPENAI_BASE_URL", 'http://localhost:5006/v1'),
api_key=os.environ.get("OPENAI_API_KEY", 'sk-ip'),
)
urls = {
'tree': 'https://images.freeimages.com/images/large-previews/e59/autumn-tree-1408307.jpg',
'waterfall': 'https://images.freeimages.com/images/large-previews/242/waterfall-1537490.jpg',
'horse': 'https://images.freeimages.com/images/large-previews/5fa/attenborough-nature-reserve-1398791.jpg',
'leaf': 'https://images.freeimages.com/images/large-previews/cd7/gingko-biloba-1058537.jpg',
}
quality_urls = {
'98.21': ('What is the total bill?', 'https://ocr.space/Content/Images/receipt-ocr-original.webp'),
'walmart': ('What store is the receipt from?', 'https://ocr.space/Content/Images/receipt-ocr-original.webp'),
}
no_image = {
'5': 'In the integer sequence: 1, 2, 3, 4, ... What number comes next after 4?'
}
green_pass = '\033[92mpass\033[0m✅'
red_fail = '\033[91mfail\033[0m❌'
def data_url_from_url(img_url: str) -> str:
response = requests.get(img_url)
img_data = response.content
content_type = response.headers['content-type']
return str(DataURI.make(mimetype=content_type, charset='utf-8', base64=True, data=img_data))
def record_result(cmd_args, results, t, mem, note):
# update all_results with the test data
all_results.extend([{
'args': cmd_args,
'results': results,
'time': t,
'mem': mem,
'note': note
}])
result = all(results)
print(f"test {green_pass if result else red_fail}, time: {t:.1f}s, mem: {mem:.1f}GB, {note}")
if __name__ == '__main__':
# Initialize argparse
parser = argparse.ArgumentParser(description='Test vision using OpenAI')
parser.add_argument('-s', '--system-prompt', type=str, default=None)
parser.add_argument('-m', '--max-tokens', type=int, default=None)
parser.add_argument('-t', '--temperature', type=float, default=None)
parser.add_argument('-p', '--top_p', type=float, default=None)
parser.add_argument('-v', '--verbose', action='store_true', help="Verbose")
parser.add_argument('--openai-model', type=str, default="gpt-4-vision-preview")
parser.add_argument('--abort-on-fail', action='store_true', help="Abort testing on fail.")
parser.add_argument('--quiet', action='store_true', help="Less test noise.")
parser.add_argument('-L', '--log-level', default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the log level")
args = parser.parse_args()
params = {}
if args.max_tokens is not None:
params['max_tokens'] = args.max_tokens
if args.temperature is not None:
params['temperature'] = args.temperature
if args.top_p is not None:
params['top_p'] = args.top_p
def generate_response(image_url, prompt):
messages = [{ "role": "system", "content": [{ 'type': 'text', 'text': args.system_prompt }] }] if args.system_prompt else []
if isinstance(image_url, str):
image_url = [image_url]
content = []
for url in image_url:
content.extend([{ "type": "image_url", "image_url": { "url": url } }])
content.extend([{ "type": "text", "text": prompt }])
messages.extend([{ "role": "user", "content": content }])
response = client.chat.completions.create(model=args.openai_model, messages=messages, **params)
completion_tokens = 0
answer = response.choices[0].message.content
if response.usage:
completion_tokens = response.usage.completion_tokens
return answer, completion_tokens
def generate_stream_response(image_url, prompt):
messages = [{ "role": "system", "content": [{ 'type': 'text', 'text': args.system_prompt }] }] if args.system_prompt else []
if isinstance(image_url, str):
image_url = [image_url]
content = []
for url in image_url:
content.extend([{ "type": "image_url", "image_url": { "url": url } }])
content.extend([{ "type": "text", "text": prompt }])
messages.extend([{ "role": "user", "content": content }])
response = client.chat.completions.create(model=args.openai_model, messages=messages, **params, stream=True)
answer = ''
completion_tokens = 0
for chunk in response:
if chunk.choices[0].delta.content:
answer += chunk.choices[0].delta.content
if chunk.usage:
completion_tokens = chunk.usage.completion_tokens
return answer, completion_tokens
if True:
# XXX TODO: timeout
results = []
### Single round
timing = []
def single_test(url, question, right_answer, label, generator=generate_response):
tps_time = time.time()
answer, tok = generator(url, question)
tps_time = time.time() - tps_time
correct = right_answer in answer.lower()
results.extend([correct])
if not correct:
print(f"{right_answer}[{label}]: {red_fail}, got: {answer}")
#if args.abort_on_fail:
# break
else:
print(f"{right_answer}[{label}]: {green_pass}{', got: ' + answer if args.verbose else ''}")
if tok > 1:
timing.extend([(tok, tps_time)])
test_time = time.time()
# url tests
for name, url in urls.items():
single_test(url, "What is the subject of the image?", name, "url", generate_response)
data_url = data_url_from_url(url)
single_test(data_url, "What is the subject of the image?", name, "data", generate_response)
single_test(data_url, "What is the subject of the image?", name, "data_stream", generate_stream_response)
## OCR tests
quality_urls = {
'98.21': ('What is the total bill?', 'https://ocr.space/Content/Images/receipt-ocr-original.webp'),
'walmart': ('What store is the receipt from?', 'https://ocr.space/Content/Images/receipt-ocr-original.webp'),
}
for name, question in quality_urls.items():
prompt, data_url = question
single_test(data_url, prompt, name, "quality", generate_stream_response)
# No image tests
no_image = {
'5': 'In the sequence of numbers: 1, 2, 3, 4, ... What number comes next after 4? Answer only the number.'
}
for name, prompt in no_image.items():
single_test([], prompt, name, 'no_img', generate_response)
# Multi-image test
multi_image = {
"water": ("What natural element is common in both images?",
[ 'https://images.freeimages.com/images/large-previews/e59/autumn-tree-1408307.jpg',
'https://images.freeimages.com/images/large-previews/242/waterfall-1537490.jpg'])
}
for name, question in multi_image.items():
prompt, data_url = question
single_test(data_url, prompt, name, "multi-image", generate_stream_response)
test_time = time.time() - test_time
result = all(results)
note = f'{results.count(True)}/{len(results)} tests passed.'
if timing:
tok_total, tim_total = 0, 0.0
for tok, tim in timing:
if tok > 1 and tim > 0:
tok_total += tok
tim_total += tim
if tim_total > 0.0:
note += f', ({tok_total}/{tim_total:0.1f}s) {tok_total/tim_total:0.1f} T/s'
print(f"test {green_pass if results else red_fail}, time: {test_time:.1f}s, {note}")