-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
97 lines (81 loc) · 2.35 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import einops
import numpy as np
import matplotlib.pyplot as plt
from vdm.models import VDM, NoiseScheduleNN, ScoreNetwork
from vdm.sample import sample_fn
from vdm.utils import get_sharding
def image_shaper(images):
n, c, h, w = images.shape
b = int(np.sqrt(n))
return einops.rearrange(
images,
"(b1 b2) c h w -> (b2 h) (b1 w) c",
b1=b,
b2=b,
h=h,
w=w,
c=c
)
if __name__ == "__main__":
from data.cifar10 import cifar10
key = jr.PRNGKey(0)
dataset = cifar10(key)
sharding = get_sharding()
# Data hyper-parameters
context_dim = None
data_shape = dataset.data_shape
dataset_name = dataset.name
# Model hyper-parameters
model_name = "vdm_" + dataset_name
init_gamma_0 = -13.3
init_gamma_1 = 5.
activation = jax.nn.tanh
T_train = 0
T_sample = 1000
n_sample = 64
# Plotting
proj_dir = "./"
imgs_dir = os.path.join(proj_dir, "imgs_" + dataset_name)
key_s, key_n = jr.split(key)
score_network = ScoreNetwork(
data_shape,
context_dim,
init_gamma_0,
init_gamma_1,
key=key_s
)
noise_schedule = NoiseScheduleNN(
init_gamma_0, init_gamma_1, key=key_n
)
vdm = VDM(score_network, noise_schedule)
vdm = eqx.tree_deserialise_leaves(model_name, vdm)
print("Loaded:", model_name)
for i in range(5):
print("inference ", i)
key = jr.fold_in(key, i)
zs, x_preds, samples = sample_fn(
key, vdm, n_sample, T_sample, data_shape, sharding=sharding
)
print("Sampled", samples.min(), samples.max())
samples = image_shaper(dataset.scaler.reverse(samples))
zs = image_shaper(dataset.scaler.reverse(zs))
x_preds = image_shaper(dataset.scaler.reverse(x_preds))
print("Scaled", samples.min(), samples.max())
fig, axs = plt.subplots(1, 2, figsize=(16., 8.), dpi=300)
ax = axs[0]
ax.imshow(zs)
ax.axis("off")
ax = axs[1]
ax.imshow(samples)
ax.axis("off")
plt.subplots_adjust(wspace=0.01, hspace=0.01)
plt.savefig(
os.path.join(imgs_dir, f"inferences_{i}.png"),
bbox_inches="tight"
)
plt.close()