From 5ec0f3e07d9ab8fddc4607d474af2b8455115b93 Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo Date: Fri, 18 Oct 2024 17:21:56 +0000 Subject: [PATCH] reduce memory usage of metapath counting --- src/kg_topology_toolbox/utils.py | 70 +++++++++++++++++------------ tests/test_edge_topology_toolbox.py | 13 +++--- 2 files changed, 49 insertions(+), 34 deletions(-) diff --git a/src/kg_topology_toolbox/utils.py b/src/kg_topology_toolbox/utils.py index 8194572..d2e798f 100644 --- a/src/kg_topology_toolbox/utils.py +++ b/src/kg_topology_toolbox/utils.py @@ -193,7 +193,7 @@ def _composition_count_worker( n_nodes = adj_csr.shape[1] n_rels = adj_csr.shape[0] // n_nodes adj_2hop = adj_csr @ adj_csc - adj_composition = (adj_2hop.tocsc() * (adj_mask > 0)).tocoo() + adj_composition = (adj_2hop.tocsc() * adj_mask).tocoo() if n_rels > 1: h, r1 = np.divmod(adj_composition.row, n_rels) r2, t = np.divmod(adj_composition.col + tail_shift, n_nodes) @@ -250,19 +250,15 @@ def composition_count( n_nodes = df[["h", "t"]].max().max() + 1 n_rels = df["r"].max() + 1 + adj = coo_array( + (np.ones(len(df)), (df.h, df.t)), + shape=[n_nodes, n_nodes], + ).astype(np.uint16) if metapaths: - adj_repeated = csc_array( - ( - np.ones(n_rels * n_rels * len(df)), - ( - (n_rels * df.h.values[:, None] + np.arange(n_rels)).repeat(n_rels), - np.tile( - df.t.values[:, None] + n_nodes * np.arange(n_rels), n_rels - ).flatten(), - ), - ), - shape=[n_nodes * n_rels, n_nodes * n_rels], - ).astype(np.uint16) + if not directed: + raise NotImplementedError( + "Metapath counting only implemented for directed triangles" + ) adj_csr = csr_array( (np.ones(len(df)), (df.h * n_rels + df.r, df.t)), shape=[n_nodes * n_rels, n_nodes], @@ -271,26 +267,24 @@ def composition_count( (np.ones(len(df)), (df.h, df.r * n_nodes + df.t)), shape=[n_nodes, n_nodes * n_rels], ).astype(np.uint16) - n_cols = adj_csc.shape[1] - adj_repeated_slices = { - i: adj_repeated[:, i * chunk_size : min((i + 1) * chunk_size, n_cols)] - for i in range(int(np.ceil(n_cols / chunk_size))) - } - if not directed: - raise NotImplementedError( - "Metapath counting only implemented for directed triangles" - ) + # boolean mask to filter results with only the edges in the KG + msk = csc_array( + ( + [True] * (len(adj.data) * n_rels), + ( + (n_rels * adj.row + np.arange(n_rels)[:, None]).flatten(), + np.tile(adj.col, n_rels), + ), + ), + shape=[n_nodes * n_rels, n_nodes], + ) else: - adj = coo_array( - (np.ones(len(df)), (df.h, df.t)), - shape=[n_nodes, n_nodes], - ).astype(np.uint16) if not directed: adj = adj + adj.T adj_csr = adj.tocsr() adj_csc = adj.tocsc() - n_cols = adj_csc.shape[1] + n_cols = adj_csc.shape[1] adj_csc_slices = { i: adj_csc[:, i * chunk_size : min((i + 1) * chunk_size, n_cols)] for i in range(int(np.ceil(n_cols / chunk_size))) @@ -304,7 +298,16 @@ def composition_count( ( adj_csr, adj_csc_slice, - adj_repeated_slices[i] if metapaths else adj_csc_slice, + ( + # relevant slice of mask (with wraparound) + msk[ + :, + (i * chunk_size + np.arange(adj_csc_slice.shape[1])) + % msk.shape[1], + ] + if metapaths + else adj_csc_slice > 0 + ), i * chunk_size, ) for i, adj_csc_slice in adj_csc_slices.items() @@ -315,7 +318,16 @@ def composition_count( _composition_count_worker( adj_csr, adj_csc_slice, - adj_repeated_slices[i] if metapaths else adj_csc_slice, + ( + # relevant slice of mask (with wraparound) + msk[ + :, + (i * chunk_size + np.arange(adj_csc_slice.shape[1])) + % msk.shape[1], + ] + if metapaths + else adj_csc_slice > 0 + ), i * chunk_size, ) for i, adj_csc_slice in adj_csc_slices.items() diff --git a/tests/test_edge_topology_toolbox.py b/tests/test_edge_topology_toolbox.py index 49bfa5e..849bd5c 100644 --- a/tests/test_edge_topology_toolbox.py +++ b/tests/test_edge_topology_toolbox.py @@ -23,13 +23,14 @@ def test_edge_metapath_count() -> None: - res = kgtt.edge_metapath_count() + res = kgtt.edge_metapath_count(composition_chunk_size=3) assert np.allclose(res["index"], [2, 2]) assert np.allclose(res["h"], [0, 0]) assert np.allclose(res["r"], [0, 0]) assert np.allclose(res["t"], [2, 2]) - assert np.allclose(res["r1"], [0, 1]) - assert np.allclose(res["r2"], [1, 1]) + assert set(zip(res["r1"].values.tolist(), res["r2"].values.tolist())) == set( + [(0, 1), (1, 1)] + ) assert np.allclose(res["n_triangles"], [1, 1]) @@ -71,7 +72,9 @@ def test_edge_degree_cardinality_summary() -> None: @pytest.mark.parametrize("return_metapath_list", [True, False]) def test_edge_pattern_summary(return_metapath_list: bool) -> None: # relation pattern symmetry - res = kgtt.edge_pattern_summary(return_metapath_list=return_metapath_list) + res = kgtt.edge_pattern_summary( + return_metapath_list=return_metapath_list, composition_chunk_size=3 + ) assert np.allclose( res["is_loop"], [False, False, False, False, False, False, True, True] ) @@ -96,7 +99,7 @@ def test_edge_pattern_summary(return_metapath_list: bool) -> None: assert np.allclose(res["n_triangles"], [0, 0, 2, 0, 0, 0, 0, 0]) assert np.allclose(res["n_undirected_triangles"], [3, 3, 2, 6, 2, 2, 0, 0]) if return_metapath_list: - assert res["metapath_list"][2] == ["0-1", "1-1"] + assert set(res["metapath_list"][2]) == set(["0-1", "1-1"]) def test_filter_relations() -> None: