forked from harpArk614/3d-pose-warping
-
Notifications
You must be signed in to change notification settings - Fork 1
/
encoder.py
112 lines (91 loc) · 3.2 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#Function for 2D Residual Block
def relu_bn(inputs: Tensor) -> Tensor:
relu = ReLU()(inputs)
bn = BatchNormalization()(relu)
return bn
def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:
y = Conv2D(kernel_size=kernel_size,
strides= (1 if not downsample else 2),
filters=filters,
padding="same")(x)
y = relu_bn(y)
y = Conv2D(kernel_size=kernel_size,
strides=1,
filters=filters,
padding="same")(y)
if downsample:
x = Conv2D(kernel_size=1,
strides=2,
filters=filters,
padding="same")(x)
out = Add()([x, y])
out = relu_bn(out)
return out
def residual_block_decode(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:
y = layers.Conv2DTranspose(kernel_size=kernel_size,
strides= (1 if not downsample else 2),
filters=filters,
padding="same")(x)
y = relu_bn(y)
y = layers.Conv2DTranspose(kernel_size=kernel_size,
strides=1,
filters=filters,
padding="same")(y)
if downsample:
x = layers.Conv2DTranspose(kernel_size=1,
strides=2,
filters=filters,
padding="same")(x)
out = Add()([x, y])
out = relu_bn(out)
return out
#Function for 3D Residual Block
def relu_bn3d(inputs: Tensor) -> Tensor:
relu = ReLU()(inputs)
bn = GroupNormalization()(relu)
return bn
def residual_block3d(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:
y = Conv3D(kernel_size=kernel_size,
strides= (1 if not downsample else 2),
filters=filters,
padding="same")(x)
y = relu_bn3d(y)
y = Conv3D(kernel_size=kernel_size,
strides=1,
filters=filters,
padding="same")(y)
if downsample:
x = Conv3D(kernel_size=1,
strides=2,
filters=filters,
padding="same")(x)
out = Add()([x, y])
out = relu_bn3d(out)
return out
# Function to create Encoder model
def create_resnet_encoder():
inputs = Input(shape=(256, 256, 3))
t = BatchNormalization()(inputs)
t = Conv2D(kernel_size=3,
strides=1,
filters=64,
padding="same")(t)
t = relu_bn(t)
num_blocks_list = [2, 2, 3, 3]
for i in range(len(num_blocks_list)):
num_blocks = num_blocks_list[i]
for j in range(num_blocks):
t = residual_block(t, downsample=(j==0 and i!=0), filters=64)
t = Conv2D(kernel_size=3,
strides=1,
filters=128,
padding="same")(t)
t = layers.Reshape((32, 32, 64, 2))(t)
t = layers.Conv3D(64, (3, 3, 3), padding='same', strides=(1, 1, 1))(t)
num_blocks_list = [2]
for i in range(len(num_blocks_list)):
num_blocks = num_blocks_list[i]
for j in range(num_blocks):
t = residual_block3d(t, downsample=(j==0 and i!=0), filters=64)
outputs = t
return (inputs,outputs)