-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_images.py
executable file
·79 lines (62 loc) · 2.97 KB
/
test_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python
import base64
import time
import argparse
import sys
import io
from PIL import Image
import openai
client = openai.Client(base_url='http://localhost:5005/v1')
TEST_DIR = 'test'
not_enhanced = "I NEED to test how the tool works with extremely simple prompts. DO NOT add any detail, just use it AS-IS:"
def generate_image(prompt, model, res, f, n = 1, suffix=''):
start = time.time()
response = client.images.generate(prompt=prompt, model=model, size=res, response_format='b64_json', n=n)
#image = Image.open(io.BytesIO(base64.b64decode(response.data[0].b64_json)))
#image.show()
end = time.time()
print(f"## {model} {res} took {end-start:.1f} seconds", file=f)
for i, img in enumerate(response.data, 1):
fname = f"test_image_{model}_{res}{suffix}-{i:02d}_{n:02d}.png"
with open(f'{TEST_DIR}/{fname}', 'wb') as png:
png.write(base64.b64decode(img.b64_json))
# markdown record the details of the test, including any extra revised_prompt
print(f"![{prompt}]({fname})", file=f)
if img.revised_prompt:
print("revised_prompt: " + img.revised_prompt, file=f)
print("\n", file=f, flush=True)
print("-"*50, file=f)
print("\n", file=f, flush=True)
def full_test(prompt, n=1):
for model in ['dall-e-1', 'dall-e-2']:
with open(f"{TEST_DIR}/test_images-{model}.md", "w") as f:
print(f"# {prompt}", file=f)
for res in ['256x256', '512x512', '1024x1024']:
generate_image(prompt, model, res, f, n=n)
model = 'dall-e-3'
with open(f"{TEST_DIR}/test_images-{model}.md", "w") as f:
print(f"# {prompt}", file=f)
for res in ['1024x1024', '1024x1796', '1796x1024']:
generate_image(prompt, model, res, f, n=n)
generate_image(not_enhanced + prompt, model, res, f, n=n, suffix='-not-enhanced')
def quick_test(prompt, n=1):
with open(f"{TEST_DIR}/test_images_quick.md", "w") as f:
print(f"# {prompt}", file=f)
generate_image(prompt, "dall-e-1", "512x512", f, n=n)
generate_image(prompt, "dall-e-1", "1024x1024", f, n=n)
generate_image(prompt, "dall-e-2", "1024x1024", f, n=n)
generate_image(not_enhanced + prompt, "dall-e-3", "1024x1024", f, n=n, suffix='-not-enhanced')
generate_image(prompt, "dall-e-3", "1024x1024", f, n=n)
def parse_args(argv=None):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-p', '--prompt', action='store', type=str, default="A cute baby sea otter")
parser.add_argument('-q', '--quick', action='store_true')
parser.add_argument('-f', '--full', action='store_true')
parser.add_argument('-n', '--batch', action='store', type=int, default=1)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args(sys.argv[1:])
if args.quick:
quick_test(args.prompt, n=args.batch)
elif args.full:
full_test(args.prompt, n=args.batch)