Skip to content

Commit abd96f4

Browse files
committed
[release] TF1.0
1 parent fceb8cb commit abd96f4

29 files changed

+4016
-1191
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ TensorFlow / TensorLayer implementation of [Deep Convolutional Generative Advers
1111
## Prerequisites
1212

1313
- Python 2.7 or Python 3.3+
14-
- [TensorFlow==0.10.0 or higher](https://www.tensorflow.org/)
15-
- [TensorLayer==1.2.6 or higher](https://github.com/zsdonghao/tensorlayer) (already in this repo)
14+
- [TensorFlow==1.0+](https://www.tensorflow.org/)
15+
- [TensorLayer==1.4+](https://github.com/zsdonghao/tensorlayer)
1616

1717

1818
## Usage
@@ -25,4 +25,6 @@ To train a model with downloaded dataset:
2525

2626
$ python main.py
2727

28+
## Result
2829

30+
![alt tag](result.png)

main.py

Lines changed: 56 additions & 297 deletions
Large diffs are not rendered by default.

model.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
2+
import tensorflow as tf
3+
import tensorlayer as tl
4+
from tensorlayer.layers import *
5+
6+
7+
flags = tf.app.flags
8+
FLAGS = flags.FLAGS
9+
10+
11+
12+
def generator_simplified_api(inputs, is_train=True, reuse=False):
13+
image_size = 64
14+
s2, s4, s8, s16 = int(image_size/2), int(image_size/4), int(image_size/8), int(image_size/16)
15+
gf_dim = 64 # Dimension of gen filters in first conv layer. [64]
16+
c_dim = FLAGS.c_dim # n_color 3
17+
batch_size = FLAGS.batch_size # 64
18+
19+
w_init = tf.random_normal_initializer(stddev=0.02)
20+
gamma_init = tf.random_normal_initializer(1., 0.02)
21+
22+
with tf.variable_scope("generator", reuse=reuse):
23+
tl.layers.set_name_reuse(reuse)
24+
25+
net_in = InputLayer(inputs, name='g/in')
26+
net_h0 = DenseLayer(net_in, n_units=gf_dim*8*s16*s16, W_init=w_init,
27+
act = tf.identity, name='g/h0/lin')
28+
net_h0 = ReshapeLayer(net_h0, shape=[-1, s16, s16, gf_dim*8], name='g/h0/reshape')
29+
net_h0 = BatchNormLayer(net_h0, act=tf.nn.relu, is_train=is_train,
30+
gamma_init=gamma_init, name='g/h0/batch_norm')
31+
32+
net_h1 = DeConv2d(net_h0, gf_dim*4, (5, 5), out_size=(s8, s8), strides=(2, 2),
33+
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h1/decon2d')
34+
net_h1 = BatchNormLayer(net_h1, act=tf.nn.relu, is_train=is_train,
35+
gamma_init=gamma_init, name='g/h1/batch_norm')
36+
37+
net_h2 = DeConv2d(net_h1, gf_dim*2, (5, 5), out_size=(s4, s4), strides=(2, 2),
38+
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h2/decon2d')
39+
net_h2 = BatchNormLayer(net_h2, act=tf.nn.relu, is_train=is_train,
40+
gamma_init=gamma_init, name='g/h2/batch_norm')
41+
42+
net_h3 = DeConv2d(net_h2, gf_dim, (5, 5), out_size=(s2, s2), strides=(2, 2),
43+
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h3/decon2d')
44+
net_h3 = BatchNormLayer(net_h3, act=tf.nn.relu, is_train=is_train,
45+
gamma_init=gamma_init, name='g/h3/batch_norm')
46+
47+
net_h4 = DeConv2d(net_h3, c_dim, (5, 5), out_size=(image_size, image_size), strides=(2, 2),
48+
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h4/decon2d')
49+
logits = net_h4.outputs
50+
net_h4.outputs = tf.nn.tanh(net_h4.outputs)
51+
return net_h4, logits
52+
53+
54+
def discriminator_simplified_api(inputs, is_train=True, reuse=False):
55+
df_dim = 64 # Dimension of discrim filters in first conv layer. [64]
56+
c_dim = FLAGS.c_dim # n_color 3
57+
batch_size = FLAGS.batch_size # 64
58+
59+
w_init = tf.random_normal_initializer(stddev=0.02)
60+
gamma_init = tf.random_normal_initializer(1., 0.02)
61+
62+
with tf.variable_scope("discriminator", reuse=reuse):
63+
tl.layers.set_name_reuse(reuse)
64+
65+
net_in = InputLayer(inputs, name='d/in')
66+
net_h0 = Conv2d(net_in, df_dim, (5, 5), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2),
67+
padding='SAME', W_init=w_init, name='d/h0/conv2d')
68+
69+
net_h1 = Conv2d(net_h0, df_dim*2, (5, 5), (2, 2), act=None,
70+
padding='SAME', W_init=w_init, name='d/h1/conv2d')
71+
net_h1 = BatchNormLayer(net_h1, act=lambda x: tl.act.lrelu(x, 0.2),
72+
is_train=is_train, gamma_init=gamma_init, name='d/h1/batch_norm')
73+
74+
net_h2 = Conv2d(net_h1, df_dim*4, (5, 5), (2, 2), act=None,
75+
padding='SAME', W_init=w_init, name='d/h2/conv2d')
76+
net_h2 = BatchNormLayer(net_h2, act=lambda x: tl.act.lrelu(x, 0.2),
77+
is_train=is_train, gamma_init=gamma_init, name='d/h2/batch_norm')
78+
79+
net_h3 = Conv2d(net_h2, df_dim*8, (5, 5), (2, 2), act=None,
80+
padding='SAME', W_init=w_init, name='d/h3/conv2d')
81+
net_h3 = BatchNormLayer(net_h3, act=lambda x: tl.act.lrelu(x, 0.2),
82+
is_train=is_train, gamma_init=gamma_init, name='d/h3/batch_norm')
83+
84+
net_h4 = FlattenLayer(net_h3, name='d/h4/flatten')
85+
net_h4 = DenseLayer(net_h4, n_units=1, act=tf.identity,
86+
W_init = w_init, name='d/h4/lin_sigmoid')
87+
logits = net_h4.outputs
88+
net_h4.outputs = tf.nn.sigmoid(net_h4.outputs)
89+
return net_h4, logits

tensorlayer/__init__.py

100644100755
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,7 @@
2525
from . import rein
2626

2727

28-
__version__ = "1.2.3"
28+
__version__ = "1.4.2"
29+
30+
global_flag = {}
31+
global_dict = {}
914 Bytes
Binary file not shown.
Binary file not shown.
19.1 KB
Binary file not shown.
28.1 KB
Binary file not shown.
9.11 KB
Binary file not shown.
169 KB
Binary file not shown.
31 KB
Binary file not shown.
4.71 KB
Binary file not shown.
52.4 KB
Binary file not shown.
3.21 KB
Binary file not shown.
13.8 KB
Binary file not shown.
9.94 KB
Binary file not shown.

tensorlayer/activation.py

100644100755
Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,13 @@ def pixel_wise_softmax(output, name='pixel_wise_softmax'):
9797
- `tf.reverse <https://www.tensorflow.org/versions/master/api_docs/python/array_ops.html#reverse>`_
9898
"""
9999
with tf.name_scope(name) as scope:
100-
exp_map = tf.exp(output)
101-
if output.get_shape().ndims == 4: # 2d image
102-
evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, True]))
103-
elif output.get_shape().ndims == 5: # 3d image
104-
evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, False, True]))
105-
else:
106-
raise Exception("output parameters should be 2d or 3d image, not %s" % str(output._shape))
107-
return tf.div(exp_map, evidence)
100+
return tf.nn.softmax(output)
101+
## old implementation
102+
# exp_map = tf.exp(output)
103+
# if output.get_shape().ndims == 4: # 2d image
104+
# evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, True]))
105+
# elif output.get_shape().ndims == 5: # 3d image
106+
# evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, False, True]))
107+
# else:
108+
# raise Exception("output parameters should be 2d or 3d image, not %s" % str(output._shape))
109+
# return tf.div(exp_map, evidence)

0 commit comments

Comments
 (0)