Skip to content

Commit

Permalink
fix bugs in link prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
WwZzz committed Jul 10, 2023
1 parent 3cdff94 commit 09a35a8
Show file tree
Hide file tree
Showing 8 changed files with 765 additions and 22 deletions.
21 changes: 20 additions & 1 deletion flgo/benchmark/citeseer_link_prediction/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,32 @@
import flgo.benchmark
import os
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling
import torch


class GCN(torch.nn.Module):
def __init__(self):
super(GCN, self).__init__()
self.conv1 = GCNConv(3703, 128)
self.conv2 = GCNConv(128, 64)

def forward(self, tdata):
z = self.encode(tdata.x, tdata.edge_index)
if self.training:
neg_edge_index = negative_sampling(
edge_index=torch.cat([tdata.edge_index, tdata.edge_label_index],dim=-1,), num_nodes=tdata.num_nodes,
num_neg_samples=tdata.edge_label_index.size(1),
force_undirected=tdata.is_undirected()
)
outputs = self.decode(z, torch.cat([tdata.edge_label_index, neg_edge_index],dim=-1,)).view(-1)
return outputs, neg_edge_index
else:
outputs = self.decode(z, tdata.edge_label_index).view(-1)
return outputs

def encode(self, x, edge_index):
x = self.conv1(x, edge_index.long())
x = x.relu()
Expand All @@ -22,8 +40,9 @@ def decode_all(self, z):
prob_adj = z @ z.t()
return (prob_adj > 0).nonzero(as_tuple=False).t()

trans = T.NormalizeFeatures()
path = os.path.join(flgo.benchmark.path,'RAW_DATA', 'CITESEER')
dataset = torch_geometric.datasets.Planetoid(path, name='Citeseer')
dataset = torch_geometric.datasets.Planetoid(path, name='Citeseer', transform=trans)
train_data = dataset[0]

def get_model():
Expand Down
19 changes: 18 additions & 1 deletion flgo/benchmark/cora_link_prediction/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,30 @@
import flgo.benchmark
import os
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling

class GCN(torch.nn.Module):
def __init__(self):
super(GCN, self).__init__()
self.conv1 = GCNConv(1433, 128)
self.conv2 = GCNConv(128, 64)

def forward(self, tdata):
z = self.encode(tdata.x, tdata.edge_index)
if self.training:
neg_edge_index = negative_sampling(
edge_index=torch.cat([tdata.edge_index, tdata.edge_label_index],dim=-1,), num_nodes=tdata.num_nodes,
num_neg_samples=tdata.edge_label_index.size(1),
force_undirected=tdata.is_undirected()
)
outputs = self.decode(z, torch.cat([tdata.edge_label_index, neg_edge_index],dim=-1,)).view(-1)
return outputs, neg_edge_index
else:
outputs = self.decode(z, tdata.edge_label_index).view(-1)
return outputs

def encode(self, x, edge_index):
x = self.conv1(x, edge_index.long())
x = x.relu()
Expand All @@ -22,8 +38,9 @@ def decode_all(self, z):
prob_adj = z @ z.t()
return (prob_adj > 0).nonzero(as_tuple=False).t()

transform = T.NormalizeFeatures()
path = os.path.join(flgo.benchmark.path,'RAW_DATA', 'CORA')
dataset = torch_geometric.datasets.Planetoid(path, name='Cora')
dataset = torch_geometric.datasets.Planetoid(path, name='Cora', transform=transform)
train_data = dataset[0]

def get_model():
Expand Down
610 changes: 610 additions & 0 deletions flgo/benchmark/partition.py

Large diffs are not rendered by default.

20 changes: 19 additions & 1 deletion flgo/benchmark/pubmed_link_prediction/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,31 @@
import flgo.benchmark
import os
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling
import torch

class GCN(torch.nn.Module):
def __init__(self):
super(GCN, self).__init__()
self.conv1 = GCNConv(500, 128)
self.conv2 = GCNConv(128, 64)

