Skip to content

Commit

Permalink
refactor metapath counting with sparse matmuls
Browse files Browse the repository at this point in the history
  • Loading branch information
AlCatt91 committed Oct 15, 2024
1 parent 432615f commit 7c6d257
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 41 deletions.
82 changes: 61 additions & 21 deletions src/kg_topology_toolbox/topology_toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,56 @@ def edge_cardinality(self) -> pd.DataFrame:
).astype(str)
return df_res

def edge_metapath_count(
self,
filter_relations: list[int] = [],
composition_chunk_size: int = 2**8,
composition_workers: int = min(32, mp.cpu_count() - 1 or 1),
) -> pd.DataFrame:
"""
For each edge in the KG, compute the number of triangles of different
metapaths (i.e., the unique tuples (r1, r2) of relation types
of the two additional edges of the triangle).
:param filter_relations:
If not empty, compute the output only for the edges with relation
in this list of relation IDs.
:param composition_chunk_size:
Size of column chunks of sparse adjacency matrix
to compute the triangle count. Default: 2**8.
:param composition_workers:
Number of workers to compute the triangle count. By default, assigned based
on number of available threads (max: 32).
:return:
The output dataframe has one row for each (h, t, r1, r2) such that
there exists at least one triangle of metapath (r1, r2) over (any) edge
connecting h, t.
The number of metapath triangles is given in the column **n_triangles**.
"""
# discard loops as edges of a triangle
df_wo_loops = self.df[self.df.h != self.df.t]
if len(filter_relations) > 0:
rel_df = self.df[self.df.r.isin(filter_relations)]
filter_heads = rel_df.h.unique()
filter_tails = rel_df.t.unique()
df_triangles = df_wo_loops[
np.logical_or(
df_wo_loops.h.isin(filter_heads), df_wo_loops.t.isin(filter_tails)
)
]
else:
rel_df = self.df
df_triangles = df_wo_loops

return composition_count(
df_triangles,
chunk_size=composition_chunk_size,
workers=composition_workers,
metapaths=True,
directed=True,
)

