-
Notifications
You must be signed in to change notification settings - Fork 25
/
GPT4V_ZS.py
115 lines (93 loc) · 3.98 KB
/
GPT4V_ZS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from openai import OpenAI
import base64
import glob
import os
import json
from datetime import datetime
import time
import hashlib
import random
from config_path import find_path
import yaml
client = OpenAI()
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
dataset_name = 'dtd' # ["Your Dataset Name Here"]
config_path = find_path(dataset_name)
with open(config_path, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
categories = config['categories']
data_path = config['data_path']
image_files = sorted(glob.glob(os.path.join(data_path, "*", "*.jpg")))
random.seed(666)
random.shuffle(image_files)
total_num = len(image_files)
print('Total images: {}'.format(total_num))
print('Total classes: {}'.format(len(categories)))
output_dir = f'./GPT4V_Pred_{dataset_name}'
if not os.path.exists(output_dir):
os.mkdir(output_dir)
# hash encoding image names
all_names = []
hash_dict = {}
for idx in range(len(image_files)):
ori_name = os.path.basename(image_files[idx])
hash_name = hashlib.sha256(ori_name.encode('utf-8')).hexdigest()[:10]
all_names.append(hash_name)
hash_dict[hash_name] = ori_name
hash_dict_path = os.path.join(output_dir, 'hash_dict.json')
with open(hash_dict_path, "w") as file:
json.dump(hash_dict, file, indent=4)
log_error = []
log_path = 'log_error.txt'
processed_image_num = 1 # batch size
processed_step = total_num // processed_image_num if total_num % processed_image_num == 0 else total_num // processed_image_num + 1
i = 0
while True:
if i == processed_step:
break
try:
st = i * processed_image_num
end = (i+1) * processed_image_num
print(f'processing {end} images')
subset = image_files[st:end] # choose images
image_names = all_names[st:end]
text_prompt = "I want you to act as a Texture Image Classifier with a ranking system. I will provide you with an image and a list of optinal categories. Your task is to choose the 5 most relevant categories for the image and rank them from most to least likely to accurately describe the image. Provide the output in a dict format, key is the image name, value is the list of top-5 category (no line wrap). Do not provide explanations for your choices or any additional information. Here is the image({}) and its optional categories({}). You have to choose strictly among the given categories and do not give any predictions that are not in the given category.".format(image_names, categories)
base64Images = [encode_image(p) for p in subset]
PROMPT_MESSAGES = [
{
"role": "user",
"content": [
text_prompt,
*map(lambda x: {"image": x, "resize": 512}, base64Images),
],
}
]
params = {
"model": "gpt-4-vision-preview",
"messages": PROMPT_MESSAGES,
"max_tokens": 4096 # upperbound
}
# request
result = client.chat.completions.create(**params)
markdown_str = result.choices[0].message.content
# json_str = markdown_str.replace('```json\n', '').replace('\n```', '')
json_str = markdown_str.replace('{', '').replace('}', ',\n')
text_filename = '{}_pred.json'.format(dataset_name)
output_path = os.path.join(output_dir, text_filename)
# Write the string to a text file
with open(output_path, 'a+') as file:
file.write(json_str)
print(result.usage)
# with open('token_usage.txt', 'a+') as file:
# file.write(str(result.usage)+'\n')
i = i + 1
except Exception as e:
current_date_and_time = datetime.now()
error_information = current_date_and_time.strftime("%D:%H:%M:%S") + str(e) + '\n'
# log_error.append(error_information)
with open(log_path, 'a') as f:
f.writelines(error_information)
time.sleep(2)