def forward(self, tdata):
z = self.encode(tdata.x, tdata.edge_index)
if self.training:
neg_edge_index = negative_sampling(
edge_index=torch.cat([tdata.edge_index, tdata.edge_label_index],dim=-1,), num_nodes=tdata.num_nodes,
num_neg_samples=tdata.edge_label_index.size(1),
force_undirected=tdata.is_undirected()
)
outputs = self.decode(z, torch.cat([tdata.edge_label_index, neg_edge_index],dim=-1,)).view(-1)
return outputs, neg_edge_index
else:
outputs = self.decode(z, tdata.edge_label_index).view(-1)
return outputs

def encode(self, x, edge_index):
x = self.conv1(x, edge_index.long())
x = x.relu()
Expand All @@ -22,8 +39,9 @@ def decode_all(self, z):
prob_adj = z @ z.t()
return (prob_adj > 0).nonzero(as_tuple=False).t()

trans = T.NormalizeFeatures()
path = os.path.join(flgo.benchmark.path,'RAW_DATA', 'PUBMED')
dataset = torch_geometric.datasets.Planetoid(path, name='Pubmed')
dataset = torch_geometric.datasets.Planetoid(path, name='Pubmed', transform=trans)
train_data = dataset[0]

def get_model():
Expand Down
33 changes: 16 additions & 17 deletions flgo/benchmark/toolkits/graph/link_prediction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@

class FromDatasetGenerator(flgo.benchmark.base.FromDatasetGenerator):
def prepare_data_for_partition(self):
return torch_geometric.utils.to_networkx(self.train_data, to_undirected=self.train_data.is_undirected(), node_attrs=['x', 'y', 'train_mask', 'val_mask', 'test_mask'])
edge_attrs=['edge_attr'] if hasattr(self.train_data, 'edge_attr') and self.train_data.edge_attr is not None else None
return torch_geometric.utils.to_networkx(self.train_data, to_undirected=self.train_data.is_undirected(), node_attrs=['x', 'y', 'train_mask', 'val_mask', 'test_mask'], edge_attrs=edge_attrs)

class FromDatasetPipe(flgo.benchmark.base.FromDatasetPipe):
def save_task(self, generator):
client_names = self.gen_client_names(len(generator.local_datas))
feddata = {'client_names': client_names,}
partitioner_name = generator.partitioner.__class__.__name__.lower()
feddata['partition_by_edge'] = 'link' in partitioner_name or 'edge' in partitioner_name
for cid in range(len(client_names)):
feddata[client_names[cid]] = {'data': generator.local_datas[cid], }
with open(os.path.join(self.task_path, 'data.json'), 'w') as outf:
Expand Down Expand Up @@ -59,16 +62,20 @@ def load_data(self, running_time_option) -> dict:
neg_sampling_ratio=neg_sampling_ratio, is_undirected=server_val_data.is_undirected(),
num_test=0.2, num_val=0.0)(server_val_data)
task_data = {'server': {'test': server_test_data, 'val': server_val_data}}
G = torch_geometric.utils.to_networkx(train_data, to_undirected=train_data.is_undirected(), node_attrs=['x', 'y', 'train_mask', 'val_mask', 'test_mask'])
edge_attrs=['edge_attr'] if hasattr(train_data, 'edge_attr') and train_data.edge_attr is not None else None
G = torch_geometric.utils.to_networkx(train_data, to_undirected=train_data.is_undirected(), node_attrs=['x', 'y', 'train_mask', 'val_mask', 'test_mask'], edge_attrs=edge_attrs)
num_val = running_time_option['train_holdout']
if num_val>0 and running_time_option['local_test']:
num_test, num_val = 0.5*num_val, 0.5*num_val
else:
num_test = 0.0
all_edges = list(G.edges)
for cid, cname in enumerate(self.feddata['client_names']):
c_dataset = from_networkx(nx.subgraph(G, self.feddata[cname]['data']))
ctrans = T.RandomLinkSplit(neg_sampling_ratio=neg_sampling_ratio, is_undirected=c_dataset.is_undirected(), num_test=num_test, num_val=num_val,
add_negative_train_samples=False, disjoint_train_ratio=disjoint_train_ratio)
if self.feddata['partition_by_edge']:
c_dataset = from_networkx(nx.edge_subgraph(G, [all_edges[eid] for eid in self.feddata[cname]['data']]))
else:
c_dataset = from_networkx(nx.subgraph(G, self.feddata[cname]['data']))
ctrans = T.RandomLinkSplit(neg_sampling_ratio=neg_sampling_ratio, is_undirected=c_dataset.is_undirected(), num_test=num_test, num_val=num_val, add_negative_train_samples=False, disjoint_train_ratio=disjoint_train_ratio)
cdata_train, cdata_val, cdata_test = ctrans(c_dataset)
task_data[cname] = {'train': cdata_train, 'val': cdata_val if len(cdata_val.edge_label)>0 else None, 'test':cdata_test if len(cdata_test.edge_label)>0 else None}
return task_data
Expand Down Expand Up @@ -167,30 +174,22 @@ def __init__(self, device, optimizer_name='sgd'):
self.DataLoader = torch_geometric.loader.DataLoader

