forked from CarperAI/trlx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
simulacra.py
38 lines (31 loc) · 1.17 KB
/
simulacra.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Optimize prompts by training on prompts-ratings pairings dataset
# taken from https://github.com/JD-P/simulacra-aesthetic-captions
import os
import sqlite3
from urllib.request import urlretrieve
from accelerate import Accelerator
import trlx
from trlx.data.default_configs import default_ilql_config
url = "https://raw.githubusercontent.com/JD-P/simulacra-aesthetic-captions/main/sac_public_2022_06_29.sqlite"
dbpath = "sac_public_2022_06_29.sqlite"
if __name__ == "__main__":
accelerator = Accelerator()
if os.environ.get("LOCAL_RANK", "0") == "0" and not os.path.exists(dbpath):
print(f"fetching {dbpath}")
urlretrieve(url, dbpath)
accelerator.wait_for_everyone()
conn = sqlite3.connect(dbpath)
c = conn.cursor()
c.execute(
"SELECT prompt, rating FROM ratings "
"JOIN images ON images.id=ratings.iid "
"JOIN generations ON images.gid=generations.id "
"WHERE rating IS NOT NULL;"
)
prompts, ratings = tuple(map(list, zip(*c.fetchall())))
trlx.train(
config=default_ilql_config(),
samples=prompts,
rewards=ratings,
eval_prompts=["An astronaut riding a horse"] * 64,
)