forked from roimehrez/contextualLoss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
35 lines (27 loc) · 1.51 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# ---------------------------------------------------
# code credits: https://github.com/CQFIO/PhotographicImageSynthesis
# ---------------------------------------------------
import tensorflow.contrib.slim as slim
from vgg_model import *
from config import *
# this function have been modify such that the images are portrait and not landscape
def recursive_generator(input_image, width):
ar = config.TRAIN.aspect_ratio
if width >= 128:
dim = 512 // config.TRAIN.reduce_dim
else:
dim = 1024 // config.TRAIN.reduce_dim
if width == 4:
input = input_image
else:
downsampled_width = width // 2
downsampled_input = tf.image.resize_area(input_image, (downsampled_width, downsampled_width // ar), align_corners=False)
recursive_call = recursive_generator(downsampled_input, downsampled_width)
predicted_on_downsampled = tf.image.resize_bilinear(recursive_call, (width, width // ar), align_corners=True)
input = tf.concat([predicted_on_downsampled, input_image], 3)
net = slim.conv2d(input, dim, [3, 3], rate=1, normalizer_fn=slim.layer_norm, activation_fn=lrelu, scope='g_' + str(width) + '_conv1')
net = slim.conv2d(net, dim, [3, 3], rate=1, normalizer_fn=slim.layer_norm, activation_fn=lrelu, scope='g_' + str(width) + '_conv2')
if width == config.TRAIN.sp*config.TRAIN.aspect_ratio:
net = slim.conv2d(net, 3, [1, 1], rate=1, activation_fn=None, scope='g_' + str(width) + '_conv100')
net = (net + 1.0) / 2.0 * 255.0
return net