-
Notifications
You must be signed in to change notification settings - Fork 2
/
layers.py
120 lines (103 loc) · 3.82 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
################################################################################
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""Custom layers for CURL."""
from absl import logging
import sonnet as snt
import tensorflow.compat.v1 as tf
tfc = tf.compat.v1
class ResidualStack(snt.AbstractModule):
"""A stack of ResNet V2 blocks."""
def __init__(self,
num_hiddens,
num_residual_layers,
num_residual_hiddens,
filter_size=3,
initializers=None,
data_format='NHWC',
activation=tf.nn.relu,
name='residual_stack'):
"""Instantiate a ResidualStack."""
super(ResidualStack, self).__init__(name=name)
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens
self._filter_size = filter_size
self._initializers = initializers
self._data_format = data_format
self._activation = activation
def _build(self, h):
for i in range(self._num_residual_layers):
h_i = self._activation(h)
h_i = snt.Conv2D(
output_channels=self._num_residual_hiddens,
kernel_shape=(self._filter_size, self._filter_size),
stride=(1, 1),
initializers=self._initializers,
data_format=self._data_format,
name='res_nxn_%d' % i)(
h_i)
h_i = self._activation(h_i)
h_i = snt.Conv2D(
output_channels=self._num_hiddens,
kernel_shape=(1, 1),
stride=(1, 1),
initializers=self._initializers,
data_format=self._data_format,
name='res_1x1_%d' % i)(
h_i)
h += h_i
return self._activation(h)
class SharedConvModule(snt.AbstractModule):
"""Convolutional decoder."""
def __init__(self,
filters,
kernel_size,
activation,
strides,
name='shared_conv_encoder'):
super(SharedConvModule, self).__init__(name=name)
self._filters = filters
self._kernel_size = kernel_size
self._activation = activation
self.strides = strides
assert len(strides) == len(filters) - 1
self.conv_shapes = None
def _build(self, x, is_training=True):
with tf.control_dependencies([tfc.assert_rank(x, 4)]):
self.conv_shapes = [x.shape.as_list()] # Needed by deconv module
conv = x
for i, (filter_i,
stride_i) in enumerate(zip(self._filters, self.strides), 1):
conv = tf.layers.Conv2D(
filters=filter_i,
kernel_size=self._kernel_size,
padding='same',
activation=self._activation,
strides=stride_i,
name='enc_conv_%d' % i)(
conv)
self.conv_shapes.append(conv.shape.as_list())
conv_flat = snt.BatchFlatten()(conv)
enc_mlp = snt.nets.MLP(
name='enc_mlp',
output_sizes=[self._filters[-1]],
activation=self._activation,
activate_final=True)
h = enc_mlp(conv_flat)
logging.info('Shared conv module layer shapes:')
logging.info('\n'.join([str(el) for el in self.conv_shapes]))
logging.info(h.shape.as_list())
return h