-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
59 lines (43 loc) · 2.46 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import time
import tensorflow as tf
import tensorflow_addons as tfa
from dataset import CifarDataset, ImagenetDataset
from model import DiffusionModel
from sampling_callback import SamplingCallback
import unet
BATCH_SIZE = 128
EPOCHS = 145
LEARNING_RATE = 1e-4
flags = tf.compat.v1.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("checkpoint_dir", None, "Directory to load model state from to resume training.")
flags.DEFINE_string("experiment_name", None, "Name of the experiment being run.")
flags.DEFINE_bool("use_mixed_precision", False, "Whether to use float16 mixed precision training.")
if FLAGS.use_mixed_precision:
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)
strategy = tf.distribute.MirroredStrategy()
batch_size = len(tf.config.list_physical_devices('GPU')) * BATCH_SIZE
with strategy.scope():
dataset = CifarDataset(batch_size)
unet = unet.Unet(dim=128, num_res_blocks=2, dropout=0.3, dim_mults=[1, 2, 2, 2], attention_resolutions=(2, 4), resblock_updown=True, num_classes=dataset.num_classes, learned_variance=True)
model = DiffusionModel(dataset.image_size, dataset.betas, unet, model_var_type='learned_range')
adam = tf.keras.optimizers.Adam(LEARNING_RATE)
optimizer = tfa.optimizers.MovingAverage(adam, average_decay=0.9999)
model.compile(optimizer=optimizer)
if FLAGS.checkpoint_dir:
checkpoint_dir = FLAGS.checkpoint_dir
latest = tf.train.latest_checkpoint(checkpoint_dir)
model.load_weights(latest)
elif FLAGS.experiment_name:
checkpoint_dir = 'checkpoints/{}'.format(FLAGS.experiment_name)
else:
checkpoint_dir = 'checkpoints/{}'.format(time.strftime("%m_%d_%y-%H_%M"))
checkpoint_path = checkpoint_dir + "/checkpoint.ckpt"
cp_callback = tfa.callbacks.AverageModelCheckpoint(filepath=checkpoint_path,
update_weights=False,
save_weights_only=True,
verbose=1)
buar_callback = tf.keras.callbacks.experimental.BackupAndRestore(checkpoint_dir)
sampling_callback = SamplingCallback(checkpoint_dir=checkpoint_dir, batch_size=batch_size, sample_classes=dataset.get_sample_classes(), run_every=5, image_size=dataset.image_size, model_var_type='learned_range')
model.fit(dataset.load(), epochs=EPOCHS, callbacks=[cp_callback, buar_callback, sampling_callback])