From 9a0511c8e91a7f633c9c3292fccbcbad5281d1f5 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 4 Nov 2019 06:29:42 +0800 Subject: [PATCH] [NN] nn modules & examples update (#890) * upd * damn it * fuck * fuck pylint * fudge * remove some comments about MXNet * upd * upd * damn it * damn it * fuck * fuck * upd * upd * pylint bastard * upd * upd * upd * upd * upd * upd * upd * upd * upd --- docs/source/api/python/nn.mxnet.rst | 100 ++++++ examples/mxnet/appnp/README.md | 32 ++ examples/mxnet/appnp/appnp.py | 170 ++++++++++ examples/mxnet/gat/gat.py | 97 +----- examples/mxnet/gcn/train.py | 1 + examples/mxnet/gin/README.md | 41 +++ examples/mxnet/gin/dataloader.py | 103 ++++++ examples/mxnet/gin/gin.py | 163 +++++++++ examples/mxnet/gin/main.py | 158 +++++++++ examples/mxnet/gin/parser.py | 86 +++++ examples/mxnet/graphsage/README.md | 27 ++ examples/mxnet/graphsage/main.py | 163 +++++++++ examples/mxnet/monet/README.md | 16 + examples/mxnet/monet/citation.py | 173 ++++++++++ examples/mxnet/sgc/README.md | 34 ++ examples/mxnet/sgc/sgc.py | 122 +++++++ examples/pytorch/appnp/appnp.py | 2 - examples/pytorch/appnp/train.py | 7 - examples/pytorch/gin/dataloader.py | 14 +- examples/pytorch/gin/gin.py | 55 ++- examples/pytorch/gin/main.py | 20 +- examples/pytorch/graphsage/graphsage.py | 1 - .../pytorch/model_zoo/geometric/.gitignore | 1 + .../pytorch/model_zoo/geometric/README.md | 25 ++ .../pytorch/model_zoo/geometric/coarsening.py | 312 ++++++++++++++++++ .../pytorch/model_zoo/geometric/coordinate.py | 30 ++ .../pytorch/model_zoo/geometric/grid_graph.py | 69 ++++ examples/pytorch/model_zoo/geometric/mnist.py | 182 ++++++++++ examples/pytorch/monet/README.md | 19 ++ examples/pytorch/monet/citation.py | 189 +++++++++++ python/dgl/data/gindt.py | 5 +- python/dgl/nn/mxnet/conv/__init__.py | 19 +- python/dgl/nn/mxnet/conv/agnnconv.py | 66 ++++ python/dgl/nn/mxnet/conv/appnpconv.py | 75 +++++ python/dgl/nn/mxnet/conv/chebconv.py | 123 +++++++ python/dgl/nn/mxnet/conv/densechebconv.py | 100 ++++++ python/dgl/nn/mxnet/conv/densegraphconv.py | 98 ++++++ python/dgl/nn/mxnet/conv/densesageconv.py | 85 +++++ python/dgl/nn/mxnet/conv/edgeconv.py | 76 +++++ python/dgl/nn/mxnet/conv/gatconv.py | 123 +++++++ python/dgl/nn/mxnet/conv/gatedgraphconv.py | 95 ++++++ python/dgl/nn/mxnet/conv/ginconv.py | 79 +++++ python/dgl/nn/mxnet/conv/gmmconv.py | 128 +++++++ python/dgl/nn/mxnet/conv/nnconv.py | 109 ++++++ python/dgl/nn/mxnet/conv/sageconv.py | 121 +++++++ python/dgl/nn/mxnet/conv/sgconv.py | 100 ++++++ python/dgl/nn/mxnet/utils.py | 35 +- python/dgl/nn/pytorch/conv/agnnconv.py | 4 +- python/dgl/nn/pytorch/conv/appnpconv.py | 10 +- python/dgl/nn/pytorch/conv/chebconv.py | 2 +- python/dgl/nn/pytorch/conv/densesageconv.py | 2 +- python/dgl/nn/pytorch/conv/edgeconv.py | 6 +- python/dgl/nn/pytorch/conv/gatedgraphconv.py | 29 +- python/dgl/nn/pytorch/conv/gmmconv.py | 8 +- python/dgl/nn/pytorch/conv/sgconv.py | 10 +- tests/mxnet/test_nn.py | 218 ++++++++++++ tests/pytorch/test_nn.py | 16 +- third_party/dmlc-core | 2 +- 58 files changed, 3950 insertions(+), 206 deletions(-) create mode 100644 examples/mxnet/appnp/README.md create mode 100644 examples/mxnet/appnp/appnp.py create mode 100644 examples/mxnet/gin/README.md create mode 100644 examples/mxnet/gin/dataloader.py create mode 100644 examples/mxnet/gin/gin.py create mode 100644 examples/mxnet/gin/main.py create mode 100644 examples/mxnet/gin/parser.py create mode 100644 examples/mxnet/graphsage/README.md create mode 100644 examples/mxnet/graphsage/main.py create mode 100644 examples/mxnet/monet/README.md create mode 100644 examples/mxnet/monet/citation.py create mode 100644 examples/mxnet/sgc/README.md create mode 100644 examples/mxnet/sgc/sgc.py create mode 100644 examples/pytorch/model_zoo/geometric/.gitignore create mode 100644 examples/pytorch/model_zoo/geometric/README.md create mode 100644 examples/pytorch/model_zoo/geometric/coarsening.py create mode 100644 examples/pytorch/model_zoo/geometric/coordinate.py create mode 100644 examples/pytorch/model_zoo/geometric/grid_graph.py create mode 100644 examples/pytorch/model_zoo/geometric/mnist.py create mode 100644 examples/pytorch/monet/README.md create mode 100644 examples/pytorch/monet/citation.py create mode 100644 python/dgl/nn/mxnet/conv/agnnconv.py create mode 100644 python/dgl/nn/mxnet/conv/appnpconv.py create mode 100644 python/dgl/nn/mxnet/conv/chebconv.py create mode 100644 python/dgl/nn/mxnet/conv/densechebconv.py create mode 100644 python/dgl/nn/mxnet/conv/densegraphconv.py create mode 100644 python/dgl/nn/mxnet/conv/densesageconv.py create mode 100644 python/dgl/nn/mxnet/conv/edgeconv.py create mode 100644 python/dgl/nn/mxnet/conv/gatconv.py create mode 100644 python/dgl/nn/mxnet/conv/gatedgraphconv.py create mode 100644 python/dgl/nn/mxnet/conv/ginconv.py create mode 100644 python/dgl/nn/mxnet/conv/gmmconv.py create mode 100644 python/dgl/nn/mxnet/conv/nnconv.py create mode 100644 python/dgl/nn/mxnet/conv/sageconv.py create mode 100644 python/dgl/nn/mxnet/conv/sgconv.py diff --git a/docs/source/api/python/nn.mxnet.rst b/docs/source/api/python/nn.mxnet.rst index 5e4b125da3d9..707354fa2a26 100644 --- a/docs/source/api/python/nn.mxnet.rst +++ b/docs/source/api/python/nn.mxnet.rst @@ -38,6 +38,106 @@ TAGConv :members: forward :show-inheritance: +GATConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.mxnet.conv.GATConv + :members: forward + :show-inheritance: + +EdgeConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.mxnet.conv.EdgeConv + :members: forward + :show-inheritance: + +SAGEConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.mxnet.conv.SAGEConv + :members: forward + :show-inheritance: + +SGConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.mxnet.conv.SGConv + :members: forward + :show-inheritance: + +APPNPConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.mxnet.conv.APPNPConv + :members: forward + :show-inheritance: + +GINConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.mxnet.conv.GINConv + :members: forward + :show-inheritance: + +GatedGraphConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.mxnet.conv.GatedGraphConv + :members: forward + :show-inheritance: + +GMMConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.mxnet.conv.GMMConv + :members: forward + :show-inheritance: + +ChebConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.mxnet.conv.ChebConv + :members: forward + :show-inheritance: + +AGNNConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.mxnet.conv.AGNNConv + :members: forward + :show-inheritance: + +NNConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.mxnet.conv.NNConv + :members: forward + :show-inheritance + +Dense Conv Layers +---------------------------------------- + +DenseGraphConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.mxnet.conv.DenseGraphConv + :members: forward + :show-inheritance: + +DenseSAGEConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.mxnet.conv.DenseSAGEConv + :members: forward + :show-inheritance + +DenseChebConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.pytorch.conv.DenseChebConv + :members: forward + :show-inheritance: Global Pooling Layers ---------------------------------------- diff --git a/examples/mxnet/appnp/README.md b/examples/mxnet/appnp/README.md new file mode 100644 index 000000000000..3fd40fe7b773 --- /dev/null +++ b/examples/mxnet/appnp/README.md @@ -0,0 +1,32 @@ +Predict then Propagate: Graph Neural Networks meet Personalized PageRank (APPNP) +============ + +- Paper link: [Predict then Propagate: Graph Neural Networks meet Personalized PageRank](https://arxiv.org/abs/1810.05997) +- Author's code repo: [https://github.com/klicperajo/ppnp](https://github.com/klicperajo/ppnp). + +Dependencies +------------ +- MXNET 1.5+ +- requests + +``bash +pip install torch requests +`` + +Code +----- +The folder contains an implementation of APPNP (`appnp.py`). + +Results +------- + +Run with following (available dataset: "cora", "citeseer", "pubmed") +```bash +DGLBACKEND=mxnet python3 train.py --dataset cora --gpu 0 +``` + +* cora: 0.8370 (paper: 0.850) +* citeseer: 0.713 (paper: 0.757) +* pubmed: 0.798 (paper: 0.797) + +Experiments were done on dgl datasets (GCN settings) which are different from those used in the original implementation. (discrepancies are detailed in experimental section of the original paper) \ No newline at end of file diff --git a/examples/mxnet/appnp/appnp.py b/examples/mxnet/appnp/appnp.py new file mode 100644 index 000000000000..82d045992eb8 --- /dev/null +++ b/examples/mxnet/appnp/appnp.py @@ -0,0 +1,170 @@ +import argparse, time +import numpy as np +import dgl +import mxnet as mx +from mxnet import nd, gluon +from mxnet.gluon import nn +from dgl import DGLGraph +from dgl.data import register_data_args, load_data +from dgl.nn.mxnet.conv import APPNPConv + +class APPNP(nn.Block): + def __init__(self, + g, + in_feats, + hiddens, + n_classes, + activation, + feat_drop, + edge_drop, + alpha, + k): + super(APPNP, self).__init__() + self.g = g + + with self.name_scope(): + self.layers = nn.Sequential() + # input layer + self.layers.add(nn.Dense(hiddens[0], in_units=in_feats)) + # hidden layers + for i in range(1, len(hiddens)): + self.layers.add(nn.Dense(hiddens[i], in_units=hiddens[i - 1])) + # output layer + self.layers.add(nn.Dense(n_classes, in_units=hiddens[-1])) + self.activation = activation + if feat_drop: + self.feat_drop = nn.Dropout(feat_drop) + else: + self.feat_drop = lambda x: x + self.propagate = APPNPConv(k, alpha, edge_drop) + + def forward(self, features): + # prediction step + h = features + h = self.feat_drop(h) + h = self.activation(self.layers[0](h)) + for layer in self.layers[1:-1]: + h = self.activation(layer(h)) + h = self.layers[-1](self.feat_drop(h)) + # propagation step + h = self.propagate(self.g, h) + return h + +def evaluate(model, features, labels, mask): + pred = model(features).argmax(axis=1) + accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar() + return accuracy.asscalar() + +def main(args): + # load and preprocess dataset + data = load_data(args) + features = nd.array(data.features) + labels = nd.array(data.labels) + train_mask = nd.array(data.train_mask) + val_mask = nd.array(data.val_mask) + test_mask = nd.array(data.test_mask) + + in_feats = features.shape[1] + n_classes = data.num_labels + n_edges = data.graph.number_of_edges() + print("""----Data statistics------' + #Edges %d + #Classes %d + #Train samples %d + #Val samples %d + #Test samples %d""" % + (n_edges, n_classes, + train_mask.sum().asscalar(), + val_mask.sum().asscalar(), + test_mask.sum().asscalar())) + + if args.gpu < 0: + ctx = mx.cpu() + else: + ctx = mx.gpu(args.gpu) + + features = features.as_in_context(ctx) + labels = labels.as_in_context(ctx) + train_mask = train_mask.as_in_context(ctx) + val_mask = val_mask.as_in_context(ctx) + test_mask = test_mask.as_in_context(ctx) + + # graph preprocess and calculate normalization factor + g = DGLGraph(data.graph) + n_edges = g.number_of_edges() + # add self loop + g.add_edges(g.nodes(), g.nodes()) + g.set_n_initializer(dgl.init.zero_initializer) + g.set_e_initializer(dgl.init.zero_initializer) + + # create APPNP model + model = APPNP(g, + in_feats, + args.hidden_sizes, + n_classes, + nd.relu, + args.in_drop, + args.edge_drop, + args.alpha, + args.k) + + model.initialize(ctx=ctx) + n_train_samples = train_mask.sum().asscalar() + loss_fcn = gluon.loss.SoftmaxCELoss() + + # use optimizer + print(model.collect_params()) + trainer = gluon.Trainer(model.collect_params(), 'adam', + {'learning_rate': args.lr, 'wd': args.weight_decay}) + + # initialize graph + dur = [] + for epoch in range(args.n_epochs): + if epoch >= 3: + t0 = time.time() + # forward + with mx.autograd.record(): + pred = model(features) + loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1)) + loss = loss.sum() / n_train_samples + + loss.backward() + trainer.step(batch_size=1) + + if epoch >= 3: + loss.asscalar() + dur.append(time.time() - t0) + acc = evaluate(model, features, labels, val_mask) + print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " + "ETputs(KTEPS) {:.2f}". format( + epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000)) + + # test set accuracy + acc = evaluate(model, features, labels, test_mask) + print("Test accuracy {:.2%}".format(acc)) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='APPNP') + register_data_args(parser) + parser.add_argument("--in-drop", type=float, default=0.5, + help="input feature dropout") + parser.add_argument("--edge-drop", type=float, default=0.5, + help="edge propagation dropout") + parser.add_argument("--gpu", type=int, default=-1, + help="gpu") + parser.add_argument("--lr", type=float, default=1e-2, + help="learning rate") + parser.add_argument("--n-epochs", type=int, default=200, + help="number of training epochs") + parser.add_argument("--hidden_sizes", type=int, nargs='+', default=[64], + help="hidden unit sizes for appnp") + parser.add_argument("--k", type=int, default=10, + help="Number of propagation steps") + parser.add_argument("--alpha", type=float, default=0.1, + help="Teleport Probability") + parser.add_argument("--weight-decay", type=float, default=5e-4, + help="Weight for L2 loss") + args = parser.parse_args() + print(args) + + main(args) \ No newline at end of file diff --git a/examples/mxnet/gat/gat.py b/examples/mxnet/gat/gat.py index 6f41bd244eb4..0489018e0a14 100644 --- a/examples/mxnet/gat/gat.py +++ b/examples/mxnet/gat/gat.py @@ -7,87 +7,8 @@ Pytorch implementation: https://github.com/Diego999/pyGAT """ -import mxnet as mx -from mxnet import gluon -import mxnet.ndarray as nd import mxnet.gluon.nn as nn -import dgl.function as fn -from dgl.nn.mxnet import edge_softmax - - -class GraphAttention(gluon.Block): - def __init__(self, - g, - in_dim, - out_dim, - num_heads, - feat_drop, - attn_drop, - alpha, - residual=False): - super(GraphAttention, self).__init__() - self.g = g - self.num_heads = num_heads - self.fc = nn.Dense(num_heads * out_dim, use_bias=False, - weight_initializer=mx.init.Xavier()) - if feat_drop: - self.feat_drop = nn.Dropout(feat_drop) - else: - self.feat_drop = lambda x : x - if attn_drop: - self.attn_drop = nn.Dropout(attn_drop) - else: - self.attn_drop = lambda x : x - self.attn_l = self.params.get("left_att", grad_req="add", - shape=(1, num_heads, out_dim), - init=mx.init.Xavier()) - self.attn_r = self.params.get("right_att", grad_req="add", - shape=(1, num_heads, out_dim), - init=mx.init.Xavier()) - self.alpha = alpha - self.softmax = edge_softmax - self.residual = residual - if residual: - if in_dim != out_dim: - self.res_fc = nn.Dense(num_heads * out_dim, use_bias=False, - weight_initializer=mx.init.Xavier()) - else: - self.res_fc = None - - def forward(self, inputs): - # prepare - h = self.feat_drop(inputs) # NxD - ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' - a1 = (ft * self.attn_l.data(ft.context)).sum(axis=-1).expand_dims(-1) # N x H x 1 - a2 = (ft * self.attn_r.data(ft.context)).sum(axis=-1).expand_dims(-1) # N x H x 1 - self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2}) - # 1. compute edge attention - self.g.apply_edges(self.edge_attention) - # 2. compute softmax - self.edge_softmax() - # 3. compute the aggregated node features - self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), - fn.sum('ft', 'ft')) - ret = self.g.ndata['ft'] - # 4. residual - if self.residual: - if self.res_fc is not None: - resval = self.res_fc(h).reshape( - (h.shape[0], self.num_heads, -1)) # NxHxD' - else: - resval = nd.expand_dims(h, axis=1) # Nx1xD' - ret = resval + ret - return ret - - def edge_attention(self, edges): - # an edge UDF to compute unnormalized attention values from src and dst - a = nd.LeakyReLU(edges.src['a1'] + edges.dst['a2'], slope=self.alpha) - return {'a' : a} - - def edge_softmax(self): - attention = self.softmax(self.g, self.g.edata.pop('a')) - # Dropout attention scores and save them - self.g.edata['a_drop'] = self.attn_drop(attention) +from dgl.nn.mxnet.conv import GATConv class GAT(nn.Block): @@ -109,18 +30,18 @@ def __init__(self, self.gat_layers = [] self.activation = activation # input projection (no residual) - self.gat_layers.append(GraphAttention( - g, in_dim, num_hidden, heads[0], + self.gat_layers.append(GATConv( + in_dim, num_hidden, heads[0], feat_drop, attn_drop, alpha, False)) # hidden layers for l in range(1, num_layers): # due to multi-head, the in_dim = num_hidden * num_heads - self.gat_layers.append(GraphAttention( - g, num_hidden * heads[l-1], num_hidden, heads[l], + self.gat_layers.append(GATConv( + num_hidden * heads[l-1], num_hidden, heads[l], feat_drop, attn_drop, alpha, residual)) # output projection - self.gat_layers.append(GraphAttention( - g, num_hidden * heads[-2], num_classes, heads[-1], + self.gat_layers.append(GATConv( + num_hidden * heads[-2], num_classes, heads[-1], feat_drop, attn_drop, alpha, residual)) for i, layer in enumerate(self.gat_layers): self.register_child(layer, "gat_layer_{}".format(i)) @@ -128,8 +49,8 @@ def __init__(self, def forward(self, inputs): h = inputs for l in range(self.num_layers): - h = self.gat_layers[l](h).flatten() + h = self.gat_layers[l](self.g, h).flatten() h = self.activation(h) # output projection - logits = self.gat_layers[-1](h).mean(1) + logits = self.gat_layers[-1](self.g, h).mean(1) return logits diff --git a/examples/mxnet/gcn/train.py b/examples/mxnet/gcn/train.py index 037514b48e94..6dbe37242aa4 100644 --- a/examples/mxnet/gcn/train.py +++ b/examples/mxnet/gcn/train.py @@ -25,6 +25,7 @@ def main(args): train_mask = mx.nd.array(data.train_mask) val_mask = mx.nd.array(data.val_mask) test_mask = mx.nd.array(data.test_mask) + in_feats = features.shape[1] n_classes = data.num_labels n_edges = data.graph.number_of_edges() diff --git a/examples/mxnet/gin/README.md b/examples/mxnet/gin/README.md new file mode 100644 index 000000000000..1e0a9192736c --- /dev/null +++ b/examples/mxnet/gin/README.md @@ -0,0 +1,41 @@ +Graph Isomorphism Network (GIN) +============ + +- Paper link: [arXiv](https://arxiv.org/abs/1810.00826) [OpenReview](https://openreview.net/forum?id=ryGs6iA5Km) +- Author's code repo: [https://github.com/weihua916/powerful-gnns](https://github.com/weihua916/powerful-gnns). + +Dependencies +------------ +- MXNet 1.5+ +- sklearn +- tqdm + +``bash +pip install torch sklearn tqdm +`` + +How to run +---------- + +An experiment on the GIN in default settings can be run with + +```bash +DGLBACKEND=mxnet python main.py +``` + +An experiment on the GIN in customized settings can be run with +```bash +DGLBACKEND=mxnet python main.py [--device 0 | --disable-cuda] --dataset COLLAB \ + --graph_pooling_type max --neighbor_pooling_type sum +``` + +Results +------- + +Run with following with the double SUM pooling way: +(tested dataset: "MUTAG"(default), "COLLAB", "IMDBBINARY", "IMDBMULTI") +```bash +DGLBACKEND=mxnet python main.py --dataset MUTAG --device 0 \ + --graph_pooling_type sum --neighbor_pooling_type sum +``` + diff --git a/examples/mxnet/gin/dataloader.py b/examples/mxnet/gin/dataloader.py new file mode 100644 index 000000000000..74574bc8228e --- /dev/null +++ b/examples/mxnet/gin/dataloader.py @@ -0,0 +1,103 @@ +""" +MxNet compatible dataloader +""" + +from mxnet.gluon.data import DataLoader, Sampler + +import math +import numpy as np +from mxnet import nd +from sklearn.model_selection import StratifiedKFold +import dgl + +class SubsetRandomSampler(Sampler): + def __init__(self, indices): + self.indices = indices + + def __iter__(self): + return iter([self.indices[i] for i in np.random.permutation(len(self.indices))]) + + def __len__(self): + return len(self.indices) + +# default collate function +def collate(samples): + # The input `samples` is a list of pairs (graph, label). + graphs, labels = map(list, zip(*samples)) + for g in graphs: + # deal with node feats + for key in g.node_attr_schemes().keys(): + g.ndata[key] = nd.array(g.ndata[key]) + # no edge feats + batched_graph = dgl.batch(graphs) + labels = nd.array(labels) + return batched_graph, labels + +class GraphDataLoader(): + def __init__(self, + dataset, + batch_size, + collate_fn=collate, + seed=0, + shuffle=True, + split_name='fold10', + fold_idx=0, + split_ratio=0.7): + + self.shuffle = shuffle + self.seed = seed + + labels = [l for _, l in dataset] + + if split_name == 'fold10': + train_idx, valid_idx = self._split_fold10( + labels, fold_idx, seed, shuffle) + elif split_name == 'rand': + train_idx, valid_idx = self._split_rand( + labels, split_ratio, seed, shuffle) + else: + raise NotImplementedError() + + train_sampler = SubsetRandomSampler(train_idx) + valid_sampler = SubsetRandomSampler(valid_idx) + + self.train_loader = DataLoader( + dataset, sampler=train_sampler, + batch_size=batch_size, batchify_fn=collate_fn) + self.valid_loader = DataLoader( + dataset, sampler=valid_sampler, + batch_size=batch_size, batchify_fn=collate_fn) + + def train_valid_loader(self): + return self.train_loader, self.valid_loader + + def _split_fold10(self, labels, fold_idx=0, seed=0, shuffle=True): + ''' 10 flod ''' + assert 0 <= fold_idx and fold_idx < 10, print( + "fold_idx must be from 0 to 9.") + + skf = StratifiedKFold(n_splits=10, shuffle=shuffle, random_state=seed) + idx_list = [] + for idx in skf.split(np.zeros(len(labels)), labels): # split(x, y) + idx_list.append(idx) + train_idx, valid_idx = idx_list[fold_idx] + + print( + "train_set : test_set = %d : %d", + len(train_idx), len(valid_idx)) + + return train_idx, valid_idx + + def _split_rand(self, labels, split_ratio=0.7, seed=0, shuffle=True): + num_entries = len(labels) + indices = list(range(num_entries)) + np.random.seed(seed) + np.random.shuffle(indices) + split = int(math.floor(split_ratio * num_entries)) + train_idx, valid_idx = indices[:split], indices[split:] + + print( + "train_set : test_set = %d : %d", + len(train_idx), len(valid_idx)) + + return train_idx, valid_idx diff --git a/examples/mxnet/gin/gin.py b/examples/mxnet/gin/gin.py new file mode 100644 index 000000000000..b45ec84e60eb --- /dev/null +++ b/examples/mxnet/gin/gin.py @@ -0,0 +1,163 @@ +""" +How Powerful are Graph Neural Networks +https://arxiv.org/abs/1810.00826 +https://openreview.net/forum?id=ryGs6iA5Km +Author's implementation: https://github.com/weihua916/powerful-gnns +""" + +import mxnet as mx +from mxnet import nd, gluon +from mxnet.gluon import nn +from dgl.nn.mxnet.conv import GINConv +from dgl.nn.mxnet.glob import SumPooling, AvgPooling, MaxPooling + + +class ApplyNodeFunc(nn.Block): + """Update the node feature hv with MLP, BN and ReLU.""" + def __init__(self, mlp): + super(ApplyNodeFunc, self).__init__() + with self.name_scope(): + self.mlp = mlp + self.bn = nn.BatchNorm(in_channels=self.mlp.output_dim) + + def forward(self, h): + h = self.mlp(h) + h = self.bn(h) + h = nd.relu(h) + return h + + +class MLP(nn.Block): + """MLP with linear output""" + def __init__(self, num_layers, input_dim, hidden_dim, output_dim): + """MLP layers construction + + Paramters + --------- + num_layers: int + The number of linear layers + input_dim: int + The dimensionality of input features + hidden_dim: int + The dimensionality of hidden units at ALL layers + output_dim: int + The number of classes for prediction + """ + super(MLP, self).__init__() + self.linear_or_not = True + self.num_layers = num_layers + self.output_dim = output_dim + + with self.name_scope(): + if num_layers < 1: + raise ValueError("number of layers should be positive!") + elif num_layers == 1: + # Linear model + self.linear = nn.Dense(output_dim, in_units=input_dim) + else: + self.linear_or_not = False + self.linears = nn.Sequential() + self.batch_norms = nn.Sequential() + + self.linears.add(nn.Dense(hidden_dim, in_units=input_dim)) + for layer in range(num_layers - 2): + self.linears.add(nn.Dense(hidden_dim, in_units=hidden_dim)) + self.linears.add(nn.Dense(output_dim, in_units=hidden_dim)) + + for layer in range(num_layers - 1): + self.batch_norms.add(nn.BatchNorm(in_channels=hidden_dim)) + + def forward(self, x): + if self.linear_or_not: + return self.linear(x) + else: + h = x + for i in range(self.num_layers - 1): + h = nd.relu(self.batch_norms[i](self.linears[i](h))) + return self.linears[-1](h) + + +class GIN(nn.Block): + """GIN model""" + def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, + output_dim, final_dropout, learn_eps, graph_pooling_type, + neighbor_pooling_type): + """model parameters setting + + Paramters + --------- + num_layers: int + The number of linear layers in the neural network + num_mlp_layers: int + The number of linear layers in mlps + input_dim: int + The dimensionality of input features + hidden_dim: int + The dimensionality of hidden units at ALL layers + output_dim: int + The number of classes for prediction + final_dropout: float + dropout ratio on the final linear layer + learn_eps: boolean + If True, learn epsilon to distinguish center nodes from neighbors + If False, aggregate neighbors and center nodes altogether. + neighbor_pooling_type: str + how to aggregate neighbors (sum, mean, or max) + graph_pooling_type: str + how to aggregate entire nodes in a graph (sum, mean or max) + + """ + super(GIN, self).__init__() + self.num_layers = num_layers + self.learn_eps = learn_eps + + with self.name_scope(): + # List of MLPs + self.ginlayers = nn.Sequential() + self.batch_norms = nn.Sequential() + + for i in range(self.num_layers - 1): + if i == 0: + mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim) + else: + mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim) + + self.ginlayers.add( + GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps)) + self.batch_norms.add(nn.BatchNorm(in_channels=hidden_dim)) + + self.linears_prediction = nn.Sequential() + + for i in range(num_layers): + if i == 0: + self.linears_prediction.add(nn.Dense(output_dim, in_units=input_dim)) + else: + self.linears_prediction.add(nn.Dense(output_dim, in_units=hidden_dim)) + + self.drop = nn.Dropout(final_dropout) + + if graph_pooling_type == 'sum': + self.pool = SumPooling() + elif graph_pooling_type == 'mean': + self.pool = AvgPooling() + elif graph_pooling_type == 'max': + self.pool = MaxPooling() + else: + raise NotImplementedError + + def forward(self, g, h): + hidden_rep = [h] + + for i in range(self.num_layers - 1): + h = self.ginlayers[i](g, h) + h = self.batch_norms[i](h) + h = nd.relu(h) + hidden_rep.append(h) + + score_over_layer = 0 + # perform pooling over all nodes in each graph in every layer + for i, h in enumerate(hidden_rep): + pooled_h = self.pool(g, h) + score_over_layer = score_over_layer + self.drop(self.linears_prediction[i](pooled_h)) + + return score_over_layer diff --git a/examples/mxnet/gin/main.py b/examples/mxnet/gin/main.py new file mode 100644 index 000000000000..52f055b73ae9 --- /dev/null +++ b/examples/mxnet/gin/main.py @@ -0,0 +1,158 @@ +import sys +import numpy as np +from tqdm import tqdm + +import mxnet as mx +from mxnet import gluon, nd +from mxnet.gluon import nn + +from dgl.data.gindt import GINDataset +from dataloader import GraphDataLoader, collate +from parser import Parser +from gin import GIN + + +def train(args, net, trainloader, trainer, criterion, epoch): + running_loss = 0 + total_iters = len(trainloader) + # setup the offset to avoid the overlap with mouse cursor + bar = tqdm(range(total_iters), unit='batch', position=2, file=sys.stdout) + + for pos, (graphs, labels) in zip(bar, trainloader): + # batch graphs will be shipped to device in forward part of model + labels = labels.as_in_context(args.device) + feat = graphs.ndata['attr'].astype('float32').as_in_context(args.device) + + with mx.autograd.record(): + outputs = net(graphs, feat) + loss = criterion(outputs, labels) + loss = loss.sum() / len(labels) + + running_loss += loss.asscalar() + + # backprop + loss.backward() + trainer.step(batch_size=1) + + # report + bar.set_description('epoch-{}'.format(epoch)) + bar.close() + # the final batch will be aligned + running_loss = running_loss / total_iters + + return running_loss + + +def eval_net(args, net, dataloader, criterion): + total = 0 + total_loss = 0 + total_correct = 0 + + for data in dataloader: + graphs, labels = data + labels = labels.as_in_context(args.device) + feat = graphs.ndata['attr'].astype('float32').as_in_context(args.device) + + total += len(labels) + + outputs = net(graphs, feat) + predicted = nd.argmax(outputs, axis=1) + + total_correct += (predicted == labels).sum().asscalar() + loss = criterion(outputs, labels) + # crossentropy(reduce=True) for default + total_loss += loss.sum().asscalar() + + loss, acc = 1.0 * total_loss / total, 1.0*total_correct / total + + return loss, acc + + +def main(args): + + # set up seeds, args.seed supported + mx.random.seed(0) + np.random.seed(seed=0) + + + + if args.device >= 0: + args.device = mx.gpu(args.device) + else: + args.device = mx.cpu() + + dataset = GINDataset(args.dataset, not args.learn_eps) + + trainloader, validloader = GraphDataLoader( + dataset, batch_size=args.batch_size, + collate_fn=collate, seed=args.seed, shuffle=True, + split_name='fold10', fold_idx=args.fold_idx).train_valid_loader() + # or split_name='rand', split_ratio=0.7 + + model = GIN( + args.num_layers, args.num_mlp_layers, + dataset.dim_nfeats, args.hidden_dim, dataset.gclasses, + args.final_dropout, args.learn_eps, + args.graph_pooling_type, args.neighbor_pooling_type) + model.initialize(ctx=args.device) + + criterion = gluon.loss.SoftmaxCELoss() + + print(model.collect_params()) + lr_scheduler = mx.lr_scheduler.FactorScheduler(50, 0.5) + trainer = gluon.Trainer(model.collect_params(), 'adam', + {'lr_scheduler': lr_scheduler}) + + # it's not cost-effective to hanle the cursor and init 0 + # https://stackoverflow.com/a/23121189 + tbar = tqdm(range(args.epochs), unit="epoch", position=3, ncols=0, file=sys.stdout) + vbar = tqdm(range(args.epochs), unit="epoch", position=4, ncols=0, file=sys.stdout) + lrbar = tqdm(range(args.epochs), unit="epoch", position=5, ncols=0, file=sys.stdout) + + for epoch, _, _ in zip(tbar, vbar, lrbar): + train(args, model, trainloader, trainer, criterion, epoch) + + train_loss, train_acc = eval_net( + args, model, trainloader, criterion) + tbar.set_description( + 'train set - average loss: {:.4f}, accuracy: {:.0f}%' + .format(train_loss, 100. * train_acc)) + + valid_loss, valid_acc = eval_net( + args, model, validloader, criterion) + vbar.set_description( + 'valid set - average loss: {:.4f}, accuracy: {:.0f}%' + .format(valid_loss, 100. * valid_acc)) + + if not args.filename == "": + with open(args.filename, 'a') as f: + f.write('%s %s %s %s' % ( + args.dataset, + args.learn_eps, + args.neighbor_pooling_type, + args.graph_pooling_type + )) + f.write("\n") + f.write("%f %f %f %f" % ( + train_loss, + train_acc, + valid_loss, + valid_acc + )) + f.write("\n") + + lrbar.set_description( + "Learning eps with learn_eps={}: {}".format( + args.learn_eps, [layer.eps.data(args.device).asscalar() for layer in model.ginlayers])) + + tbar.close() + vbar.close() + lrbar.close() + + +if __name__ == '__main__': + args = Parser(description='GIN').args + print('show all arguments configuration...') + print(args) + + main(args) \ No newline at end of file diff --git a/examples/mxnet/gin/parser.py b/examples/mxnet/gin/parser.py new file mode 100644 index 000000000000..007ecbae9ae9 --- /dev/null +++ b/examples/mxnet/gin/parser.py @@ -0,0 +1,86 @@ +"""Parser for arguments + +Put all arguments in one file and group similar arguments +""" +import argparse + + +class Parser(): + + def __init__(self, description): + ''' + arguments parser + ''' + self.parser = argparse.ArgumentParser(description=description) + self.args = None + self._parse() + + def _parse(self): + # dataset + self.parser.add_argument( + '--dataset', type=str, default="MUTAG", + help='name of dataset (default: MUTAG)') + self.parser.add_argument( + '--batch_size', type=int, default=32, + help='batch size for training and validation (default: 32)') + self.parser.add_argument( + '--fold_idx', type=int, default=0, + help='the index(<10) of fold in 10-fold validation.') + self.parser.add_argument( + '--filename', type=str, default="", + help='output file') + + # device + self.parser.add_argument( + '--disable-cuda', action='store_true', + help='Disable CUDA') + self.parser.add_argument( + '--device', type=int, default=0, + help='which gpu device to use (default: 0)') + + # net + self.parser.add_argument( + '--net', type=str, default="gin", + help='gnn net (default: gin)') + self.parser.add_argument( + '--num_layers', type=int, default=5, + help='number of layers (default: 5)') + self.parser.add_argument( + '--num_mlp_layers', type=int, default=2, + help='number of MLP layers(default: 2). 1 means linear model.') + self.parser.add_argument( + '--hidden_dim', type=int, default=64, + help='number of hidden units (default: 64)') + + # graph + self.parser.add_argument( + '--graph_pooling_type', type=str, + default="sum", choices=["sum", "mean", "max"], + help='type of graph pooling: sum, mean or max') + self.parser.add_argument( + '--neighbor_pooling_type', type=str, + default="sum", choices=["sum", "mean", "max"], + help='type of neighboring pooling: sum, mean or max') + self.parser.add_argument( + '--learn_eps', action="store_true", + help='learn the epsilon weighting') + self.parser.add_argument( + '--degree_as_tag', action="store_true", + help='take the degree of nodes as input feature') + + # learning + self.parser.add_argument( + '--seed', type=int, default=0, + help='random seed (default: 0)') + self.parser.add_argument( + '--epochs', type=int, default=350, + help='number of epochs to train (default: 350)') + self.parser.add_argument( + '--lr', type=float, default=0.01, + help='learning rate (default: 0.01)') + self.parser.add_argument( + '--final_dropout', type=float, default=0.5, + help='final layer dropout (default: 0.5)') + + # done + self.args = self.parser.parse_args() \ No newline at end of file diff --git a/examples/mxnet/graphsage/README.md b/examples/mxnet/graphsage/README.md new file mode 100644 index 000000000000..856b5784edb1 --- /dev/null +++ b/examples/mxnet/graphsage/README.md @@ -0,0 +1,27 @@ +Inductive Representation Learning on Large Graphs (GraphSAGE) +============ + +- Paper link: [http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf](http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf) +- Author's code repo: [https://github.com/williamleif/graphsage-simple](https://github.com/williamleif/graphsage-simple). Note that the original code is +simple reference implementation of GraphSAGE. + +Requirements +------------ +- requests + +``bash +pip install requests +`` + + +Results +------- + +Run with following (available dataset: "cora", "citeseer", "pubmed") +```bash +python3 main.py --dataset cora --gpu 0 +``` + +* cora: ~0.817 +* citeseer: ~0.699 +* pubmed: ~0.790 \ No newline at end of file diff --git a/examples/mxnet/graphsage/main.py b/examples/mxnet/graphsage/main.py new file mode 100644 index 000000000000..a04db488e058 --- /dev/null +++ b/examples/mxnet/graphsage/main.py @@ -0,0 +1,163 @@ +""" +Inductive Representation Learning on Large Graphs +Paper: http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf +Code: https://github.com/williamleif/graphsage-simple +Simple reference implementation of GraphSAGE. +""" +import argparse +import time +import numpy as np +import networkx as nx +import mxnet as mx +from mxnet import nd, gluon +from mxnet.gluon import nn +from dgl import DGLGraph +from dgl.data import register_data_args, load_data +from dgl.nn.mxnet.conv import SAGEConv + + +class GraphSAGE(nn.Block): + def __init__(self, + g, + in_feats, + n_hidden, + n_classes, + n_layers, + activation, + dropout, + aggregator_type): + super(GraphSAGE, self).__init__() + self.g = g + + with self.name_scope(): + self.layers = nn.Sequential() + # input layer + self.layers.add(SAGEConv(in_feats, n_hidden, aggregator_type, feat_drop=dropout, activation=activation)) + # hidden layers + for i in range(n_layers - 1): + self.layers.add(SAGEConv(n_hidden, n_hidden, aggregator_type, feat_drop=dropout, activation=activation)) + # output layer + self.layers.add(SAGEConv(n_hidden, n_classes, aggregator_type, feat_drop=dropout, activation=None)) # activation None + + def forward(self, features): + h = features + for layer in self.layers: + h = layer(self.g, h) + return h + +def evaluate(model, features, labels, mask): + pred = model(features).argmax(axis=1) + accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar() + return accuracy.asscalar() + +def main(args): + # load and preprocess dataset + data = load_data(args) + features = nd.array(data.features) + labels = nd.array(data.labels) + train_mask = nd.array(data.train_mask) + val_mask = nd.array(data.val_mask) + test_mask = nd.array(data.test_mask) + + in_feats = features.shape[1] + n_classes = data.num_labels + n_edges = data.graph.number_of_edges() + print("""----Data statistics------' + #Edges %d + #Classes %d + #Train samples %d + #Val samples %d + #Test samples %d""" % + (n_edges, n_classes, + train_mask.sum().asscalar(), + val_mask.sum().asscalar(), + test_mask.sum().asscalar())) + + if args.gpu < 0: + ctx = mx.cpu(0) + else: + ctx = mx.gpu(args.gpu) + print("use cuda:", args.gpu) + + features = features.as_in_context(ctx) + labels = labels.as_in_context(ctx) + train_mask = train_mask.as_in_context(ctx) + val_mask = val_mask.as_in_context(ctx) + test_mask = test_mask.as_in_context(ctx) + + # graph preprocess and calculate normalization factor + g = data.graph + g.remove_edges_from(nx.selfloop_edges(g)) + g = DGLGraph(g) + n_edges = g.number_of_edges() + + # create GraphSAGE model + model = GraphSAGE(g, + in_feats, + args.n_hidden, + n_classes, + args.n_layers, + nd.relu, + args.dropout, + args.aggregator_type + ) + + model.initialize(ctx=ctx) + n_train_samples = train_mask.sum().asscalar() + loss_fcn = gluon.loss.SoftmaxCELoss() + + print(model.collect_params()) + trainer = gluon.Trainer(model.collect_params(), 'adam', + {'learning_rate': args.lr, 'wd': args.weight_decay}) + + + # initialize graph + dur = [] + for epoch in range(args.n_epochs): + if epoch >= 3: + t0 = time.time() + # forward + with mx.autograd.record(): + pred = model(features) + loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1)) + loss = loss.sum() / n_train_samples + + loss.backward() + trainer.step(batch_size=1) + + if epoch >= 3: + loss.asscalar() + dur.append(time.time() - t0) + acc = evaluate(model, features, labels, val_mask) + print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " + "ETputs(KTEPS) {:.2f}". format( + epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000)) + + # test set accuracy + acc = evaluate(model, features, labels, test_mask) + print("Test accuracy {:.2%}".format(acc)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='GraphSAGE') + register_data_args(parser) + parser.add_argument("--dropout", type=float, default=0.5, + help="dropout probability") + parser.add_argument("--gpu", type=int, default=-1, + help="gpu") + parser.add_argument("--lr", type=float, default=1e-2, + help="learning rate") + parser.add_argument("--n-epochs", type=int, default=200, + help="number of training epochs") + parser.add_argument("--n-hidden", type=int, default=16, + help="number of hidden gcn units") + parser.add_argument("--n-layers", type=int, default=1, + help="number of hidden gcn layers") + parser.add_argument("--weight-decay", type=float, default=5e-4, + help="Weight for L2 loss") + parser.add_argument("--aggregator-type", type=str, default="gcn", + help="Aggregator type: mean/gcn/pool/lstm") + args = parser.parse_args() + print(args) + + main(args) \ No newline at end of file diff --git a/examples/mxnet/monet/README.md b/examples/mxnet/monet/README.md new file mode 100644 index 000000000000..54eb71e40c1c --- /dev/null +++ b/examples/mxnet/monet/README.md @@ -0,0 +1,16 @@ +MoNet +===== + +- paper link: [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/pdf/1611.08402.pdf) + +Dependencies +============ + +- MXNet 1.5+ + +Results +======= + +Node classification on citation networks: +- Cora: ~0.814 +- Pubmed: ~0.748 diff --git a/examples/mxnet/monet/citation.py b/examples/mxnet/monet/citation.py new file mode 100644 index 000000000000..cd4fdaf5dc55 --- /dev/null +++ b/examples/mxnet/monet/citation.py @@ -0,0 +1,173 @@ +import argparse +import time +import numpy as np +import networkx as nx +import mxnet as mx +from mxnet import gluon, nd +from mxnet.gluon import nn +from dgl import DGLGraph +from dgl.data import register_data_args, load_data +from dgl.nn.mxnet.conv import GMMConv + + +class MoNet(nn.Block): + def __init__(self, + g, + in_feats, + n_hidden, + out_feats, + n_layers, + dim, + n_kernels): + super(MoNet, self).__init__() + self.g = g + with self.name_scope(): + self.layers = nn.Sequential() + self.pseudo_proj = nn.Sequential() + + # Input layer + self.layers.add( + GMMConv(in_feats, n_hidden, dim, n_kernels)) + self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh')) + + # Hidden layer + for _ in range(n_layers - 1): + self.layers.add(GMMConv(n_hidden, n_hidden, dim, n_kernels)) + self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh')) + + # Output layer + self.layers.add(GMMConv(n_hidden, out_feats, dim, n_kernels)) + self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation='tanh')) + + def forward(self, feat, pseudo): + h = feat + for i in range(len(self.layers)): + h = self.layers[i]( + self.g, h, self.pseudo_proj[i](pseudo)) + return h + + +def evaluate(model, features, pseudo, labels, mask): + pred = model(features, pseudo).argmax(axis=1) + accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar() + return accuracy.asscalar() + +def main(args): + # load and preprocess dataset + data = load_data(args) + features = nd.array(data.features) + labels = nd.array(data.labels) + train_mask = nd.array(data.train_mask) + val_mask = nd.array(data.val_mask) + test_mask = nd.array(data.test_mask) + + in_feats = features.shape[1] + n_classes = data.num_labels + n_edges = data.graph.number_of_edges() + print("""----Data statistics------' + #Edges %d + #Classes %d + #Train samples %d + #Val samples %d + #Test samples %d""" % + (n_edges, n_classes, + train_mask.sum().asscalar(), + val_mask.sum().asscalar(), + test_mask.sum().asscalar())) + + if args.gpu < 0: + ctx = mx.cpu(0) + else: + ctx = mx.gpu(args.gpu) + print("use cuda:", args.gpu) + + features = features.as_in_context(ctx) + labels = labels.as_in_context(ctx) + train_mask = train_mask.as_in_context(ctx) + val_mask = val_mask.as_in_context(ctx) + test_mask = test_mask.as_in_context(ctx) + + # graph preprocess and calculate normalization factor + g = data.graph + g.remove_edges_from(nx.selfloop_edges(g)) + g = DGLGraph(g) + n_edges = g.number_of_edges() + us, vs = g.edges() + pseudo = [] + for i in range(g.number_of_edges()): + pseudo.append([ + 1 / np.sqrt(g.in_degree(us[i].asscalar())), + 1 / np.sqrt(g.in_degree(vs[i].asscalar())) + ]) + pseudo = nd.array(pseudo, ctx=ctx) + + # create GraphSAGE model + model = MoNet(g, + in_feats, + args.n_hidden, + n_classes, + args.n_layers, + args.pseudo_dim, + args.n_kernels, + ) + model.initialize(ctx=ctx) + n_train_samples = train_mask.sum().asscalar() + loss_fcn = gluon.loss.SoftmaxCELoss() + + print(model.collect_params()) + trainer = gluon.Trainer(model.collect_params(), 'adam', + {'learning_rate': args.lr, 'wd': args.weight_decay}) + + + # initialize graph + dur = [] + for epoch in range(args.n_epochs): + if epoch >= 3: + t0 = time.time() + # forward + with mx.autograd.record(): + pred = model(features, pseudo) + loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1)) + loss = loss.sum() / n_train_samples + + loss.backward() + trainer.step(batch_size=1) + + if epoch >= 3: + loss.asscalar() + dur.append(time.time() - t0) + acc = evaluate(model, features, pseudo, labels, val_mask) + print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " + "ETputs(KTEPS) {:.2f}". format( + epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000)) + + # test set accuracy + acc = evaluate(model, features, pseudo, labels, test_mask) + print("Test accuracy {:.2%}".format(acc)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='MoNet on citation network') + register_data_args(parser) + parser.add_argument("--dropout", type=float, default=0.5, + help="dropout probability") + parser.add_argument("--gpu", type=int, default=-1, + help="gpu") + parser.add_argument("--lr", type=float, default=1e-2, + help="learning rate") + parser.add_argument("--n-epochs", type=int, default=200, + help="number of training epochs") + parser.add_argument("--n-hidden", type=int, default=16, + help="number of hidden gcn units") + parser.add_argument("--n-layers", type=int, default=1, + help="number of hidden gcn layers") + parser.add_argument("--pseudo-dim", type=int, default=2, + help="Pseudo coordinate dimensions in GMMConv, 2 for cora and 3 for pubmed") + parser.add_argument("--n-kernels", type=int, default=3, + help="Number of kernels in GMMConv layer") + parser.add_argument("--weight-decay", type=float, default=5e-5, + help="Weight for L2 loss") + args = parser.parse_args() + print(args) + + main(args) \ No newline at end of file diff --git a/examples/mxnet/sgc/README.md b/examples/mxnet/sgc/README.md new file mode 100644 index 000000000000..d1e060a364b4 --- /dev/null +++ b/examples/mxnet/sgc/README.md @@ -0,0 +1,34 @@ +Simple Graph Convolution (SGC) +============ + +- Paper link: [Simplifying Graph Convolutional Networks](https://arxiv.org/abs/1902.07153) +- Author's code repo: [https://github.com/Tiiiger/SGC](https://github.com/Tiiiger/SGC). + +Dependencies +------------ +- MXNET 1.5+ +- requests + +``bash +pip install torch requests +`` + +Codes +----- +The folder contains an implementation of SGC (`sgc.py`). + +Results +------- + +Run with following (available dataset: "cora", "citeseer", "pubmed") +```bash +DGLBACKEND=mxnet python3 sgc.py --dataset cora --gpu 0 +DGLBACKEND=mxnet python3 sgc.py --dataset citeseer --weight-decay 5e-5 --n-epochs 150 --bias --gpu 0 +DGLBACKEND=mxnet python3 sgc.py --dataset pubmed --weight-decay 5e-5 --bias --gpu 0 +``` + +On NVIDIA V100 + +* cora: 0.818 (paper: 0.810) +* citeseer: 0.725 (paper: 0.719) +* pubmed: 0.788 (paper: 0.789) diff --git a/examples/mxnet/sgc/sgc.py b/examples/mxnet/sgc/sgc.py new file mode 100644 index 000000000000..fae6595ea99b --- /dev/null +++ b/examples/mxnet/sgc/sgc.py @@ -0,0 +1,122 @@ +""" +This code was modified from the GCN implementation in DGL examples. +Simplifying Graph Convolutional Networks +Paper: https://arxiv.org/abs/1902.07153 +Code: https://github.com/Tiiiger/SGC +SGC implementation in DGL. +""" +import argparse, time, math +import numpy as np +import mxnet as mx +from mxnet import nd, gluon +from mxnet.gluon import nn +from dgl import DGLGraph +from dgl.data import register_data_args, load_data +from dgl.nn.mxnet.conv import SGConv + + +def evaluate(model, g, features, labels, mask): + pred = model(g, features).argmax(axis=1) + accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar() + return accuracy.asscalar() + +def main(args): + # load and preprocess dataset + data = load_data(args) + features = nd.array(data.features) + labels = nd.array(data.labels) + train_mask = nd.array(data.train_mask) + val_mask = nd.array(data.val_mask) + test_mask = nd.array(data.test_mask) + + in_feats = features.shape[1] + n_classes = data.num_labels + n_edges = data.graph.number_of_edges() + print("""----Data statistics------' + #Edges %d + #Classes %d + #Train samples %d + #Val samples %d + #Test samples %d""" % + (n_edges, n_classes, + train_mask.sum().asscalar(), + val_mask.sum().asscalar(), + test_mask.sum().asscalar())) + + if args.gpu < 0: + ctx = mx.cpu(0) + else: + ctx = mx.gpu(args.gpu) + + features = features.as_in_context(ctx) + labels = labels.as_in_context(ctx) + train_mask = train_mask.as_in_context(ctx) + val_mask = val_mask.as_in_context(ctx) + test_mask = test_mask.as_in_context(ctx) + + # graph preprocess and calculate normalization factor + g = DGLGraph(data.graph) + n_edges = g.number_of_edges() + # add self loop + g.add_edges(g.nodes(), g.nodes()) + + # create SGC model + model = SGConv(in_feats, + n_classes, + k=2, + cached=True, + bias=args.bias) + + model.initialize(ctx=ctx) + n_train_samples = train_mask.sum().asscalar() + loss_fcn = gluon.loss.SoftmaxCELoss() + + # use optimizer + print(model.collect_params()) + trainer = gluon.Trainer(model.collect_params(), 'adam', + {'learning_rate': args.lr, 'wd': args.weight_decay}) + + # initialize graph + dur = [] + for epoch in range(args.n_epochs): + if epoch >= 3: + t0 = time.time() + # forward + with mx.autograd.record(): + pred = model(g, features) + loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1)) + loss = loss.sum() / n_train_samples + + loss.backward() + trainer.step(batch_size=1) + + if epoch >= 3: + loss.asscalar() + dur.append(time.time() - t0) + acc = evaluate(model, g, features, labels, val_mask) + print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " + "ETputs(KTEPS) {:.2f}". format( + epoch, np.mean(dur), loss.asscalar(), acc, n_edges / np.mean(dur) / 1000)) + + # test set accuracy + acc = evaluate(model, g, features, labels, test_mask) + print("Test accuracy {:.2%}".format(acc)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='SGC') + register_data_args(parser) + parser.add_argument("--gpu", type=int, default=-1, + help="gpu") + parser.add_argument("--lr", type=float, default=0.2, + help="learning rate") + parser.add_argument("--bias", action='store_true', default=False, + help="flag to use bias") + parser.add_argument("--n-epochs", type=int, default=100, + help="number of training epochs") + parser.add_argument("--weight-decay", type=float, default=5e-6, + help="Weight for L2 loss") + args = parser.parse_args() + print(args) + + main(args) \ No newline at end of file diff --git a/examples/pytorch/appnp/appnp.py b/examples/pytorch/appnp/appnp.py index c7ee38116f9d..f9dc64a591e9 100644 --- a/examples/pytorch/appnp/appnp.py +++ b/examples/pytorch/appnp/appnp.py @@ -5,9 +5,7 @@ Paper: https://arxiv.org/abs/1810.05997 Author's code: https://github.com/klicperajo/ppnp """ -import torch import torch.nn as nn -import dgl.function as fn from dgl.nn.pytorch.conv import APPNPConv diff --git a/examples/pytorch/appnp/train.py b/examples/pytorch/appnp/train.py index 59d9c5899b4e..befc2cfec925 100644 --- a/examples/pytorch/appnp/train.py +++ b/examples/pytorch/appnp/train.py @@ -65,13 +65,6 @@ def main(args): g.add_edges(g.nodes(), g.nodes()) g.set_n_initializer(dgl.init.zero_initializer) g.set_e_initializer(dgl.init.zero_initializer) - # normalization - degs = g.in_degrees().float() - norm = torch.pow(degs, -0.5) - norm[torch.isinf(norm)] = 0 - if cuda: - norm = norm.cuda() - g.ndata['norm'] = norm.unsqueeze(1) # create APPNP model model = APPNP(g, diff --git a/examples/pytorch/gin/dataloader.py b/examples/pytorch/gin/dataloader.py index 6ed1858d704b..cd46eea8671f 100644 --- a/examples/pytorch/gin/dataloader.py +++ b/examples/pytorch/gin/dataloader.py @@ -1,6 +1,5 @@ """ PyTorch compatible dataloader - """ @@ -19,12 +18,8 @@ def collate(samples): graphs, labels = map(list, zip(*samples)) for g in graphs: # deal with node feats - for feat in g.node_attr_schemes().keys(): - # TODO torch.Tensor is not recommended - # torch.DoubleTensor and torch.tensor - # will meet error in executor.py@runtime line 472, tensor.py@backend line 147 - # RuntimeError: expected type torch.cuda.DoubleTensor but got torch.cuda.FloatTensor - g.ndata[feat] = torch.Tensor(g.ndata[feat]) + for key in g.node_attr_schemes().keys(): + g.ndata[key] = torch.from_numpy(g.ndata[key]).float() # no edge feats batched_graph = dgl.batch(graphs) labels = torch.tensor(labels) @@ -63,10 +58,10 @@ def __init__(self, self.train_loader = DataLoader( dataset, sampler=train_sampler, - batch_size=batch_size, collate_fn=collate, **self.kwargs) + batch_size=batch_size, collate_fn=collate_fn, **self.kwargs) self.valid_loader = DataLoader( dataset, sampler=valid_sampler, - batch_size=batch_size, collate_fn=collate, **self.kwargs) + batch_size=batch_size, collate_fn=collate_fn, **self.kwargs) def train_valid_loader(self): return self.train_loader, self.valid_loader @@ -76,7 +71,6 @@ def _split_fold10(self, labels, fold_idx=0, seed=0, shuffle=True): assert 0 <= fold_idx and fold_idx < 10, print( "fold_idx must be from 0 to 9.") - idx_list = [] skf = StratifiedKFold(n_splits=10, shuffle=shuffle, random_state=seed) idx_list = [] for idx in skf.split(np.zeros(len(labels)), labels): # split(x, y) diff --git a/examples/pytorch/gin/gin.py b/examples/pytorch/gin/gin.py index e0b93c88e2d2..8a5408eb97cc 100644 --- a/examples/pytorch/gin/gin.py +++ b/examples/pytorch/gin/gin.py @@ -10,9 +10,7 @@ import torch.nn as nn import torch.nn.functional as F from dgl.nn.pytorch.conv import GINConv - -import dgl -import dgl.function as fn +from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling class ApplyNodeFunc(nn.Module): @@ -77,16 +75,16 @@ def forward(self, x): else: # If MLP h = x - for layer in range(self.num_layers - 1): - h = F.relu(self.batch_norms[layer](self.linears[layer](h))) - return self.linears[self.num_layers - 1](h) + for i in range(self.num_layers - 1): + h = F.relu(self.batch_norms[i](self.linears[i](h))) + return self.linears[-1](h) class GIN(nn.Module): """GIN model""" def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, output_dim, final_dropout, learn_eps, graph_pooling_type, - neighbor_pooling_type, device): + neighbor_pooling_type): """model parameters setting Paramters @@ -110,15 +108,10 @@ def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, how to aggregate neighbors (sum, mean, or max) graph_pooling_type: str how to aggregate entire nodes in a graph (sum, mean or max) - device: str - which device to use """ super(GIN, self).__init__() - self.final_dropout = final_dropout - self.device = device self.num_layers = num_layers - self.graph_pooling_type = graph_pooling_type self.learn_eps = learn_eps # List of MLPs @@ -147,36 +140,32 @@ def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, self.linears_prediction.append( nn.Linear(hidden_dim, output_dim)) - def forward(self, g): - h = g.ndata['attr'] - h = h.to(self.device) + self.drop = nn.Dropout(final_dropout) + + if graph_pooling_type == 'sum': + self.pool = SumPooling() + elif graph_pooling_type == 'mean': + self.pool = AvgPooling() + elif graph_pooling_type == 'max': + self.pool = MaxPooling() + else: + raise NotImplementedError + def forward(self, g, h): # list of hidden representation at each layer (including input) hidden_rep = [h] - for layer in range(self.num_layers - 1): - h = self.ginlayers[layer](g, h) - h = self.batch_norms[layer](h) + for i in range(self.num_layers - 1): + h = self.ginlayers[i](g, h) + h = self.batch_norms[i](h) h = F.relu(h) hidden_rep.append(h) score_over_layer = 0 # perform pooling over all nodes in each graph in every layer - for layer, h in enumerate(hidden_rep): - g.ndata['h'] = h - if self.graph_pooling_type == 'sum': - pooled_h = dgl.sum_nodes(g, 'h') - elif self.graph_pooling_type == 'mean': - pooled_h = dgl.mean_nodes(g, 'h') - elif self.graph_pooling_type == 'max': - pooled_h = dgl.max_nodes(g, 'h') - else: - raise NotImplementedError() - - score_over_layer += F.dropout( - self.linears_prediction[layer](pooled_h), - self.final_dropout, - training=self.training) + for i, h in enumerate(hidden_rep): + pooled_h = self.pool(g, h) + score_over_layer += self.drop(self.linears_prediction[i](pooled_h)) return score_over_layer diff --git a/examples/pytorch/gin/main.py b/examples/pytorch/gin/main.py index c78134e010a7..eeb36501eefe 100644 --- a/examples/pytorch/gin/main.py +++ b/examples/pytorch/gin/main.py @@ -23,16 +23,16 @@ def train(args, net, trainloader, optimizer, criterion, epoch): for pos, (graphs, labels) in zip(bar, trainloader): # batch graphs will be shipped to device in forward part of model labels = labels.to(args.device) - outputs = net(graphs) + feat = graphs.ndata['attr'].to(args.device) + outputs = net(graphs, feat) loss = criterion(outputs, labels) running_loss += loss.item() # backprop - if optimizer is not None: - optimizer.zero_grad() - loss.backward() - optimizer.step() + optimizer.zero_grad() + loss.backward() + optimizer.step() # report bar.set_description('epoch-{}'.format(epoch)) @@ -50,15 +50,12 @@ def eval_net(args, net, dataloader, criterion): total_loss = 0 total_correct = 0 - # total_iters = len(dataloader) - for data in dataloader: graphs, labels = data + feat = graphs.ndata['attr'].to(args.device) labels = labels.to(args.device) - total += len(labels) - - outputs = net(graphs) + outputs = net(graphs, feat) _, predicted = torch.max(outputs.data, 1) total_correct += (predicted == labels.data).sum().item() @@ -99,8 +96,7 @@ def main(args): args.num_layers, args.num_mlp_layers, dataset.dim_nfeats, args.hidden_dim, dataset.gclasses, args.final_dropout, args.learn_eps, - args.graph_pooling_type, args.neighbor_pooling_type, - args.device).to(args.device) + args.graph_pooling_type, args.neighbor_pooling_type).to(args.device) criterion = nn.CrossEntropyLoss() # defaul reduce is true optimizer = optim.Adam(model.parameters(), lr=args.lr) diff --git a/examples/pytorch/graphsage/graphsage.py b/examples/pytorch/graphsage/graphsage.py index 2919a89f3099..7fcbffe2b953 100644 --- a/examples/pytorch/graphsage/graphsage.py +++ b/examples/pytorch/graphsage/graphsage.py @@ -6,7 +6,6 @@ """ import argparse import time -import abc import numpy as np import networkx as nx import torch diff --git a/examples/pytorch/model_zoo/geometric/.gitignore b/examples/pytorch/model_zoo/geometric/.gitignore new file mode 100644 index 000000000000..689af916ab0a --- /dev/null +++ b/examples/pytorch/model_zoo/geometric/.gitignore @@ -0,0 +1 @@ +MNIST/ diff --git a/examples/pytorch/model_zoo/geometric/README.md b/examples/pytorch/model_zoo/geometric/README.md new file mode 100644 index 000000000000..48940e9fea48 --- /dev/null +++ b/examples/pytorch/model_zoo/geometric/README.md @@ -0,0 +1,25 @@ +Geometric Deep Learning models +========= + +This example shows how to use geometric deep learning models defined in `dgl.nn.pytorch.conv` for +graph classification. + +Currently we support following models: +- [ChebNet](https://arxiv.org/pdf/1606.09375.pdf) +- [MoNet](https://arxiv.org/pdf/1611.08402.pdf) + +## Image Classification on MNIST + +By transforming images to graphs, graph classifcation algorithms could +be applied to image classification problems. + +### Usage +```bash +python mnist.py --model cheb --gpu 0 +python mnist.py --model monet --gpu 0 +``` + +### Acknowledgement +We thank [Xavier Bresson](https://github.com/xbresson) for providing +code for graph coarsening algorithm and grid graph building in +[CE7454_2019 Labs](https://github.com/xbresson/CE7454_2019/tree/master/codes/labs_lecture14/lab01_ChebGCNs). diff --git a/examples/pytorch/model_zoo/geometric/coarsening.py b/examples/pytorch/model_zoo/geometric/coarsening.py new file mode 100644 index 000000000000..0f5cb3baefa3 --- /dev/null +++ b/examples/pytorch/model_zoo/geometric/coarsening.py @@ -0,0 +1,312 @@ +# author: xbresson +# code link: https://github.com/xbresson/CE7454_2019/blob/master/codes/labs_lecture14/lab01_ChebGCNs/lib/coarsening.py + +import numpy as np +import scipy.sparse +import sklearn.metrics + + +def laplacian(W, normalized=True): + """Return graph Laplacian""" + + # Degree matrix. + d = W.sum(axis=0) + + # Laplacian matrix. + if not normalized: + D = scipy.sparse.diags(d.A.squeeze(), 0) + L = D - W + else: + d += np.spacing(np.array(0, W.dtype)) + d = 1 / np.sqrt(d) + D = scipy.sparse.diags(d.A.squeeze(), 0) + I = scipy.sparse.identity(d.size, dtype=W.dtype) + L = I - D * W * D + + assert np.abs(L - L.T).mean() < 1e-9 + assert type(L) is scipy.sparse.csr.csr_matrix + return L + + +def rescale_L(L, lmax=2): + """Rescale Laplacian eigenvalues to [-1,1]""" + M, M = L.shape + I = scipy.sparse.identity(M, format='csr', dtype=L.dtype) + L /= lmax * 2 + L -= I + return L + + +def lmax_L(L): + """Compute largest Laplacian eigenvalue""" + return scipy.sparse.linalg.eigsh(L, k=1, which='LM', return_eigenvectors=False)[0] + + +# graph coarsening with Heavy Edge Matching +def coarsen(A, levels): + graphs, parents = HEM(A, levels) + perms = compute_perm(parents) + + laplacians = [] + for i, A in enumerate(graphs): + M, M = A.shape + + if i < levels: + A = perm_adjacency(A, perms[i]) + + A = A.tocsr() + A.eliminate_zeros() + Mnew, Mnew = A.shape + print('Layer {0}: M_{0} = |V| = {1} nodes ({2} added), |E| = {3} edges'.format(i, Mnew, Mnew - M, A.nnz // 2)) + + L = laplacian(A, normalized=True) + laplacians.append(L) + + return laplacians, perms[0] if len(perms) > 0 else None + + +def HEM(W, levels, rid=None): + """ + Coarsen a graph multiple times using the Heavy Edge Matching (HEM). + Input + W: symmetric sparse weight (adjacency) matrix + levels: the number of coarsened graphs + Output + graph[0]: original graph of size N_1 + graph[2]: coarser graph of size N_2 < N_1 + graph[levels]: coarsest graph of Size N_levels < ... < N_2 < N_1 + parents[i] is a vector of size N_i with entries ranging from 1 to N_{i+1} + which indicate the parents in the coarser graph[i+1] + nd_sz{i} is a vector of size N_i that contains the size of the supernode in the graph{i} + Note + if "graph" is a list of length k, then "parents" will be a list of length k-1 + """ + + N, N = W.shape + + if rid is None: + rid = np.random.permutation(range(N)) + + ss = np.array(W.sum(axis=0)).squeeze() + rid = np.argsort(ss) + + parents = [] + degree = W.sum(axis=0) - W.diagonal() + graphs = [] + graphs.append(W) + + print('Heavy Edge Matching coarsening with Xavier version') + + for _ in range(levels): + + # CHOOSE THE WEIGHTS FOR THE PAIRING + # weights = ones(N,1) # metis weights + weights = degree # graclus weights + # weights = supernode_size # other possibility + weights = np.array(weights).squeeze() + + # PAIR THE VERTICES AND CONSTRUCT THE ROOT VECTOR + idx_row, idx_col, val = scipy.sparse.find(W) + cc = idx_row + rr = idx_col + vv = val + + # TO BE SPEEDUP + if not (list(cc) == list(np.sort(cc))): + tmp = cc + cc = rr + rr = tmp + + cluster_id = HEM_one_level(cc, rr, vv, rid, weights) # cc is ordered + parents.append(cluster_id) + + # COMPUTE THE EDGES WEIGHTS FOR THE NEW GRAPH + nrr = cluster_id[rr] + ncc = cluster_id[cc] + nvv = vv + Nnew = cluster_id.max() + 1 + # CSR is more appropriate: row,val pairs appear multiple times + W = scipy.sparse.csr_matrix((nvv, (nrr, ncc)), shape=(Nnew, Nnew)) + W.eliminate_zeros() + + # Add new graph to the list of all coarsened graphs + graphs.append(W) + N, N = W.shape + + # COMPUTE THE DEGREE (OMIT OR NOT SELF LOOPS) + degree = W.sum(axis=0) + # degree = W.sum(axis=0) - W.diagonal() + + # CHOOSE THE ORDER IN WHICH VERTICES WILL BE VISTED AT THE NEXT PASS + # [~, rid]=sort(ss); # arthur strategy + # [~, rid]=sort(supernode_size); # thomas strategy + # rid=randperm(N); # metis/graclus strategy + ss = np.array(W.sum(axis=0)).squeeze() + rid = np.argsort(ss) + + return graphs, parents + + +# Coarsen a graph given by rr,cc,vv. rr is assumed to be ordered +def HEM_one_level(rr, cc, vv, rid, weights): + nnz = rr.shape[0] + N = rr[nnz - 1] + 1 + + marked = np.zeros(N, np.bool) + rowstart = np.zeros(N, np.int32) + rowlength = np.zeros(N, np.int32) + cluster_id = np.zeros(N, np.int32) + + oldval = rr[0] + count = 0 + clustercount = 0 + + for ii in range(nnz): + rowlength[count] = rowlength[count] + 1 + if rr[ii] > oldval: + oldval = rr[ii] + rowstart[count + 1] = ii + count = count + 1 + + for ii in range(N): + tid = rid[ii] + if not marked[tid]: + wmax = 0.0 + rs = rowstart[tid] + marked[tid] = True + bestneighbor = -1 + for jj in range(rowlength[tid]): + nid = cc[rs + jj] + if marked[nid]: + tval = 0.0 + else: + + # First approach + if 2 == 1: + tval = vv[rs + jj] * (1.0 / weights[tid] + 1.0 / weights[nid]) + + # Second approach + if 1 == 1: + Wij = vv[rs + jj] + Wii = vv[rowstart[tid]] + Wjj = vv[rowstart[nid]] + di = weights[tid] + dj = weights[nid] + tval = (2. * Wij + Wii + Wjj) * 1. / (di + dj + 1e-9) + + if tval > wmax: + wmax = tval + bestneighbor = nid + + cluster_id[tid] = clustercount + + if bestneighbor > -1: + cluster_id[bestneighbor] = clustercount + marked[bestneighbor] = True + + clustercount += 1 + + return cluster_id + + +def compute_perm(parents): + """ + Return a list of indices to reorder the adjacency and data matrices so + that the union of two neighbors from layer to layer forms a binary tree. + """ + + # Order of last layer is random (chosen by the clustering algorithm). + indices = [] + if len(parents) > 0: + M_last = max(parents[-1]) + 1 + indices.append(list(range(M_last))) + + for parent in parents[::-1]: + + # Fake nodes go after real ones. + pool_singeltons = len(parent) + + indices_layer = [] + for i in indices[-1]: + indices_node = list(np.where(parent == i)[0]) + assert 0 <= len(indices_node) <= 2 + + # Add a node to go with a singelton. + if len(indices_node) is 1: + indices_node.append(pool_singeltons) + pool_singeltons += 1 + + # Add two nodes as children of a singelton in the parent. + elif len(indices_node) is 0: + indices_node.append(pool_singeltons + 0) + indices_node.append(pool_singeltons + 1) + pool_singeltons += 2 + + indices_layer.extend(indices_node) + indices.append(indices_layer) + + # Sanity checks. + for i, indices_layer in enumerate(indices): + M = M_last * 2 ** i + # Reduction by 2 at each layer (binary tree). + assert len(indices[0] == M) + # The new ordering does not omit an indice. + assert sorted(indices_layer) == list(range(M)) + + return indices[::-1] + + +assert (compute_perm([np.array([4, 1, 1, 2, 2, 3, 0, 0, 3]), np.array([2, 1, 0, 1, 0])]) + == [[3, 4, 0, 9, 1, 2, 5, 8, 6, 7, 10, 11], [2, 4, 1, 3, 0, 5], [0, 1, 2]]) + + +def perm_adjacency(A, indices): + """ + Permute adjacency matrix, i.e. exchange node ids, + so that binary unions form the clustering tree. + """ + if indices is None: + return A + + M, M = A.shape + Mnew = len(indices) + A = A.tocoo() + + # Add Mnew - M isolated vertices. + rows = scipy.sparse.coo_matrix((Mnew - M, M), dtype=np.float32) + cols = scipy.sparse.coo_matrix((Mnew, Mnew - M), dtype=np.float32) + A = scipy.sparse.vstack([A, rows]) + A = scipy.sparse.hstack([A, cols]) + + # Permute the rows and the columns. + perm = np.argsort(indices) + A.row = np.array(perm)[A.row] + A.col = np.array(perm)[A.col] + + assert np.abs(A - A.T).mean() < 1e-8 # 1e-9 + assert type(A) is scipy.sparse.coo.coo_matrix + return A + + +def perm_data(x, indices): + """ + Permute data matrix, i.e. exchange node ids, + so that binary unions form the clustering tree. + """ + if indices is None: + return x + + N, M = x.shape + Mnew = len(indices) + assert Mnew >= M + xnew = np.empty((N, Mnew)) + for i, j in enumerate(indices): + # Existing vertex, i.e. real data. + if j < M: + xnew[:, i] = x[:, j] + # Fake vertex because of singeltons. + # They will stay 0 so that max pooling chooses the singelton. + # Or -infty ? + else: + xnew[:, i] = np.zeros(N) + return xnew diff --git a/examples/pytorch/model_zoo/geometric/coordinate.py b/examples/pytorch/model_zoo/geometric/coordinate.py new file mode 100644 index 000000000000..cdd473a84890 --- /dev/null +++ b/examples/pytorch/model_zoo/geometric/coordinate.py @@ -0,0 +1,30 @@ +import torch as th + +"""Compute x,y coordinate for nodes in the graph""" +eps = 1e-8 +def get_coordinates(graphs, grid_side, coarsening_levels, perm): + rst = [] + for l in range(coarsening_levels + 1): + xs, ys = [], [] + for i in range(graphs[l].number_of_nodes()): + cnt = eps + x_accum = 0 + y_accum = 0 + for j in range(i * 2 ** l, (i + 1) * 2 ** l): + if perm[j] < grid_side ** 2: + x_accum += (perm[j] // grid_side) + y_accum += (perm[j] % grid_side) + cnt += 1 + xs.append(x_accum / cnt) + ys.append(y_accum / cnt) + rst.append(th.cat([th.tensor(xs).view(-1, 1), th.tensor(ys).view(-1, 1)], -1)) + return rst + +"""Cartesian coordinate to polar coordinate""" +def z2polar(edges): + z = edges.dst['xy'] - edges.src['xy'] + rho = th.norm(z, dim=-1, p=2) + x, y = z.unbind(dim=-1) + phi = th.atan2(y, x) + return {'u': th.cat([rho.unsqueeze(-1), phi.unsqueeze(-1)], -1)} + diff --git a/examples/pytorch/model_zoo/geometric/grid_graph.py b/examples/pytorch/model_zoo/geometric/grid_graph.py new file mode 100644 index 000000000000..74f36d255a13 --- /dev/null +++ b/examples/pytorch/model_zoo/geometric/grid_graph.py @@ -0,0 +1,69 @@ +# author: xbresson +# code link: https://github.com/xbresson/CE7454_2019/blob/master/codes/labs_lecture14/lab01_ChebGCNs/lib/grid_graph.py + +import sklearn +import sklearn.metrics +import scipy.sparse, scipy.sparse.linalg # scipy.spatial.distance +import numpy as np + + +def grid_graph(grid_side,number_edges,metric): + """Generate graph of a grid""" + z = grid(grid_side) + dist, idx = distance_sklearn_metrics(z, k=number_edges, metric=metric) + A = adjacency(dist, idx) + print("nb edges: ",A.nnz) + return A + + +def grid(m, dtype=np.float32): + """Return coordinates of grid points""" + M = m**2 + x = np.linspace(0,1,m, dtype=dtype) + y = np.linspace(0,1,m, dtype=dtype) + xx, yy = np.meshgrid(x, y) + z = np.empty((M,2), dtype) + z[:,0] = xx.reshape(M) + z[:,1] = yy.reshape(M) + return z + + +def distance_sklearn_metrics(z, k=4, metric='euclidean'): + """Compute pairwise distances""" + #d = sklearn.metrics.pairwise.pairwise_distances(z, metric=metric, n_jobs=-2) + d = sklearn.metrics.pairwise.pairwise_distances(z, metric=metric, n_jobs=1) + # k-NN + idx = np.argsort(d)[:,1:k+1] + d.sort() + d = d[:,1:k+1] + return d, idx + + +def adjacency(dist, idx): + """Return adjacency matrix of a kNN graph""" + M, k = dist.shape + assert M, k == idx.shape + assert dist.min() >= 0 + assert dist.max() <= 1 + + # Pairwise distances + sigma2 = np.mean(dist[:,-1])**2 + dist = np.exp(- dist**2 / sigma2) + + # Weight matrix + I = np.arange(0, M).repeat(k) + J = idx.reshape(M*k) + V = dist.reshape(M*k) + W = scipy.sparse.coo_matrix((V, (I, J)), shape=(M, M)) + + # No self-connections + W.setdiag(0) + + # Undirected graph + bigger = W.T > W + W = W - W.multiply(bigger) + W.T.multiply(bigger) + + assert W.nnz % 2 == 0 + assert np.abs(W - W.T).mean() < 1e-10 + assert type(W) is scipy.sparse.csr.csr_matrix + return W diff --git a/examples/pytorch/model_zoo/geometric/mnist.py b/examples/pytorch/model_zoo/geometric/mnist.py new file mode 100644 index 000000000000..e2dc837c725f --- /dev/null +++ b/examples/pytorch/model_zoo/geometric/mnist.py @@ -0,0 +1,182 @@ +import argparse +import time +import numpy as np +import networkx as nx +import torch +import torch.nn as nn +import torch.nn.functional as F +import dgl +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from dgl import DGLGraph +from dgl.data import register_data_args, load_data +from dgl.nn.pytorch.conv import ChebConv, GMMConv +from dgl.nn.pytorch.glob import MaxPooling +from grid_graph import grid_graph +from coarsening import coarsen +from coordinate import get_coordinates, z2polar + +argparser = argparse.ArgumentParser("MNIST") +argparser.add_argument("--gpu", type=int, default=-1, + help="gpu id, use cpu if set to -1") +argparser.add_argument("--model", type=str, default="chebnet", + help="model to use, chebnet/monet") +argparser.add_argument("--batch-size", type=int, default=100, + help="batch size") +args = argparser.parse_args() + +grid_side = 28 +number_edges = 8 +metric = 'euclidean' + +A = grid_graph(28, 8, metric) + +coarsening_levels = 4 +L, perm = coarsen(A, coarsening_levels) +g_arr = [DGLGraph(csr) for csr in L] + +coordinate_arr = get_coordinates(g_arr, grid_side, coarsening_levels, perm) +for g, coordinate_arr in zip(g_arr, coordinate_arr): + g.ndata['xy'] = coordinate_arr + g.apply_edges(z2polar) + +def batcher(batch): + g_batch = [[] for _ in range(coarsening_levels + 1)] + x_batch = [] + y_batch = [] + for x, y in batch: + x = torch.cat([x.view(-1), x.new_zeros(928 - 28 ** 2)], 0) + x = x[perm] + x_batch.append(x) + y_batch.append(y) + for i in range(coarsening_levels + 1): + g_batch[i].append(g_arr[i]) + + x_batch = torch.cat(x_batch).unsqueeze(-1) + y_batch = torch.LongTensor(y_batch) + g_batch = [dgl.batch(g) for g in g_batch] + return g_batch, x_batch, y_batch + +trainset = datasets.MNIST(root='.', train=True, download=True, transform=transforms.ToTensor()) +testset = datasets.MNIST(root='.', train=False, download=True, transform=transforms.ToTensor()) + +train_loader = DataLoader(trainset, + batch_size=args.batch_size, + shuffle=True, + collate_fn=batcher, + num_workers=6) +test_loader = DataLoader(testset, + batch_size=args.batch_size, + shuffle=False, + collate_fn=batcher, + num_workers=6) + +class MoNet(nn.Module): + def __init__(self, + n_kernels, + in_feats, + hiddens, + out_feats): + super(MoNet, self).__init__() + self.pool = nn.MaxPool1d(2) + self.layers = nn.ModuleList() + self.readout = MaxPooling() + + # Input layer + self.layers.append( + GMMConv(in_feats, hiddens[0], 2, n_kernels)) + + # Hidden layer + for i in range(1, len(hiddens)): + self.layers.append(GMMConv(hiddens[i - 1], hiddens[i], 2, n_kernels)) + + self.cls = nn.Sequential( + nn.Linear(hiddens[-1], out_feats), + nn.LogSoftmax() + ) + + def forward(self, g_arr, feat): + for g, layer in zip(g_arr, self.layers): + u = g.edata['u'].to(feat.device) + feat = self.pool(layer(g, feat, u).transpose(-1, -2).unsqueeze(0))\ + .squeeze(0).transpose(-1, -2) + return self.cls(self.readout(g_arr[-1], feat)) + +class ChebNet(nn.Module): + def __init__(self, + k, + in_feats, + hiddens, + out_feats): + super(ChebNet, self).__init__() + self.pool = nn.MaxPool1d(2) + self.layers = nn.ModuleList() + self.readout = MaxPooling() + + # Input layer + self.layers.append( + ChebConv(in_feats, hiddens[0], k)) + + for i in range(1, len(hiddens)): + self.layers.append( + ChebConv(hiddens[i - 1], hiddens[i], k)) + + self.cls = nn.Sequential( + nn.Linear(hiddens[-1], out_feats), + nn.LogSoftmax() + ) + + def forward(self, g_arr, feat): + for g, layer in zip(g_arr, self.layers): + feat = self.pool(layer(g, feat, [2] * g.batch_size).transpose(-1, -2).unsqueeze(0))\ + .squeeze(0).transpose(-1, -2) + return self.cls(self.readout(g_arr[-1], feat)) + +if args.gpu == -1: + device = torch.device('cpu') +else: + device = torch.device(args.gpu) + +if args.model == 'chebnet': + model = ChebNet(2, 1, [32, 64, 128, 256], 10) +else: + model = MoNet(10, 1, [32, 64, 128, 256], 10) + +model = model.to(device) + +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) +log_interval = 50 + +for epoch in range(10): + print('epoch {} starts'.format(epoch)) + model.train() + hit, tot = 0, 0 + loss_accum = 0 + for i, (g, x, y) in enumerate(train_loader): + x = x.to(device) + y = y.to(device) + out = model(g, x) + hit += (out.max(-1)[1] == y).sum().item() + tot += len(y) + loss = F.nll_loss(out, y) + loss_accum += loss.item() + + if (i + 1) % log_interval == 0: + print('loss: {}, acc: {}'.format(loss_accum / log_interval, hit / tot)) + hit, tot = 0, 0 + loss_accum = 0 + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + model.eval() + hit, tot = 0, 0 + for g, x, y in test_loader: + x = x.to(device) + y = y.to(device) + out = model(g, x) + hit += (out.max(-1)[1] == y).sum().item() + tot += len(y) + + print('test acc: ', hit / tot) diff --git a/examples/pytorch/monet/README.md b/examples/pytorch/monet/README.md new file mode 100644 index 000000000000..c3a979bd61d6 --- /dev/null +++ b/examples/pytorch/monet/README.md @@ -0,0 +1,19 @@ +MoNet +===== + +- paper link: [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/pdf/1611.08402.pdf) + +Dependencies +============ + +- pytorch 1.1+ + +Results +======= + +Node classification on citation networks: +- Cora: ~0.816 +- Pubmed: ~0.763 + +Image classification on MNIST: +- please refer to [model_zoo/geometric](../model_zoo/geometric). \ No newline at end of file diff --git a/examples/pytorch/monet/citation.py b/examples/pytorch/monet/citation.py new file mode 100644 index 000000000000..67bc29c13520 --- /dev/null +++ b/examples/pytorch/monet/citation.py @@ -0,0 +1,189 @@ +import argparse +import time +import numpy as np +import networkx as nx +import torch +import torch.nn as nn +import torch.nn.functional as F +from dgl import DGLGraph +from dgl.data import register_data_args, load_data +from dgl.nn.pytorch.conv import GMMConv + + +class MoNet(nn.Module): + def __init__(self, + g, + in_feats, + n_hidden, + out_feats, + n_layers, + dim, + n_kernels, + dropout): + super(MoNet, self).__init__() + self.g = g + self.layers = nn.ModuleList() + self.pseudo_proj = nn.ModuleList() + + # Input layer + self.layers.append( + GMMConv(in_feats, n_hidden, dim, n_kernels)) + self.pseudo_proj.append( + nn.Sequential(nn.Linear(2, dim), nn.Tanh())) + + # Hidden layer + for _ in range(n_layers - 1): + self.layers.append(GMMConv(n_hidden, n_hidden, dim, n_kernels)) + self.pseudo_proj.append( + nn.Sequential(nn.Linear(2, dim), nn.Tanh())) + + # Output layer + self.layers.append(GMMConv(n_hidden, out_feats, dim, n_kernels)) + self.pseudo_proj.append( + nn.Sequential(nn.Linear(2, dim), nn.Tanh())) + self.dropout = nn.Dropout(dropout) + + def forward(self, feat, pseudo): + h = feat + for i in range(len(self.layers)): + if i != 0: + h = self.dropout(h) + h = self.layers[i]( + self.g, h, self.pseudo_proj[i](pseudo)) + return h + +def evaluate(model, features, pseudo, labels, mask): + model.eval() + with torch.no_grad(): + logits = model(features, pseudo) + logits = logits[mask] + labels = labels[mask] + _, indices = torch.max(logits, dim=1) + correct = torch.sum(indices == labels) + return correct.item() * 1.0 / len(labels) + +def main(args): + # load and preprocess dataset + data = load_data(args) + features = torch.FloatTensor(data.features) + labels = torch.LongTensor(data.labels) + if False: #hasattr(torch, 'BoolTensor'): + train_mask = torch.BoolTensor(data.train_mask) + val_mask = torch.BoolTensor(data.val_mask) + test_mask = torch.BoolTensor(data.test_mask) + else: + train_mask = torch.ByteTensor(data.train_mask) + val_mask = torch.ByteTensor(data.val_mask) + test_mask = torch.ByteTensor(data.test_mask) + in_feats = features.shape[1] + n_classes = data.num_labels + n_edges = data.graph.number_of_edges() + print("""----Data statistics------' + #Edges %d + #Classes %d + #Train samples %d + #Val samples %d + #Test samples %d""" % + (n_edges, n_classes, + train_mask.sum().item(), + val_mask.sum().item(), + test_mask.sum().item())) + + if args.gpu < 0: + cuda = False + else: + cuda = True + torch.cuda.set_device(args.gpu) + features = features.cuda() + labels = labels.cuda() + train_mask = train_mask.cuda() + val_mask = val_mask.cuda() + test_mask = test_mask.cuda() + print("use cuda:", args.gpu) + + # graph preprocess and calculate normalization factor + g = data.graph + g.remove_edges_from(nx.selfloop_edges(g)) + g = DGLGraph(g) + n_edges = g.number_of_edges() + us, vs = g.edges() + pseudo = [] + for i in range(g.number_of_edges()): + pseudo.append([ + 1 / np.sqrt(g.in_degree(us[i])), + 1 / np.sqrt(g.in_degree(vs[i])) + ]) + pseudo = torch.Tensor(pseudo) + if cuda: + pseudo = pseudo.cuda() + + # create GraphSAGE model + model = MoNet(g, + in_feats, + args.n_hidden, + n_classes, + args.n_layers, + args.pseudo_dim, + args.n_kernels, + args.dropout + ) + + if cuda: + model.cuda() + loss_fcn = torch.nn.CrossEntropyLoss() + + # use optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + # initialize graph + dur = [] + for epoch in range(args.n_epochs): + model.train() + if epoch >= 3: + t0 = time.time() + # forward + logits = model(features, pseudo) + loss = loss_fcn(logits[train_mask], labels[train_mask]) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if epoch >= 3: + dur.append(time.time() - t0) + + acc = evaluate(model, features, pseudo, labels, val_mask) + print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " + "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(), + acc, n_edges / np.mean(dur) / 1000)) + + print() + acc = evaluate(model, features, pseudo, labels, test_mask) + print("Test Accuracy {:.4f}".format(acc)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='MoNet on citation network') + register_data_args(parser) + parser.add_argument("--dropout", type=float, default=0.5, + help="dropout probability") + parser.add_argument("--gpu", type=int, default=-1, + help="gpu") + parser.add_argument("--lr", type=float, default=1e-2, + help="learning rate") + parser.add_argument("--n-epochs", type=int, default=200, + help="number of training epochs") + parser.add_argument("--n-hidden", type=int, default=16, + help="number of hidden gcn units") + parser.add_argument("--n-layers", type=int, default=1, + help="number of hidden gcn layers") + parser.add_argument("--pseudo-dim", type=int, default=2, + help="Pseudo coordinate dimensions in GMMConv, 2 for cora and 3 for pubmed") + parser.add_argument("--n-kernels", type=int, default=3, + help="Number of kernels in GMMConv layer") + parser.add_argument("--weight-decay", type=float, default=5e-4, + help="Weight for L2 loss") + args = parser.parse_args() + print(args) + + main(args) diff --git a/python/dgl/data/gindt.py b/python/dgl/data/gindt.py index 90adf72d5efb..f89b5857d241 100644 --- a/python/dgl/data/gindt.py +++ b/python/dgl/data/gindt.py @@ -10,6 +10,8 @@ import os import numpy as np +from .. import backend as F + from .utils import download, extract_archive, get_download_dir, _get_dgl_url from ..graph import DGLGraph @@ -235,8 +237,7 @@ def _load(self): for g in self.graphs: g.ndata['attr'] = np.zeros(( g.number_of_nodes(), len(label2idx))) - g.ndata['attr'][range(g.number_of_nodes( - )), [label2idx[nl.item()] for nl in g.ndata['label']]] = 1 + g.ndata['attr'][:, [label2idx[F.as_scalar(nl)] for nl in g.ndata['label']]] = 1 # after load, get the #classes and #dim self.gclasses = len(self.glabel_dict) diff --git a/python/dgl/nn/mxnet/conv/__init__.py b/python/dgl/nn/mxnet/conv/__init__.py index 063612539d4a..6096d57786ec 100644 --- a/python/dgl/nn/mxnet/conv/__init__.py +++ b/python/dgl/nn/mxnet/conv/__init__.py @@ -4,5 +4,22 @@ from .graphconv import GraphConv from .relgraphconv import RelGraphConv from .tagconv import TAGConv +from .gatconv import GATConv +from .sageconv import SAGEConv +from .gatedgraphconv import GatedGraphConv +from .chebconv import ChebConv +from .agnnconv import AGNNConv +from .appnpconv import APPNPConv +from .densegraphconv import DenseGraphConv +from .densesageconv import DenseSAGEConv +from .densechebconv import DenseChebConv +from .edgeconv import EdgeConv +from .ginconv import GINConv +from .gmmconv import GMMConv +from .nnconv import NNConv +from .sgconv import SGConv -__all__ = ['GraphConv', 'TAGConv', 'RelGraphConv'] +__all__ = ['GraphConv', 'TAGConv', 'RelGraphConv', 'GATConv', + 'SAGEConv', 'GatedGraphConv', 'ChebConv', 'AGNNConv', + 'APPNPConv', 'DenseGraphConv', 'DenseSAGEConv', 'DenseChebConv', + 'EdgeConv', 'GINConv', 'GMMConv', 'NNConv', 'SGConv'] diff --git a/python/dgl/nn/mxnet/conv/agnnconv.py b/python/dgl/nn/mxnet/conv/agnnconv.py new file mode 100644 index 000000000000..d22b0232d284 --- /dev/null +++ b/python/dgl/nn/mxnet/conv/agnnconv.py @@ -0,0 +1,66 @@ +"""MXNet Module for Attention-based Graph Neural Network layer""" +# pylint: disable= no-member, arguments-differ, invalid-name +import mxnet as mx +from mxnet.gluon import nn + +from .... import function as fn +from ..softmax import edge_softmax +from ..utils import normalize + +class AGNNConv(nn.Block): + r"""Attention-based Graph Neural Network layer from paper `Attention-based + Graph Neural Network for Semi-Supervised Learning + `__. + + .. math:: + H^{l+1} = P H^{l} + + where :math:`P` is computed as: + + .. math:: + P_{ij} = \mathrm{softmax}_i ( \beta \cdot \cos(h_i^l, h_j^l)) + + Parameters + ---------- + init_beta : float, optional + The :math:`\beta` in the formula. + learn_beta : bool, optional + If True, :math:`\beta` will be learnable parameter. + """ + def __init__(self, + init_beta=1., + learn_beta=True): + super(AGNNConv, self).__init__() + with self.name_scope(): + self.beta = self.params.get('beta', + shape=(1,), + grad_req='write' if learn_beta else 'null', + init=mx.init.Constant(init_beta)) + + def forward(self, graph, feat): + r"""Compute AGNN Layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : mxnet.NDArray + The input feature of shape :math:`(N, *)` :math:`N` is the + number of nodes, and :math:`*` could be of any shape. + + Returns + ------- + mxnet.NDArray + The output feature of shape :math:`(N, *)` where :math:`*` + should be the same as input shape. + """ + graph = graph.local_var() + graph.ndata['h'] = feat + graph.ndata['norm_h'] = normalize(feat, p=2, axis=-1) + # compute cosine distance + graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos')) + cos = graph.edata.pop('cos') + e = self.beta.data(feat.context) * cos + graph.edata['p'] = edge_softmax(graph, e) + graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h')) + return graph.ndata.pop('h') diff --git a/python/dgl/nn/mxnet/conv/appnpconv.py b/python/dgl/nn/mxnet/conv/appnpconv.py new file mode 100644 index 000000000000..6db5e97304b9 --- /dev/null +++ b/python/dgl/nn/mxnet/conv/appnpconv.py @@ -0,0 +1,75 @@ +"""MXNet Module for APPNPConv""" +# pylint: disable= no-member, arguments-differ, invalid-name +import mxnet as mx +from mxnet import nd +from mxnet.gluon import nn + +from .... import function as fn + +class APPNPConv(nn.Block): + r"""Approximate Personalized Propagation of Neural Predictions + layer from paper `Predict then Propagate: Graph Neural Networks + meet Personalized PageRank `__. + + .. math:: + H^{0} & = X + + H^{t+1} & = (1-\alpha)\left(\hat{D}^{-1/2} + \hat{A} \hat{D}^{-1/2} H^{t} + \alpha H^{0}\right) + + Parameters + ---------- + k : int + Number of iterations :math:`K`. + alpha : float + The teleport probability :math:`\alpha`. + edge_drop : float, optional + Dropout rate on edges that controls the + messages received by each node. Default: ``0``. + """ + def __init__(self, + k, + alpha, + edge_drop=0.): + super(APPNPConv, self).__init__() + self._k = k + self._alpha = alpha + with self.name_scope(): + self.edge_drop = nn.Dropout(edge_drop) + + def forward(self, graph, feat): + r"""Compute APPNP layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : mx.NDArray + The input feature of shape :math:`(N, *)` :math:`N` is the + number of nodes, and :math:`*` could be of any shape. + + Returns + ------- + mx.NDArray + The output feature of shape :math:`(N, *)` where :math:`*` + should be the same as input shape. + """ + graph = graph.local_var() + norm = mx.nd.power(mx.nd.clip( + graph.in_degrees().astype(feat.dtype), a_min=1, a_max=float("inf")), -0.5) + shp = norm.shape + (1,) * (feat.ndim - 1) + norm = norm.reshape(shp).as_in_context(feat.context) + feat_0 = feat + for _ in range(self._k): + # normalization by src node + feat = feat * norm + graph.ndata['h'] = feat + graph.edata['w'] = self.edge_drop( + nd.ones((graph.number_of_edges(), 1), ctx=feat.context)) + graph.update_all(fn.u_mul_e('h', 'w', 'm'), + fn.sum('m', 'h')) + feat = graph.ndata.pop('h') + # normalization by dst node + feat = feat * norm + feat = (1 - self._alpha) * feat + self._alpha * feat_0 + return feat diff --git a/python/dgl/nn/mxnet/conv/chebconv.py b/python/dgl/nn/mxnet/conv/chebconv.py new file mode 100644 index 000000000000..b63f3dc8b642 --- /dev/null +++ b/python/dgl/nn/mxnet/conv/chebconv.py @@ -0,0 +1,123 @@ +"""MXNet Module for Chebyshev Spectral Graph Convolution layer""" +# pylint: disable= no-member, arguments-differ, invalid-name +import math +import mxnet as mx +from mxnet import nd +from mxnet.gluon import nn + +from .... import laplacian_lambda_max, broadcast_nodes, function as fn + + +class ChebConv(nn.Block): + r"""Chebyshev Spectral Graph Convolution layer from paper `Convolutional + Neural Networks on Graphs with Fast Localized Spectral Filtering + `__. + + .. math:: + h_i^{l+1} &= \sum_{k=0}^{K-1} W^{k, l}z_i^{k, l} + + Z^{0, l} &= H^{l} + + Z^{1, l} &= \hat{L} \cdot H^{l} + + Z^{k, l} &= 2 \cdot \hat{L} \cdot Z^{k-1, l} - Z^{k-2, l} + + \hat{L} &= 2\left(I - \hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2}\right)/\lambda_{max} - I + + Parameters + ---------- + in_feats: int + Number of input features. + out_feats: int + Number of output features. + k : int + Chebyshev filter size. + bias : bool, optional + If True, adds a learnable bias to the output. Default: ``True``. + """ + def __init__(self, + in_feats, + out_feats, + k, + bias=True): + super(ChebConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._k = k + with self.name_scope(): + self.fc = nn.Sequential() + for _ in range(k): + self.fc.add( + nn.Dense(out_feats, use_bias=False, + weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)), + in_units=in_feats) + ) + if bias: + self.bias = self.params.get('bias', shape=(out_feats,), + init=mx.init.Zero()) + else: + self.bias = None + + def forward(self, graph, feat, lambda_max=None): + r"""Compute ChebNet layer. + + Parameters + ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. + feat : mxnet.NDArray + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + lambda_max : list or mxnet.NDArray or None, optional. + A list(tensor) with length :math:`B`, stores the largest eigenvalue + of the normalized laplacian of each individual graph in ``graph``, + where :math:`B` is the batch size of the input graph. Default: None. + If None, this method would compute the list by calling + ``dgl.laplacian_lambda_max``. + + Returns + ------- + mxnet.NDArray + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + """ + with graph.local_scope(): + degs = graph.in_degrees().astype('float32') + norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5) + norm = norm.expand_dims(-1).as_in_context(feat.context) + if lambda_max is None: + lambda_max = laplacian_lambda_max(graph) + if isinstance(lambda_max, list): + lambda_max = nd.array(lambda_max).as_in_context(feat.context) + if lambda_max.ndim == 1: + lambda_max = lambda_max.expand_dims(-1) + # broadcast from (B, 1) to (N, 1) + lambda_max = broadcast_nodes(graph, lambda_max) + # T0(X) + Tx_0 = feat + rst = self.fc[0](Tx_0) + # T1(X) + if self._k > 1: + graph.ndata['h'] = Tx_0 * norm + graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) + h = graph.ndata.pop('h') * norm + # Λ = 2 * (I - D ^ -1/2 A D ^ -1/2) / lambda_max - I + # = - 2(D ^ -1/2 A D ^ -1/2) / lambda_max + (2 / lambda_max - 1) I + Tx_1 = -2. * h / lambda_max + Tx_0 * (2. / lambda_max - 1) + rst = rst + self.fc[1](Tx_1) + # Ti(x), i = 2...k + for i in range(2, self._k): + graph.ndata['h'] = Tx_1 * norm + graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) + h = graph.ndata.pop('h') * norm + # Tx_k = 2 * Λ * Tx_(k-1) - Tx_(k-2) + # = - 4(D ^ -1/2 A D ^ -1/2) / lambda_max Tx_(k-1) + + # (4 / lambda_max - 2) Tx_(k-1) - + # Tx_(k-2) + Tx_2 = -4. * h / lambda_max + Tx_1 * (4. / lambda_max - 2) - Tx_0 + rst = rst + self.fc[i](Tx_2) + Tx_1, Tx_0 = Tx_2, Tx_1 + # add bias + if self.bias is not None: + rst = rst + self.bias.data(feat.context) + return rst diff --git a/python/dgl/nn/mxnet/conv/densechebconv.py b/python/dgl/nn/mxnet/conv/densechebconv.py new file mode 100644 index 000000000000..cb500fe575dc --- /dev/null +++ b/python/dgl/nn/mxnet/conv/densechebconv.py @@ -0,0 +1,100 @@ +"""MXNet Module for DenseChebConv""" +# pylint: disable= no-member, arguments-differ, invalid-name +import math +import mxnet as mx +from mxnet import nd +from mxnet.gluon import nn + + +class DenseChebConv(nn.Block): + r"""Chebyshev Spectral Graph Convolution layer from paper `Convolutional + Neural Networks on Graphs with Fast Localized Spectral Filtering + `__. + + We recommend to use this module when inducing ChebConv operations on dense + graphs / k-hop graphs. + + Parameters + ---------- + in_feats: int + Number of input features. + out_feats: int + Number of output features. + k : int + Chebyshev filter size. + bias : bool, optional + If True, adds a learnable bias to the output. Default: ``True``. + + See also + -------- + ChebConv + """ + def __init__(self, + in_feats, + out_feats, + k, + bias=True): + super(DenseChebConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._k = k + with self.name_scope(): + self.fc = nn.Sequential() + for _ in range(k): + self.fc.add( + nn.Dense(out_feats, in_units=in_feats, use_bias=False, + weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0))) + ) + if bias: + self.bias = self.params.get('bias', shape=(out_feats,), + init=mx.init.Zero()) + else: + self.bias = None + + def forward(self, adj, feat, lambda_max=None): + r"""Compute (Dense) Chebyshev Spectral Graph Convolution layer. + + Parameters + ---------- + adj : mxnet.NDArray + The adjacency matrix of the graph to apply Graph Convolution on, + should be of shape :math:`(N, N)`, where a row represents the destination + and a column represents the source. + feat : mxnet.NDArray + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + lambda_max : float or None, optional + A float value indicates the largest eigenvalue of given graph. + Default: None. + + Returns + ------- + mxnet.NDArray + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + """ + A = adj.astype(feat.dtype).as_in_context(feat.context) + num_nodes = A.shape[0] + + in_degree = 1. / nd.clip(A.sum(axis=1), 1, float('inf')).sqrt() + D_invsqrt = nd.diag(in_degree) + I = nd.eye(num_nodes, ctx=A.context) + L = I - nd.dot(D_invsqrt, nd.dot(A, D_invsqrt)) + + if lambda_max is None: + # NOTE(zihao): this only works for directed graph. + lambda_max = (nd.linalg.syevd(L)[1]).max() + + L_hat = 2 * L / lambda_max - I + Z = [nd.eye(num_nodes, ctx=A.context)] + Zh = self.fc[0](feat) + for i in range(1, self._k): + if i == 1: + Z.append(L_hat) + else: + Z.append(2 * nd.dot(L_hat, Z[-1]) - Z[-2]) + Zh = Zh + nd.dot(Z[i], self.fc[i](feat)) + + if self.bias is not None: + Zh = Zh + self.bias.data(feat.context) + return Zh diff --git a/python/dgl/nn/mxnet/conv/densegraphconv.py b/python/dgl/nn/mxnet/conv/densegraphconv.py new file mode 100644 index 000000000000..ca6409588296 --- /dev/null +++ b/python/dgl/nn/mxnet/conv/densegraphconv.py @@ -0,0 +1,98 @@ +"""MXNet Module for DenseGraphConv""" +# pylint: disable= no-member, arguments-differ, invalid-name +import math +import mxnet as mx +from mxnet import nd +from mxnet.gluon import nn + + +class DenseGraphConv(nn.Block): + """Graph Convolutional Network layer where the graph structure + is given by an adjacency matrix. + We recommend user to use this module when inducing graph convolution + on dense graphs / k-hop graphs. + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + norm : bool + If True, the normalizer :math:`c_{ij}` is applied. Default: ``True``. + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + activation : callable activation function/layer or None, optional + If not None, applies an activation function to the updated node features. + Default: ``None``. + + See also + -------- + GraphConv + """ + def __init__(self, + in_feats, + out_feats, + norm=True, + bias=True, + activation=None): + super(DenseGraphConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._norm = norm + with self.name_scope(): + self.weight = self.params.get('weight', shape=(in_feats, out_feats), + init=mx.init.Xavier(magnitude=math.sqrt(2.0))) + if bias: + self.bias = self.params.get('bias', shape=(out_feats,), + init=mx.init.Zero()) + else: + self.bias = None + self._activation = activation + + def forward(self, adj, feat): + r"""Compute (Dense) Graph Convolution layer. + + Parameters + ---------- + adj : mxnet.NDArray + The adjacency matrix of the graph to apply Graph Convolution on, + should be of shape :math:`(N, N)`, where a row represents the destination + and a column represents the source. + feat : mxnet.NDArray + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + + Returns + ------- + mxnet.NDArray + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + """ + adj = adj.astype(feat.dtype).as_in_context(feat.context) + if self._norm: + in_degrees = adj.sum(axis=1) + norm = nd.power(in_degrees, -0.5) + shp = norm.shape + (1,) * (feat.ndim - 1) + norm = norm.reshape(shp).as_in_context(feat.context) + feat = feat * norm + + if self._in_feats > self._out_feats: + # mult W first to reduce the feature size for aggregation. + feat = nd.dot(feat, self.weight.data(feat.context)) + rst = nd.dot(adj, feat) + else: + # aggregate first then mult W + rst = nd.dot(adj, feat) + rst = nd.dot(rst, self.weight.data(feat.context)) + + if self._norm: + rst = rst * norm + + if self.bias is not None: + rst = rst + self.bias.data(feat.context) + + if self._activation is not None: + rst = self._activation(rst) + + return rst diff --git a/python/dgl/nn/mxnet/conv/densesageconv.py b/python/dgl/nn/mxnet/conv/densesageconv.py new file mode 100644 index 000000000000..50b1beb6bb47 --- /dev/null +++ b/python/dgl/nn/mxnet/conv/densesageconv.py @@ -0,0 +1,85 @@ +"""MXNet Module for DenseGraphSAGE""" +# pylint: disable= no-member, arguments-differ, invalid-name +import math +import mxnet as mx +from mxnet import nd +from mxnet.gluon import nn + + +class DenseSAGEConv(nn.Block): + """GraphSAGE layer where the graph structure is given by an + adjacency matrix. + We recommend to use this module when inducing GraphSAGE operations + on dense graphs / k-hop graphs. + + Note that we only support gcn aggregator in DenseSAGEConv. + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + feat_drop : float, optional + Dropout rate on features. Default: 0. + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + norm : callable activation function/layer or None, optional + If not None, applies normalization to the updated node features. + activation : callable activation function/layer or None, optional + If not None, applies an activation function to the updated node features. + Default: ``None``. + + See also + -------- + SAGEConv + """ + def __init__(self, + in_feats, + out_feats, + feat_drop=0., + bias=True, + norm=None, + activation=None): + super(DenseSAGEConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._norm = norm + with self.name_scope(): + self.feat_drop = nn.Dropout(feat_drop) + self.activation = activation + self.fc = nn.Dense(out_feats, in_units=in_feats, use_bias=bias, + weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0))) + + def forward(self, adj, feat): + r"""Compute (Dense) Graph SAGE layer. + + Parameters + ---------- + adj : mxnet.NDArray + The adjacency matrix of the graph to apply Graph Convolution on, + should be of shape :math:`(N, N)`, where a row represents the destination + and a column represents the source. + feat : mxnet.NDArray + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + + Returns + ------- + mxnet.NDArray + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + """ + adj = adj.astype(feat.dtype).as_in_context(feat.context) + feat = self.feat_drop(feat) + in_degrees = adj.sum(axis=1, keepdims=True) + h_neigh = (nd.dot(adj, feat) + feat) / (in_degrees + 1) + rst = self.fc(h_neigh) + # activation + if self.activation is not None: + rst = self.activation(rst) + # normalization + if self._norm is not None: + rst = self._norm(rst) + + return rst diff --git a/python/dgl/nn/mxnet/conv/edgeconv.py b/python/dgl/nn/mxnet/conv/edgeconv.py new file mode 100644 index 000000000000..2f79b3157113 --- /dev/null +++ b/python/dgl/nn/mxnet/conv/edgeconv.py @@ -0,0 +1,76 @@ +"""MXNet Module for EdgeConv Layer""" +# pylint: disable= no-member, arguments-differ, invalid-name +import mxnet as mx +from mxnet.gluon import nn + +from .... import function as fn + + +class EdgeConv(nn.Block): + r"""EdgeConv layer. + + Introduced in "`Dynamic Graph CNN for Learning on Point Clouds + `__". Can be described as follows: + + .. math:: + x_i^{(l+1)} = \max_{j \in \mathcal{N}(i)} \mathrm{ReLU}( + \Theta \cdot (x_j^{(l)} - x_i^{(l)}) + \Phi \cdot x_i^{(l)}) + + where :math:`\mathcal{N}(i)` is the neighbor of :math:`i`. + + Parameters + ---------- + in_feat : int + Input feature size. + out_feat : int + Output feature size. + batch_norm : bool + Whether to include batch normalization on messages. + """ + def __init__(self, + in_feat, + out_feat, + batch_norm=False): + super(EdgeConv, self).__init__() + self.batch_norm = batch_norm + + with self.name_scope(): + self.theta = nn.Dense(out_feat, in_units=in_feat, + weight_initializer=mx.init.Xavier()) + self.phi = nn.Dense(out_feat, in_units=in_feat, + weight_initializer=mx.init.Xavier()) + + if batch_norm: + self.bn = nn.BatchNorm(in_channels=out_feat) + + def message(self, edges): + r"""The message computation function + """ + theta_x = self.theta(edges.dst['x'] - edges.src['x']) + phi_x = self.phi(edges.src['x']) + return {'e': theta_x + phi_x} + + def forward(self, g, h): + r"""Forward computation + + Parameters + ---------- + g : DGLGraph + The graph. + h : mxnet.NDArray + :math:`(N, D)` where :math:`N` is the number of nodes and + :math:`D` is the number of feature dimensions. + Returns + ------- + mxnet.NDArray + New node features. + """ + with g.local_scope(): + g.ndata['x'] = h + if not self.batch_norm: + g.update_all(self.message, fn.max('e', 'x')) + else: + g.apply_edges(self.message) + g.edata['e'] = self.bn(g.edata['e']) + g.update_all(fn.copy_e('e', 'm'), fn.max('m', 'x')) + return g.ndata['x'] diff --git a/python/dgl/nn/mxnet/conv/gatconv.py b/python/dgl/nn/mxnet/conv/gatconv.py new file mode 100644 index 000000000000..532c537a8f2f --- /dev/null +++ b/python/dgl/nn/mxnet/conv/gatconv.py @@ -0,0 +1,123 @@ +"""MXNet modules for graph attention networks(GAT).""" +# pylint: disable= no-member, arguments-differ, invalid-name +import math +import mxnet as mx +from mxnet.gluon import nn +from mxnet.gluon.contrib.nn import Identity + +from .... import function as fn +from ..softmax import edge_softmax + +#pylint: enable=W0235 +class GATConv(nn.Block): + r"""Apply `Graph Attention Network `__ + over an input signal. + + .. math:: + h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)} + + where :math:`\alpha_{ij}` is the attention score bewteen node :math:`i` and + node :math:`j`: + + .. math:: + \alpha_{ij}^{l} & = \mathrm{softmax_i} (e_{ij}^{l}) + + e_{ij}^{l} & = \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right) + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + num_heads : int + Number of heads in Multi-Head Attention. + feat_drop : float, optional + Dropout rate on feature, defaults: ``0``. + attn_drop : float, optional + Dropout rate on attention weight, defaults: ``0``. + negative_slope : float, optional + LeakyReLU angle of negative slope. + residual : bool, optional + If True, use residual connection. + activation : callable activation function/layer or None, optional. + If not None, applies an activation function to the updated node features. + Default: ``None``. + """ + def __init__(self, + in_feats, + out_feats, + num_heads, + feat_drop=0., + attn_drop=0., + negative_slope=0.2, + residual=False, + activation=None): + super(GATConv, self).__init__() + self._num_heads = num_heads + self._in_feats = in_feats + self._out_feats = out_feats + with self.name_scope(): + self.fc = nn.Dense(out_feats * num_heads, use_bias=False, + weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)), + in_units=in_feats) + self.attn_l = self.params.get('attn_l', + shape=(1, num_heads, out_feats), + init=mx.init.Xavier(magnitude=math.sqrt(2.0))) + self.attn_r = self.params.get('attn_r', + shape=(1, num_heads, out_feats), + init=mx.init.Xavier(magnitude=math.sqrt(2.0))) + self.feat_drop = nn.Dropout(feat_drop) + self.attn_drop = nn.Dropout(attn_drop) + self.leaky_relu = nn.LeakyReLU(negative_slope) + if residual: + if in_feats != out_feats: + self.res_fc = nn.Dense(out_feats * num_heads, use_bias=False, + weight_initializer=mx.init.Xavier( + magnitude=math.sqrt(2.0)), + in_units=in_feats) + else: + self.res_fc = Identity() + else: + self.res_fc = None + self.activation = activation + + def forward(self, graph, feat): + r"""Compute graph attention network layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : mxnet.NDArray + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + + Returns + ------- + mxnet.NDArray + The output feature of shape :math:`(N, H, D_{out})` where :math:`H` + is the number of heads, and :math:`D_{out}` is size of output feature. + """ + graph = graph.local_var() + h = self.feat_drop(feat) + feat = self.fc(h).reshape(-1, self._num_heads, self._out_feats) + el = (feat * self.attn_l.data(feat.context)).sum(axis=-1).expand_dims(-1) + er = (feat * self.attn_r.data(feat.context)).sum(axis=-1).expand_dims(-1) + graph.ndata.update({'ft': feat, 'el': el, 'er': er}) + # compute edge attention + graph.apply_edges(fn.u_add_v('el', 'er', 'e')) + e = self.leaky_relu(graph.edata.pop('e')) + # compute softmax + graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) + graph.update_all(fn.u_mul_e('ft', 'a', 'm'), + fn.sum('m', 'ft')) + rst = graph.ndata['ft'] + # residual + if self.res_fc is not None: + resval = self.res_fc(h).reshape(h.shape[0], -1, self._out_feats) + rst = rst + resval + # activation + if self.activation: + rst = self.activation(rst) + return rst diff --git a/python/dgl/nn/mxnet/conv/gatedgraphconv.py b/python/dgl/nn/mxnet/conv/gatedgraphconv.py new file mode 100644 index 000000000000..060a0160f474 --- /dev/null +++ b/python/dgl/nn/mxnet/conv/gatedgraphconv.py @@ -0,0 +1,95 @@ +"""MXNet Module for Gated Graph Convolution layer""" +# pylint: disable= no-member, arguments-differ, invalid-name, cell-var-from-loop +import mxnet as mx +from mxnet import gluon, nd +from mxnet.gluon import nn + +from .... import function as fn + +class GatedGraphConv(nn.Block): + r"""Gated Graph Convolution layer from paper `Gated Graph Sequence + Neural Networks `__. + + .. math:: + h_{i}^{0} & = [ x_i \| \mathbf{0} ] + + a_{i}^{t} & = \sum_{j\in\mathcal{N}(i)} W_{e_{ij}} h_{j}^{t} + + h_{i}^{t+1} & = \mathrm{GRU}(a_{i}^{t}, h_{i}^{t}) + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + n_steps : int + Number of recurrent steps. + n_etypes : int + Number of edge types. + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + Can only be set to True in MXNet. + """ + def __init__(self, + in_feats, + out_feats, + n_steps, + n_etypes, + bias=True): + super(GatedGraphConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._n_steps = n_steps + self._n_etypes = n_etypes + if not bias: + raise KeyError('MXNet do not support disabling bias in GRUCell.') + with self.name_scope(): + self.linears = nn.Sequential() + for _ in range(n_etypes): + self.linears.add( + nn.Dense(out_feats, + weight_initializer=mx.init.Xavier(), + in_units=out_feats) + ) + self.gru = gluon.rnn.GRUCell(out_feats, input_size=out_feats) + + def forward(self, graph, feat, etypes): + """Compute Gated Graph Convolution layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : mxnet.NDArray + The input feature of shape :math:`(N, D_{in})` where :math:`N` + is the number of nodes of the graph and :math:`D_{in}` is the + input feature size. + etypes : torch.LongTensor + The edge type tensor of shape :math:`(E,)` where :math:`E` is + the number of edges of the graph. + + Returns + ------- + mxnet.NDArray + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is the output feature size. + """ + graph = graph.local_var() + zero_pad = nd.zeros((feat.shape[0], self._out_feats - feat.shape[1]), ctx=feat.context) + feat = nd.concat(feat, zero_pad, dim=-1) + + for _ in range(self._n_steps): + graph.ndata['h'] = feat + for i in range(self._n_etypes): + eids = (etypes.asnumpy() == i).nonzero()[0] + eids = nd.from_numpy(eids, zero_copy=True) + if len(eids) > 0: + graph.apply_edges( + lambda edges: {'W_e*h': self.linears[i](edges.src['h'])}, + eids + ) + graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a')) + a = graph.ndata.pop('a') + feat = self.gru(a, [feat])[0] + return feat diff --git a/python/dgl/nn/mxnet/conv/ginconv.py b/python/dgl/nn/mxnet/conv/ginconv.py new file mode 100644 index 000000000000..b6c90660dd8a --- /dev/null +++ b/python/dgl/nn/mxnet/conv/ginconv.py @@ -0,0 +1,79 @@ +"""MXNet Module for Graph Isomorphism Network layer""" +# pylint: disable= no-member, arguments-differ, invalid-name +import mxnet as mx +from mxnet.gluon import nn + +from .... import function as fn + + +class GINConv(nn.Block): + r"""Graph Isomorphism Network layer from paper `How Powerful are Graph + Neural Networks? `__. + + .. math:: + h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} + + \mathrm{aggregate}\left(\left\{h_j^{l}, j\in\mathcal{N}(i) + \right\}\right)\right) + + Parameters + ---------- + apply_func : callable activation function/layer or None + If not None, apply this function to the updated node feature, + the :math:`f_\Theta` in the formula. + aggregator_type : str + Aggregator type to use (``sum``, ``max`` or ``mean``). + init_eps : float, optional + Initial :math:`\epsilon` value, default: ``0``. + learn_eps : bool, optional + If True, :math:`\epsilon` will be a learnable parameter. + """ + def __init__(self, + apply_func, + aggregator_type, + init_eps=0, + learn_eps=False): + super(GINConv, self).__init__() + if aggregator_type == 'sum': + self._reducer = fn.sum + elif aggregator_type == 'max': + self._reducer = fn.max + elif aggregator_type == 'mean': + self._reducer = fn.mean + else: + raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type)) + + with self.name_scope(): + self.apply_func = apply_func + self.eps = self.params.get('eps', + shape=(1,), + grad_req='write' if learn_eps else 'null', + init=mx.init.Constant(init_eps)) + + def forward(self, graph, feat): + r"""Compute Graph Isomorphism Network layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : torch.Tensor + The input feature of shape :math:`(N, D)` where :math:`D` + could be any positive integer, :math:`N` is the number + of nodes. If ``apply_func`` is not None, :math:`D` should + fit the input dimensionality requirement of ``apply_func``. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, D_{out})` where + :math:`D_{out}` is the output dimensionality of ``apply_func``. + If ``apply_func`` is None, :math:`D_{out}` should be the same + as input dimensionality. + """ + graph = graph.local_var() + graph.ndata['h'] = feat + graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) + rst = (1 + self.eps.data(feat.context)) * feat + graph.ndata['neigh'] + if self.apply_func is not None: + rst = self.apply_func(rst) + return rst diff --git a/python/dgl/nn/mxnet/conv/gmmconv.py b/python/dgl/nn/mxnet/conv/gmmconv.py new file mode 100644 index 000000000000..17af044003ce --- /dev/null +++ b/python/dgl/nn/mxnet/conv/gmmconv.py @@ -0,0 +1,128 @@ +"""Torch Module for GMM Conv""" +# pylint: disable= no-member, arguments-differ, invalid-name +import math +import mxnet as mx +from mxnet import nd +from mxnet.gluon import nn +from mxnet.gluon.contrib.nn import Identity + +from .... import function as fn + + +class GMMConv(nn.Block): + r"""The Gaussian Mixture Model Convolution layer from `Geometric Deep + Learning on Graphs and Manifolds using Mixture Model CNNs + `__. + + .. math:: + h_i^{l+1} & = \mathrm{aggregate}\left(\left\{\frac{1}{K} + \sum_{k}^{K} w_k(u_{ij}), \forall j\in \mathcal{N}(i)\right\}\right) + + w_k(u) & = \exp\left(-\frac{1}{2}(u-\mu_k)^T \Sigma_k^{-1} (u - \mu_k)\right) + + Parameters + ---------- + in_feats : int + Number of input features. + out_feats : int + Number of output features. + dim : int + Dimensionality of pseudo-coordinte. + n_kernels : int + Number of kernels :math:`K`. + aggregator_type : str + Aggregator type (``sum``, ``mean``, ``max``). Default: ``sum``. + residual : bool + If True, use residual connection inside this layer. Default: ``False``. + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + """ + def __init__(self, + in_feats, + out_feats, + dim, + n_kernels, + aggregator_type='sum', + residual=False, + bias=True): + super(GMMConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._dim = dim + self._n_kernels = n_kernels + if aggregator_type == 'sum': + self._reducer = fn.sum + elif aggregator_type == 'mean': + self._reducer = fn.mean + elif aggregator_type == 'max': + self._reducer = fn.max + else: + raise KeyError("Aggregator type {} not recognized.".format(aggregator_type)) + + with self.name_scope(): + self.mu = self.params.get('mu', + shape=(n_kernels, dim), + init=mx.init.Normal(0.1)) + self.inv_sigma = self.params.get('inv_sigma', + shape=(n_kernels, dim), + init=mx.init.Constant(1)) + self.fc = nn.Dense(n_kernels * out_feats, + in_units=in_feats, + use_bias=False, + weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0))) + if residual: + if in_feats != out_feats: + self.res_fc = nn.Dense(out_feats, in_units=in_feats, use_bias=False) + else: + self.res_fc = Identity() + else: + self.res_fc = None + + if bias: + self.bias = self.params.get('bias', + shape=(out_feats,), + init=mx.init.Zero()) + else: + self.bias = None + + def forward(self, graph, feat, pseudo): + """Compute Gaussian Mixture Model Convolution layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : mxnet.NDArray + The input feature of shape :math:`(N, D_{in})` where :math:`N` + is the number of nodes of the graph and :math:`D_{in}` is the + input feature size. + pseudo : mxnet.NDArray + The pseudo coordinate tensor of shape :math:`(E, D_{u})` where + :math:`E` is the number of edges of the graph and :math:`D_{u}` + is the dimensionality of pseudo coordinate. + + Returns + ------- + mxnet.NDArray + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is the output feature size. + """ + graph = graph.local_var() + graph.ndata['h'] = self.fc(feat).reshape(-1, self._n_kernels, self._out_feats) + E = graph.number_of_edges() + # compute gaussian weight + gaussian = -0.5 * ((pseudo.reshape(E, 1, self._dim) - + self.mu.data(feat.context).reshape(1, self._n_kernels, self._dim)) ** 2) + gaussian = gaussian *\ + (self.inv_sigma.data(feat.context).reshape(1, self._n_kernels, self._dim) ** 2) + gaussian = nd.exp(gaussian.sum(axis=-1, keepdims=True)) # (E, K, 1) + graph.edata['w'] = gaussian + graph.update_all(fn.u_mul_e('h', 'w', 'm'), self._reducer('m', 'h')) + rst = graph.ndata['h'].sum(1) + # residual connection + if self.res_fc is not None: + rst = rst + self.res_fc(feat) + # bias + if self.bias is not None: + rst = rst + self.bias.data(feat.context) + return rst diff --git a/python/dgl/nn/mxnet/conv/nnconv.py b/python/dgl/nn/mxnet/conv/nnconv.py new file mode 100644 index 000000000000..d0e3ec5ce2ad --- /dev/null +++ b/python/dgl/nn/mxnet/conv/nnconv.py @@ -0,0 +1,109 @@ +"""MXNet Module for NNConv layer""" +# pylint: disable= no-member, arguments-differ, invalid-name +import mxnet as mx +from mxnet.gluon import nn +from mxnet.gluon.contrib.nn import Identity + +from .... import function as fn + + +class NNConv(nn.Block): + r"""Graph Convolution layer introduced in `Neural Message Passing + for Quantum Chemistry `__. + + .. math:: + h_{i}^{l+1} = h_{i}^{l} + \mathrm{aggregate}\left(\left\{ + f_\Theta (e_{ij}) \cdot h_j^{l}, j\in \mathcal{N}(i) \right\}\right) + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + edge_func : callable activation function/layer + Maps each edge feature to a vector of shape + ``(in_feats * out_feats)`` as weight to compute + messages. + Also is the :math:`f_\Theta` in the formula. + aggregator_type : str + Aggregator type to use (``sum``, ``mean`` or ``max``). + residual : bool, optional + If True, use residual connection. Default: ``False``. + bias : bool, optional + If True, adds a learnable bias to the output. Default: ``True``. + """ + def __init__(self, + in_feats, + out_feats, + edge_func, + aggregator_type, + residual=False, + bias=True): + super(NNConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + if aggregator_type == 'sum': + self.reducer = fn.sum + elif aggregator_type == 'mean': + self.reducer = fn.mean + elif aggregator_type == 'max': + self.reducer = fn.max + else: + raise KeyError('Aggregator type {} not recognized: '.format(aggregator_type)) + self._aggre_type = aggregator_type + + with self.name_scope(): + self.edge_nn = edge_func + if residual: + if in_feats != out_feats: + self.res_fc = nn.Dense(out_feats, in_units=in_feats, use_bias=False, + weight_initializer=mx.init.Xavier()) + else: + self.res_fc = Identity() + else: + self.res_fc = None + + if bias: + self.bias = self.params.get('bias', + shape=(out_feats,), + init=mx.init.Zero()) + else: + self.bias = None + + def forward(self, graph, feat, efeat): + r"""Compute MPNN Graph Convolution layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : mxnet.NDArray + The input feature of shape :math:`(N, D_{in})` where :math:`N` + is the number of nodes of the graph and :math:`D_{in}` is the + input feature size. + efeat : mxnet.NDArray + The edge feature of shape :math:`(N, *)`, should fit the input + shape requirement of ``edge_nn``. + + Returns + ------- + mxnet.NDArray + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is the output feature size. + """ + graph = graph.local_var() + # (n, d_in, 1) + graph.ndata['h'] = feat.expand_dims(-1) + # (n, d_in, d_out) + graph.edata['w'] = self.edge_nn(efeat).reshape(-1, self._in_feats, self._out_feats) + # (n, d_in, d_out) + graph.update_all(fn.u_mul_e('h', 'w', 'm'), self.reducer('m', 'neigh')) + rst = graph.ndata.pop('neigh').sum(axis=1) # (n, d_out) + # residual connection + if self.res_fc is not None: + rst = rst + self.res_fc(feat) + # bias + if self.bias is not None: + rst = rst + self.bias.data(feat.context) + return rst diff --git a/python/dgl/nn/mxnet/conv/sageconv.py b/python/dgl/nn/mxnet/conv/sageconv.py new file mode 100644 index 000000000000..af803a25bcd6 --- /dev/null +++ b/python/dgl/nn/mxnet/conv/sageconv.py @@ -0,0 +1,121 @@ +"""MXNet Module for GraphSAGE layer""" +# pylint: disable= no-member, arguments-differ, invalid-name +import math +import mxnet as mx +from mxnet import nd +from mxnet.gluon import nn + +from .... import function as fn + +class SAGEConv(nn.Block): + r"""GraphSAGE layer from paper `Inductive Representation Learning on + Large Graphs `__. + + .. math:: + h_{\mathcal{N}(i)}^{(l+1)} & = \mathrm{aggregate} + \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right) + + h_{i}^{(l+1)} & = \sigma \left(W \cdot \mathrm{concat} + (h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1} + b) \right) + + h_{i}^{(l+1)} & = \mathrm{norm}(h_{i}^{l}) + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + feat_drop : float + Dropout rate on features, default: ``0``. + aggregator_type : str + Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``). + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + norm : callable activation function/layer or None, optional + If not None, applies normalization to the updated node features. + activation : callable activation function/layer or None, optional + If not None, applies an activation function to the updated node features. + Default: ``None``. + """ + def __init__(self, + in_feats, + out_feats, + aggregator_type='mean', + feat_drop=0., + bias=True, + norm=None, + activation=None): + super(SAGEConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._aggre_type = aggregator_type + with self.name_scope(): + self.norm = norm + self.feat_drop = nn.Dropout(feat_drop) + self.activation = activation + if aggregator_type == 'pool': + self.fc_pool = nn.Dense(in_feats, use_bias=bias, + weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)), + in_units=in_feats) + if aggregator_type == 'lstm': + raise NotImplementedError + if aggregator_type != 'gcn': + self.fc_self = nn.Dense(out_feats, use_bias=bias, + weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)), + in_units=in_feats) + self.fc_neigh = nn.Dense(out_feats, use_bias=bias, + weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)), + in_units=in_feats) + + def forward(self, graph, feat): + r"""Compute GraphSAGE layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : mxnet.NDArray + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + + Returns + ------- + mxnet.NDArray + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + """ + graph = graph.local_var() + feat = self.feat_drop(feat) + h_self = feat + if self._aggre_type == 'mean': + graph.ndata['h'] = feat + graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh')) + h_neigh = graph.ndata['neigh'] + elif self._aggre_type == 'gcn': + graph.ndata['h'] = feat + graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh')) + # divide in degrees + degs = graph.in_degrees().astype(feat.dtype) + degs = degs.as_in_context(feat.context) + h_neigh = (graph.ndata['neigh'] + graph.ndata['h']) / (degs.expand_dims(-1) + 1) + elif self._aggre_type == 'pool': + graph.ndata['h'] = nd.relu(self.fc_pool(feat)) + graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh')) + h_neigh = graph.ndata['neigh'] + elif self._aggre_type == 'lstm': + raise NotImplementedError + else: + raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type)) + + if self._aggre_type == 'gcn': + rst = self.fc_neigh(h_neigh) + else: + rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) + # activation + if self.activation is not None: + rst = self.activation(rst) + # normalization + if self.norm is not None: + rst = self.norm(rst) + return rst diff --git a/python/dgl/nn/mxnet/conv/sgconv.py b/python/dgl/nn/mxnet/conv/sgconv.py new file mode 100644 index 000000000000..636274097f94 --- /dev/null +++ b/python/dgl/nn/mxnet/conv/sgconv.py @@ -0,0 +1,100 @@ +"""MXNet Module for Simplifying Graph Convolution layer""" +# pylint: disable= no-member, arguments-differ, invalid-name + +import mxnet as mx +from mxnet import nd +from mxnet.gluon import nn + +from .... import function as fn + + +class SGConv(nn.Block): + r"""Simplifying Graph Convolution layer from paper `Simplifying Graph + Convolutional Networks `__. + + .. math:: + H^{l+1} = (\hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2})^K H^{l} \Theta^{l} + + Parameters + ---------- + in_feats : int + Number of input features. + out_feats : int + Number of output features. + k : int + Number of hops :math:`K`. Defaults:``1``. + cached : bool + If True, the module would cache + + .. math:: + (\hat{D}^{-\frac{1}{2}}\hat{A}\hat{D}^{-\frac{1}{2}})^K X\Theta + + at the first forward call. This parameter should only be set to + ``True`` in Transductive Learning setting. + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + norm : callable activation function/layer or None, optional + If not None, applies normalization to the updated node features. + """ + def __init__(self, + in_feats, + out_feats, + k=1, + cached=False, + bias=True, + norm=None): + super(SGConv, self).__init__() + self._cached = cached + self._cached_h = None + self._k = k + with self.name_scope(): + self.norm = norm + self.fc = nn.Dense(out_feats, in_units=in_feats, use_bias=bias, + weight_initializer=mx.init.Xavier()) + + def forward(self, graph, feat): + r"""Compute Simplifying Graph Convolution layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : mxnet.NDArray + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + + Returns + ------- + mxnet.NDArray + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + + Notes + ----- + If ``cache`` is se to True, ``feat`` and ``graph`` should not change during + training, or you will get wrong results. + """ + graph = graph.local_var() + if self._cached_h is not None: + feat = self._cached_h + else: + # compute normalization + degs = nd.clip(graph.in_degrees().astype(feat.dtype), 1, float('inf')) + norm = nd.power(degs, -0.5).expand_dims(1) + norm = norm.as_in_context(feat.context) + # compute (D^-1 A D)^k X + for _ in range(self._k): + feat = feat * norm + graph.ndata['h'] = feat + graph.update_all(fn.copy_u('h', 'm'), + fn.sum('m', 'h')) + feat = graph.ndata.pop('h') + feat = feat * norm + + if self.norm is not None: + feat = self.norm(feat) + + # cache feature + if self._cached: + self._cached_h = feat + return self.fc(feat) diff --git a/python/dgl/nn/mxnet/utils.py b/python/dgl/nn/mxnet/utils.py index f9446a97b872..20381ce34e03 100644 --- a/python/dgl/nn/mxnet/utils.py +++ b/python/dgl/nn/mxnet/utils.py @@ -30,14 +30,14 @@ def matmul_maybe_select(A, B): Parameters ---------- - A : torch.Tensor + A : mxnet.NDArray lhs tensor - B : torch.Tensor + B : mxnet.NDArray rhs tensor Returns ------- - C : torch.Tensor + C : mxnet.NDArray result tensor """ if A.dtype in (np.int32, np.int64) and len(A.shape) == 1: @@ -67,16 +67,16 @@ def bmm_maybe_select(A, B, index): Parameters ---------- - A : torch.Tensor + A : mxnet.NDArray lhs tensor - B : torch.Tensor + B : mxnet.NDArray rhs tensor - index : torch.Tensor + index : mxnet.NDArray index tensor Returns ------- - C : torch.Tensor + C : mxnet.NDArray return tensor """ if A.dtype in (np.int32, np.int64) and len(A.shape) == 1: @@ -84,3 +84,24 @@ def bmm_maybe_select(A, B, index): else: BB = nd.take(B, index, axis=0) return nd.batch_dot(A.expand_dims(1), BB).squeeze() + +def normalize(x, p=2, axis=1, eps=1e-12): + r"""Performs :math:`L_p` normalization of inputs over specified dimension. + + For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each + :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as + + .. math:: + v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. + + With the default arguments it uses the Euclidean norm over vectors along dimension + :math:`1` for normalization. + + Args: + x: input ndarray of any shape + ord (float): the exponent value in the norm formulation. Default: 2 + dim (int): the dimension to reduce. Default: 1 + eps (float): small value to avoid division by zero. Default: 1e-12 + """ + denom = nd.clip(nd.norm(x, ord=p, axis=axis, keepdims=True), eps, float('inf')) + return x / denom diff --git a/python/dgl/nn/pytorch/conv/agnnconv.py b/python/dgl/nn/pytorch/conv/agnnconv.py index 22d15a1b8d44..62625e6f8bd9 100644 --- a/python/dgl/nn/pytorch/conv/agnnconv.py +++ b/python/dgl/nn/pytorch/conv/agnnconv.py @@ -58,8 +58,8 @@ def forward(self, graph, feat): graph.ndata['h'] = feat graph.ndata['norm_h'] = F.normalize(feat, p=2, dim=-1) # compute cosine distance - graph.apply_edges(fn.u_mul_v('norm_h', 'norm_h', 'cos')) - cos = graph.edata.pop('cos').sum(-1) + graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos')) + cos = graph.edata.pop('cos') e = self.beta * cos graph.edata['p'] = edge_softmax(graph, e) graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h')) diff --git a/python/dgl/nn/pytorch/conv/appnpconv.py b/python/dgl/nn/pytorch/conv/appnpconv.py index fd2ab8a97795..f06a9c50f2ea 100644 --- a/python/dgl/nn/pytorch/conv/appnpconv.py +++ b/python/dgl/nn/pytorch/conv/appnpconv.py @@ -4,7 +4,6 @@ from torch import nn from .... import function as fn -from ..utils import Identity class APPNPConv(nn.Module): @@ -35,7 +34,7 @@ def __init__(self, super(APPNPConv, self).__init__() self._k = k self._alpha = alpha - self.edge_drop = nn.Dropout(edge_drop) if edge_drop > 0 else Identity() + self.edge_drop = nn.Dropout(edge_drop) def forward(self, graph, feat): r"""Compute APPNP layer. @@ -56,10 +55,11 @@ def forward(self, graph, feat): """ graph = graph.local_var() norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5) - norm = norm.unsqueeze(-1).to(feat.device) + shp = norm.shape + (1,) * (feat.dim() - 1) + norm = th.reshape(norm, shp).to(feat.device) feat_0 = feat for _ in range(self._k): - # normalization by src + # normalization by src node feat = feat * norm graph.ndata['h'] = feat graph.edata['w'] = self.edge_drop( @@ -67,7 +67,7 @@ def forward(self, graph, feat): graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h')) feat = graph.ndata.pop('h') - # normalization by dst + # normalization by dst node feat = feat * norm feat = (1 - self._alpha) * feat + self._alpha * feat_0 return feat diff --git a/python/dgl/nn/pytorch/conv/chebconv.py b/python/dgl/nn/pytorch/conv/chebconv.py index a019ac5439e8..4d14b23a21dc 100644 --- a/python/dgl/nn/pytorch/conv/chebconv.py +++ b/python/dgl/nn/pytorch/conv/chebconv.py @@ -93,7 +93,7 @@ def forward(self, graph, feat, lambda_max=None): lambda_max = laplacian_lambda_max(graph) if isinstance(lambda_max, list): lambda_max = th.Tensor(lambda_max).to(feat.device) - if lambda_max.dim() < 1: + if lambda_max.dim() == 1: lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1) # broadcast from (B, 1) to (N, 1) lambda_max = broadcast_nodes(graph, lambda_max) diff --git a/python/dgl/nn/pytorch/conv/densesageconv.py b/python/dgl/nn/pytorch/conv/densesageconv.py index cdfe7f96d08e..40619161ac86 100644 --- a/python/dgl/nn/pytorch/conv/densesageconv.py +++ b/python/dgl/nn/pytorch/conv/densesageconv.py @@ -73,7 +73,7 @@ def forward(self, adj, feat): """ adj = adj.float().to(feat.device) feat = self.feat_drop(feat) - in_degrees = adj.sum(dim=1).unsqueeze(-1) + in_degrees = adj.sum(dim=1, keepdim=True) h_neigh = (adj @ feat + feat) / (in_degrees + 1) rst = self.fc(h_neigh) # activation diff --git a/python/dgl/nn/pytorch/conv/edgeconv.py b/python/dgl/nn/pytorch/conv/edgeconv.py index d0ae58cf31ed..e8a1184dc617 100644 --- a/python/dgl/nn/pytorch/conv/edgeconv.py +++ b/python/dgl/nn/pytorch/conv/edgeconv.py @@ -12,7 +12,6 @@ class EdgeConv(nn.Module): `__". Can be described as follows: .. math:: - x_i^{(l+1)} = \max_{j \in \mathcal{N}(i)} \mathrm{ReLU}( \Theta \cdot (x_j^{(l)} - x_i^{(l)}) + \Phi \cdot x_i^{(l)}) @@ -27,7 +26,10 @@ class EdgeConv(nn.Module): batch_norm : bool Whether to include batch normalization on messages. """ - def __init__(self, in_feat, out_feat, batch_norm=False): + def __init__(self, + in_feat, + out_feat, + batch_norm=False): super(EdgeConv, self).__init__() self.batch_norm = batch_norm diff --git a/python/dgl/nn/pytorch/conv/gatedgraphconv.py b/python/dgl/nn/pytorch/conv/gatedgraphconv.py index 7a974118e4d5..0f836d56f15a 100644 --- a/python/dgl/nn/pytorch/conv/gatedgraphconv.py +++ b/python/dgl/nn/pytorch/conv/gatedgraphconv.py @@ -1,5 +1,5 @@ """Torch Module for Gated Graph Convolution layer""" -# pylint: disable= no-member, arguments-differ, invalid-name +# pylint: disable= no-member, arguments-differ, invalid-name, cell-var-from-loop import torch as th from torch import nn from torch.nn import init @@ -41,7 +41,10 @@ def __init__(self, self._in_feats = in_feats self._out_feats = out_feats self._n_steps = n_steps - self.edge_embed = nn.Embedding(n_etypes, out_feats * out_feats) + self._n_etypes = n_etypes + self.linears = nn.ModuleList( + [nn.Linear(out_feats, out_feats) for _ in range(n_etypes)] + ) self.gru = nn.GRUCell(out_feats, out_feats, bias=bias) self.reset_parameters() @@ -49,7 +52,9 @@ def reset_parameters(self): """Reinitialize learnable parameters.""" gain = init.calculate_gain('relu') self.gru.reset_parameters() - init.xavier_normal_(self.edge_embed.weight, gain=gain) + for linear in self.linears: + init.xavier_normal_(linear.weight, gain=gain) + init.zeros_(linear.bias) def forward(self, graph, feat, etypes): """Compute Gated Graph Convolution layer. @@ -75,13 +80,17 @@ def forward(self, graph, feat, etypes): graph = graph.local_var() zero_pad = feat.new_zeros((feat.shape[0], self._out_feats - feat.shape[1])) feat = th.cat([feat, zero_pad], -1) - # NOTE(zihao): there is still room to optimize, we may do kernel fusion - # for such operations in the future. - graph.edata['w'] = self.edge_embed(etypes).view(-1, self._out_feats, self._out_feats) + for _ in range(self._n_steps): - graph.ndata['h'] = feat.unsqueeze(-1) # (N, D, 1) - graph.update_all(fn.u_mul_e('h', 'w', 'm'), - fn.sum('m', 'a')) - a = graph.ndata.pop('a').sum(dim=1) # (N, D) + graph.ndata['h'] = feat + for i in range(self._n_etypes): + eids = (etypes == i).nonzero().view(-1) + if len(eids) > 0: + graph.apply_edges( + lambda edges: {'W_e*h': self.linears[i](edges.src['h'])}, + eids + ) + graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a')) + a = graph.ndata.pop('a') # (N, D) feat = self.gru(a, feat) return feat diff --git a/python/dgl/nn/pytorch/conv/gmmconv.py b/python/dgl/nn/pytorch/conv/gmmconv.py index f998df386d73..1a099e5c6963 100644 --- a/python/dgl/nn/pytorch/conv/gmmconv.py +++ b/python/dgl/nn/pytorch/conv/gmmconv.py @@ -32,7 +32,7 @@ class GMMConv(nn.Module): aggregator_type : str Aggregator type (``sum``, ``mean``, ``max``). residual : bool - If True, use residual connection inside this layer. + If True, use residual connection inside this layer. Default: ``False``. bias : bool If True, adds a learnable bias to the output. Default: ``True``. """ @@ -41,8 +41,8 @@ def __init__(self, out_feats, dim, n_kernels, - aggregator_type, - residual=True, + aggregator_type='sum', + residual=False, bias=True): super(GMMConv, self).__init__() self._in_feats = in_feats @@ -82,7 +82,7 @@ def reset_parameters(self): if isinstance(self.res_fc, nn.Linear): init.xavier_normal_(self.res_fc.weight, gain=gain) init.normal_(self.mu.data, 0, 0.1) - init.normal_(self.inv_sigma.data, 1, 0.1) + init.constant_(self.inv_sigma.data, 1) if self.bias is not None: init.zeros_(self.bias.data) diff --git a/python/dgl/nn/pytorch/conv/sgconv.py b/python/dgl/nn/pytorch/conv/sgconv.py index 80b081e1c19a..bd0a01f4b2bc 100644 --- a/python/dgl/nn/pytorch/conv/sgconv.py +++ b/python/dgl/nn/pytorch/conv/sgconv.py @@ -47,6 +47,13 @@ def __init__(self, self._cached_h = None self._k = k self.norm = norm + self.reset_parameters() + + def reset_parameters(self): + """Reinitialize learnable parameters.""" + nn.init.xavier_uniform_(self.fc.weight) + if self.fc.bias is not None: + nn.init.zeros_(self.fc.bias) def forward(self, graph, feat): r"""Compute Simplifying Graph Convolution layer. @@ -77,9 +84,8 @@ def forward(self, graph, feat): # compute normalization degs = graph.in_degrees().float().clamp(min=1) norm = th.pow(degs, -0.5) - norm[th.isinf(norm)] = 0 norm = norm.to(feat.device).unsqueeze(1) - # compute (D^-1 A D) X + # compute (D^-1 A^k D)^k X for _ in range(self._k): feat = feat * norm graph.ndata['h'] = feat diff --git a/tests/mxnet/test_nn.py b/tests/mxnet/test_nn.py index 4cfb2747a94f..c98c3c27af2a 100644 --- a/tests/mxnet/test_nn.py +++ b/tests/mxnet/test_nn.py @@ -112,6 +112,210 @@ def test_tagconv(): h1 = conv(g, h0) assert h1.shape[-1] == 2 +def test_gat_conv(): + g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) + ctx = F.ctx() + + gat = nn.GATConv(10, 20, 5) # n_heads = 5 + gat.initialize(ctx=ctx) + print(gat) + + # test#1: basic + h0 = F.randn((20, 10)) + h1 = gat(g, h0) + assert h1.shape == (20, 5, 20) + +def test_sage_conv(): + g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) + ctx = F.ctx() + + graphsage = nn.SAGEConv(10, 20) + graphsage.initialize(ctx=ctx) + print(graphsage) + + # test#1: basic + h0 = F.randn((20, 10)) + h1 = graphsage(g, h0) + assert h1.shape == (20, 20) + +def test_gg_conv(): + g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) + ctx = F.ctx() + + gg_conv = nn.GatedGraphConv(10, 20, 3, 4) # n_step = 3, n_etypes = 4 + gg_conv.initialize(ctx=ctx) + print(gg_conv) + + # test#1: basic + h0 = F.randn((20, 10)) + etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx) + h1 = gg_conv(g, h0, etypes) + assert h1.shape == (20, 20) + +def test_cheb_conv(): + g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) + ctx = F.ctx() + + cheb = nn.ChebConv(10, 20, 3) # k = 3 + cheb.initialize(ctx=ctx) + print(cheb) + + # test#1: basic + h0 = F.randn((20, 10)) + h1 = cheb(g, h0) + assert h1.shape == (20, 20) + +def test_agnn_conv(): + g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) + ctx = F.ctx() + + agnn_conv = nn.AGNNConv(0.1, True) + agnn_conv.initialize(ctx=ctx) + print(agnn_conv) + + # test#1: basic + h0 = F.randn((20, 10)) + h1 = agnn_conv(g, h0) + assert h1.shape == (20, 10) + +def test_appnp_conv(): + g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) + ctx = F.ctx() + + appnp_conv = nn.APPNPConv(3, 0.1, 0) + appnp_conv.initialize(ctx=ctx) + print(appnp_conv) + + # test#1: basic + h0 = F.randn((20, 10)) + h1 = appnp_conv(g, h0) + assert h1.shape == (20, 10) + +def test_dense_cheb_conv(): + for k in range(1, 4): + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.3), readonly=True) + adj = g.adjacency_matrix(ctx=ctx).tostype('default') + cheb = nn.ChebConv(5, 2, k) + dense_cheb = nn.DenseChebConv(5, 2, k) + cheb.initialize(ctx=ctx) + dense_cheb.initialize(ctx=ctx) + + for i in range(len(cheb.fc)): + dense_cheb.fc[i].weight.set_data( + cheb.fc[i].weight.data()) + if cheb.bias is not None: + dense_cheb.bias.set_data( + cheb.bias.data()) + + feat = F.randn((100, 5)) + out_cheb = cheb(g, feat, [2.0]) + out_dense_cheb = dense_cheb(adj, feat, 2.0) + assert F.allclose(out_cheb, out_dense_cheb) + +def test_dense_graph_conv(): + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.3), readonly=True) + adj = g.adjacency_matrix(ctx=ctx).tostype('default') + conv = nn.GraphConv(5, 2, norm=False, bias=True) + dense_conv = nn.DenseGraphConv(5, 2, norm=False, bias=True) + conv.initialize(ctx=ctx) + dense_conv.initialize(ctx=ctx) + dense_conv.weight.set_data( + conv.weight.data()) + dense_conv.bias.set_data( + conv.bias.data()) + feat = F.randn((100, 5)) + + out_conv = conv(g, feat) + out_dense_conv = dense_conv(adj, feat) + assert F.allclose(out_conv, out_dense_conv) + +def test_dense_sage_conv(): + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + adj = g.adjacency_matrix(ctx=ctx).tostype('default') + sage = nn.SAGEConv(5, 2, 'gcn') + dense_sage = nn.DenseSAGEConv(5, 2) + sage.initialize(ctx=ctx) + dense_sage.initialize(ctx=ctx) + dense_sage.fc.weight.set_data( + sage.fc_neigh.weight.data()) + dense_sage.fc.bias.set_data( + sage.fc_neigh.bias.data()) + feat = F.randn((100, 5)) + + out_sage = sage(g, feat) + out_dense_sage = dense_sage(adj, feat) + assert F.allclose(out_sage, out_dense_sage) + +def test_edge_conv(): + g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) + ctx = F.ctx() + + edge_conv = nn.EdgeConv(5, 2) + edge_conv.initialize(ctx=ctx) + print(edge_conv) + + # test #1: basic + h0 = F.randn((g.number_of_nodes(), 5)) + h1 = edge_conv(g, h0) + assert h1.shape == (g.number_of_nodes(), 2) + +def test_gin_conv(): + g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) + ctx = F.ctx() + + gin_conv = nn.GINConv(lambda x: x, 'mean', 0.1) + gin_conv.initialize(ctx=ctx) + print(gin_conv) + + # test #1: basic + h0 = F.randn((g.number_of_nodes(), 5)) + h1 = gin_conv(g, h0) + assert h1.shape == (g.number_of_nodes(), 5) + +def test_gmm_conv(): + g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) + ctx = F.ctx() + + gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max') + gmm_conv.initialize(ctx=ctx) + print(gmm_conv) + + # test #1: basic + h0 = F.randn((g.number_of_nodes(), 5)) + pseudo = F.randn((g.number_of_edges(), 5)) + h1 = gmm_conv(g, h0, pseudo) + assert h1.shape == (g.number_of_nodes(), 2) + +def test_nn_conv(): + g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) + ctx = F.ctx() + + nn_conv = nn.NNConv(5, 2, gluon.nn.Embedding(3, 5 * 2), 'max') + nn_conv.initialize(ctx=ctx) + print(nn_conv) + + # test #1: basic + h0 = F.randn((g.number_of_nodes(), 5)) + etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx) + h1 = nn_conv(g, h0, etypes) + assert h1.shape == (g.number_of_nodes(), 2) + +def test_sg_conv(): + g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) + ctx = F.ctx() + + sgc = nn.SGConv(5, 2, 2) + sgc.initialize(ctx=ctx) + print(sgc) + + # test #1: basic + h0 = F.randn((g.number_of_nodes(), 5)) + h1 = sgc(g, h0) + assert h1.shape == (g.number_of_nodes(), 2) + def test_set2set(): g = dgl.DGLGraph(nx.path_graph(10)) ctx = F.ctx() @@ -306,6 +510,20 @@ def test_rgcn(): if __name__ == '__main__': test_graph_conv() + test_gat_conv() + test_sage_conv() + test_gg_conv() + test_cheb_conv() + test_agnn_conv() + test_appnp_conv() + test_dense_cheb_conv() + test_dense_graph_conv() + test_dense_sage_conv() + test_edge_conv() + test_gin_conv() + test_gmm_conv() + test_nn_conv() + test_sg_conv() test_edge_softmax() test_partial_edge_softmax() test_set2set() diff --git a/tests/pytorch/test_nn.py b/tests/pytorch/test_nn.py index 8a13bbb6986a..c79c1a4184c1 100644 --- a/tests/pytorch/test_nn.py +++ b/tests/pytorch/test_nn.py @@ -403,7 +403,6 @@ def test_gat_conv(): if F.gpu_ctx(): gat = gat.to(ctx) - feat = feat.to(ctx) h = gat(g, feat) assert h.shape[-1] == 2 and h.shape[-2] == 4 @@ -417,7 +416,6 @@ def test_sage_conv(): if F.gpu_ctx(): sage = sage.to(ctx) - feat = feat.to(ctx) h = sage(g, feat) assert h.shape[-1] == 10 @@ -431,7 +429,6 @@ def test_sgc_conv(): if F.gpu_ctx(): sgc = sgc.to(ctx) - feat = feat.to(ctx) h = sgc(g, feat) assert h.shape[-1] == 10 @@ -455,7 +452,6 @@ def test_appnp_conv(): if F.gpu_ctx(): appnp = appnp.to(ctx) - feat = feat.to(ctx) h = appnp(g, feat) assert h.shape[-1] == 5 @@ -472,7 +468,6 @@ def test_gin_conv(): if F.gpu_ctx(): gin = gin.to(ctx) - feat = feat.to(ctx) h = gin(g, feat) assert h.shape[-1] == 12 @@ -485,7 +480,6 @@ def test_agnn_conv(): if F.gpu_ctx(): agnn = agnn.to(ctx) - feat = feat.to(ctx) h = agnn(g, feat) assert h.shape[-1] == 5 @@ -499,7 +493,6 @@ def test_gated_graph_conv(): if F.gpu_ctx(): ggconv = ggconv.to(ctx) - feat = feat.to(ctx) etypes = etypes.to(ctx) h = ggconv(g, feat, etypes) @@ -516,8 +509,6 @@ def test_nn_conv(): if F.gpu_ctx(): nnconv = nnconv.to(ctx) - feat = feat.to(ctx) - efeat = efeat.to(ctx) h = nnconv(g, feat, efeat) # currently we only do shape check @@ -532,8 +523,6 @@ def test_gmm_conv(): if F.gpu_ctx(): gmmconv = gmmconv.to(ctx) - feat = feat.to(ctx) - pseudo = pseudo.to(ctx) h = gmmconv(g, feat, pseudo) # currently we only do shape check @@ -551,7 +540,6 @@ def test_dense_graph_conv(): if F.gpu_ctx(): conv = conv.to(ctx) dense_conv = dense_conv.to(ctx) - feat = feat.to(ctx) out_conv = conv(g, feat) out_dense_conv = dense_conv(adj, feat) @@ -561,7 +549,7 @@ def test_dense_sage_conv(): ctx = F.ctx() g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) adj = g.adjacency_matrix(ctx=ctx).to_dense() - sage = nn.SAGEConv(5, 2, 'gcn',) + sage = nn.SAGEConv(5, 2, 'gcn') dense_sage = nn.DenseSAGEConv(5, 2) dense_sage.fc.weight.data = sage.fc_neigh.weight.data dense_sage.fc.bias.data = sage.fc_neigh.bias.data @@ -569,7 +557,6 @@ def test_dense_sage_conv(): if F.gpu_ctx(): sage = sage.to(ctx) dense_sage = dense_sage.to(ctx) - feat = feat.to(ctx) out_sage = sage(g, feat) out_dense_sage = dense_sage(adj, feat) @@ -590,7 +577,6 @@ def test_dense_cheb_conv(): if F.gpu_ctx(): cheb = cheb.to(ctx) dense_cheb = dense_cheb.to(ctx) - feat = feat.to(ctx) out_cheb = cheb(g, feat, [2.0]) out_dense_cheb = dense_cheb(adj, feat, 2.0) diff --git a/third_party/dmlc-core b/third_party/dmlc-core index 7ce90a342b0b..0f3ddbc7240e 160000 --- a/third_party/dmlc-core +++ b/third_party/dmlc-core @@ -1 +1 @@ -Subproject commit 7ce90a342b0bda9b7f88e707a326496324d60efd +Subproject commit 0f3ddbc7240efa05bfffd5bca808ec262ce3630e