-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
- Loading branch information
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch_geometric.nn as pyg_nn | ||
import torch_geometric.utils as pyg_utils | ||
|
||
class CustomConv(pyg_nn.MessagePassing): | ||
def __init__(self, in_channels, out_channels): | ||
super(CustomConv, self).__init__(aggr='add') # "Add" aggregation. | ||
self.lin = nn.Linear(in_channels, out_channels) | ||
self.lin_self = nn.Linear(in_channels, out_channels) | ||
|
||
torch.nn.init.xavier_uniform_(self.lin.weight, gain=0.001) | ||
torch.nn.init.xavier_uniform_(self.lin_self, gain=0.001) | ||
|
||
def forward(self, x, edge_index, edge_attribute): | ||
# x has shape [N, in_channels] | ||
# edge_index has shape [2, E] | ||
|
||
# Add self-loops to the adjacency matrix. | ||
edge_index, _ = pyg_utils.add_self_loops(edge_index) | ||
|
||
# Transform node feature matrix. | ||
self_x = self.lin_self(x) | ||
#x = self.lin(x) | ||
|
||
# here the linear layer is done on the neighbours | ||
return self_x + self.propagate(edge_index, edge_attr=edge_attribute, size=(x.size(0), x.size(0)), x=self.lin(x)) | ||
|
||
def message(self, x_j, edge_index, size, edge_attr): | ||
# Constructs messages to node in analogy to for each edge (i, j) | ||
# Note that we generally refer to as the central nodes that aggregates | ||
# information, and refer to as the neighboring nodes, since this is the most common notation | ||
# Compute messages | ||
# x_j has shape [E, out_channels] | ||
# _j refers to neighbours | ||
index_targets, index_neighbours = edge_index | ||
deg = pyg_utils.degree(index_targets, size[0], dtype=x_j.dtype) | ||
degree_inverse_sqrt = deg.pow(-0.5) | ||
norm = degree_inverse_sqrt[index_targets] * degree_inverse_sqrt[index_neighbours] | ||
|
||
return norm.view(-1, 1) * x_j | ||
#return x_j | ||
|
||
def update(self, aggr_out): | ||
# aggr_out has shape [N, out_channels] | ||
return aggr_out |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
|
||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch_geometric.nn as pyg_nn | ||
|
||
import MNIST.GraphLayer as GraphLayer | ||
import MNIST.EquivariantLayer as EquivariantLayer | ||
|
||
class EquivariantGNNStack(nn.Module): | ||
def __init__(self, input_dim, hidden_dim, output_dim, task='node'): | ||
super(EquivariantGNNStack, self).__init__() | ||
self.task = task | ||
self.convs = nn.ModuleList() | ||
self.convs.append(self.build_conv_model(input_dim, hidden_dim)) | ||
|
||
number_layers = 4 | ||
# self.lns = nn.Linear(input_dim, out_node_nf) | ||
for l in range(number_layers): | ||
self.convs.append(self.build_conv_model(hidden_dim, hidden_dim)) | ||
|
||
# pre_processing | ||
self.embedding_in = nn.Linear(input_dim, hidden_dim) | ||
|
||
# post-message-passing | ||
self.post_mp = nn.Sequential( | ||
nn.Linear(hidden_dim, hidden_dim), nn.Dropout(0.25), | ||
nn.Linear(hidden_dim, output_dim)) | ||
|
||
if not (self.task == 'node' or self.task == 'graph'): | ||
raise RuntimeError('Unknown task.') | ||
|
||
self.dropout = 0.25 | ||
self.num_layers = 3 | ||
|
||
def build_conv_model(self, input_dim, hidden_dim): | ||
|
||
return EquivariantLayer.EquivariantLayer(hidden_dim, hidden_dim) | ||
|
||
|
||
def forward(self, data): | ||
x, edge_index, batch, edge_attributes = data.x, data.edge_index, data.batch, data.edge_attr | ||
|
||
x = self.embedding_in(x) | ||
|
||
for i in range(self.num_layers): | ||
x = self.convs[i](x, edge_index, edge_attributes, batch) | ||
|
||
if self.task == 'graph': | ||
x = pyg_nn.global_mean_pool(x, batch) | ||
|
||
x = self.post_mp(x) | ||
|
||
return F.log_softmax(x, dim=1) | ||
|
||
def loss(self, pred, label): | ||
return F.nll_loss(pred, label) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch_geometric.nn as pyg_nn | ||
from torch_geometric.typing import Size | ||
|
||
|
||
class EquivariantLayer(pyg_nn.MessagePassing): | ||
def __init__(self, in_channels, out_channels): | ||
super(EquivariantLayer, self).__init__(aggr='add') # "Add" aggregation. | ||
act_fn = nn.SiLU() | ||
|
||
self.dropout = nn.Dropout(0.25) | ||
size_distance = 1 | ||
self.edge_mlp = nn.Sequential( | ||
nn.Linear(in_channels + in_channels + size_distance, out_channels), | ||
self.dropout, | ||
act_fn, | ||
nn.Linear(out_channels, in_channels), | ||
act_fn) | ||
|
||
self.node_mlp = nn.Sequential( | ||
nn.Linear(in_channels + in_channels, out_channels), | ||
self.dropout, | ||
act_fn, | ||
nn.Linear(out_channels, in_channels)) | ||
|
||
self.node_dim = 0 | ||
|
||
def forward(self, x, edge_index, edge_attr, batch): | ||
hidden_out = self.propagate(edge_index, x=x, edge_attr=edge_attr, | ||
batch=batch) | ||
|
||
return hidden_out | ||
|
||
def propagate(self, edge_index, size: Size = None, **kwargs): | ||
size = self.__check_input__(edge_index, size) | ||
hidden_feats = kwargs["x"] | ||
coll_dict = self.__collect__(self.__user_args__, edge_index, size, kwargs) | ||
msg_kwargs = self.inspector.distribute('message', coll_dict) | ||
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict) | ||
update_kwargs = self.inspector.distribute('update', coll_dict) | ||
|
||
# get messages | ||
m_ij = self.message(**msg_kwargs) | ||
edge_attr = kwargs["edge_attr"] | ||
|
||
# update feats if specified | ||
m_i = self.aggregate(m_ij, **aggr_kwargs) | ||
|
||
hidden_out = self.node_mlp(torch.cat([hidden_feats, m_i], dim=-1)) | ||
hidden_out = kwargs["x"] + hidden_out | ||
|
||
return hidden_out | ||
|
||
def message(self, x_i, x_j, edge_index, size, edge_attr): | ||
# adding a fake dimension | ||
edge_attr = edge_attr[:, None] | ||
message_input = torch.cat([x_j, x_i, edge_attr], dim=-1) | ||
message_transformed = self.edge_mlp(message_input) | ||
|
||
return message_transformed | ||
|
||
def update(self, aggr_out): | ||
# aggr_out has shape [N, out_channels] | ||
return aggr_out |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import torch.nn as nn | ||
import torch_geometric.nn as pyg_nn | ||
import torch | ||
import torch.nn.functional as F | ||
import torch_geometric.utils as pyg_utils | ||
import utils as functions | ||
|
||
class GNNStack(nn.Module): | ||
def __init__(self, input_dim, hidden_dim, output_dim): | ||
super(GNNStack, self).__init__() | ||
|
||
self.dropout = 0.25 | ||
self.number_layers = 2 | ||
|
||
self.convs = nn.ModuleList() | ||
# first layer has different parameters | ||
self.convs.append(pyg_nn.GCNConv(input_dim, hidden_dim)) | ||
for _ in range(self.number_layers - 1): | ||
self.convs.append(CustomConv(hidden_dim, hidden_dim)) | ||
|
||
self.post_message_passing = nn.ModuleList() | ||
for _ in range(self.number_layers - 1): | ||
self.post_message_passing.append(nn.LayerNorm(hidden_dim)) | ||
|
||
self.post_processing = nn.Linear(hidden_dim, output_dim) | ||
|
||
|
||
|
||
def forward(self, data): | ||
# x is the convention for node_features | ||
x, edge_index, batch, edge_attribute = data.x, data.edge_index, data.batch, data.edge_attr | ||
|
||
i_last_layer = self.number_layers - 1 | ||
|
||
for i_layer in range(self.number_layers): | ||
x = self.convs[i_layer](x, edge_index, edge_attribute) | ||
x = F.relu(x) | ||
x = F.dropout(x, p=self.dropout, training=self.training) | ||
|
||
is_last_layer = i_layer == i_last_layer | ||
|
||
if not is_last_layer: | ||
x = self.post_message_passing[i_layer](x) | ||
|
||
x = pyg_nn.global_mean_pool(x, batch) | ||
out = self.post_processing(x) | ||
|
||
|
||
|
||
return out | ||
|
||
def predict(self, data): | ||
prediction = self.forward(data) | ||
prediction = nn.Sigmoid()(prediction) | ||
prediction = functions.converting_probability_array_to_binary(prediction, threshold=0.5) | ||
|
||
|
||
|
||
return prediction | ||
|
||
|
||
|
||
class CustomConv(pyg_nn.MessagePassing): | ||
def __init__(self, in_channels, out_channels): | ||
aggregation_method = "add" | ||
super(CustomConv, self).__init__(aggr=aggregation_method) | ||
self.lin = nn.Linear(in_channels, out_channels) | ||
self.lin_self = nn.Linear(in_channels, out_channels) | ||
|
||
torch.nn.init.xavier_uniform_(self.lin.weight, gain=0.001) | ||
torch.nn.init.xavier_uniform_(self.lin_self.weight, gain=0.001) | ||
|
||
def forward(self, x, edge_index, edge_attribute): | ||
|
||
edge_index, _ = pyg_utils.add_self_loops(edge_index) | ||
|
||
x = self.lin_self(x) | ||
|
||
row, col = edge_index | ||
deg = pyg_utils.degree(col, x.size(0), dtype=x.dtype) | ||
deg_inv_sqrt = deg.pow(-0.5) | ||
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 | ||
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] | ||
|
||
propagation = self.propagate(edge_index=edge_index, x=x, | ||
norm=norm, edge_attr=edge_attribute, | ||
size=(x.size(0), x.size(0))) | ||
|
||
return propagation | ||
|
||
def message(self, x_j, edge_index, size, edge_attr, norm): | ||
|
||
# x_j is the neighbour node | ||
|
||
message = norm.view(-1, 1) * x_j | ||
|
||
return message |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch_geometric.nn as pyg_nn | ||
|
||
class GNNStack(nn.Module): | ||
def __init__(self, input_dim, hidden_dim, output_dim, task='node'): | ||
super(GNNStack, self).__init__() | ||
self.task = task | ||
self.convs = nn.ModuleList() | ||
self.convs.append(self.build_conv_model(input_dim, hidden_dim)) | ||
self.lns = nn.ModuleList() | ||
self.lns.append(nn.LayerNorm(hidden_dim)) | ||
self.lns.append(nn.LayerNorm(hidden_dim)) | ||
for l in range(2): | ||
self.convs.append(self.build_conv_model(hidden_dim, hidden_dim)) | ||
|
||
# post-message-passing | ||
self.post_mp = nn.Sequential( | ||
nn.Linear(hidden_dim, hidden_dim), nn.Dropout(0.25), | ||
nn.Linear(hidden_dim, output_dim)) | ||
if not (self.task == 'node' or self.task == 'graph'): | ||
raise RuntimeError('Unknown task.') | ||
|
||
self.dropout = 0.25 | ||
self.num_layers = 3 | ||
|
||
def build_conv_model(self, input_dim, hidden_dim): | ||
# refer to pytorch geometric nn module for different implementation of GNNs. | ||
if self.task == 'node': | ||
return pyg_nn.GCNConv(input_dim, hidden_dim) | ||
else: | ||
return pyg_nn.GINConv(nn.Sequential(nn.Linear(input_dim, hidden_dim), | ||
nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))) | ||
|
||
def forward(self, data): | ||
x, edge_index, batch = data.x, data.edge_index, data.batch | ||
if data.num_node_features == 0: | ||
x = torch.ones(data.num_nodes, 1) | ||
|
||
for i in range(self.num_layers): | ||
x = self.convs[i](x, edge_index) | ||
x = F.relu(x) | ||
x = F.dropout(x, p=self.dropout, training=self.training) | ||
if not i == self.num_layers - 1: | ||
x = self.lns[i](x) | ||
|
||
if self.task == 'graph': | ||
x = pyg_nn.global_mean_pool(x, batch) | ||
|
||
x = self.post_mp(x) | ||
|
||
return F.log_softmax(x, dim=1) | ||
|
||
def loss(self, pred, label): | ||
return F.nll_loss(pred, label) |