forked from dorarad/gansformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_network.py
556 lines (462 loc) · 31.9 KB
/
run_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
# Ignore all future warnings
from warnings import simplefilter
simplefilter(action = "ignore", category = FutureWarning)
import argparse
import copy
import glob
import sys
import os
import dnnlib
from dnnlib import EasyDict
from metrics.metric_defaults import metric_defaults
from training import misc
import pretrained_networks
# Suppress warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
# Conditional set: if property is not None, then assign d[name] := prop
# for every d in a set of dictionaries
def cset(dicts, name, prop):
if not isinstance(dicts, list):
dicts = [dicts]
if prop is not None:
for d in dicts:
d[name] = prop
# Conditional set: if dict[name] is not populated from the command line, then assign dict[name] := prop
def nset(args, name, prop):
flag = f"--{name.replace('_', '-')}"
if flag not in sys.argv:
args[name] = prop
# Set network (generator or discriminator): model, loss and optimizer
def set_net(net, reg_interval):
ret = EasyDict()
net2name = {
"D": "Discriminator",
"G": "Generator"
}
ret.args = EasyDict(func_name = f"training.networks.{net2name[net]}") # network options
ret.loss_args = EasyDict(func_name = f"training.loss.{net}_loss") # loss options
ret.opt_args = EasyDict(beta1 = 0.0, beta2 = 0.99, epsilon = 1e-8) # optimizer options
ret.reg_interval = reg_interval
return ret
def run(**args):
args = EasyDict(args)
train = EasyDict(run_func_name = "training.training_loop.training_loop") # training loop options
sched = EasyDict() # TrainingSchedule options
vis = EasyDict() # visualize.eval() options
grid = EasyDict(size = (3, 2), layout = "random") # setup_snapshot_img_grid() options
sc = dnnlib.SubmitConfig() # dnnlib.submit_run() options
# GANformer and baselines default settings
# ----------------------------------------------------------------------------
if args.ganformer_default:
task = args.dataset
nset(args, "recompile", args.pretrained_pkl is not None)
nset(args, "mirror_augment", task in ["cityscapes", "ffhq"])
nset(args, "transformer", True)
nset(args, "components_num", {"clevr": 8}.get(task, 16))
nset(args, "latent_size", {"clevr": 128}.get(task, 512))
nset(args, "normalize", "layer")
nset(args, "integration", "mul")
nset(args, "kmeans", True)
nset(args, "use_pos", True)
nset(args, "mapping_ltnt2ltnt", task != "clevr")
nset(args, "style", task != "clevr")
nset(args, "g_arch", "resnet")
nset(args, "mapping_resnet", True)
gammas = {
"ffhq": 10,
"cityscapes": 20,
"clevr": 40,
"bedrooms": 100
}
nset(args, "gamma", gammas.get(task, 10))
if args.baseline == "GAN":
nset(args, "style", False)
nset(args, "latent_stem", True)
if args.baseline == "SAGAN":
nset(args, "style", False)
nset(args, "latent_stem", True)
nset(args, "g_img2img", 5)
if args.baseline == "kGAN":
nset(args, "kgan", True)
nset(args, "merge_layer", 5)
nset(args, "merge_type", "softmax")
nset(args, "components_num", 8)
# General setup
# ----------------------------------------------------------------------------
# If the flag is specified without arguments (--arg), set to True
for arg in ["summarize", "keep_samples", "style", "fused_modconv", "local_noise"]:
if args[arg] is None:
args[arg] = True
# Environment configuration
tf_config = {
"rnd.np_random_seed": 1000,
"allow_soft_placement": True,
"gpu_options.per_process_gpu_memory_fraction": 1.0
}
if args.gpus != "":
num_gpus = len(args.gpus.split(","))
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
assert num_gpus in [1, 2, 4, 8]
sc.num_gpus = num_gpus
# Data setup
# ----------------------------------------------------------------------------
# Dataset configuration
if not os.path.exists(f"{args.data_dir}/{args.dataset}"):
misc.error(f"The dataset {args.data_dir}/{args.dataset} directory does not exist")
# For bedrooms, we choose the most common ratio in the
# dataset and crop the other images into that ratio.
ratios = {
"clevr": 0.75,
"bedrooms": 188/256,
"cityscapes": 0.5,
"ffhq": 1.0
}
args.ratio = args.ratio or ratios.get(args.dataset, 1.0)
args.crop_ratio = 0.5 if args.resolution > 256 and args.ratio <= 0.5 else None
dataset_args = EasyDict(
tfrecord_dir = args.dataset,
max_imgs = args.train_images_num,
num_threads = args.num_threads,
resolution = args.resolution,
crop_ratio = args.crop_ratio,
ratio = args.ratio
)
for arg in ["data_dir", "mirror_augment", "total_kimg"]:
cset(train, arg, args[arg])
# Optimization setup
# ----------------------------------------------------------------------------
# Networks configuration
cG = set_net("G", reg_interval = 4)
cD = set_net("D", reg_interval = 16)
cset([cG, cD], "crop_ratio", args.crop_ratio)
# Training and Optimizations configuration
if not any([args.train, args.eval, args.vis]):
misc.log("Warning: None of --train, --eval or --vis are provided. Therefore, we only print network shapes", "red")
for arg in ["train", "eval", "vis", "recompile", "last_snapshots"]:
cset(train, arg, args[arg])
if args.batch_size % (args.batch_gpu * num_gpus) != 0:
misc.error("--batch-size should be divided by --batch-gpu * 'num_gpus'")
if args.latent_size % args.components_num != 0:
misc.error(f"--latent-size ({args.latent_size}) should be divisible by --components-num (k={k})")
sched_args = {
"G_lrate": "g_lr",
"D_lrate": "d_lr",
"batch_size": "batch_size",
"batch_gpu": "batch_gpu"
}
for arg, cmd_arg in sched_args.items():
cset(sched, arg, args[cmd_arg])
cset(train, "clip", args.clip)
# Evaluation and visualization
# ----------------------------------------------------------------------------
# Logging and metrics configuration
for metric in args.metrics:
if metric not in metric_defaults:
misc.error(f"Unknown metric: {metric}")
metrics = [metric_defaults[x] for x in args.metrics]
for arg in ["summarize", "eval_images_num"]:
cset(train, arg, args[arg])
cset(cG.args, "truncation_psi", args.truncation_psi)
for arg in ["keep_samples", "num_heads"]:
cset(vis, arg, args[arg])
# Visualization
args.vis_imgs = args.vis_images
args.vis_ltnts = args.vis_latents
vis_types = ["imgs", "ltnts", "maps", "layer_maps", "interpolations", "noise_var", "style_mix"]
# Set of all the set visualization types option
vis.vis_types = {arg for arg in vis_types if args[f"vis_{arg}"]}
vis_args = {
"attention": "transformer",
"grid": "vis_grid",
"num": "vis_num",
"rich_num": "vis_rich_num",
"section_size": "vis_section_size",
"intrp_density": "interpolation_density",
# "intrp_per_component": "interpolation_per_component",
"alpha": "blending_alpha"
}
for arg, cmd_arg in vis_args.items():
cset(vis, arg, args[cmd_arg])
# Networks setup
# ----------------------------------------------------------------------------
# Networks architecture
cset(cG.args, "architecture", args.g_arch)
cset(cD.args, "architecture", args.d_arch)
cset(cG.args, "tanh", args.tanh)
# Latent sizes
if args.components_num > 1:
if not (args.transformer or args.kgan):
misc.error("--components-num > 1 but the model is not using components. " +
"Either add --transformer for GANformer or --kgan for k-GAN.")
args.latent_size = int(args.latent_size / args.components_num)
cD.args.a_dim = cG.args.z_dim = cG.args.w_dim = args.latent_size
cset([cG.args, cD.args, vis], "components_num", args.components_num)
# Mapping network
for arg in ["layersnum", "lrmul", "dim", "resnet", "shared_dim"]:
field = f"mapping_{arg}"
cset(cG.args, field, args[field])
# StyleGAN settings
for arg in ["style", "latent_stem", "fused_modconv", "local_noise"]:
cset(cG.args, arg, args[arg])
cD.args.mbstd_group_size = min(args.batch_size, 4)
# GANformer
cset(cG.args, "transformer", args.transformer)
cset(cD.args, "transformer", args.d_transformer)
args.norm = args.normalize
for arg in ["norm", "integration", "ltnt_gate", "img_gate", "iterative", "kmeans",
"kmeans_iters", "mapping_ltnt2ltnt"]:
cset(cG.args, arg, args[arg])
for arg in ["use_pos", "num_heads"]:
cset([cG.args, cD.args], arg, args[arg])
# Positional encoding
for arg in ["dim", "type", "init", "directions_num"]:
field = f"pos_{arg}"
cset([cG.args, cD.args], field, args[field])
# k-GAN
for arg in ["layer", "type", "same"]:
field = f"merge_{arg}"
cset(cG.args, field, args[field])
cset(cG.args, "merge", args.kgan)
if args.kgan and args.transformer:
misc.error("Either have --transformer for GANformer or --kgan for k-GAN, not both")
# Attention
for arg in ["start_res", "end_res", "ltnt2ltnt", "img2img"]: # , "local_attention"
cset(cG.args, arg, args[f"g_{arg}"])
cset(cD.args, arg, args[f"d_{arg}"])
cset(cG.args, "img2ltnt", args.g_img2ltnt)
# cset(cD.args, "ltnt2img", args.d_ltnt2img)
# Mixing and dropout
for arg in ["style_mixing", "component_mixing", "component_dropout", "attention_dropout"]:
cset(cG.args, arg, args[arg])
# Loss and regularization
gloss_args = {
"loss_type": "g_loss",
"reg_weight": "g_reg_weight",
# "pathreg": "pathreg"
}
dloss_args = {
"loss_type": "d_loss",
"reg_type": "d_reg",
"gamma": "gamma"
}
for arg, cmd_arg in gloss_args.items():
cset(cG.loss_args, arg, args[cmd_arg])
for arg, cmd_arg in dloss_args.items():
cset(cD.loss_args, arg, args[cmd_arg])
# Setup and launching
# ----------------------------------------------------------------------------
##### Experiments management:
# Whenever we start a new experiment we store its result in a directory named 'args.expname:000'.
# When we rerun a training or evaluation command it restores the model from that directory by default.
# If we wish to restart the model training, we can set --restart and then we will store data in a new
# directory: 'args.expname:001' after the first restart, then 'args.expname:002' after the second, etc.
# Find the latest directory that matches the experiment
exp_dir = sorted(glob.glob(f"{args.result_dir}/{args.expname}-*"))
run_id = 0
if len(exp_dir) > 0:
run_id = int(exp_dir[-1].split("-")[-1])
# If restart, then work over a new directory
if args.restart:
run_id += 1
run_name = f"{args.expname}-{run_id:03d}"
train.printname = f"{misc.bold(args.expname)} "
snapshot, kimg, resume = None, 0, False
pkls = sorted(glob.glob(f"{args.result_dir}/{run_name}/network*.pkl"))
# Load a particular snapshot is specified
if args.pretrained_pkl is not None and args.pretrained_pkl != "None":
# Soft links support
if args.pretrained_pkl.startswith("gdrive"):
if args.pretrained_pkl not in pretrained_networks.gdrive_urls:
misc.error("--pretrained_pkl {} not available in the catalog (see pretrained_networks.py)")
snapshot = args.pretrained_pkl
else:
snapshot = glob.glob(args.pretrained_pkl)[0]
if os.path.islink(snapshot):
snapshot = os.readlink(snapshot)
# Extract training step from the snapshot if specified
try:
kimg = int(snapshot.split("-")[-1].split(".")[0])
except:
pass
# Find latest snapshot in the directory
elif len(pkls) > 0:
snapshot = pkls[-1]
kimg = int(snapshot.split("-")[-1].split(".")[0])
resume = True
if snapshot:
misc.log(f"Resuming {run_name}, from {snapshot}, kimg {kimg}", "white")
train.resume_pkl = snapshot
train.resume_kimg = kimg
else:
misc.log("Start model training from scratch", "white")
# Run environment configuration
sc.run_dir_root = args.result_dir
sc.run_desc = args.expname
sc.run_id = run_id
sc.run_name = run_name
sc.submit_target = dnnlib.SubmitTarget.LOCAL
sc.local.do_not_copy_source_files = True
kwargs = EasyDict(train)
kwargs.update(cG = cG, cD = cD)
kwargs.update(dataset_args = dataset_args, vis_args = vis, sched_args = sched,
grid_args = grid, metric_arg_list = metrics, tf_config = tf_config)
kwargs.submit_config = copy.deepcopy(sc)
kwargs.resume = resume
kwargs.load_config = args.reload
dnnlib.submit_run(**kwargs)
# ----------------------------------------------------------------------------
def _str_to_bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1", ""):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Error: Boolean value expected")
def _parse_comma_sep(s):
if s is None or s.lower() == "none" or s == "":
return []
return s.split(",")
def main():
parser = argparse.ArgumentParser(description = "Train the GANformer")
# Framework
# ------------------------------------------------------------------------------------------------------
parser.add_argument("--expname", help = "Experiment name", default = "exp", type = str)
parser.add_argument("--train", help = "Train mode (default: False)", default = None, action = "store_true")
parser.add_argument("--eval", help = "Evaluation mode (default: False)", default = None, action = "store_true")
parser.add_argument("--vis", help = "Visualization mode (default: False)", default = None, action = "store_true")
parser.add_argument("--gpus", help = "Comma-separated list of GPUs to be used (default: %(default)s)", default = "0", type = str)
## Default configurations
parser.add_argument("--ganformer-default", help = "Select a default GANformer configuration, either pretrained (default) or from scratch (with --pretrained-pkl None)", default = None, action = "store_true")
parser.add_argument("--baseline", help = "Use a baseline model configuration", default = None, choices = ["GAN", "StyleGAN2", "kGAN", "SAGAN"], type = str)
## Resumption
parser.add_argument("--pretrained-pkl", help = "Filename for a snapshot to resume (optional)", default = None, type = str)
parser.add_argument("--restart", help = "Restart training from scratch", default = False, action = "store_true")
parser.add_argument("--reload", help = "Reload options from the original experiment configuration file. " +
"If False, uses the command line arguments when resuming training (default: %(default)s)", default = False, action = "store_true")
parser.add_argument("--recompile", help = "Recompile model from source code when resuming training. " +
"If False, loading modules created when the experiment first started", default = None, action = "store_true")
parser.add_argument("--last-snapshots", help = "Number of last snapshots to save. -1 for all (default: 10)", default = None, type = int)
## Dataset
parser.add_argument("--data-dir", help = "Datasets root directory (default: %(default)s)", default = "datasets", metavar = "DIR")
parser.add_argument("--dataset", help = "Training dataset name (subdirectory of data-dir)", required = True)
parser.add_argument("--ratio", help = "Image height/width ratio in the dataset", default = None, type = float)
parser.add_argument("--resolution", help = "Training resolution", default = 256, type = int)
parser.add_argument("--num-threads", help = "Number of input processing threads (default: %(default)s)", default = 4, type = int)
parser.add_argument("--mirror-augment", help = "Perform horizontal flip augmentation for the data (default: %(default)s)", default = None, action = "store_true")
parser.add_argument("--train-images-num", help = "Maximum number of images to train on. If not specified, train on the whole dataset.", default = None, type = int)
## Training
parser.add_argument("--batch-size", help = "Global batch size (optimization step) (default: %(default)s)", default = 32, type = int)
parser.add_argument("--batch-gpu", help = "Batch size per GPU, gradients will be accumulated to match batch-size (default: %(default)s)", default = 4, type = int)
parser.add_argument("--total-kimg", help = "Training length in thousands of images (default: %(default)s)", metavar = "KIMG", default = 25000, type = int)
parser.add_argument("--gamma", help = "R1 regularization weight (default: %(default)s)", default = 10, type = float)
parser.add_argument("--clip", help = "Gradient clipping threshold (optional)", default = None, type = float)
parser.add_argument("--g-lr", help = "Generator learning rate (default: %(default)s)", default = 0.002, type = float)
parser.add_argument("--d-lr", help = "Discriminator learning rate (default: %(default)s)", default = 0.002, type = float)
## Logging and evaluation
parser.add_argument("--result-dir", help = "Root directory for experiments (default: %(default)s)", default = "results", metavar = "DIR")
parser.add_argument("--metrics", help = "Comma-separated list of metrics or none (default: %(default)s)", default = "fid", type = _parse_comma_sep)
parser.add_argument("--summarize", help = "Create TensorBoard summaries (default: %(default)s)", default = True, metavar = "BOOL", type = _str_to_bool, nargs = "?")
parser.add_argument("--truncation-psi", help = "Truncation Psi to be used in producing sample images " +
"(used only for visualizations, _not used_ in training or for computing metrics) (default: %(default)s)", default = 0.75, type = float)
parser.add_argument("--keep-samples", help = "Keep all prior samples during training, or if False, just the most recent ones (default: %(default)s)", default = True, metavar = "BOOL", type = _str_to_bool, nargs = "?")
parser.add_argument("--eval-images-num", help = "Number of images to evaluate metrics on (default: 50,000)", default = None, type = int)
## Visualization
parser.add_argument("--vis-images", help = "Save image samples", default = None, action = "store_true")
parser.add_argument("--vis-latents", help = "Save latent vectors", default = None, action = "store_true")
parser.add_argument("--vis-maps", help = "Save attention maps (for GANformer only)", default = None, action = "store_true")
parser.add_argument("--vis-layer-maps", help = "Save attention maps for all layers (for GANformer only)", default = None, action = "store_true")
parser.add_argument("--vis-interpolations", help = "Create latent interpolations", default = None, action = "store_true")
parser.add_argument("--vis-noise-var", help = "Create noise variation visualization", default = None, action = "store_true")
parser.add_argument("--vis-style-mix", help = "Create style mixing visualization", default = None, action = "store_true")
parser.add_argument("--vis-grid", help = "Whether to save the samples in one large grid files (default: True in training)", default = None, action = "store_true")
parser.add_argument("--vis-num", help = "Number of images for which visualization will be created (default: grid-size/100 in train/eval)", default = None, type = int)
parser.add_argument("--vis-rich-num", help = "Number of samples for which richer visualizations will be created (default: 5)", default = None, type = int)
parser.add_argument("--vis-section-size", help = "Visualization section size to process at one (section-size <= vis-num) for memory footprint (default: 100)", default = None, type = int)
parser.add_argument("--blending-alpha", help = "Proportion for generated images and attention maps blends (default: 0.3)", default = None, type = float)
parser.add_argument("--interpolation-density", help = "Number of samples in between two end points of an interpolation (default: 8)", default = None, type = int)
# parser.add_argument("--interpolation-per-component", help = "Whether to perform interpolation along particular latent components when true, or all of them at once otherwise (default: False)", default = None, action = "store_true")
# Model
# ------------------------------------------------------------------------------------------------------
## General architecture
parser.add_argument("--g-arch", help = "Generator architecture type (default: skip)", default = None, choices = ["orig", "skip", "resnet"], type = str)
parser.add_argument("--d-arch", help = "Discriminator architecture type (default: resnet)", default = None, choices = ["orig", "skip", "resnet"], type = str)
parser.add_argument("--tanh", help = "tanh on generator output (default: False)", default = None, action = "store_true")
# Mapping network
parser.add_argument("--mapping-layersnum", help = "Number of mapping layers (default: 8)", default = None, type = int)
parser.add_argument("--mapping-lrmul", help = "Mapping network learning rate multiplier (default: 0.01)", default = None, type = float)
parser.add_argument("--mapping-dim", help = "Mapping layers dimension (default: latent_size)", default = None, type = int)
parser.add_argument("--mapping-resnet", help = "Use resent connections in mapping layers (default: False)", default = None, action = "store_true")
parser.add_argument("--mapping-shared-dim", help = "Perform one shared mapping to all latent components concatenated together using the set dimension (default: disabled)", default = None, type = int)
# Loss
# parser.add_argument("--pathreg", help = "Use path regularization in generator training (default: False)", default = None, action = "store_true")
parser.add_argument("--g-loss", help = "Generator loss type (default: %(default)s)", default = "logistic_ns", choices = ["logistic", "logistic_ns", "hinge", "wgan"], type = str)
parser.add_argument("--g-reg-weight", help = "Generator regularization weight (default: %(default)s)", default = 1.0, type = float)
parser.add_argument("--d-loss", help = "Discriminator loss type (default: %(default)s)", default = "logistic", choices = ["wgan", "logistic", "hinge"], type = str)
parser.add_argument("--d-reg", help = "Discriminator regularization type (default: %(default)s)", default = "r1", choices = ["non", "gp", "r1", "r2"], type = str)
# --gamma effectively functions as discriminator regularization weight
# Mixing and dropout
parser.add_argument("--style-mixing", help = "Style mixing (layerwise) probability (default: %(default)s)", default = 0.9, type = float)
parser.add_argument("--component-mixing", help = "Component mixing (objectwise) probability (default: %(default)s)", default = 0.0, type = float)
parser.add_argument("--component-dropout", help = "Component dropout (default: %(default)s)", default = 0.0, type = float)
parser.add_argument("--attention-dropout", help = "Attention dropout (default: 0.12)", default = None, type = float)
# StyleGAN additions
parser.add_argument("--style", help = "Global style modulation (default: %(default)s)", default = True, metavar = "BOOL", type = _str_to_bool, nargs = "?")
parser.add_argument("--latent-stem", help = "Input latent through the generator stem grid (default: False)", default = None, action = "store_true")
parser.add_argument("--fused-modconv", help = "Fuse modulation and convolution operations (default: %(default)s)", default = True, metavar = "BOOL", type = _str_to_bool, nargs = "?")
parser.add_argument("--local-noise", help = "Add stochastic local noise each layer (default: %(default)s)", default = True, metavar = "BOOL", type = _str_to_bool, nargs = "?")
## GANformer
parser.add_argument("--transformer", help = "Add transformer layers to the generator: top-down latents-to-image (default: False)", default = None, action = "store_true")
parser.add_argument("--latent-size", help = "Latent size, summing the dimension of all components (default: %(default)s)", default = 512, type = int)
parser.add_argument("--components-num", help = "Components number. Each component has latent dimension of 'latent-size / components-num'. " +
"1 for StyleGAN since it has one global latent vector (default: %(default)s)", default = 1, type = int)
parser.add_argument("--num-heads", help = "Number of attention heads (default: %(default)s)", default = 1, type = int)
parser.add_argument("--normalize", help = "Feature normalization type (optional)", default = None, choices = ["batch", "instance", "layer"], type = str)
parser.add_argument("--integration", help = "Feature integration type: additive, multiplicative or both (default: %(default)s)", default = "add", choices = ["add", "mul", "both"], type = str)
# Generator attention layers
# Transformer resolution layers
parser.add_argument("--g-start-res", help = "Transformer minimum generator resolution (logarithmic): first layer in which transformer will be applied (default: %(default)s)", default = 0, type = int)
parser.add_argument("--g-end-res", help = "Transformer maximum generator resolution (logarithmic): last layer in which transformer will be applied (default: %(default)s)", default = 8, type = int)
# Discriminator attention layers
parser.add_argument("--d-transformer", help = "Add transformer layers to the discriminator (bottom-up image-to-latents) (default: False)", default = None, action = "store_true")
parser.add_argument("--d-start-res", help = "Transformer minimum discriminator resolution (logarithmic): first layer in which transformer will be applied (default: %(default)s)", default = 0, type = int)
parser.add_argument("--d-end-res", help = "Transformer maximum discriminator resolution (logarithmic): last layer in which transformer will be applied (default: %(default)s)", default = 8, type = int)
# Attention
parser.add_argument("--ltnt-gate", help = "Gate attention from latents, such that components may not send information " +
"when gate value is low (default: False)", default = None, action = "store_true")
parser.add_argument("--img-gate", help = "Gate attention for images, such that some image positions may not get updated " +
"or receive information when gate value is low (default: False)", default = None, action = "store_true")
parser.add_argument("--kmeans", help = "Track and update image-to-latents assignment centroids, used in the duplex attention (default: False)", default = None, action = "store_true")
parser.add_argument("--kmeans-iters", help = "Number of K-means iterations per transformer layer. Note that centroids are carried from layer to layer (default: %(default)s)", default = 1, type = int) # -per-layer
parser.add_argument("--iterative", help = "Whether to carry over attention assignments across transformer layers of different resolutions (default: False)", default = None, action = "store_true")
# Attention directions
# format is A2B: Elements _from_ B attend _to_ elements in A, and B elements get updated accordingly.
# Note that it means that information propagates in the following direction: A -> B
parser.add_argument("--mapping-ltnt2ltnt", help = "Add self-attention over latents in the mapping network (default: False)", default = None, action = "store_true")
parser.add_argument("--g-ltnt2ltnt", help = "Add self-attention over latents in the synthesis network (default: False)", default = None, action = "store_true")
parser.add_argument("--g-img2img", help = "Add self-attention between images positions in that layer of the generator (SAGAN) (default: disabled)", default = 0, type = int)
parser.add_argument("--g-img2ltnt", help = "Add image to latents attention (bottom-up) (default: %(default)s)", default = None, action = "store_true")
# g-ltnt2img: default information flow direction when using --transformer
# parser.add_argument("--d-ltnt2img", help = "Add latents to image attention (top-down) (default: %(default)s)", default = None, action = "store_true")
parser.add_argument("--d-ltnt2ltnt", help = "Add self-attention over latents in the discriminator (default: False)", default = None, action = "store_true")
parser.add_argument("--d-img2img", help = "Add self-attention over images positions in that layer of the discriminator (SAGAN) (default: disabled)", default = 0, type = int)
# d-img2ltnt: default information flow direction when using --d-transformer
# Local attention operations (replacing convolution)
# parser.add_argument("--g-local-attention", help = "Local attention operations in the generation up to this layer (default: disabled)", default = None, type = int)
# parser.add_argument("--d-local-attention", help = "Local attention operations in the discriminator up to this layer (default: disabled)", default = None, type = int)
# Positional encoding
parser.add_argument("--use-pos", help = "Use positional encoding (default: False)", default = None, action = "store_true")
parser.add_argument("--pos-dim", help = "Positional encoding dimension (default: latent-size)", default = None, type = int)
parser.add_argument("--pos-type", help = "Positional encoding type (default: %(default)s)", default = "sinus", choices = ["linear", "sinus", "trainable", "trainable2d"], type = str)
parser.add_argument("--pos-init", help = "Positional encoding initialization distribution (default: %(default)s)", default = "uniform", choices = ["uniform", "normal"], type = str)
parser.add_argument("--pos-directions-num", help = "Positional encoding number of spatial directions (default: %(default)s)", default = 2, type = int)
## k-GAN
parser.add_argument("--kgan", help = "Generate components-num images and then merge them (k-GAN) (default: False)", default = None, action = "store_true")
parser.add_argument("--merge-layer", help = "Merge layer, where images get combined through alpha-composition (default: %(default)s)", default = -1, type = int)
parser.add_argument("--merge-type", help = "Merge type (default: sum)", default = None, choices = ["sum", "softmax", "max", "leaves"], type = str)
parser.add_argument("--merge-same", help = "Merge images with same alpha weights across all spatial positions (default: %(default)s)", default = None, action = "store_true")
args = parser.parse_args()
run(**vars(args))
if __name__ == "__main__":
main()