def compute_loss(self, model, data):
model.train()
tdata = self.to_device(data)
z = model.encode(tdata.x, tdata.edge_label_index)
neg_edge_index = negative_sampling(
edge_index=tdata.edge_index, num_nodes=tdata.num_nodes,
num_neg_samples=tdata.edge_label_index.size(1)
)
edge_label_index = torch.cat(
[tdata.edge_label_index, neg_edge_index],
dim=-1,
)
outputs, neg_edge_index = model(tdata)
edge_label = torch.cat([
tdata.edge_label,
tdata.edge_label.new_zeros(neg_edge_index.size(1))
], dim=0)
outputs = model.decode(z, edge_label_index).view(-1)
loss = self.criterion(outputs, edge_label)
return {'loss': loss}

@torch.no_grad()
def test(self, model, dataset, batch_size=64, num_workers=0, pin_memory=False):
res = {}
model.eval()
tdata = self.to_device(dataset)
z = model.encode(tdata.x, tdata.edge_index)
outputs = model.decode(z, tdata.edge_label_index).view(-1)
outputs = model(tdata)
loss = self.criterion(outputs, tdata.edge_label)
res['loss'] = loss.item()
sigmoid_out = outputs.sigmoid()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
r"""
This module is a template for benchmark of link prediction within a graph in horizontalFL. To use this module,
one can write codes as below
Example::
"""
from .model import default_model
import flgo.benchmark.toolkits.visualization
import flgo.benchmark.toolkits.partition
Expand Down
74 changes: 73 additions & 1 deletion flgo/benchmark/toolkits/graph/link_prediction/temp/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import os
import torch_geometric
import networkx as nx
import torch_geometric.transforms as T
from torch_geometric.utils import from_networkx
from flgo.benchmark.toolkits.graph.link_prediction import GeneralCalculator, FromDatasetPipe, FromDatasetGenerator
from .config import train_data
try:
import ujson as json
except:
import json
try:
from .config import test_data
except:
Expand All @@ -11,12 +19,76 @@
val_data = None

class TaskGenerator(FromDatasetGenerator):
def __init__(self):
def __init__(self, disjoint_train_ratio=0.0, neg_sampling_ratio=1.0, add_negative_train_samples=False):
super(TaskGenerator, self).__init__(benchmark=os.path.split(os.path.dirname(__file__))[-1],
train_data=train_data, val_data=val_data, test_data=test_data)
self.disjoin_train_ratio = disjoint_train_ratio
self.neg_sampling_ratio = neg_sampling_ratio
self.add_negative_train_samples = add_negative_train_samples

class TaskPipe(FromDatasetPipe):
def __init__(self, task_path):
super(TaskPipe, self).__init__(task_path, train_data=train_data, val_data=val_data, test_data=test_data)

def save_task(self, generator):
client_names = self.gen_client_names(len(generator.local_datas))
feddata = {'client_names': client_names, 'disjoint_train_ratio': generator.disjoin_train_ratio, 'neg_sampling_ratio': generator.neg_sampling_ratio, 'add_negative_train_samples':generator.add_negative_train_samples}
partitioner_name = generator.partitioner.__class__.__name__.lower()
feddata['partition_by_edge'] = 'link' in partitioner_name or 'edge' in partitioner_name
for cid in range(len(client_names)):
feddata[client_names[cid]] = {'data': generator.local_datas[cid], }
with open(os.path.join(self.task_path, 'data.json'), 'w') as outf:
json.dump(feddata, outf)
return

