-
Notifications
You must be signed in to change notification settings - Fork 1
/
cli.py
71 lines (70 loc) · 1.51 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
parser = argparse.ArgumentParser(description='beta-VAE MNIST / dSprites')
parser.add_argument(
'model_name',
type=str,
default='simple',
metavar='MODEL',
nargs='?',
help='model name (default: simple)'
)
parser.add_argument(
'--data',
type=str,
default='MNIST',
metavar='D',
help='dataset name (default: MNIST, also: dSprites)'
)
parser.add_argument(
'--z-dim',
type=int,
default=15,
metavar='Z',
help='number of latent variables z (default: 15)'
)
parser.add_argument(
'--beta',
type=int,
default=5.0,
metavar='B',
help='regularisation coefficient * the KLD (default: 5.0)'
)
parser.add_argument(
'--batch-size',
type=int,
default=128,
metavar='N',
help='input batch size for training (default: 128)'
)
parser.add_argument(
'--epochs',
type=int,
default=10,
metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument(
'--no-cuda',
action='store_true',
default=False,
help='enables CUDA training'
)
parser.add_argument(
'--seed',
type=int,
default=1,
metavar='S',
help='random seed (default: 1)'
)
parser.add_argument(
'--log-interval',
type=int,
default=100,
metavar='N',
help='how many batches to wait before logging training status'
)
parser.add_argument(
'--tensorboard',
action='store_true',
default=False,
help='plots losses with tensorboard (default: False), to view run `$ tensorboard --logdir runs`'
)