forked from lcswillems/rl-starter-files
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
104 lines (85 loc) · 3.45 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
import torch_rl
# Function from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py
def initialize_parameters(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
m.weight.data.normal_(0, 1)
m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
if m.bias is not None:
m.bias.data.fill_(0)
class ACModel(nn.Module, torch_rl.RecurrentACModel):
def __init__(self, obs_space, action_space, use_instr=True, use_memory=True):
super().__init__()
# Decide which components are enabled
self.use_instr = use_instr
self.use_memory = use_memory
# Define image embedding
self.image_embedding_size = 64
self.image_conv = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(2, 2)),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2), stride=2),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(2, 2)),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=self.image_embedding_size, kernel_size=(2, 2)),
nn.ReLU()
)
# Define memory
if self.use_memory:
self.memory_rnn = nn.LSTMCell(self.image_embedding_size, self.semi_memory_size)
# Define instruction embedding
if self.use_instr:
self.word_embedding_size = 32
self.word_embedding = nn.Embedding(obs_space["instr"], self.word_embedding_size)
self.instr_embedding_size = 128
self.instr_rnn = nn.GRU(self.word_embedding_size, self.instr_embedding_size, batch_first=True)
# Resize image embedding
self.embedding_size = self.semi_memory_size
if self.use_instr:
self.embedding_size += self.instr_embedding_size
# Define actor's model
self.actor = nn.Sequential(
nn.Linear(self.embedding_size, 64),
nn.Tanh(),
nn.Linear(64, action_space.n)
)
# Define critic's model
self.critic = nn.Sequential(
nn.Linear(self.embedding_size, 64),
nn.Tanh(),
nn.Linear(64, 1)
)
# Initialize parameters correctly
self.apply(initialize_parameters)
@property
def memory_size(self):
return 2*self.semi_memory_size
@property
def semi_memory_size(self):
return self.image_embedding_size
def forward(self, obs, memory):
x = torch.transpose(torch.transpose(obs.image, 1, 3), 2, 3)
x = self.image_conv(x)
x = x.reshape(x.shape[0], -1)
if self.use_memory:
hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:])
hidden = self.memory_rnn(x, hidden)
embedding = hidden[0]
memory = torch.cat(hidden, dim=1)
else:
embedding = x
if self.use_instr:
embed_instr = self._get_embed_instr(obs.instr)
embedding = torch.cat((embedding, embed_instr), dim=1)
x = self.actor(embedding)
dist = Categorical(logits=F.log_softmax(x, dim=1))
x = self.critic(embedding)
value = x.squeeze(1)
return dist, value, memory
def _get_embed_instr(self, instr):
_, hidden = self.instr_rnn(self.word_embedding(instr))
return hidden[-1]