def load_data(self, running_time_option) -> dict:
train_data = self.train_data
neg_sampling_ratio = self.feddata['neg_sampling_ratio']
disjoint_train_ratio = self.feddata['disjoint_train_ratio']
add_negative_train_samples = self.feddata['add_negative_train_samples']
test_data = self.test_data
val_data = self.val_data
server_test_data = test_data
server_val_data = val_data
if server_test_data is not None:
if val_data is None:
if running_time_option['test_holdout']>0:
num_test,num_val = 1 - running_time_option['test_holdout'], running_time_option['test_holdout'] * 0.2
else:
num_test, num_val = 0.2,0.0
# split test_data
server_val_data, _, server_test_data = T.RandomLinkSplit(
neg_sampling_ratio=neg_sampling_ratio, is_undirected=server_test_data.is_undirected(),
num_test=num_test, num_val=num_val)(server_test_data)
if num_val==0.0: server_val_data = None
else:
_tmp1, _tmp2, server_test_data = T.RandomLinkSplit(
neg_sampling_ratio=neg_sampling_ratio, is_undirected=server_test_data.is_undirected(),
num_test=0.2, num_val=0.0)(server_test_data)
_tmp1, _tmp2, server_val_data = T.RandomLinkSplit(
neg_sampling_ratio=neg_sampling_ratio, is_undirected=server_val_data.is_undirected(),
num_test=0.2, num_val=0.0)(server_val_data)
elif server_val_data is not None:
_tmp1, _tmp2, server_val_data = T.RandomLinkSplit(
neg_sampling_ratio=neg_sampling_ratio, is_undirected=server_val_data.is_undirected(),
num_test=0.2, num_val=0.0)(server_val_data)
task_data = {'server': {'test': server_test_data, 'val': server_val_data}}
edge_attrs=['edge_attr'] if hasattr(train_data, 'edge_attr') and train_data.edge_attr is not None else None
G = torch_geometric.utils.to_networkx(train_data, to_undirected=train_data.is_undirected(), node_attrs=['x', 'y', 'train_mask', 'val_mask', 'test_mask'], edge_attrs=edge_attrs)
num_val = running_time_option['train_holdout']
if num_val>0 and running_time_option['local_test']:
num_test, num_val = 0.5*num_val, 0.5*num_val
else:
num_test = 0.0
all_edges = list(G.edges)
for cid, cname in enumerate(self.feddata['client_names']):
if self.feddata['partition_by_edge']:
c_dataset = from_networkx(nx.edge_subgraph(G, [all_edges[eid] for eid in self.feddata[cname]['data']]))
else:
c_dataset = from_networkx(nx.subgraph(G, self.feddata[cname]['data']))
ctrans = T.RandomLinkSplit(neg_sampling_ratio=neg_sampling_ratio, is_undirected=c_dataset.is_undirected(), num_test=num_test, num_val=num_val, add_negative_train_samples=add_negative_train_samples, disjoint_train_ratio=disjoint_train_ratio)
cdata_train, cdata_val, cdata_test = ctrans(c_dataset)
task_data[cname] = {'train': cdata_train, 'val': cdata_val if len(cdata_val.edge_label)>0 else None, 'test':cdata_test if len(cdata_test.edge_label)>0 else None}
return task_data

TaskCalculator = GeneralCalculator
3 changes: 2 additions & 1 deletion flgo/utils/fflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import flgo.experiment.logger
import flgo.experiment.device_scheduler
from flgo.simulator.base import BasicSimulator
import flgo.benchmark.partition
import flgo.benchmark.toolkits.partition
import flgo.algorithm

Expand Down Expand Up @@ -414,7 +415,7 @@ def gen_task(config={}, task_path:str= '', rawdata_path:str= '', seed:int=0):
Partitioner = gen_option['partitioner']['name']
if type(Partitioner) is str:
if Partitioner in globals().keys(): Partitioner = eval(Partitioner)
else: Partitioner = getattr(flgo.benchmark.toolkits.partition, Partitioner)
else: Partitioner = getattr(flgo.benchmark.partition, Partitioner)
partitioner = Partitioner(**gen_option['partitioner']['para'])
task_generator.register_partitioner(partitioner)
partitioner.register_generator(task_generator)
Expand Down

0 comments on commit 09a35a8

Please sign in to comment.