-
Notifications
You must be signed in to change notification settings - Fork 1
/
arw.py
145 lines (115 loc) · 5.51 KB
/
arw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import igraph as ig
import numpy as np
import random
import time
from collections import defaultdict, Counter, deque
import itertools
MAX_LENGTH = 1000 # maximum random walk length (hyperparameter)
class RandomWalkSingleAttribute(object):
def __init__(self, p_diff, p_same, jump, out, gpre, attr_name='single_attr', debug=True):
self.p_diff = p_diff
self.p_same = p_same
self.jump = jump
self.out = out
self.gpre = gpre
self.attr_name = attr_name
self.directed = True
self.g = ig.Graph(directed=self.directed)
self.attributed = attr_name in gpre.vs.attributes()
self.debug = debug
self.setup()
def summary(self): return self.g.summary()
def flip(self, p): return random.random() < p
def setup(self):
self.seed_same = self.p_same/(self.p_same+self.p_diff)
self.seed_diff = 1.-self.seed_same
self.n0 = len(self.gpre.vs)
self.total_edges = len(self.gpre.es)
self.next_nid = self.n0
self.chunk_nid = self.next_nid-1
self.chunk_size = 1
self.nbors = defaultdict(list)
self.out_nbors = defaultdict(list)
self.in_nbors = defaultdict(list)
self.nid_chunk_map = {}
self.nid_attr_map = {}
self.attr_nid_map = defaultdict(list)
for nid, nbor_nids in enumerate(self.gpre.get_adjlist(mode='ALL')): self.nbors[nid] = nbor_nids
for nid, nbor_nids in enumerate(self.gpre.get_adjlist(mode='OUT')): self.out_nbors[nid] = nbor_nids
for nid, nbor_nids in enumerate(self.gpre.get_adjlist(mode='IN')): self.in_nbors[nid] = nbor_nids
for node in self.gpre.vs: self.nid_attr_map[node.index] = node[self.attr_name] if self.attributed else None
for nid, attr in self.nid_attr_map.items(): self.attr_nid_map[attr].append(nid)
for nid in self.gpre.vs.indices: self.nid_chunk_map[nid] = 0
def add_nodes(self, chunk_seq, mean_seq, chunk_attr_sampler=None):
if self.attributed: assert chunk_attr_sampler
num_chunks = len(chunk_seq)
chunk_debug = num_chunks//10
if (self.debug): print ("Total chunks: {}".format(num_chunks))
for idx, (chunk_size, m) in enumerate(zip(chunk_seq, mean_seq)):
if self.debug and (idx + 1) % chunk_debug == 0: print (idx, end=' ')
self.chunk_size = chunk_size
self.m = m
self.add_chunk(idx, attr_sampler=chunk_attr_sampler[idx][:] if self.attributed else None)
self.chunk_nid = self.next_nid-1
self.build_graph()
def add_chunk(self, chunk_id, attr_sampler=None):
if self.attributed: assert attr_sampler
marked = defaultdict(frozenset)
for _ in range(self.chunk_size):
new_nid = self.next_nid; self.next_nid += 1
self.nid_chunk_map[new_nid] = chunk_id
attrs = attr_sampler.pop() if self.attributed else None
marked[new_nid] = self.add_node(new_nid, attrs=attrs)
self.update_node(new_nid, marked[new_nid])
def update_node(self, nid, marked):
for nbor_nid in marked:
self.out_nbors[nid].append(nbor_nid)
self.in_nbors[nbor_nid].append(nid)
def build_graph(self):
self.edges = edges = set()
all_nbors = self.out_nbors
for node, nbors in all_nbors.items():
for nbor in nbors: edges.add((node, nbor))
self.g.add_vertices(self.next_nid)
self.g.add_edges(list(edges))
self.g.simplify()
self.g.vs['chunk_id'] = [self.nid_chunk_map[n] for n in self.g.vs.indices]
if self.attributed: self.g.vs[self.attr_name] = [self.nid_attr_map[n] for n in self.g.vs.indices]
if self.debug: print ("\n{}".format(self.g.summary()))
def link(self, cur_nid, attrs=None):
if not self.attributed:
return random.random() < self.p_diff
else:
cur_attrs = self.nid_attr_map[cur_nid]
p = self.p_same if cur_attrs == attrs else self.p_diff
return random.random() < p
def get_seed_nid(self, new_nid, attrs=None):
if not self.attributed: return random.randint(0, new_nid-1)
if random.random() < self.seed_diff: return random.randint(0, new_nid-1)
same_nids = self.attr_nid_map[attrs]
if same_nids: return random.choice(same_nids)
return np.random.randint(0, new_nid-1)
def add_node(self, new_nid, attrs=None):
marked = set()
m = int(round(self.m if self.flip(0.5) else self.m+0.5))
cur_nid = seed_nid = self.get_seed_nid(new_nid, attrs=attrs)
num_marked, length, max_length = 0, 0, MAX_LENGTH/max(self.p_same, self.p_diff)
while num_marked < m:
length += 1
if length > max_length:
break
if cur_nid not in marked and self.link(cur_nid, attrs):
num_marked += 1
marked.add(cur_nid)
if random.random() < self.jump:
cur_nid = seed_nid
else:
use_out = random.random() < self.out
nbors = self.out_nbors[cur_nid] if use_out else self.in_nbors[cur_nid]
if not nbors: nbors = self.in_nbors[cur_nid] if use_out else self.out_nbors[cur_nid]
if nbors: cur_nid = random.choice(nbors)
else: cur_nid = seed_nid = self.get_seed_nid(new_nid, attrs=attrs)
if self.attributed:
self.nid_attr_map[new_nid] = attrs
self.attr_nid_map[attrs].append(new_nid)
return marked