diff --git a/src/kg_topology_toolbox/utils.py b/src/kg_topology_toolbox/utils.py index 0160d1c..8194572 100644 --- a/src/kg_topology_toolbox/utils.py +++ b/src/kg_topology_toolbox/utils.py @@ -194,14 +194,15 @@ def _composition_count_worker( n_rels = adj_csr.shape[0] // n_nodes adj_2hop = adj_csr @ adj_csc adj_composition = (adj_2hop.tocsc() * (adj_mask > 0)).tocoo() - col_shift = adj_composition.col + tail_shift if n_rels > 1: + h, r1 = np.divmod(adj_composition.row, n_rels) + r2, t = np.divmod(adj_composition.col + tail_shift, n_nodes) df_composition = pd.DataFrame( dict( - h=adj_composition.row // n_rels, - t=col_shift % n_nodes, - r1=adj_composition.row % n_rels, - r2=col_shift // n_nodes, + h=h, + t=t, + r1=r1, + r2=r2, n_triangles=adj_composition.data, ) ) @@ -209,7 +210,7 @@ def _composition_count_worker( df_composition = pd.DataFrame( dict( h=adj_composition.row, - t=col_shift, + t=adj_composition.col + tail_shift, n_triangles=adj_composition.data, ) )