-
Notifications
You must be signed in to change notification settings - Fork 0
/
blocks.py
133 lines (98 loc) · 3.47 KB
/
blocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from typing import Callable, Optional
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from control import CONTROL_DICT
class MLPBlock(nn.Module):
"""
ReLU MLP with dropout, can be used as encoder or decoder
"""
def __init__(
self,
input_dim: int,
output_dim: int,
hidden_dim: int,
dropout_rate: float,
num_layers: int = 2,
norm: Optional[Callable] = None,
):
super().__init__()
assert num_layers >= 2, "an MLP needs at least 2 layers"
self.dropout_rate = dropout_rate
if norm is None:
norm = nn.Identity
layers = [nn.Linear(input_dim, hidden_dim)]
# only enters loop if num_layers >= 3
for i in range(num_layers - 2):
layers.append(norm(hidden_dim))
layers.append(nn.Linear(hidden_dim, hidden_dim))
layers.append(norm(hidden_dim))
layers.append(nn.Linear(hidden_dim, output_dim))
self.layers = nn.ModuleList(layers)
def forward(self, x):
for layer in self.layers[:-1]:
x = layer(x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout_rate, training=self.training)
# no relu or dropout after last linear layer
x = self.layers[-1](x)
return x
class GCNBlock(nn.Module):
"""
a block of GCN layers with (optional) control
has flags to make linear and / or time invariant
"""
def __init__(
self,
feature_dim: int,
depth: int,
dropout_rate: float,
linear: bool,
time_inv: bool,
residual: bool,
control_type: str,
**control_kwargs,
):
super().__init__()
self.depth = depth
self.dropout_rate = dropout_rate
self.linear = linear
self.time_inv = time_inv
self.residual = residual
# only one layer if time_inv
num_layers = 1 if self.time_inv else self.depth
self.conv_layers = []
for _ in range(num_layers):
self.conv_layers.append(GCNConv(feature_dim, feature_dim))
self.conv_layers = nn.ModuleList(self.conv_layers)
if control_type != "null":
control_factory = CONTROL_DICT[control_type]
self.control_layers = []
for _ in range(num_layers):
self.control_layers.append(
control_factory(feature_dim, **control_kwargs)
)
self.control_layers = nn.ModuleList(self.control_layers)
else:
self.control_layers = None
def forward(self, x, edge_index, control_edge_index=None):
for i in range(self.depth):
# handles both time_inv = True and time_inv = False
layer_index = i % len(self.conv_layers)
conv_out = self.conv_layers[layer_index](x, edge_index)
if self.control_layers is not None:
control_out = self.control_layers[layer_index](x, control_edge_index)
out = conv_out + control_out
else:
out = conv_out
if self.residual:
x = x + out
else:
x = out
# no dropout or relu (if non-linear) after final conv
if i != (self.depth - 1):
if not self.linear:
x = F.relu(x)
x = F.dropout(x, p=self.dropout_rate, training=self.training)
return x