forked from dorarad/gansformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pretrained_networks.py
39 lines (32 loc) · 1.55 KB
/
pretrained_networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# List of pre-trained GANsformer networks
import pickle
import dnnlib
import dnnlib.tflib as tflib
gdrive_urls = {
"gdrive:clevr-snapshot.pkl": "https://drive.google.com/uc?id=1zBh-U2kyVgN3C_P_7GqsMEHvBdz2lobu",
"gdrive:cityscapes-snapshot.pkl": "https://drive.google.com/uc?id=1XPGYzUP_1ETFtz5bUhpUFPha1IBNTuZh",
"gdrive:ffhq-snapshot.pkl": "https://drive.google.com/uc?id=1tgs-hHaziWrh0piuX3sEd8PwE9gFwlNh",
"gdrive:bedrooms-snapshot.pkl": "https://drive.google.com/uc?id=19TykSlMgXIjyIiDLBakmNVEQmD5dTd0s"
}
eval_gdrive_urls = gdrive_urls.copy()
eval_gdrive_urls.update({
"gdrive:cityscapes-snapshot-2048.pkl": "https://drive.google.com/uc?id=1Zw1cFxxN6-iC_M4x6Zbf9lwH9wKryW3p",
"gdrive:ffhq-snapshot-1024.pkl": "https://drive.google.com/uc?id=10V4yK_rQWb4F6Q4vwqkO5XNKX721k3zl"
})
def get_path_or_url(path_or_gdrive_path, eval = False):
nets = eval_gdrive_urls if eval else gdrive_urls
return nets.get(path_or_gdrive_path, path_or_gdrive_path)
_cached_networks = dict()
def load_networks(path_or_gdrive_path, eval = False):
path_or_url = get_path_or_url(path_or_gdrive_path, eval)
if path_or_url in _cached_networks:
return _cached_networks[path_or_url]
if dnnlib.util.is_url(path_or_url):
stream = dnnlib.util.open_url(path_or_url, cache_dir = ".GANsformer-cache")
else:
stream = open(path_or_url, "rb")
tflib.init_tf()
with stream:
G, D, Gs = pickle.load(stream, encoding = "latin1")[:3]
_cached_networks[path_or_url] = G, D, Gs
return G, D, Gs