def edge_degree_cardinality_summary(
self, filter_relations: list[int] = [], aggregate_by_r: bool = False
) -> pd.DataFrame:
Expand Down Expand Up @@ -425,8 +475,6 @@ def edge_pattern_summary(
self.df.h.isin(filter_tails), self.df.t.isin(filter_heads)
)
]
df_triangles_out = df_wo_loops[df_wo_loops.h.isin(filter_heads)]
df_triangles_in = df_wo_loops[df_wo_loops.t.isin(filter_tails)]
df_triangles = df_wo_loops[
np.logical_or(
df_wo_loops.h.isin(filter_heads), df_wo_loops.t.isin(filter_tails)
Expand All @@ -440,9 +488,7 @@ def edge_pattern_summary(
]
else:
rel_df = inference_df = inverse_df = self.df
df_triangles = df_triangles_und = df_triangles_out = df_triangles_in = (
df_wo_loops
)
df_triangles = df_triangles_und = df_wo_loops
df_res = df_res = pd.DataFrame(
{"h": rel_df.h, "r": rel_df.r, "t": rel_df.t, "is_symmetric": False}
)
Expand Down Expand Up @@ -501,30 +547,24 @@ def edge_pattern_summary(

# composition & metapaths
if return_metapath_list:
# 2-hop paths
df_bridges = df_triangles_out.merge(
df_triangles_in, left_on="t", right_on="h", how="inner"
counts = self.edge_metapath_count(
filter_relations,
composition_chunk_size,
composition_workers,
)
df_res_triangles = df_res[df_res.h != df_res.t].merge(
df_bridges, left_on=["h", "t"], right_on=["h_x", "t_y"], how="inner"
counts["metapath"] = (
counts["r1"].astype(str) + "-" + counts["r2"].astype(str)
)
df_res_triangles["metapath"] = (
df_res_triangles["r_x"].astype(str)
+ "-"
+ df_res_triangles["r_y"].astype(str)
)
grouped_triangles = df_res_triangles.groupby(
["h", "r", "t"], as_index=False
).agg(
n_triangles=("metapath", "count"), metapath_list=("metapath", "unique")
grouped_triangles = counts.groupby(["h", "t"], as_index=False).agg(
n_triangles=("n_triangles", "sum"), metapath_list=("metapath", list)
)
df_res = df_res.merge(
grouped_triangles,
on=["h", "r", "t"],
on=["h", "t"],
how="left",
)
df_res["metapath_list"] = df_res["metapath_list"].apply(
lambda agg: agg.tolist() if isinstance(agg, np.ndarray) else []
lambda agg: agg if isinstance(agg, list) else []
)
df_res["n_triangles"] = df_res["n_triangles"].fillna(0).astype(int)
else:
Expand Down
101 changes: 81 additions & 20 deletions src/kg_topology_toolbox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,22 +188,40 @@ def jaccard_similarity(


def _composition_count_worker(
adj_csr: csr_array, adj_csc: csc_array, tail_shift: int = 0
adj_csr: csr_array, adj_csc: csc_array, adj_mask: csc_array, tail_shift: int = 0
) -> pd.DataFrame:
n_nodes = adj_csr.shape[1]
n_rels = adj_csr.shape[0] // n_nodes
adj_2hop = adj_csr @ adj_csc
adj_composition = (adj_2hop.tocsc() * (adj_csc > 0)).tocoo()
df_composition = pd.DataFrame(
dict(
h=adj_composition.row,
t=adj_composition.col + tail_shift,
n_triangles=adj_composition.data,
adj_composition = (adj_2hop.tocsc() * (adj_mask > 0)).tocoo()
col_shift = adj_composition.col + tail_shift
if n_rels > 1:
df_composition = pd.DataFrame(
dict(
h=adj_composition.row // n_rels,
t=col_shift % n_nodes,
r1=adj_composition.row % n_rels,
r2=col_shift // n_nodes,
n_triangles=adj_composition.data,
)
)
else:
df_composition = pd.DataFrame(
dict(
h=adj_composition.row,
t=col_shift,
n_triangles=adj_composition.data,
)
)
)
return df_composition


def composition_count(
df: pd.DataFrame, chunk_size: int, workers: int, directed: bool = True
df: pd.DataFrame,
chunk_size: int,
workers: int,
metapaths: bool = False,
directed: bool = True,
) -> pd.DataFrame:
"""A helper function to compute the composition count of a graph.
Expand All @@ -227,15 +245,48 @@ def composition_count(
"""

n_nodes = df[["h", "t"]].max().max() + 1
adj = coo_array(
(np.ones(len(df)), (df.h, df.t)),
shape=[n_nodes, n_nodes],
).astype(np.uint16)
if not directed:
adj = adj + adj.T
n_cols = adj.shape[1]
adj_csr = adj.tocsr()
adj_csc = adj.tocsc()
n_rels = df["r"].max() + 1
if metapaths:
adj_repeated = csc_array(
(
np.ones(n_rels * n_rels * len(df)),
(
(n_rels * df.h.values[:, None] + np.arange(n_rels)).repeat(n_rels),
np.tile(
df.t.values[:, None] + n_nodes * np.arange(n_rels), n_rels
).flatten(),
),
),
shape=[n_nodes * n_rels, n_nodes * n_rels],
).astype(np.uint16)
adj_csr = csr_array(
(np.ones(len(df)), (df.h * n_rels + df.r, df.t)),
shape=[n_nodes * n_rels, n_nodes],
).astype(np.uint16)
adj_csc = csc_array(
(np.ones(len(df)), (df.h, df.r * n_nodes + df.t)),
shape=[n_nodes, n_nodes * n_rels],
).astype(np.uint16)
n_cols = adj_csc.shape[1]
adj_repeated_slices = {
i: adj_repeated[:, i * chunk_size : min((i + 1) * chunk_size, n_cols)]
for i in range(int(np.ceil(n_cols / chunk_size)))
}
if not directed:
raise NotImplementedError(
"Metapath counting only implemented for directed triangles"
)
else:
adj = coo_array(
(np.ones(len(df)), (df.h, df.t)),
shape=[n_nodes, n_nodes],
).astype(np.uint16)
if not directed:
adj = adj + adj.T
adj_csr = adj.tocsr()
adj_csc = adj.tocsc()
n_cols = adj_csc.shape[1]

adj_csc_slices = {
i: adj_csc[:, i * chunk_size : min((i + 1) * chunk_size, n_cols)]
for i in range(int(np.ceil(n_cols / chunk_size)))
Expand All @@ -246,13 +297,23 @@ def composition_count(
df_composition_list = pool.starmap(
_composition_count_worker,
(
(adj_csr, adj_csc_slice, i * chunk_size)
(
adj_csr,
adj_csc_slice,
adj_repeated_slices[i] if metapaths else adj_csc_slice,
i * chunk_size,
)
for i, adj_csc_slice in adj_csc_slices.items()
),
)
else:
df_composition_list = [
_composition_count_worker(adj_csr, adj_csc_slice, i * chunk_size)
_composition_count_worker(
adj_csr,
adj_csc_slice,
adj_repeated_slices[i] if metapaths else adj_csc_slice,
i * chunk_size,
)
for i, adj_csc_slice in adj_csc_slices.items()
]

Expand Down

0 comments on commit 7c6d257

Please sign in to comment.