From 0c08b8dedf30593025289bc716893600b29b4801 Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo Date: Mon, 19 Aug 2024 09:47:25 +0000 Subject: [PATCH] address PR comments --- docs/source/notebooks/ogb_biokg_demo.ipynb | 2 +- src/kg_topology_toolbox/topology_toolbox.py | 18 ++---------- src/kg_topology_toolbox/utils.py | 32 +++++++++++++++++++++ 3 files changed, 36 insertions(+), 16 deletions(-) diff --git a/docs/source/notebooks/ogb_biokg_demo.ipynb b/docs/source/notebooks/ogb_biokg_demo.ipynb index 6ac2db3..bb73448 100644 --- a/docs/source/notebooks/ogb_biokg_demo.ipynb +++ b/docs/source/notebooks/ogb_biokg_demo.ipynb @@ -31,7 +31,7 @@ "source": [ "import sys\n", "!{sys.executable} -m pip uninstall -y kg_topology_toolbox\n", - "!pip install -q git+https://github.com/graphcore-research/kg-topology-toolbox.git@refactor_kgtt --no-cache-dir\n", + "!pip install -q git+https://github.com/graphcore-research/kg-topology-toolbox.git --no-cache-dir\n", "!pip install -q jupyter ipywidgets ogb seaborn" ] }, diff --git a/src/kg_topology_toolbox/topology_toolbox.py b/src/kg_topology_toolbox/topology_toolbox.py index e6516bc..8d29e96 100644 --- a/src/kg_topology_toolbox/topology_toolbox.py +++ b/src/kg_topology_toolbox/topology_toolbox.py @@ -5,16 +5,15 @@ Topology toolbox main functionalities """ -import warnings from functools import cache import numpy as np import pandas as pd -from pandas.api.types import is_integer_dtype from scipy.sparse import coo_array from kg_topology_toolbox.utils import ( aggregate_by_relation, + check_kg_df_structure, composition_count, jaccard_similarity, node_degrees_and_rels, @@ -49,22 +48,11 @@ def __init__( The name of the column with the IDs of tail entities. Default: "t". """ - for col_name in [head_column, relation_column, tail_column]: - if col_name in kg_df.columns: - if not is_integer_dtype(kg_df[col_name]): - raise TypeError( - f"Column {col_name} needs to be of an integer dtype" - ) - else: - raise ValueError(f"DataFrame {kg_df} has no column named {col_name}") + check_kg_df_structure(kg_df, head_column, relation_column, tail_column) + self.df = kg_df[[head_column, relation_column, tail_column]].rename( columns={head_column: "h", relation_column: "r", tail_column: "t"} ) - if self.df.duplicated().any(): - warnings.warn( - "The Knowledge Graph contains duplicated edges" - " -- some functionalities may produce incorrect results" - ) self.n_entity = self.df[["h", "t"]].max().max() + 1 self.n_rel = self.df.r.max() + 1 diff --git a/src/kg_topology_toolbox/utils.py b/src/kg_topology_toolbox/utils.py index 58d3f37..d3a3d55 100644 --- a/src/kg_topology_toolbox/utils.py +++ b/src/kg_topology_toolbox/utils.py @@ -4,15 +4,47 @@ Utility functions """ +import warnings from collections.abc import Iterable from multiprocessing import Pool import numpy as np import pandas as pd from numpy.typing import NDArray +from pandas.api.types import is_integer_dtype from scipy.sparse import coo_array, csc_array, csr_array +def check_kg_df_structure(kg_df: pd.DataFrame, h: str, r: str, t: str) -> None: + """ + Utility to perform sanity checks on the structure of the provided DataFrame, + to ensure that it encodes a Knowledge Graph in a compatible way. + + :param kg_df: + The Knowledge Graph DataFrame. + :param h: + The name of the column with the IDs of head entities. + :param r: + The name of the column with the IDs of relation types. + :param t: + The name of the column with the IDs of tail entities. + + """ + # check h,r,t columns are present and of an integer type + for col_name in [h, r, t]: + if col_name in kg_df.columns: + if not is_integer_dtype(kg_df[col_name]): + raise TypeError(f"Column {col_name} needs to be of an integer dtype") + else: + raise ValueError(f"DataFrame {kg_df} has no column named {col_name}") + # check there are no duplicated (h,r,t) triples + if kg_df[[h, r, t]].duplicated().any(): + warnings.warn( + "The Knowledge Graph contains duplicated edges" + " -- some functionalities may produce incorrect results" + ) + + def node_degrees_and_rels( df: pd.DataFrame, column: str, n_entity: int, return_relation_list: bool ) -> pd.DataFrame: