-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcgcnn_modules.py
182 lines (153 loc) · 5.84 KB
/
cgcnn_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
@reference: https://github.com/txie-93/cgcnn
"""
import torch.nn as nn
import torch
from nwr_gae_layers import PairNorm
class ConvLayer(nn.Module):
"""
Convolutional operation on graphs
"""
def __init__(self, atom_fea_len, nbr_fea_len):
"""
Initialize ConvLayer.
Parameters
----------
atom_fea_len: int
Number of atom hidden features.
nbr_fea_len: int
Number of bond features.
"""
super(ConvLayer, self).__init__()
self.atom_fea_len = atom_fea_len
self.nbr_fea_len = nbr_fea_len
self.fc_full = nn.Linear(2*self.atom_fea_len+self.nbr_fea_len,
2*self.atom_fea_len)
self.sigmoid = nn.Sigmoid()
self.softplus1 = nn.Softplus()
self.bn1 = nn.BatchNorm1d(2*self.atom_fea_len)
self.bn2 = nn.BatchNorm1d(self.atom_fea_len)
self.softplus2 = nn.Softplus()
def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
"""
Forward pass
N: Total number of atoms in the batch
M: Max number of neighbors
Parameters
----------
atom_in_fea: Variable(torch.Tensor) shape (N, atom_fea_len)
Atom hidden features before convolution
nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len)
Bond features of each atom's M neighbors
nbr_fea_idx: torch.LongTensor shape (N, M)
Indices of M neighbors of each atom
Returns
-------
atom_out_fea: nn.Variable shape (N, atom_fea_len)
Atom hidden features after convolution
"""
# TODO will there be problems with the index zero padding?
N, M = nbr_fea_idx.shape
# convolution
atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]
total_nbr_fea = torch.cat(
[atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len),
atom_nbr_fea, nbr_fea], dim=2)
total_gated_fea = self.fc_full(total_nbr_fea)
total_gated_fea = self.bn1(total_gated_fea.view(
-1, self.atom_fea_len*2)).view(N, M, self.atom_fea_len*2)
nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2)
nbr_filter = self.sigmoid(nbr_filter)
nbr_core = self.softplus1(nbr_core)
nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1)
nbr_sumed = self.bn2(nbr_sumed)
out = self.softplus2(atom_in_fea + nbr_sumed)
return out
class CrystalGraphConvNet(nn.Module):
"""
Create a crystal graph convolutional neural network for predicting total
material properties.
"""
def __init__(self, orig_atom_fea_len, nbr_fea_len,
atom_fea_len, n_conv):
"""
Initialize CrystalGraphConvNet.
Parameters
----------
orig_atom_fea_len: int
Number of atom features in the input.
nbr_fea_len: int
Number of bond features.
atom_fea_len: int
Number of hidden atom features in the convolutional layers
n_conv: int
Number of convolutional layers
h_fea_len: int
Number of hidden features after pooling
n_h: int
Number of hidden layers after pooling
"""
super(CrystalGraphConvNet, self).__init__()
self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len)
self.convs = nn.ModuleList([ConvLayer(atom_fea_len=atom_fea_len,
nbr_fea_len=nbr_fea_len)
for _ in range(n_conv)])
self.softplus = nn.Softplus()
# self.pair_norm = PairNorm(mode='PN-SCS')
self.pre_train_mode = False
def set_pre_train(self, pre_train):
"""
Parameters
----------
pre_train: bool
"""
self.pre_train_mode = pre_train
def forward(self, atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx):
"""
Forward pass
N: Total number of atoms in the batch
M: Max number of neighbors
N0: Total number of crystals in the batch
Parameters
----------
atom_fea: Variable(torch.Tensor) shape (N, orig_atom_fea_len)
Atom features from atom type
nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len)
Bond features of each atom's M neighbors
nbr_fea_idx: torch.LongTensor shape (N, M)
Indices of M neighbors of each atom
crystal_atom_idx: list of torch.LongTensor of length N0
Mapping from the crystal idx to atom idx
Returns
-------
prediction: nn.Variable shape (N, )
Atom hidden features after convolution
"""
outputs_per_layer = [atom_fea]
atom_fea = self.embedding(atom_fea)
for conv_func in self.convs:
atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx)
# atom_fea = self.pair_norm(atom_fea)
outputs_per_layer.append(atom_fea)
if self.pre_train_mode:
return outputs_per_layer
else:
crys_fea = self.softplus(self.pooling(atom_fea, crystal_atom_idx))
return crys_fea
def pooling(self, atom_fea, crystal_atom_idx):
"""
Pooling the atom features to crystal features
N: Total number of atoms in the batch
N0: Total number of crystals in the batch
Parameters
----------
atom_fea: Variable(torch.Tensor) shape (N, atom_fea_len)
Atom feature vectors of the batch
crystal_atom_idx: list of torch.LongTensor of length N0
Mapping from the crystal idx to atom idx
"""
assert sum([len(idx_map) for idx_map in crystal_atom_idx]) ==\
atom_fea.data.shape[0]
summed_fea = [torch.mean(atom_fea[idx_map], dim=0, keepdim=True)
for idx_map in crystal_atom_idx]
return torch.cat(summed_fea, dim=0)