-
Notifications
You must be signed in to change notification settings - Fork 0
/
api_utils.py
210 lines (160 loc) · 6.61 KB
/
api_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import base64
from openai import OpenAI
import os
import requests
def encode_image(image_path):
"""
Encode the image into base64 format
"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def get_prompts(caption_text):
"""
Given the path of the text file that can fill in the basic prompt, return the filled prompt for the user to fill in the content.
Args:
caption_text: str, the caption of the figure
Returns:
filled_prompt: str, the filled prompt for the user to fill in the content
"""
Basic_prompt = """I have images generated by diffusion model, but the text in the picture is bad and I want to replace them.
You should guess the content in the red circle. There are three steps:
1. given caption of the figure, read and understand it
2. given pictures from the paper, summarize the content from the paper
3. given several pictures with red circles, guess the content in the red circle.
If the content is confusing, you should DEFINITELY discard the text and guess the content by the context!
You should first briefly output the summarization of step 2 in points within 200 tokens, then
ouput the top-3 possible contents in the circle in the format of i: (Guess words i, Reason i) separated by '@'
and at last output the original confusing OCR recognition result.
Example format: Step2*** ...(summarization) Step3*** 1: (guess1, reason1) @ 2: (guess2, reason2) @ 3: (guess3, reason3) OCR*** XXX.
Please rigorously follow the example format. Don't output any other contexts in step3, don't contain \' in the format. Don't have analysis on the top of output like 'based on ...'Let's begin.
Step1 caption is {}.
Step2 The pictures given except the last one is the reference from the paper.
Step3 The last picture which have red circle is used in step 3.
"""
#
filled_prompt = Basic_prompt.format(caption_text)
print(f'filled_prompt: {filled_prompt}')
return filled_prompt
def get_result_list_from_content(content):
"""
Parse the response content from the API and return the result list. The response is determined by the prompt.
Args:
content: the response content from the API
Returns:
result_list: a list of tuples, each tuple is (guess, reason)
ocr: the OCR text
"""
result_list = []
# get the content of Step 3
content = content.split('Step3***')[1].strip()
# # get corrected OCR
# correct_ocr = content.split('Corrected OCR: ')[1].strip()
# print(f'######### correct_ocr: {correct_ocr}')
# content = content.split('Corrected OCR: ')[0]
# first get the 'OCR: XXX.' from the end of the content
ocr = content.split('OCR***')[1].strip()
print(f'######### ocr: {ocr}')
content = content.split('OCR***')[0]
print(f'content = {content}')
# Parse the response content that contain '1: (guess, reason)'
for line in content.strip().split('@'):
if line:
parts = line.split(":", 1)
print(f'line= {line}, parts = {parts}')
if len(parts) == 2:
index_part, rest = parts
try:
# Split the rest on the comma and the first occurrence of '(' and ')'
# This assumes that the structure is always like '1: (text, text)'
# We then clean up the extracted text
guess, reason = rest.strip().split(',', 1)
guess = guess.strip("(")
reason = reason.strip(")")
result_list.append((guess, reason))
except ValueError:
# Handle the case where splitting didn't work as expected
print(f"Could not parse line: {line}")
# Now result_list contains the extracted information
return result_list, ocr
def get_content_list(caption_text, reference_images_path, red_circle_recon_image_path):
"""
Get content list from adding prompts, pictures from reference and pictures from red_circle
Args:
caption_text: str, the caption of the figure
reference_images_path: list of str, paths of the reference images
red_circle_recon_image_path: str, path of the red circle image
Returns:
content_list: list of dict, the content list for the API
"""
## add the filled prompt
filled_prompt = get_prompts(caption_text)
content_list = [
{"type": "text", "text": filled_prompt},
]
## add the reference images
if len(reference_images_path) == 0:
print(f'No reference images found!')
else:
print(f'Add reference images: {reference_images_path}')
for reference_image_path in reference_images_path:
content_list.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encode_image(reference_image_path)}",
},
}
)
## add the red circle image
content_list.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encode_image(red_circle_recon_image_path)}",
},
}
)
# print(f'content_list: {content_list}')
return content_list
def get_api_key():
"""
Get the API key
Returns:
api_key: str, the API key
"""
return os.environ.get('NLP_API_KEY')
def get_response(content_list):
"""
Get response from the API
Args:
content_list: list of dict, the content list for the API
api_key: str, the API key
Returns:
response: the response from the API
"""
api_key = get_api_key()
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": content_list,
}
],
max_tokens=500,
)
return response
def run_GPT4V_api_one_step(caption_text, reference_images_path, red_circle_recon_image_path):
"""
Run the GPT-4 Vision API for one step
Args:
caption_text: str, the caption of the figure
reference_images_path: list of str, paths of the reference images
red_circle_recon_image_path: str, path of the red circle image
Returns:
response: the response from the API
"""
content_list = get_content_list(caption_text, reference_images_path, red_circle_recon_image_path)
response = get_response(content_list)
return response