Skip to content

Commit

Permalink
reduce memory usage of metapath counting
Browse files Browse the repository at this point in the history
  • Loading branch information
AlCatt91 committed Oct 18, 2024
1 parent 05e3ee3 commit 5ec0f3e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 34 deletions.
70 changes: 41 additions & 29 deletions src/kg_topology_toolbox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _composition_count_worker(
n_nodes = adj_csr.shape[1]
n_rels = adj_csr.shape[0] // n_nodes
adj_2hop = adj_csr @ adj_csc
adj_composition = (adj_2hop.tocsc() * (adj_mask > 0)).tocoo()
adj_composition = (adj_2hop.tocsc() * adj_mask).tocoo()
if n_rels > 1:
h, r1 = np.divmod(adj_composition.row, n_rels)
r2, t = np.divmod(adj_composition.col + tail_shift, n_nodes)
Expand Down Expand Up @@ -250,19 +250,15 @@ def composition_count(

n_nodes = df[["h", "t"]].max().max() + 1
n_rels = df["r"].max() + 1
adj = coo_array(
(np.ones(len(df)), (df.h, df.t)),
shape=[n_nodes, n_nodes],
).astype(np.uint16)
if metapaths:
adj_repeated = csc_array(
(
np.ones(n_rels * n_rels * len(df)),
(
(n_rels * df.h.values[:, None] + np.arange(n_rels)).repeat(n_rels),
np.tile(
df.t.values[:, None] + n_nodes * np.arange(n_rels), n_rels
).flatten(),
),
),
shape=[n_nodes * n_rels, n_nodes * n_rels],
).astype(np.uint16)
if not directed:
raise NotImplementedError(
"Metapath counting only implemented for directed triangles"
)
adj_csr = csr_array(
(np.ones(len(df)), (df.h * n_rels + df.r, df.t)),
shape=[n_nodes * n_rels, n_nodes],
Expand All @@ -271,26 +267,24 @@ def composition_count(
(np.ones(len(df)), (df.h, df.r * n_nodes + df.t)),
shape=[n_nodes, n_nodes * n_rels],
).astype(np.uint16)
n_cols = adj_csc.shape[1]
adj_repeated_slices = {
i: adj_repeated[:, i * chunk_size : min((i + 1) * chunk_size, n_cols)]
for i in range(int(np.ceil(n_cols / chunk_size)))
}
if not directed:
raise NotImplementedError(
"Metapath counting only implemented for directed triangles"
)
# boolean mask to filter results with only the edges in the KG
msk = csc_array(
(
[True] * (len(adj.data) * n_rels),
(
(n_rels * adj.row + np.arange(n_rels)[:, None]).flatten(),
np.tile(adj.col, n_rels),
),
),
shape=[n_nodes * n_rels, n_nodes],
)
else:
adj = coo_array(
(np.ones(len(df)), (df.h, df.t)),
shape=[n_nodes, n_nodes],
).astype(np.uint16)
if not directed:
adj = adj + adj.T
adj_csr = adj.tocsr()
adj_csc = adj.tocsc()
n_cols = adj_csc.shape[1]

n_cols = adj_csc.shape[1]
adj_csc_slices = {
i: adj_csc[:, i * chunk_size : min((i + 1) * chunk_size, n_cols)]
for i in range(int(np.ceil(n_cols / chunk_size)))
Expand All @@ -304,7 +298,16 @@ def composition_count(
(
adj_csr,
adj_csc_slice,
adj_repeated_slices[i] if metapaths else adj_csc_slice,
(
# relevant slice of mask (with wraparound)
msk[
:,
(i * chunk_size + np.arange(adj_csc_slice.shape[1]))
% msk.shape[1],
]
if metapaths
else adj_csc_slice > 0
),
i * chunk_size,
)
for i, adj_csc_slice in adj_csc_slices.items()
Expand All @@ -315,7 +318,16 @@ def composition_count(
_composition_count_worker(
adj_csr,
adj_csc_slice,
adj_repeated_slices[i] if metapaths else adj_csc_slice,
(
# relevant slice of mask (with wraparound)
msk[
:,
(i * chunk_size + np.arange(adj_csc_slice.shape[1]))
% msk.shape[1],
]
if metapaths
else adj_csc_slice > 0
),
i * chunk_size,
)
for i, adj_csc_slice in adj_csc_slices.items()
Expand Down
13 changes: 8 additions & 5 deletions tests/test_edge_topology_toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@


def test_edge_metapath_count() -> None:
res = kgtt.edge_metapath_count()
res = kgtt.edge_metapath_count(composition_chunk_size=3)
assert np.allclose(res["index"], [2, 2])
assert np.allclose(res["h"], [0, 0])
assert np.allclose(res["r"], [0, 0])
assert np.allclose(res["t"], [2, 2])
assert np.allclose(res["r1"], [0, 1])
assert np.allclose(res["r2"], [1, 1])
assert set(zip(res["r1"].values.tolist(), res["r2"].values.tolist())) == set(
[(0, 1), (1, 1)]
)
assert np.allclose(res["n_triangles"], [1, 1])


Expand Down Expand Up @@ -71,7 +72,9 @@ def test_edge_degree_cardinality_summary() -> None:
@pytest.mark.parametrize("return_metapath_list", [True, False])
def test_edge_pattern_summary(return_metapath_list: bool) -> None:
# relation pattern symmetry
res = kgtt.edge_pattern_summary(return_metapath_list=return_metapath_list)
res = kgtt.edge_pattern_summary(
return_metapath_list=return_metapath_list, composition_chunk_size=3
)
assert np.allclose(
res["is_loop"], [False, False, False, False, False, False, True, True]
)
Expand All @@ -96,7 +99,7 @@ def test_edge_pattern_summary(return_metapath_list: bool) -> None:
assert np.allclose(res["n_triangles"], [0, 0, 2, 0, 0, 0, 0, 0])
assert np.allclose(res["n_undirected_triangles"], [3, 3, 2, 6, 2, 2, 0, 0])
if return_metapath_list:
assert res["metapath_list"][2] == ["0-1", "1-1"]
assert set(res["metapath_list"][2]) == set(["0-1", "1-1"])


def test_filter_relations() -> None:
Expand Down

0 comments on commit 5ec0f3e

Please sign in to comment.