-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] add model example GCN-based Anti-Spam (dmlc#3145)
* add model example GCN-based Anti-Spam * update example index * add usage info * improvements as per comments * fix image invisiable problem * add image file Co-authored-by: zhjwy9343 <[email protected]>
- Loading branch information
Showing
8 changed files
with
868 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# DGL Implementation of the GAS Paper | ||
|
||
This DGL example implements the Heterogeneous GCN part of the model proposed in the paper [Spam Review Detection with Graph Convolutional Networks](https://arxiv.org/abs/1908.10679). | ||
|
||
Example implementor | ||
---------------------- | ||
This example was implemented by [Kay Liu](https://github.com/kayzliu) during his SDE intern work at the AWS Shanghai AI Lab. | ||
|
||
Dependencies | ||
---------------------- | ||
- Python 3.7.10 | ||
- PyTorch 1.8.1 | ||
- dgl 0.7.0 | ||
- scikit-learn 0.23.2 | ||
|
||
Dataset | ||
--------------------------------------- | ||
The datasets used for edge classification are variants of DGL's built-in [fake news datasets](https://github.com/dmlc/dgl/blob/master/python/dgl/data/fakenews.py). The converting process from tree-structured graph to bipartite graph is shown in the figure. | ||
|
||
 | ||
|
||
**NOTE**: Same as the original fake news dataset, this variant is for academic use only as well, and commercial use is prohibited. The statistics are summarized as followings: | ||
|
||
**Politifact** | ||
|
||
- Nodes: | ||
- user (u): 276,277 | ||
- news (v): 581 | ||
- Edges: | ||
- forward: 399,016 | ||
- backward: 399,016 | ||
- Number of Classes: 2 | ||
- Node feature size: 300 | ||
- Edge feature size: 300 | ||
|
||
**Gossicop** | ||
|
||
- Nodes: | ||
- user (u): 565,660 | ||
- news (v): 10,333 | ||
- Edges: | ||
- forward: 1,254,469 | ||
- backward: 1,254,469 | ||
- Number of Classes: 2 | ||
- Node feature size: 300 | ||
- Edge feature size: 300 | ||
|
||
How to run | ||
-------------------------------- | ||
In the gas folder, run | ||
``` | ||
python main.py | ||
``` | ||
|
||
If want to use a GPU, run | ||
``` | ||
python main.py --gpu 0 | ||
``` | ||
|
||
If the mini-batch training is required to run on a GPU, run | ||
``` | ||
python main_sampling.py --gpu 0 | ||
``` | ||
|
||
Performance | ||
------------------------- | ||
|Dataset | Xianyu Graph (paper reported) | Fake News Politifact | Fake News Gossipcop | | ||
| -------------------- | ----------------- | -------------------- | ------------------- | | ||
| F1 | 0.8143 | 0.9994 | 0.9942 | | ||
| AUC | 0.9860 | 1.0000 | 0.9991 | | ||
| Recall@90% precision | 0.6702 | 0.9999 | 0.9976 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import os | ||
import dgl | ||
import torch as th | ||
import numpy as np | ||
import scipy.io as sio | ||
from dgl.data import DGLBuiltinDataset | ||
from dgl.data.utils import save_graphs, load_graphs, _get_dgl_url | ||
|
||
|
||
class GASDataset(DGLBuiltinDataset): | ||
file_urls = { | ||
'pol': 'dataset/GASPOL.zip', | ||
'gos': 'dataset/GASGOS.zip' | ||
} | ||
|
||
def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1): | ||
assert name in ['gos', 'pol'], "Only supports 'gos' or 'pol'." | ||
self.seed = random_seed | ||
self.train_size = train_size | ||
self.val_size = val_size | ||
url = _get_dgl_url(self.file_urls[name]) | ||
super(GASDataset, self).__init__(name=name, | ||
url=url, | ||
raw_dir=raw_dir) | ||
|
||
def process(self): | ||
"""process raw data to graph, labels and masks""" | ||
data = sio.loadmat(os.path.join(self.raw_path, f'{self.name}_retweet_graph.mat')) | ||
|
||
adj = data['graph'].tocoo() | ||
num_edges = len(adj.row) | ||
row, col = adj.row[:int(num_edges/2)], adj.col[:int(num_edges/2)] | ||
|
||
graph = dgl.graph((np.concatenate((row, col)), np.concatenate((col, row)))) | ||
news_labels = data['label'].squeeze() | ||
num_news = len(news_labels) | ||
|
||
node_feature = np.load(os.path.join(self.raw_path, f'{self.name}_node_feature.npy')) | ||
edge_feature = np.load(os.path.join(self.raw_path, f'{self.name}_edge_feature.npy'))[:int(num_edges/2)] | ||
|
||
graph.ndata['feat'] = th.tensor(node_feature) | ||
graph.edata['feat'] = th.tensor(np.tile(edge_feature, (2, 1))) | ||
pos_news = news_labels.nonzero()[0] | ||
|
||
edge_labels = th.zeros(num_edges) | ||
edge_labels[graph.in_edges(pos_news, form='eid')] = 1 | ||
edge_labels[graph.out_edges(pos_news, form='eid')] = 1 | ||
graph.edata['label'] = edge_labels | ||
|
||
ntypes = th.ones(graph.num_nodes(), dtype=int) | ||
etypes = th.ones(graph.num_edges(), dtype=int) | ||
|
||
ntypes[graph.nodes() < num_news] = 0 | ||
etypes[:int(num_edges/2)] = 0 | ||
|
||
graph.ndata['_TYPE'] = ntypes | ||
graph.edata['_TYPE'] = etypes | ||
|
||
hg = dgl.to_heterogeneous(graph, ['v', 'u'], ['forward', 'backward']) | ||
self._random_split(hg, self.seed, self.train_size, self.val_size) | ||
|
||
self.graph = hg | ||
|
||
def save(self): | ||
"""save the graph list and the labels""" | ||
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') | ||
save_graphs(str(graph_path), self.graph) | ||
|
||
def has_cache(self): | ||
""" check whether there are processed data in `self.save_path` """ | ||
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') | ||
return os.path.exists(graph_path) | ||
|
||
def load(self): | ||
"""load processed data from directory `self.save_path`""" | ||
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') | ||
|
||
graph, _ = load_graphs(str(graph_path)) | ||
self.graph = graph[0] | ||
|
||
@property | ||
def num_classes(self): | ||
"""Number of classes for each graph, i.e. number of prediction tasks.""" | ||
return 2 | ||
|
||
def __getitem__(self, idx): | ||
r""" Get graph object | ||
Parameters | ||
---------- | ||
idx : int | ||
Item index | ||
Returns | ||
------- | ||
:class:`dgl.DGLGraph` | ||
""" | ||
assert idx == 0, "This dataset has only one graph" | ||
return self.graph | ||
|
||
def __len__(self): | ||
r"""Number of data examples | ||
Return | ||
------- | ||
int | ||
""" | ||
return len(self.graph) | ||
|
||
def _random_split(self, graph, seed=717, train_size=0.7, val_size=0.1): | ||
"""split the dataset into training set, validation set and testing set""" | ||
|
||
assert 0 <= train_size + val_size <= 1, \ | ||
"The sum of valid training set size and validation set size " \ | ||
"must between 0 and 1 (inclusive)." | ||
|
||
num_edges = graph.num_edges(etype='forward') | ||
index = np.arange(num_edges) | ||
|
||
index = np.random.RandomState(seed).permutation(index) | ||
train_idx = index[:int(train_size * num_edges)] | ||
val_idx = index[num_edges - int(val_size * num_edges):] | ||
test_idx = index[int(train_size * num_edges):num_edges - int(val_size * num_edges)] | ||
train_mask = np.zeros(num_edges, dtype=np.bool) | ||
val_mask = np.zeros(num_edges, dtype=np.bool) | ||
test_mask = np.zeros(num_edges, dtype=np.bool) | ||
train_mask[train_idx] = True | ||
val_mask[val_idx] = True | ||
test_mask[test_idx] = True | ||
graph.edges['forward'].data['train_mask'] = th.tensor(train_mask) | ||
graph.edges['forward'].data['val_mask'] = th.tensor(val_mask) | ||
graph.edges['forward'].data['test_mask'] = th.tensor(test_mask) | ||
graph.edges['backward'].data['train_mask'] = th.tensor(train_mask) | ||
graph.edges['backward'].data['val_mask'] = th.tensor(val_mask) | ||
graph.edges['backward'].data['test_mask'] = th.tensor(test_mask) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import argparse | ||
import torch as th | ||
import torch.optim as optim | ||
import torch.nn.functional as F | ||
from dataloader import GASDataset | ||
from model import GAS | ||
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve | ||
|
||
|
||
def main(args): | ||
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= # | ||
# Load dataset | ||
dataset = GASDataset(args.dataset) | ||
graph = dataset[0] | ||
|
||
# check cuda | ||
if args.gpu >= 0 and th.cuda.is_available(): | ||
device = 'cuda:{}'.format(args.gpu) | ||
else: | ||
device = 'cpu' | ||
|
||
# binary classification | ||
num_classes = dataset.num_classes | ||
|
||
# retrieve labels of ground truth | ||
labels = graph.edges['forward'].data['label'].to(device).long() | ||
|
||
# Extract node features | ||
e_feat = graph.edges['forward'].data['feat'].to(device) | ||
u_feat = graph.nodes['u'].data['feat'].to(device) | ||
v_feat = graph.nodes['v'].data['feat'].to(device) | ||
|
||
# retrieve masks for train/validation/test | ||
train_mask = graph.edges['forward'].data['train_mask'] | ||
val_mask = graph.edges['forward'].data['val_mask'] | ||
test_mask = graph.edges['forward'].data['test_mask'] | ||
|
||
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device) | ||
val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device) | ||
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device) | ||
|
||
graph = graph.to(device) | ||
|
||
# Step 2: Create model =================================================================== # | ||
model = GAS(e_in_dim=e_feat.shape[-1], | ||
u_in_dim=u_feat.shape[-1], | ||
v_in_dim=v_feat.shape[-1], | ||
e_hid_dim=args.e_hid_dim, | ||
u_hid_dim=args.u_hid_dim, | ||
v_hid_dim=args.v_hid_dim, | ||
out_dim=num_classes, | ||
num_layers=args.num_layers, | ||
dropout=args.dropout, | ||
activation=F.relu) | ||
|
||
model = model.to(device) | ||
|
||
# Step 3: Create training components ===================================================== # | ||
loss_fn = th.nn.CrossEntropyLoss() | ||
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | ||
|
||
# Step 4: training epochs =============================================================== # | ||
for epoch in range(args.max_epoch): | ||
# Training and validation using a full graph | ||
model.train() | ||
logits = model(graph, e_feat, u_feat, v_feat) | ||
|
||
# compute loss | ||
tr_loss = loss_fn(logits[train_idx], labels[train_idx]) | ||
tr_f1 = f1_score(labels[train_idx].cpu(), logits[train_idx].argmax(dim=1).cpu()) | ||
tr_auc = roc_auc_score(labels[train_idx].cpu(), logits[train_idx][:, 1].detach().cpu()) | ||
tr_pre, tr_re, _ = precision_recall_curve(labels[train_idx].cpu(), logits[train_idx][:, 1].detach().cpu()) | ||
tr_rap = tr_re[tr_pre > args.precision].max() | ||
|
||
# validation | ||
valid_loss = loss_fn(logits[val_idx], labels[val_idx]) | ||
valid_f1 = f1_score(labels[val_idx].cpu(), logits[val_idx].argmax(dim=1).cpu()) | ||
valid_auc = roc_auc_score(labels[val_idx].cpu(), logits[val_idx][:, 1].detach().cpu()) | ||
valid_pre, valid_re, _ = precision_recall_curve(labels[val_idx].cpu(), logits[val_idx][:, 1].detach().cpu()) | ||
valid_rap = valid_re[valid_pre > args.precision].max() | ||
|
||
# backward | ||
optimizer.zero_grad() | ||
tr_loss.backward() | ||
optimizer.step() | ||
|
||
# Print out performance | ||
print("In epoch {}, Train R@P: {:.4f} | Train F1: {:.4f} | Train AUC: {:.4f} | Train Loss: {:.4f}; " | ||
"Valid R@P: {:.4f} | Valid F1: {:.4f} | Valid AUC: {:.4f} | Valid loss: {:.4f}". | ||
format(epoch, tr_rap, tr_f1, tr_auc, tr_loss.item(), valid_rap, valid_f1, valid_auc, valid_loss.item())) | ||
|
||
# Test after all epoch | ||
model.eval() | ||
|
||
# forward | ||
logits = model(graph, e_feat, u_feat, v_feat) | ||
|
||
# compute loss | ||
test_loss = loss_fn(logits[test_idx], labels[test_idx]) | ||
test_f1 = f1_score(labels[test_idx].cpu(), logits[test_idx].argmax(dim=1).cpu()) | ||
test_auc = roc_auc_score(labels[test_idx].cpu(), logits[test_idx][:, 1].detach().cpu()) | ||
test_pre, test_re, _ = precision_recall_curve(labels[test_idx].cpu(), logits[test_idx][:, 1].detach().cpu()) | ||
test_rap = test_re[test_pre > args.precision].max() | ||
|
||
print("Test R@P: {:.4f} | Test F1: {:.4f} | Test AUC: {:.4f} | Test loss: {:.4f}". | ||
format(test_rap, test_f1, test_auc, test_loss.item())) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='GCN-based Anti-Spam Model') | ||
parser.add_argument("--dataset", type=str, default="pol", help="'pol', or 'gos'") | ||
parser.add_argument("--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU.") | ||
parser.add_argument("--e_hid_dim", type=int, default=128, help="Hidden layer dimension for edges") | ||
parser.add_argument("--u_hid_dim", type=int, default=128, help="Hidden layer dimension for source nodes") | ||
parser.add_argument("--v_hid_dim", type=int, default=128, help="Hidden layer dimension for destination nodes") | ||
parser.add_argument("--num_layers", type=int, default=2, help="Number of GCN layers") | ||
parser.add_argument("--max_epoch", type=int, default=100, help="The max number of epochs. Default: 100") | ||
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate. Default: 1e-3") | ||
parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate. Default: 0.0") | ||
parser.add_argument("--weight_decay", type=float, default=5e-4, help="Weight Decay. Default: 0.0005") | ||
parser.add_argument("--precision", type=float, default=0.9, help="The value p in recall@p precision. Default: 0.9") | ||
|
||
args = parser.parse_args() | ||
print(args) | ||
main(args) |
Oops, something went wrong.