-
Notifications
You must be signed in to change notification settings - Fork 34
/
decoder.py
74 lines (54 loc) · 2.04 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
from encoder import OUT_DIM
class PixelDecoder(nn.Module):
def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32):
super().__init__()
self.num_layers = num_layers
self.num_filters = num_filters
self.out_dim = OUT_DIM[num_layers]
self.fc = nn.Linear(
feature_dim, num_filters * self.out_dim * self.out_dim
)
self.deconvs = nn.ModuleList()
for i in range(self.num_layers - 1):
self.deconvs.append(
nn.ConvTranspose2d(num_filters, num_filters, 3, stride=1)
)
self.deconvs.append(
nn.ConvTranspose2d(
num_filters, obs_shape[0], 3, stride=2, output_padding=1
)
)
self.outputs = dict()
def forward(self, h):
h = torch.relu(self.fc(h))
self.outputs['fc'] = h
deconv = h.view(-1, self.num_filters, self.out_dim, self.out_dim)
self.outputs['deconv1'] = deconv
for i in range(0, self.num_layers - 1):
deconv = torch.relu(self.deconvs[i](deconv))
self.outputs['deconv%s' % (i + 1)] = deconv
obs = self.deconvs[-1](deconv)
self.outputs['obs'] = obs
return obs
def log(self, L, step, log_freq):
if step % log_freq != 0:
return
for k, v in self.outputs.items():
L.log_histogram('train_decoder/%s_hist' % k, v, step)
if len(v.shape) > 2:
L.log_image('train_decoder/%s_i' % k, v[0], step)
for i in range(self.num_layers):
L.log_param(
'train_decoder/deconv%s' % (i + 1), self.deconvs[i], step
)
L.log_param('train_decoder/fc', self.fc, step)
_AVAILABLE_DECODERS = {'pixel': PixelDecoder}
def make_decoder(
decoder_type, obs_shape, feature_dim, num_layers, num_filters
):
assert decoder_type in _AVAILABLE_DECODERS
return _AVAILABLE_DECODERS[decoder_type](
obs_shape, feature_dim, num_layers, num_filters
)