diff --git a/adbpyg_adapter/adapter.py b/adbpyg_adapter/adapter.py index e9a8847..3419e6c 100644 --- a/adbpyg_adapter/adapter.py +++ b/adbpyg_adapter/adapter.py @@ -2,11 +2,16 @@ # -*- coding: utf-8 -*- import logging from collections import defaultdict -from typing import Any, DefaultDict, List, Set, Union +from typing import Any, DefaultDict, Dict, List, Set, Union + +try: + # https://github.com/arangoml/pyg-adapter/issues/4 + from cudf import DataFrame +except ModuleNotFoundError: + from pandas import DataFrame from arango.database import Database from arango.graph import Graph as ADBGraph -from pandas import DataFrame from torch import Tensor, cat, tensor from torch_geometric.data import Data, HeteroData from torch_geometric.data.storage import EdgeStorage, NodeStorage @@ -107,13 +112,14 @@ def arangodb_to_pyg( The current supported **metagraph** values are: 1) str: The name of the ArangoDB attribute that stores your PyG-ready data - 2) Dict[str, Callable[[pandas.DataFrame], torch.Tensor] | None]: + 2) Dict[str, Callable[[(pandas | cudf).DataFrame], torch.Tensor] | None]: A dictionary mapping ArangoDB attributes to a callable Python Class (i.e has a `__call__` function defined), or to None (if the ArangoDB attribute is already a list of numerics). - 3) Callable[[pandas.DataFrame], torch.Tensor]: A user-defined function for - custom behaviour. NOTE: The function return type MUST be a tensor. + 3) Callable[[(pandas | cudf).DataFrame], torch.Tensor]: + A user-defined function for custom behaviour. + NOTE: The function return type MUST be a tensor. 1) .. code-block:: python @@ -196,7 +202,7 @@ def udf_v1_x(v1_df): } The metagraph above provides an interface for a user-defined function to - build a PyG-ready Tensor from a Pandas DataFrame equivalent to the + build a PyG-ready Tensor from a DataFrame equivalent to the associated ArangoDB collection. """ logger.debug(f"--arangodb_to_pyg('{name}')--") @@ -318,7 +324,7 @@ def pyg_to_arangodb( metagraph: PyGMetagraph = {}, explicit_metagraph: bool = True, overwrite_graph: bool = False, - use_original_adb_keys: bool = False, # TODO: explain + preserve_adb_keys: bool = False, # TODO: explain **import_options: Any, ) -> ADBGraph: """Create an ArangoDB graph from a PyG graph. @@ -353,8 +359,9 @@ def pyg_to_arangodb( 2) List[str]: A list of ArangoDB attribute names that will break down your tensor data to have one ArangoDB attribute per tensor value. - 3) Callable[[torch.Tensor], pandas.DataFrame]: A user-defined function for - custom behaviour. NOTE: The function return type MUST be a DataFrame. + 3) Callable[[torch.Tensor], (pandas | cudf).DataFrame]: + A user-defined function for custom behaviour. + NOTE: The function return type MUST be a DataFrame. 1) Here is an example entry for parameter **metagraph**: .. code-block:: python @@ -363,7 +370,7 @@ def v2_x_to_pandas_dataframe(t: Tensor): # The parameter **t** is the tensor representing # the feature matrix 'x' of the 'v2' node type. - df = pandas.DataFrame(columns=["v2_features"]) + df = (pandas | cudf).DataFrame(columns=["v2_features"]) df["v2_features"] = t.tolist() # do more things with df["v2_features"] here ... return df @@ -421,19 +428,19 @@ def v2_x_to_pandas_dataframe(t: Tensor): name, edge_definitions, orphan_collections ) - # TODO: explain - if use_original_adb_keys: + pyg_map = self.pyg_map[name] + if preserve_adb_keys: if self.adb_map[name] == {}: msg = f""" - Parameter **use_original_adb_keys** was enabled, + Parameter **preserve_adb_keys** was enabled, but no ArangoDB Map was found for graph {name} in **self.adb_map**. """ raise ValueError(msg) + # Build the reverse map for k, map in self.adb_map[name].items(): - reverse_map = {pyg_id: adb_id for adb_id, pyg_id in map.items()} - self.pyg_map[name][k].update(reverse_map) + pyg_map[k].update({pyg_id: adb_id for adb_id, pyg_id in map.items()}) # Define PyG data properties node_data: NodeStorage @@ -442,19 +449,31 @@ def v2_x_to_pandas_dataframe(t: Tensor): n_meta = metagraph.get("nodeTypes", {}) for n_type in node_types: node_data = pyg_g if is_homogeneous else pyg_g[n_type] + num_nodes = node_data.num_nodes - df = DataFrame(index=range(node_data.num_nodes)) - df["_id"] = ( - df.index.map(self.pyg_map[name][n_type]) - if use_original_adb_keys - else n_type + "/" + df.index.astype(str) - ) + df = DataFrame(index=range(num_nodes)) + if preserve_adb_keys: + num_node_keys = len(pyg_map[n_type]) + + if num_nodes != num_node_keys: + msg = f""" + {num_nodes} does not match + number of node keys in pyg_map + ({num_node_keys}) for {n_type} + """ + raise ValueError(msg) + + df["_id"] = df.index.map(pyg_map[n_type]) + else: + df["_key"] = df.index.astype(str) meta = n_meta.get(n_type, {}) - for k, t in node_data.items(): - if type(t) is Tensor and len(t) == node_data.num_nodes: - v = meta.get(k, str(k)) - df = df.join(self.__build_dataframe_from_tensor(t, k, v)) + df = self.__finish_adb_dataframe( + df, + meta, + node_data, + list(meta.keys() if explicit_metagraph else node_data.keys()), + ) if type(self.__cntrl) is not ADBPyG_Controller: f = lambda n: self.__cntrl._prepare_pyg_node(n, n_type) @@ -469,30 +488,21 @@ def v2_x_to_pandas_dataframe(t: Tensor): columns = ["_from", "_to"] df = DataFrame(zip(*(edge_data.edge_index.tolist())), columns=columns) - - if use_original_adb_keys: - df["_id"] = df.index.map(self.pyg_map[name][e_type]) - - df["_from"] = ( - df["_from"].map(self.pyg_map[name][from_col]) - if use_original_adb_keys - else from_col + "/" + df["_from"].astype(str) - ) - - df["_to"] = ( - df["_to"].map(self.pyg_map[name][to_col]) - if use_original_adb_keys - else to_col + "/" + df["_to"].astype(str) - ) + if preserve_adb_keys: + df["_id"] = df.index.map(pyg_map[e_type]) + df["_from"] = df["_from"].map(pyg_map[from_col]) + df["_to"] = df["_to"].map(pyg_map[to_col]) + else: + df["_from"] = from_col + "/" + df["_from"].astype(str) + df["_to"] = to_col + "/" + df["_to"].astype(str) meta = e_meta.get(e_type, {}) - for k, t in edge_data.items(): - if k == "edge_index": - continue - - if type(t) is Tensor and len(t) == edge_data.num_edges: - v = meta.get(k, str(k)) - df = df.join(self.__build_dataframe_from_tensor(t, k, v)) + df = self.__finish_adb_dataframe( + df, + meta, + edge_data, + list(meta.keys() if explicit_metagraph else edge_data.keys()), + ) if type(self.__cntrl) is not ADBPyG_Controller: f = lambda e: self.__cntrl._prepare_pyg_edge(e, e_type) @@ -573,7 +583,7 @@ def __fetch_adb_docs( self, col: str, empty_meta: bool, query_options: Any ) -> DataFrame: """Fetches ArangoDB documents within a collection. Returns the - documents in a Pandas DataFrame. + documents in a DataFrame. :param col: The ArangoDB collection. :type col: str @@ -583,8 +593,8 @@ def __fetch_adb_docs( :param query_options: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance. :type query_options: Any - :return: A Pandas DataFrame representing the ArangoDB documents. - :rtype: pandas.DataFrame + :return: A DataFrame representing the ArangoDB documents. + :rtype: (pandas | cudf).DataFrame """ # Only return the entire document if **empty_meta** is False data = "{_id: doc._id, _from: doc._from, _to: doc._to}" if empty_meta else "doc" @@ -637,11 +647,11 @@ def __build_tensor_from_dataframe( meta_key: str, meta_val: ADBMetagraphValues, ) -> Tensor: - """Constructs a PyG-ready Tensor from a Pandas Dataframe, based on + """Constructs a PyG-ready Tensor from a DataFrame, based on the nature of the user-defined metagraph. - :param adb_df: The Pandas Dataframe representing ArangoDB data. - :type adb_df: pandas.DataFrame + :param adb_df: The DataFrame representing ArangoDB data. + :type adb_df: (pandas | cudf).DataFrame :param meta_key: The current ArangoDB-PyG metagraph key :type meta_key: str :param meta_val: The value mapped to **meta_key** to @@ -682,13 +692,51 @@ def __build_tensor_from_dataframe( raise ADBMetagraphError(f"Invalid {meta_val} type") # pragma: no cover + def __finish_adb_dataframe( + self, + df: DataFrame, + meta: Dict[Any, PyGMetagraphValues], + pyg_data: Union[NodeStorage, EdgeStorage], + pyg_keys: List[Any], + ) -> DataFrame: + """A helper method to complete the ArangoDB Dataframe for the given + collection. Is responsible for creating DataFrames from PyG tensors, + and appending them to the main dataframe **df**. + + :param df: The main ArangoDB DataFrame containing (at minimum) + the vertex/edge _id or _key attribute. + :type df: (pandas | cudf).DataFrame + :param meta: The metagraph associated to the + current PyG node or edge type. + :type meta: Dict[Any, adbpyg_adapter.typings.PyGMetagraphValues] + :param pyg_data: The NodeStorage or EdgeStorage of the current + PyG node or edge type. + :type pyg_data: torch_geometric.data.storage.(NodeStorage | EdgeStorage) + :param pyg_keys: The set of PyG NodeStorage or EdgeStorage keys, retrieved + either from the **meta** parameter (if **explicit_metagraph** is True), + or from the **pyg_data** parameter (if **explicit_metagraph** is False). + :type pyg_keys: List[Any] + :return: The completed DataFrame for the (soon-to-be) ArangoDB collection. + :rtype: (pandas | cudf).DataFrame + """ + for k in pyg_keys: + if k == "edge_index": + continue + + t = pyg_data[k] + if type(t) is Tensor and len(t) == len(df): + v = meta.get(k, str(k)) + df = df.join(self.__build_dataframe_from_tensor(t, k, v)) + + return df + def __build_dataframe_from_tensor( self, pyg_tensor: Tensor, meta_key: Any, meta_val: PyGMetagraphValues, ) -> DataFrame: - """Builds a Pandas DataFrame from PyG Tensor, based on + """Builds a DataFrame from PyG Tensor, based on the nature of the user-defined metagraph. :param pyg_tensor: The Tensor representing PyG data. @@ -696,11 +744,11 @@ def __build_dataframe_from_tensor( :param meta_key: The current PyG-ArangoDB metagraph key :type meta_key: Any :param meta_val: The value mapped to the PyG-ArangoDB metagraph key to - help convert **tensor** into a Pandas Dataframe. + help convert **tensor** into a DataFrame. e.g the value of `metagraph['nodeTypes']['users']['x']`. :type meta_val: adbpyg_adapter.typings.PyGMetagraphValues - :return: A Pandas DataFrame equivalent to the Tensor - :rtype: pandas.DataFrame + :return: A DataFrame equivalent to the Tensor + :rtype: (pandas | cudf).DataFrame :raise adbpyg_adapter.exceptions.PyGMetagraphError: If invalid **meta_val**. """ logger.debug(f"__build_dataframe_from_tensor(df, '{meta_key}', {meta_val})")