diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 5205cc7c..01f5642a 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -18,7 +18,8 @@ from threading import Event, Thread from time import sleep import pickle -from typing import TYPE_CHECKING, Any, Iterator, NoReturn, Tuple, Union, Callable +import random +from typing import TYPE_CHECKING, Any, Iterator, NoReturn, Tuple, Union, Callable, List, Dict #import re #RE_SPLITTER = re.compile(r',(?![^\[]*\])') @@ -38,7 +39,7 @@ import pandas as pd from ..pyTigerGraphException import TigerGraphException -from .utilities import install_query_file, random_string, add_attribute +from .utilities import install_query_file, random_string, add_attribute, install_query_files __all__ = ["VertexLoader", "EdgeLoader", "NeighborLoader", "GraphLoader", "EdgeNeighborLoader", "NodePieceLoader", "HGTLoader"] @@ -89,7 +90,9 @@ def __init__( kafka_add_topic_per_epoch: bool = False, callback_fn: Callable = None, kafka_group_id: str = None, - kafka_topic: str = None + kafka_topic: str = None, + num_machines: int = 1, + num_segments: int = 20, ) -> None: """Base Class for data loaders. @@ -182,8 +185,6 @@ def __init__( self._downloader = None self._reader = None # Queues to store tasks and data - self._request_task_q = None - self._download_task_q = None self._read_task_q = None self._data_q = None self._kafka_topic = None @@ -229,14 +230,13 @@ def __init__( else: self._kafka_topic_base = self.loader_id + "_topic" self.num_batches = num_batches - self.output_format = output_format + self.output_format = output_format.lower() self.buffer_size = buffer_size self.timeout = timeout self._iterations = 0 self._iterator = False self.callback_fn = callback_fn self.distributed_query = distributed_query - self.num_heap_inserts = 10 # Kafka consumer and admin self.max_kafka_msg_size = Kafka_max_msg_size self.kafka_address_consumer = ( @@ -293,6 +293,8 @@ def __init__( ) # Initialize parameters for the query self._payload = {} + self._payload["num_machines"] = num_machines + self._payload["num_segments"] = num_segments if self.kafka_address_producer: self._payload["kafka_address"] = self.kafka_address_producer self._payload["kafka_topic_partitions"] = kafka_num_partitions @@ -548,86 +550,110 @@ def _request_kafka( tgraph.abortQuery(resp) @staticmethod - def _request_rest( + def _request_graph_rest( + tgraph: "TigerGraphConnection", + query_name: str, + read_task_q: Queue, + timeout: int = 600000, + payload: dict = {}, + ) -> NoReturn: + # Run query + resp = tgraph.runInstalledQuery( + query_name, params=payload, timeout=timeout, usePost=True + ) + # Put raw data into reading queue. + for i in resp: + read_task_q.put((i["vertex_batch"], i["edge_batch"], i["seed"])) + read_task_q.put(None) + + @staticmethod + def _request_unimode_rest( tgraph: "TigerGraphConnection", query_name: str, read_task_q: Queue, timeout: int = 600000, payload: dict = {}, - resp_type: 'Literal["both", "vertex", "edge"]' = "both", ) -> NoReturn: # Run query + #TODO: check what happens when the query times out resp = tgraph.runInstalledQuery( query_name, params=payload, timeout=timeout, usePost=True ) # Put raw data into reading queue for i in resp: - if resp_type == "both": - data = (i["vertex_batch"], i["edge_batch"]) - elif resp_type == "vertex": - data = i["vertex_batch"] - elif resp_type == "edge": - data = i["edge_batch"] - read_task_q.put(data) + read_task_q.put(i["data_batch"]) read_task_q.put(None) @staticmethod - def _download_from_kafka( + def _download_graph_kafka( exit_event: Event, read_task_q: Queue, - num_batches: int, - out_tuple: bool, - kafka_consumer: "KafkaConsumer", - max_wait_time: int = 300 + kafka_consumer: "KafkaConsumer" ) -> NoReturn: - delivered_batch = 0 + empty = False buffer = {} - wait_time = 0 - while (not exit_event.is_set()) and (wait_time < max_wait_time): - if delivered_batch == num_batches: - break + while (not exit_event.is_set()) and (not empty): resp = kafka_consumer.poll(1000) if not resp: - wait_time += 1 continue - wait_time = 0 for msgs in resp.values(): for message in msgs: key = message.key.decode("utf-8") - if out_tuple: - if key.startswith("vertex"): - companion_key = key.replace("vertex", "edge") - if companion_key in buffer: - read_task_q.put((message.value.decode("utf-8"), - buffer[companion_key])) - del buffer[companion_key] - delivered_batch += 1 - else: - buffer[key] = message.value.decode("utf-8") - elif key.startswith("edge"): - companion_key = key.replace("edge", "vertex") - if companion_key in buffer: - read_task_q.put((buffer[companion_key], - message.value.decode("utf-8"))) - del buffer[companion_key] - delivered_batch += 1 - else: - buffer[key] = message.value.decode("utf-8") + if key == "STOP": + read_task_q.put(None) + empty = True + break + if key.startswith("vertex"): + companion_key = key.replace("vertex", "edge") + if companion_key in buffer: + read_task_q.put((message.value.decode("utf-8"), + buffer[companion_key], + key.split("_", 2)[-1])) + del buffer[companion_key] else: - raise ValueError( - "Unrecognized key {} for messages in kafka".format(key) - ) + buffer[key] = message.value.decode("utf-8") + elif key.startswith("edge"): + companion_key = key.replace("edge", "vertex") + if companion_key in buffer: + read_task_q.put((buffer[companion_key], + message.value.decode("utf-8"), + key.split("_", 2)[-1])) + del buffer[companion_key] + else: + buffer[key] = message.value.decode("utf-8") + else: + warnings.warn( + "Unrecognized key {} for messages in kafka".format(key) + ) + if empty: + break + + @staticmethod + def _download_unimode_kafka( + exit_event: Event, + read_task_q: Queue, + kafka_consumer: "KafkaConsumer" + ) -> NoReturn: + empty = False + while (not exit_event.is_set()) and (not empty): + resp = kafka_consumer.poll(1000) + if not resp: + continue + for msgs in resp.values(): + for message in msgs: + key = message.key.decode("utf-8") + if key == "STOP": + read_task_q.put(None) + empty = True else: read_task_q.put(message.value.decode("utf-8")) - delivered_batch += 1 - read_task_q.put(None) @staticmethod - def _read_data( + def _read_graph_data( exit_event: Event, in_q: Queue, out_q: Queue, - in_format: str = "vertex", + batch_size: int, out_format: str = "dataframe", v_in_feats: Union[list, dict] = [], v_out_labels: Union[list, dict] = [], @@ -639,21 +665,82 @@ def _read_data( e_attr_types: dict = {}, add_self_loop: bool = False, delimiter: str = "|", - reindex: bool = True, is_hetero: bool = False, callback_fn: Callable = None, + seed_type: str = "" ) -> NoReturn: - while not exit_event.is_set(): - raw = in_q.get() + # Import the right libraries based on output format + out_format = out_format.lower() + if out_format == "pyg" or out_format == "dgl": + try: + import torch + except ImportError: + raise ImportError( + "PyTorch is not installed. Please install it to use PyG or DGL output." + ) + if out_format == "dgl": + try: + import dgl + except ImportError: + raise ImportError( + "DGL is not installed. Please install DGL to use DGL format." + ) + elif out_format == "pyg": + try: + import torch_geometric as pyg + except ImportError: + raise ImportError( + "PyG is not installed. Please install PyG to use PyG format." + ) + elif out_format.lower() == "spektral": + try: + import tensorflow as tf + except ImportError: + raise ImportError( + "Tensorflow is not installed. Please install it to use spektral output." + ) + try: + import scipy + except ImportError: + raise ImportError( + "scipy is not installed. Please install it to use spektral output." + ) + try: + import spektral + except ImportError: + raise ImportError( + "Spektral is not installed. Please install it to use spektral output." + ) + # Get raw data from queue and parse + vertex_buffer = dict() + edge_buffer = dict() + buffer_size = 0 + seeds = set() + is_empty = False + last_batch = False + while (not exit_event.is_set()) and (not is_empty): + try: + raw = in_q.get(timeout=1) + except Empty: + continue if raw is None: - in_q.task_done() - out_q.put(None) - break + is_empty = True + if buffer_size > 0: + last_batch = True + else: + vertex_buffer.update({i.strip():"" for i in raw[0].strip().splitlines()}) + edge_buffer.update({i.strip():"" for i in raw[1].strip().splitlines()}) + seeds.add(raw[2]) + buffer_size += 1 + if (buffer_size < batch_size) and (not last_batch): + continue try: - data = BaseLoader._parse_data( - raw = raw, - in_format = in_format, - out_format = out_format, + if seed_type: + raw_data = (vertex_buffer.keys(), edge_buffer.keys(), seeds) + else: + raw_data = (vertex_buffer.keys(), edge_buffer.keys()) + data = BaseLoader._parse_graph_data_to_df( + raw = raw_data, v_in_feats = v_in_feats, v_out_labels = v_out_labels, v_extra_feats = v_extra_feats, @@ -662,28 +749,349 @@ def _read_data( e_out_labels = e_out_labels, e_extra_feats = e_extra_feats, e_attr_types = e_attr_types, - add_self_loop = add_self_loop, delimiter = delimiter, - reindex = reindex, primary_id = {}, is_hetero = is_hetero, - callback_fn = callback_fn + seed_type = seed_type + ) + if out_format == "dataframe" or out_format == "df": + vertices, edges = data + if not is_hetero: + for column in vertices.columns: + vertices[column] = pd.to_numeric(vertices[column], errors="ignore") + for column in edges.columns: + edges[column] = pd.to_numeric(edges[column], errors="ignore") + else: + for key in vertices: + for column in vertices[key].columns: + vertices[key][column] = pd.to_numeric(vertices[key][column], errors="ignore") + for key in edges: + for column in edges[key].columns: + edges[key][column] = pd.to_numeric(edges[key][column], errors="ignore") + data = (vertices, edges) + elif out_format == "pyg": + data = BaseLoader._parse_df_to_pyg( + raw = data, + v_in_feats = v_in_feats, + v_out_labels = v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = e_in_feats, + e_out_labels = e_out_labels, + e_extra_feats = e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = add_self_loop, + is_hetero = is_hetero, + torch = torch, + pyg = pyg + ) + elif out_format == "dgl": + data = BaseLoader._parse_df_to_dgl( + raw = data, + v_in_feats = v_in_feats, + v_out_labels = v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = e_in_feats, + e_out_labels = e_out_labels, + e_extra_feats = e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = add_self_loop, + is_hetero = is_hetero, + torch = torch, + dgl= dgl + ) + elif out_format == "spektral" and is_hetero==False: + data = BaseLoader._parse_df_to_spektral( + raw = data, + v_in_feats = v_in_feats, + v_out_labels = v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = e_in_feats, + e_out_labels = e_out_labels, + e_extra_feats = e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = add_self_loop, + is_hetero = is_hetero, + scipy = scipy, + spektral = spektral + ) + else: + raise NotImplementedError + if callback_fn: + data = callback_fn(data) + out_q.put(data) + except Exception as err: + warnings.warn("Error parsing a graph batch. Set logging level to ERROR for details.") + logger.error(err, exc_info=True) + logger.error("Error parsing data: {}".format((vertex_buffer, edge_buffer))) + logger.error("Parameters:\n out_format={}\n v_in_feats={}\n v_out_labels={}\n v_extra_feats={}\n v_attr_types={}\n e_in_feats={}\n e_out_labels={}\n e_extra_feats={}\n e_attr_types={}\n delimiter={}\n".format( + out_format, v_in_feats, v_out_labels, v_extra_feats, v_attr_types, e_in_feats, e_out_labels, e_extra_feats, e_attr_types, delimiter)) + vertex_buffer.clear() + edge_buffer.clear() + seeds.clear() + buffer_size = 0 + out_q.put(None) + + @staticmethod + def _read_vertex_data( + exit_event: Event, + in_q: Queue, + out_q: Queue, + batch_size: int, + v_in_feats: Union[list, dict] = [], + v_out_labels: Union[list, dict] = [], + v_extra_feats: Union[list, dict] = [], + v_attr_types: dict = {}, + delimiter: str = "|", + is_hetero: bool = False, + callback_fn: Callable = None + ) -> NoReturn: + buffer = [] + last_batch = False + is_empty = False + while (not exit_event.is_set()) and (not is_empty): + try: + raw = in_q.get(timeout=1) + except Empty: + continue + if raw is None: + is_empty = True + if len(buffer) > 0: + last_batch = True + else: + buffer.append(raw) + if (len(buffer) < batch_size) and (not last_batch): + continue + try: + data = BaseLoader._parse_vertex_data( + raw = buffer, + v_in_feats = v_in_feats, + v_out_labels = v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + delimiter = delimiter, + is_hetero = is_hetero ) + if not is_hetero: + for column in data.columns: + data[column] = pd.to_numeric(data[column], errors="ignore") + else: + for key in data: + for column in data[key].columns: + data[key][column] = pd.to_numeric(data[key][column], errors="ignore") + if callback_fn: + data = callback_fn(data) out_q.put(data) except Exception as err: - warnings.warn("Error parsing a data batch. Set logging level to ERROR for details.") + warnings.warn("Error parsing a vertex batch. Set logging level to ERROR for details.") logger.error(err, exc_info=True) - logger.error("Error parsing data: {}".format(raw)) - logger.error("Parameters:\n in_format={}\n out_format={}\n v_in_feats={}\n v_out_labels={}\n v_extra_feats={}\n v_attr_types={}\n e_in_feats={}\n e_out_labels={}\n e_extra_feats={}\n e_attr_types={}\n delimiter={}\n".format( - in_format, out_format, v_in_feats, v_out_labels, v_extra_feats, v_attr_types, e_in_feats, e_out_labels, e_extra_feats, e_attr_types, delimiter)) + logger.error("Error parsing data: {}".format(buffer)) + logger.error("Parameters:\n v_in_feats={}\n v_out_labels={}\n v_extra_feats={}\n v_attr_types={}\n delimiter={}\n".format( + v_in_feats, v_out_labels, v_extra_feats, v_attr_types, delimiter)) + buffer.clear() + out_q.put(None) - in_q.task_done() + @staticmethod + def _read_edge_data( + exit_event: Event, + in_q: Queue, + out_q: Queue, + batch_size: int, + e_in_feats: Union[list, dict] = [], + e_out_labels: Union[list, dict] = [], + e_extra_feats: Union[list, dict] = [], + e_attr_types: dict = {}, + delimiter: str = "|", + is_hetero: bool = False, + callback_fn: Callable = None + ) -> NoReturn: + buffer = [] + is_empty = False + last_batch = False + while (not exit_event.is_set()) and (not is_empty): + try: + raw = in_q.get(timeout=1) + except Empty: + continue + if raw is None: + is_empty = True + if len(buffer) > 0: + last_batch = True + else: + buffer.append(raw) + if (len(buffer) < batch_size) and (not last_batch): + continue + try: + data = BaseLoader._parse_edge_data( + raw = buffer, + e_in_feats = e_in_feats, + e_out_labels = e_out_labels, + e_extra_feats = e_extra_feats, + e_attr_types = e_attr_types, + delimiter = delimiter, + is_hetero = is_hetero + ) + if not is_hetero: + for column in data.columns: + data[column] = pd.to_numeric(data[column], errors="ignore") + else: + for key in data: + for column in data[key].columns: + data[key][column] = pd.to_numeric(data[key][column], errors="ignore") + if callback_fn: + data = callback_fn(data) + out_q.put(data) + except Exception as err: + warnings.warn("Error parsing an edge batch. Set logging level to ERROR for details.") + logger.error(err, exc_info=True) + logger.error("Error parsing data: {}".format(buffer)) + logger.error("Parameters:\n e_in_feats={}\n e_out_labels={}\n e_extra_feats={}\n e_attr_types={}\n delimiter={}\n".format( + e_in_feats, e_out_labels, e_extra_feats, e_attr_types, delimiter)) + buffer.clear() + out_q.put(None) @staticmethod - def _parse_data( - raw: Union[str, Tuple[str, str]], - in_format: 'Literal["vertex", "edge", "graph"]' = "vertex", - out_format: str = "dataframe", + def _parse_vertex_data( + raw: List[str], + v_in_feats: Union[list, dict] = [], + v_out_labels: Union[list, dict] = [], + v_extra_feats: Union[list, dict] = [], + v_attr_types: dict = {}, + delimiter: str = "|", + is_hetero: bool = False, + seeds: list = [] + ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: + """Parse raw vertex data into dataframes. + """ + # Read in vertex CSVs as dataframes + # Each row is in format vid,v_in_feats,v_out_labels,v_extra_feats + # or vtype,vid,v_in_feats,v_out_labels,v_extra_feats + v_file = (line.split(delimiter) for line in raw) + # If seeds are given, create the is_seed column + if seeds: + seed_df = pd.DataFrame({ + "vid": list(seeds), + "is_seed": True + }) + if not is_hetero: + # String of vertices in format vid,v_in_feats,v_out_labels,v_extra_feats + v_attributes = ["vid"] + v_in_feats + v_out_labels + v_extra_feats + if seeds: + try: + v_attributes.remove("is_seed") + except ValueError: + pass + data = pd.DataFrame(v_file, columns=v_attributes, dtype="object") + for v_attr in v_extra_feats: + if v_attr_types.get(v_attr, "") == "MAP": + # I am sorry that this is this ugly... + data[v_attr] = data[v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + if seeds: + data = data.merge(seed_df, on="vid", how="left") + data.fillna({"is_seed": False}, inplace=True) + else: + v_file_dict = defaultdict(list) + for line in v_file: + v_file_dict[line[0]].append(line[1:]) + data = {} + for vtype in v_file_dict: + v_attributes = ["vid"] + \ + v_in_feats.get(vtype, []) + \ + v_out_labels.get(vtype, []) + \ + v_extra_feats.get(vtype, []) + if seeds: + try: + v_attributes.remove("is_seed") + except ValueError: + pass + data[vtype] = pd.DataFrame(v_file_dict[vtype], columns=v_attributes, dtype="object") + for v_attr in v_extra_feats.get(vtype, []): + if v_attr_types[vtype][v_attr] == "MAP": + # I am sorry that this is this ugly... + data[vtype][v_attr] = data[vtype][v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + if seeds: + data[vtype] = data[vtype].merge(seed_df, on="vid", how="left") + data[vtype].fillna({"is_seed": False}, inplace=True) + return data + + @staticmethod + def _parse_edge_data( + raw: List[str], + e_in_feats: Union[list, dict] = [], + e_out_labels: Union[list, dict] = [], + e_extra_feats: Union[list, dict] = [], + e_attr_types: dict = {}, + delimiter: str = "|", + is_hetero: bool = False, + seeds: list = [] + ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: + """Parse raw edge data into dataframes. + """ + # Read in edge CSVs as dataframes + # Each row is in format source_vid,target_vid,e_in_feats,e_out_labels,e_extra_feats + # or etype,source_vid,target_vid,e_in_feats,e_out_labels,e_extra_feats + e_file = (line.split(delimiter) for line in raw) + if not is_hetero: + e_attributes = ["source", "target"] + e_in_feats + e_out_labels + e_extra_feats + if seeds: + try: + e_attributes.remove("is_seed") + except ValueError: + pass + data = pd.DataFrame(e_file, columns=e_attributes, dtype="object") + for e_attr in e_extra_feats: + if e_attr_types.get(e_attr, "") == "MAP": + # I am sorry that this is this ugly... + data[e_attr] = data[e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + # If seeds are given, create the is_seed column + if seeds: + seed_df = pd.DataFrame.from_records( + [(i.split("_", 1)[0], i.rsplit("_", 1)[-1]) for i in seeds], + columns=["source", "target"]) + seed_df["is_seed"] = True + data = data.merge(seed_df, on=["source", "target"], how="left") + data.fillna({"is_seed": False}, inplace=True) + else: + e_file_dict = defaultdict(list) + for line in e_file: + e_file_dict[line[0]].append(line[1:]) + data = {} + # If seeds are given, create the is_seed column + if seeds: + seed_df = pd.DataFrame.from_records( + [(i.split("_", 1)[0], i.split("_", 1)[1].rsplit("_", 1)[0], i.rsplit("_", 1)[-1]) for i in seeds], + columns=["source", "etype", "target"]) + seed_df["is_seed"] = True + for etype in e_file_dict: + e_attributes = ["source", "target"] + \ + e_in_feats.get(etype, []) + \ + e_out_labels.get(etype, []) + \ + e_extra_feats.get(etype, []) + if seeds: + try: + e_attributes.remove("is_seed") + except ValueError: + pass + data[etype] = pd.DataFrame(e_file_dict[etype], columns=e_attributes, dtype="object") + for e_attr in e_extra_feats.get(etype, []): + if e_attr_types[etype][e_attr] == "MAP": + # I am sorry that this is this ugly... + data[etype][e_attr] = data[etype][e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + if seeds: + tmp_df = seed_df[seed_df["etype"]==etype] + if len(tmp_df)>0: + data[etype] = data[etype].merge( + tmp_df[["source", "target", "is_seed"]], on=["source", "target"], how="left") + data[etype].fillna({"is_seed": False}, inplace=True) + else: + data[etype]["is_seed"] = False + return data + + @staticmethod + def _parse_graph_data_to_df( + raw: Tuple[List[str], List[str]], v_in_feats: Union[list, dict] = [], v_out_labels: Union[list, dict] = [], v_extra_feats: Union[list, dict] = [], @@ -692,51 +1100,137 @@ def _parse_data( e_out_labels: Union[list, dict] = [], e_extra_feats: Union[list, dict] = [], e_attr_types: dict = {}, - add_self_loop: bool = False, delimiter: str = "|", - reindex: bool = True, primary_id: dict = {}, is_hetero: bool = False, - callback_fn: Callable = None, - ) -> Union[pd.DataFrame, Tuple[pd.DataFrame, pd.DataFrame], "dgl.DGLGraph", "pyg.data.Data", "spektral.data.graph.Graph", - dict, Tuple[dict, dict], "pyg.data.HeteroData"]: - """Parse raw data into dataframes, DGL graphs, or PyG graphs. + seed_type: str = "" + ) -> Union[Tuple[pd.DataFrame, pd.DataFrame], Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]]: + """Parse raw data into dataframes. """ - def attr_to_tensor( + # Read in vertex and edge CSVs as dataframes + # A pair of in-memory CSVs (vertex, edge) + if len(raw) == 3: + v_file, e_file, seed_file = raw + else: + v_file, e_file = raw + seed_file = [] + vertices = BaseLoader._parse_vertex_data( + raw = v_file, + v_in_feats = v_in_feats, + v_out_labels = v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + delimiter = delimiter, + is_hetero = is_hetero, + seeds = seed_file if seed_type=="vertex" else [] + ) + if primary_id: + id_map = pd.DataFrame({"vid": primary_id.keys(), "primary_id": primary_id.values()}, + dtype="object") + if not is_hetero: + vertices = vertices.merge(id_map, on="vid") + v_extra_feats.append("primary_id") + else: + for vtype in vertices: + vertices[vtype] = vertices[vtype].merge(id_map, on="vid") + v_extra_feats[vtype].append("primary_id") + edges = BaseLoader._parse_edge_data( + raw = e_file, + e_in_feats = e_in_feats, + e_out_labels = e_out_labels, + e_extra_feats = e_extra_feats, + e_attr_types = e_attr_types, + delimiter = delimiter, + is_hetero = is_hetero, + seeds = seed_file if seed_type=="edge" else [] + ) + return (vertices, edges) + + @staticmethod + def _attributes_to_np( attributes: list, attr_types: dict, df: pd.DataFrame - ) -> "torch.Tensor": - """Turn multiple columns of a dataframe into a tensor. - """ - x = [] - for col in attributes: - dtype = attr_types[col].lower() - if dtype.startswith("str"): - raise TypeError( - "String type not allowed for input and output features." - ) - if dtype.startswith("list"): - dtype2 = dtype.split(":")[1] - x.append(df[col].str.split(expand=True).to_numpy().astype(dtype2)) - elif dtype.startswith("set") or dtype.startswith("map") or dtype.startswith("date"): - raise NotImplementedError( - "{} type not supported for input and output features yet.".format(dtype)) - elif dtype == "bool": - x.append(df[[col]].astype("int8").to_numpy().astype(dtype)) - elif dtype == "uint": - # PyTorch only supports uint8. Need to convert it to int. - x.append(df[[col]].to_numpy().astype("int")) + ) -> np.ndarray: + """Turn multiple columns of a dataframe into a numpy array. + """ + x = [] + for col in attributes: + dtype = attr_types[col].lower() + if dtype.startswith("str"): + raise TypeError( + "String type not allowed for input and output features." + ) + if dtype.startswith("list"): + dtype2 = dtype.split(":")[1] + x.append(df[col].str.split(expand=True).to_numpy().astype(dtype2)) + elif dtype.startswith("set") or dtype.startswith("map") or dtype.startswith("date"): + raise NotImplementedError( + "{} type not supported for input and output features yet.".format(dtype)) + elif dtype == "bool": + x.append(df[[col]].astype("int8").to_numpy().astype(dtype)) + elif dtype == "uint": + # PyTorch only supports uint8. Need to convert it to int. + x.append(df[[col]].to_numpy().astype("int")) + else: + x.append(df[[col]].to_numpy().astype(dtype)) + return np.hstack(x) + + @staticmethod + def _get_edgelist( + vertices: Union[pd.DataFrame, Dict[str, pd.DataFrame]], + edges: Union[pd.DataFrame, Dict[str, pd.DataFrame]], + is_hetero: bool, + e_attr_types: dict = {} + ): + if not is_hetero: + vertices["tmp_id"] = range(len(vertices)) + id_map = vertices[["vid", "tmp_id"]] + edgelist = edges[["source", "target"]].merge(id_map, left_on="source", right_on="vid") + edgelist = edgelist.merge(id_map, left_on="target", right_on="vid") + edgelist = edgelist[["tmp_id_x", "tmp_id_y"]] + else: + edgelist = {} + id_map = {} + for vtype in vertices: + vertices[vtype]["tmp_id"] = range(len(vertices[vtype])) + id_map[vtype] = vertices[vtype][["vid", "tmp_id"]] + for etype in edges: + source_type = e_attr_types[etype]["FromVertexTypeName"] + target_type = e_attr_types[etype]["ToVertexTypeName"] + if e_attr_types[etype]["IsDirected"] or source_type==target_type: + edgelist[etype] = edges[etype][["source", "target"]].merge(id_map[source_type], left_on="source", right_on="vid") + edgelist[etype] = edgelist[etype].merge(id_map[target_type], left_on="target", right_on="vid") + edgelist[etype] = edgelist[etype][["tmp_id_x", "tmp_id_y"]] else: - x.append(df[[col]].to_numpy().astype(dtype)) - if mode == "pyg" or mode == "dgl": - return torch.tensor(np.hstack(x)).squeeze(dim=1) - elif mode == "spektral": - try: - return np.squeeze(np.hstack(x), axis=1) #throws an error if axis isn't 1 - except: - return np.hstack(x) + subdf1 = edges[etype].merge(id_map[source_type], left_on="source", right_on="vid") + subdf1 = subdf1.merge(id_map[target_type], left_on="target", right_on="vid") + if len(subdf1) < len(edges[etype]): + subdf2 = edges[etype].merge(id_map[source_type], left_on="target", right_on="vid") + subdf2 = subdf2.merge(id_map[target_type], left_on="source", right_on="vid") + subdf1 = pd.concat((subdf1, subdf2), ignore_index=True) + edges[etype] = subdf1 + edgelist[etype] = edges[etype][["tmp_id_x", "tmp_id_y"]] + return edgelist + @staticmethod + def _parse_df_to_pyg( + raw: Union[Tuple[pd.DataFrame, pd.DataFrame], Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]], + v_in_feats: Union[list, dict] = [], + v_out_labels: Union[list, dict] = [], + v_extra_feats: Union[list, dict] = [], + v_attr_types: dict = {}, + e_in_feats: Union[list, dict] = [], + e_out_labels: Union[list, dict] = [], + e_extra_feats: Union[list, dict] = [], + e_attr_types: dict = {}, + add_self_loop: bool = False, + is_hetero: bool = False, + torch = None, + pyg = None + ) -> Union["pyg.data.Data", "pyg.data.HeteroData"]: + """Parse dataframes to PyG graphs. + """ def add_attributes(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, - graph, is_hetero: bool, mode: str, feat_name: str, + graph, is_hetero: bool, feat_name: str, target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: """Add multiple attributes as a single feature to edges or vertices. """ @@ -744,31 +1238,19 @@ def add_attributes(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, if not vetype: raise ValueError("Vertex or edge type required for heterogeneous graphs") # Focus on a specific type - if mode == "pyg": - if target == "edge": - data = graph[attr_types["FromVertexTypeName"], - vetype, - attr_types["ToVertexTypeName"]] - elif target == "vertex": - data = graph[vetype] - elif mode == "dgl": - if target == "edge": - data = graph.edges[vetype].data - elif target == "vertex": - data = graph.nodes[vetype].data + if target == "edge": + data = graph[attr_types["FromVertexTypeName"], + vetype, + attr_types["ToVertexTypeName"]] + elif target == "vertex": + data = graph[vetype] else: - if mode == "pyg" or mode == "spektral": - data = graph - elif mode == "dgl": - if target == "edge": - data = graph.edata - elif target == "vertex": - data = graph.ndata - - data[feat_name] = attr_to_tensor(attr_names, attr_types, attr_df) + data = graph + array = BaseLoader._attributes_to_np(attr_names, attr_types, attr_df) + data[feat_name] = torch.tensor(array).squeeze(dim=1) def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, - graph, is_hetero: bool, mode: str, + graph, is_hetero: bool, target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: """Add each attribute as a single feature to edges or vertices. """ @@ -776,411 +1258,292 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, if not vetype: raise ValueError("Vertex or edge type required for heterogeneous graphs") # Focus on a specific type - if mode == "pyg": - if target == "edge": - data = graph[attr_types["FromVertexTypeName"], - vetype, - attr_types["ToVertexTypeName"]] - elif target == "vertex": - data = graph[vetype] - elif mode == "dgl": - if target == "edge": - data = graph.edges[vetype].data - elif target == "vertex": - data = graph.nodes[vetype].data + if target == "edge": + data = graph[attr_types["FromVertexTypeName"], + vetype, + attr_types["ToVertexTypeName"]] + elif target == "vertex": + data = graph[vetype] else: - if mode == "pyg" or mode == "spektral": - data = graph - elif mode == "dgl": - if target == "edge": - data = graph.edata - elif target == "vertex": - data = graph.ndata + data = graph for col in attr_names: dtype = attr_types[col].lower() if dtype.startswith("str") or dtype.startswith("map"): - if mode == "dgl": - if vetype is None: - # Homogeneous graph, add column directly to extra data - graph.extra_data[col] = attr_df[col].to_list() - elif vetype not in graph.extra_data: - # Hetero graph, vetype doesn't exist in extra data - graph.extra_data[vetype] = {} - graph.extra_data[vetype][col] = attr_df[col].to_list() - else: - # Hetero graph and vetype already exists - graph.extra_data[vetype][col] = attr_df[col].to_list() - elif mode == "pyg" or mode == "spektral": - data[col] = attr_df[col].to_list() + data[col] = attr_df[col].to_list() elif dtype.startswith("list"): dtype2 = dtype.split(":")[1] if dtype2.startswith("str"): - if mode == "dgl": - if vetype is None: - # Homogeneous graph, add column directly to extra data - graph.extra_data[col] = attr_df[col].str.split().to_list() - elif vetype not in graph.extra_data: - # Hetero graph, vetype doesn't exist in extra data - graph.extra_data[vetype] = {} - graph.extra_data[vetype][col] = attr_df[col].str.split().to_list() - else: - # Hetero graph and vetype already exists - graph.extra_data[vetype][col] = attr_df[col].str.split().to_list() - elif mode == "pyg" or mode == "spektral": - data[col] = attr_df[col].str.split().to_list() + data[col] = attr_df[col].str.split().to_list() else: - if mode == "pyg" or mode == "dgl": - data[col] = torch.tensor( - attr_df[col] - .str.split(expand=True) - .to_numpy() - .astype(dtype2) - ) - elif mode == "spektral": - data[col] = attr_df[col].str.split(expand=True).to_numpy().astype(dtype2) + data[col] = torch.tensor( + attr_df[col] + .str.split(expand=True) + .to_numpy() + .astype(dtype2) + ) elif dtype.startswith("set") or dtype.startswith("date"): raise NotImplementedError( "{} type not supported for extra features yet.".format(dtype)) elif dtype == "bool": - if mode == "pyg" or mode == "dgl": - data[col] = torch.tensor( - attr_df[col].astype("int8").astype(dtype) - ) - elif mode == "spektral": - data[col] = attr_df[col].astype("int8").astype(dtype) + data[col] = torch.tensor( + attr_df[col].astype("int8").astype(dtype) + ) elif dtype == "uint": # PyTorch only supports uint8. Need to convert it to int. - if mode == "pyg" or mode == "dgl": - data[col] = torch.tensor( - attr_df[col].astype("int") - ) - elif mode == "spektral": - data[col] = attr_df[col].astype(dtype) + data[col] = torch.tensor( + attr_df[col].astype("int") + ) else: - if mode == "pyg" or mode == "dgl": - data[col] = torch.tensor( - attr_df[col].astype(dtype) - ) - elif mode == "spektral": - data[col] = attr_df[col].astype(dtype) + data[col] = torch.tensor( + attr_df[col].astype(dtype) + ) - # Read in vertex and edge CSVs as dataframes - vertices, edges = None, None - if in_format == "vertex": - # String of vertices in format vid,v_in_feats,v_out_labels,v_extra_feats - if not is_hetero: - v_attributes = ["vid"] + v_in_feats + v_out_labels + v_extra_feats - v_file = (line.split(delimiter) for line in raw.split('\n') if line) - data = pd.DataFrame(v_file, columns=v_attributes) - for column in data.columns: - data[column] = pd.to_numeric(data[column], errors="ignore") - for v_attr in v_attributes: - if v_attr_types.get(v_attr, "") == "MAP": - # I am sorry that this is this ugly... - data[v_attr] = data[v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - else: - v_file = (line.split(delimiter) for line in raw.split('\n') if line) - v_file_dict = defaultdict(list) - for line in v_file: - v_file_dict[line[0]].append(line[1:]) - vertices = {} - for vtype in v_file_dict: - v_attributes = ["vid"] + \ - v_in_feats.get(vtype, []) + \ - v_out_labels.get(vtype, []) + \ - v_extra_feats.get(vtype, []) - vertices[vtype] = pd.DataFrame(v_file_dict[vtype], columns=v_attributes) - for v_attr in v_extra_feats.get(vtype, []): - if v_attr_types[vtype][v_attr] == "MAP": - # I am sorry that this is this ugly... - vertices[vtype][v_attr] = vertices[vtype][v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - data = vertices - elif in_format == "edge": - # String of edges in format source_vid,target_vid - if not is_hetero: - e_attributes = ["source", "target"] + e_in_feats + e_out_labels + e_extra_feats - #file = "\n".join(x for x in raw.split("\n") if x.strip()) - #data = pd.read_table(io.StringIO(file), header=None, names=e_attributes, sep=delimiter) - e_file = (line.split(delimiter) for line in raw.split('\n') if line) - data = pd.DataFrame(e_file, columns=e_attributes) - for column in data.columns: - data[column] = pd.to_numeric(data[column], errors="ignore") - for e_attr in e_attributes: - if e_attr_types.get(e_attr, "") == "MAP": - # I am sorry that this is this ugly... - data[e_attr] = data[e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - else: - e_file = (line.split(delimiter) for line in raw.split('\n') if line) - e_file_dict = defaultdict(list) - for line in e_file: - e_file_dict[line[0]].append(line[1:]) - edges = {} - for etype in e_file_dict: - e_attributes = ["source", "target"] + \ - e_in_feats.get(etype, []) + \ - e_out_labels.get(etype, []) + \ - e_extra_feats.get(etype, []) - edges[etype] = pd.DataFrame(e_file_dict[etype], columns=e_attributes) - for e_attr in e_extra_feats.get(etype, []): - if e_attr_types[etype][e_attr] == "MAP": - # I am sorry that this is this ugly... - edges[etype][e_attr] = edges[etype][e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - del e_file_dict, e_file - data = edges - elif in_format == "graph": - # A pair of in-memory CSVs (vertex, edge) - v_file, e_file = raw - if not is_hetero: - v_attributes = ["vid"] + v_in_feats + v_out_labels + v_extra_feats - e_attributes = ["source", "target"] + e_in_feats + e_out_labels + e_extra_feats - #file = "\n".join(x for x in v_file.split("\n") if x.strip()) - v_file = (line.split(delimiter) for line in v_file.split('\n') if line) - vertices = pd.DataFrame(v_file, columns=v_attributes) - for column in vertices.columns: - vertices[column] = pd.to_numeric(vertices[column], errors="ignore") - for v_attr in v_extra_feats: - if v_attr_types[v_attr] == "MAP": - # I am sorry that this is this ugly... - vertices[v_attr] = vertices[v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - if primary_id: - id_map = pd.DataFrame({"vid": primary_id.keys(), "primary_id": primary_id.values()}) - vertices = vertices.merge(id_map.astype({"vid": vertices["vid"].dtype}), on="vid") - v_extra_feats.append("primary_id") - #file = "\n".join(x for x in e_file.split("\n") if x.strip()) - e_file = (line.split(delimiter) for line in e_file.split('\n') if line) - #edges = pd.read_table(io.StringIO(file), header=None, names=e_attributes, dtype="object", sep=delimiter) - edges = pd.DataFrame(e_file, columns=e_attributes) - for column in edges.columns: - edges[column] = pd.to_numeric(edges[column], errors="ignore") - for e_attr in e_attributes: - if e_attr_types.get(e_attr, "") == "MAP": - # I am sorry that this is this ugly... - edges[e_attr] = edges[e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - else: - v_file = (line.split(delimiter) for line in v_file.split('\n') if line) - v_file_dict = defaultdict(list) - for line in v_file: - v_file_dict[line[0]].append(line[1:]) - vertices = {} - for vtype in v_file_dict: - v_attributes = ["vid"] + \ - v_in_feats.get(vtype, []) + \ - v_out_labels.get(vtype, []) + \ - v_extra_feats.get(vtype, []) - vertices[vtype] = pd.DataFrame(v_file_dict[vtype], columns=v_attributes, dtype="object") - for v_attr in v_extra_feats.get(vtype, []): - if v_attr_types[vtype][v_attr] == "MAP": - # I am sorry that this is this ugly... - vertices[vtype][v_attr] = vertices[vtype][v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - if primary_id: - id_map = pd.DataFrame({"vid": primary_id.keys(), "primary_id": primary_id.values()}, - dtype="object") - for vtype in vertices: - vertices[vtype] = vertices[vtype].merge(id_map, on="vid") - v_extra_feats[vtype].append("primary_id") - del v_file_dict, v_file - e_file = (line.split(delimiter) for line in e_file.split('\n') if line) - e_file_dict = defaultdict(list) - for line in e_file: - e_file_dict[line[0]].append(line[1:]) - edges = {} - for etype in e_file_dict: - e_attributes = ["source", "target"] + \ - e_in_feats.get(etype, []) + \ - e_out_labels.get(etype, []) + \ - e_extra_feats.get(etype, []) - edges[etype] = pd.DataFrame(e_file_dict[etype], columns=e_attributes, dtype="object") - for e_attr in e_extra_feats.get(etype, []): - if e_attr_types[etype][e_attr] == "MAP": - # I am sorry that this is this ugly... - edges[etype][e_attr] = edges[etype][e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - del e_file_dict, e_file - data = (vertices, edges) - else: - raise NotImplementedError - # Convert dataframes into PyG or DGL graphs - if out_format.lower() == "pyg" or out_format.lower() == "dgl": - if vertices is None or edges is None: - raise ValueError( - "Spektral, PyG, or DGL format can only be used with (sub)graph loaders." - ) - try: - import torch - except ImportError: - raise ImportError( - "PyTorch is not installed. Please install it to use PyG or DGL output." - ) - if out_format.lower() == "dgl": - try: - import dgl - mode = "dgl" - except ImportError: - raise ImportError( - "DGL is not installed. Please install DGL to use DGL format." + # Convert dataframes into PyG graphs + # Reformat as a graph. + # Need to have a pair of tables for edges and vertices. + vertices, edges = raw + edgelist = BaseLoader._get_edgelist(vertices, edges, is_hetero, e_attr_types) + if not is_hetero: + # Deal with edgelist first + edgelist = torch.tensor(edgelist.to_numpy().T, dtype=torch.long) + if add_self_loop: + edgelist = pyg.utils.add_self_loops(edgelist)[0] + data = pyg.data.Data() + data["edge_index"] = edgelist + # Deal with edge attributes + if e_in_feats: + add_attributes(e_in_feats, e_attr_types, edges, + data, is_hetero, "edge_feat", "edge") + if e_out_labels: + add_attributes(e_out_labels, e_attr_types, edges, + data, is_hetero, "edge_label", "edge") + if e_extra_feats: + add_sep_attr(e_extra_feats, e_attr_types, edges, + data, is_hetero, "edge") + del edges + # Deal with vertex attributes next + if v_in_feats: + add_attributes(v_in_feats, v_attr_types, vertices, + data, is_hetero, "x", "vertex") + if v_out_labels: + add_attributes(v_out_labels, v_attr_types, vertices, + data, is_hetero, "y", "vertex") + if v_extra_feats: + add_sep_attr(v_extra_feats, v_attr_types, vertices, + data, is_hetero, "vertex") + del vertices + else: + # Heterogeneous graph + # Deal with edgelist first + for etype in edges: + edgelist[etype] = torch.tensor(edgelist[etype].to_numpy().T, dtype=torch.long) + if add_self_loop: + edgelist[etype] = pyg.utils.add_self_loops(edgelist[etype])[0] + data = pyg.data.HeteroData() + for etype in edgelist: + data[e_attr_types[etype]["FromVertexTypeName"], + etype, + e_attr_types[etype]["ToVertexTypeName"]].edge_index = edgelist[etype] + # Deal with edge attributes + if e_in_feats: + for etype in edges: + if etype not in e_in_feats: + continue + if e_in_feats[etype]: + add_attributes(e_in_feats[etype], e_attr_types[etype], edges[etype], + data, is_hetero, "edge_feat", "edge", etype) + if e_out_labels: + for etype in edges: + if etype not in e_out_labels: + continue + if e_out_labels[etype]: + add_attributes(e_out_labels[etype], e_attr_types[etype], edges[etype], + data, is_hetero, "edge_label", "edge", etype) + if e_extra_feats: + for etype in edges: + if etype not in e_extra_feats: + continue + if e_extra_feats[etype]: + add_sep_attr(e_extra_feats[etype], e_attr_types[etype], edges[etype], + data, is_hetero, "edge", etype) + del edges + # Deal with vertex attributes next + if v_in_feats: + for vtype in vertices: + if vtype not in v_in_feats: + continue + if v_in_feats[vtype]: + add_attributes(v_in_feats[vtype], v_attr_types[vtype], vertices[vtype], + data, is_hetero, "x", "vertex", vtype) + if v_out_labels: + for vtype in vertices: + if vtype not in v_out_labels: + continue + if v_out_labels[vtype]: + add_attributes(v_out_labels[vtype], v_attr_types[vtype], vertices[vtype], + data, is_hetero, "y", "vertex", vtype) + if v_extra_feats: + for vtype in vertices: + if vtype not in v_extra_feats: + continue + if v_extra_feats[vtype]: + add_sep_attr(v_extra_feats[vtype], v_attr_types[vtype], vertices[vtype], + data, is_hetero, "vertex", vtype) + del vertices + return data + + @staticmethod + def _parse_df_to_dgl( + raw: Union[Tuple[pd.DataFrame, pd.DataFrame], Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]], + v_in_feats: Union[list, dict] = [], + v_out_labels: Union[list, dict] = [], + v_extra_feats: Union[list, dict] = [], + v_attr_types: dict = {}, + e_in_feats: Union[list, dict] = [], + e_out_labels: Union[list, dict] = [], + e_extra_feats: Union[list, dict] = [], + e_attr_types: dict = {}, + add_self_loop: bool = False, + is_hetero: bool = False, + torch = None, + dgl = None + ) -> Union["dgl.graph", "dgl.heterograph"]: + """Parse dataframes to PyG graphs. + """ + def add_attributes(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, + graph, is_hetero: bool, feat_name: str, + target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: + """Add multiple attributes as a single feature to edges or vertices. + """ + if is_hetero: + if not vetype: + raise ValueError("Vertex or edge type required for heterogeneous graphs") + # Focus on a specific type + if target == "edge": + data = graph.edges[vetype].data + elif target == "vertex": + data = graph.nodes[vetype].data + else: + if target == "edge": + data = graph.edata + elif target == "vertex": + data = graph.ndata + array = BaseLoader._attributes_to_np(attr_names, attr_types, attr_df) + data[feat_name] = torch.tensor(array).squeeze(dim=1) + + def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, + graph, is_hetero: bool, + target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: + """Add each attribute as a single feature to edges or vertices. + """ + if is_hetero: + if not vetype: + raise ValueError("Vertex or edge type required for heterogeneous graphs") + # Focus on a specific type + if target == "edge": + data = graph.edges[vetype].data + elif target == "vertex": + data = graph.nodes[vetype].data + else: + if target == "edge": + data = graph.edata + elif target == "vertex": + data = graph.ndata + + for col in attr_names: + dtype = attr_types[col].lower() + if dtype.startswith("str") or dtype.startswith("map"): + if vetype is None: + # Homogeneous graph, add column directly to extra data + graph.extra_data[col] = attr_df[col].to_list() + elif vetype not in graph.extra_data: + # Hetero graph, vetype doesn't exist in extra data + graph.extra_data[vetype] = {} + graph.extra_data[vetype][col] = attr_df[col].to_list() + else: + # Hetero graph and vetype already exists + graph.extra_data[vetype][col] = attr_df[col].to_list() + elif dtype.startswith("list"): + dtype2 = dtype.split(":")[1] + if dtype2.startswith("str"): + if vetype is None: + # Homogeneous graph, add column directly to extra data + graph.extra_data[col] = attr_df[col].str.split().to_list() + elif vetype not in graph.extra_data: + # Hetero graph, vetype doesn't exist in extra data + graph.extra_data[vetype] = {} + graph.extra_data[vetype][col] = attr_df[col].str.split().to_list() + else: + # Hetero graph and vetype already exists + graph.extra_data[vetype][col] = attr_df[col].str.split().to_list() + else: + data[col] = torch.tensor( + attr_df[col] + .str.split(expand=True) + .to_numpy() + .astype(dtype2) + ) + elif dtype.startswith("set") or dtype.startswith("date"): + raise NotImplementedError( + "{} type not supported for extra features yet.".format(dtype)) + elif dtype == "bool": + data[col] = torch.tensor( + attr_df[col].astype("int8").astype(dtype) ) - elif out_format.lower() == "pyg": - try: - from torch_geometric.data import Data as pygData - from torch_geometric.data import \ - HeteroData as pygHeteroData - from torch_geometric.utils import add_self_loops - mode = "pyg" - except ImportError: - raise ImportError( - "PyG is not installed. Please install PyG to use PyG format." + elif dtype == "uint": + # PyTorch only supports uint8. Need to convert it to int. + data[col] = torch.tensor( + attr_df[col].astype("int") ) - elif out_format.lower() == "spektral": - if vertices is None or edges is None: - raise ValueError( - "Spektral, PyG, or DGL format can only be used with (sub)graph loaders." - ) - try: - import tensorflow as tf - except ImportError: - raise ImportError( - "Tensorflow is not installed. Please install it to use spektral output." - ) - try: - import scipy - except ImportError: - raise ImportError( - "scipy is not installed. Please install it to use spektral output." - ) - try: - import spektral - mode = "spektral" - except ImportError: - raise ImportError( - "Spektral is not installed. Please install it to use spektral output." - ) - elif out_format.lower() == "dataframe": - if callback_fn: - return callback_fn(data) - else: - return data - else: - raise NotImplementedError + else: + data[col] = torch.tensor( + attr_df[col].astype(dtype) + ) + # Reformat as a graph. # Need to have a pair of tables for edges and vertices. + vertices, edges = raw + edgelist = BaseLoader._get_edgelist(vertices, edges, is_hetero, e_attr_types) if not is_hetero: # Deal with edgelist first - if reindex: - vertices["tmp_id"] = range(len(vertices)) - id_map = vertices[["vid", "tmp_id"]] - edges = edges.merge(id_map, left_on="source", right_on="vid") - edges.drop(columns=["source", "vid"], inplace=True) - edges = edges.merge(id_map, left_on="target", right_on="vid") - edges.drop(columns=["target", "vid"], inplace=True) - edgelist = edges[["tmp_id_x", "tmp_id_y"]] - else: - edgelist = edges[["source", "target"]] - - if mode == "dgl" or mode == "pyg": - edgelist = torch.tensor(edgelist.to_numpy().T, dtype=torch.long) - if mode == "dgl": - data = dgl.graph(data=(edgelist[0], edgelist[1])) - if add_self_loop: - data = dgl.add_self_loop(data) - data.extra_data = {} - elif mode == "pyg": - data = pygData() - if add_self_loop: - edgelist = add_self_loops(edgelist)[0] - data["edge_index"] = edgelist - elif mode == "spektral": - n_edges = len(edgelist) - n_vertices = len(vertices) - adjacency_data = [1 for i in range(n_edges)] #spektral adjacency format requires weights for each edge to initialize - adjacency = scipy.sparse.coo_matrix((adjacency_data, (edgelist["tmp_id_x"], edgelist["tmp_id_y"])), shape=(n_vertices, n_vertices)) - if add_self_loop: - adjacency = spektral.utils.add_self_loops(adjacency, value=1) - edge_index = np.stack((adjacency.row, adjacency.col), axis=-1) - data = spektral.data.graph.Graph(A=adjacency) - del edgelist + edgelist = torch.tensor(edgelist.to_numpy().T, dtype=torch.long) + data = dgl.graph(data=(edgelist[0], edgelist[1])) + if add_self_loop: + data = dgl.add_self_loop(data) + data.extra_data = {} # Deal with edge attributes if e_in_feats: add_attributes(e_in_feats, e_attr_types, edges, - data, is_hetero, mode, "edge_feat", "edge") - if mode == "spektral": - edge_data = data["edge_feat"] - edge_index, edge_data = spektral.utils.reorder(edge_index, edge_features=edge_data) - n_edges = len(edge_index) - data["e"] = np.array([[i] for i in edge_data]) #if something breaks when you add self-loops it's here - adjacency_data = [1 for i in range(n_edges)] - data["a"] = scipy.sparse.coo_matrix((adjacency_data, (edge_index[:, 0], edge_index[:, 1])), shape=(n_vertices, n_vertices)) - + data, is_hetero, "edge_feat", "edge") if e_out_labels: add_attributes(e_out_labels, e_attr_types, edges, - data, is_hetero, mode, "edge_label", "edge") + data, is_hetero, "edge_label", "edge") if e_extra_feats: add_sep_attr(e_extra_feats, e_attr_types, edges, - data, is_hetero, mode, "edge") + data, is_hetero, "edge") del edges # Deal with vertex attributes next if v_in_feats: add_attributes(v_in_feats, v_attr_types, vertices, - data, is_hetero, mode, "x", "vertex") + data, is_hetero, "x", "vertex") if v_out_labels: add_attributes(v_out_labels, v_attr_types, vertices, - data, is_hetero, mode, "y", "vertex") + data, is_hetero, "y", "vertex") if v_extra_feats: add_sep_attr(v_extra_feats, v_attr_types, vertices, - data, is_hetero, mode, "vertex") + data, is_hetero, "vertex") del vertices else: # Heterogeneous graph # Deal with edgelist first - edgelist = {} - if reindex: - id_map = {} - for vtype in vertices: - vertices[vtype]["tmp_id"] = range(len(vertices[vtype])) - id_map[vtype] = vertices[vtype][["vid", "tmp_id"]] - for etype in edges: - source_type = e_attr_types[etype]["FromVertexTypeName"] - target_type = e_attr_types[etype]["ToVertexTypeName"] - if e_attr_types[etype]["IsDirected"] or source_type==target_type: - edges[etype] = edges[etype].merge(id_map[source_type], left_on="source", right_on="vid") - edges[etype].drop(columns=["source", "vid"], inplace=True) - edges[etype] = edges[etype].merge(id_map[target_type], left_on="target", right_on="vid") - edges[etype].drop(columns=["target", "vid"], inplace=True) - edgelist[etype] = edges[etype][["tmp_id_x", "tmp_id_y"]] - else: - subdf1 = edges[etype].merge(id_map[source_type], left_on="source", right_on="vid") - subdf1.drop(columns=["source", "vid"], inplace=True) - subdf1 = subdf1.merge(id_map[target_type], left_on="target", right_on="vid") - subdf1.drop(columns=["target", "vid"], inplace=True) - if len(subdf1) < len(edges[etype]): - subdf2 = edges[etype].merge(id_map[source_type], left_on="target", right_on="vid") - subdf2.drop(columns=["target", "vid"], inplace=True) - subdf2 = subdf2.merge(id_map[target_type], left_on="source", right_on="vid") - subdf2.drop(columns=["source", "vid"], inplace=True) - subdf1 = pd.concat((subdf1, subdf2), ignore_index=True) - edges[etype] = subdf1 - edgelist[etype] = edges[etype][["tmp_id_x", "tmp_id_y"]] - else: - for etype in edges: - edgelist[etype] = edges[etype][["source", "target"]] for etype in edges: edgelist[etype] = torch.tensor(edgelist[etype].to_numpy().T, dtype=torch.long) - if mode == "dgl": - data = dgl.heterograph({ - (e_attr_types[etype]["FromVertexTypeName"], etype, e_attr_types[etype]["ToVertexTypeName"]): (edgelist[etype][0], edgelist[etype][1]) for etype in edgelist}) - if add_self_loop: - data = dgl.add_self_loop(data) - data.extra_data = {} - elif mode == "pyg": - data = pygHeteroData() - for etype in edgelist: - if add_self_loop: - edgelist[etype] = add_self_loops(edgelist[etype])[0] - data[e_attr_types[etype]["FromVertexTypeName"], - etype, - e_attr_types[etype]["ToVertexTypeName"]].edge_index = edgelist[etype] - elif mode == "spektral": - raise NotImplementedError - del edgelist + data = dgl.heterograph({ + (e_attr_types[etype]["FromVertexTypeName"], etype, e_attr_types[etype]["ToVertexTypeName"]): (edgelist[etype][0], edgelist[etype][1]) for etype in edgelist}) + if add_self_loop: + data = dgl.add_self_loop(data) + data.extra_data = {} # Deal with edge attributes if e_in_feats: for etype in edges: @@ -1188,21 +1551,21 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, continue if e_in_feats[etype]: add_attributes(e_in_feats[etype], e_attr_types[etype], edges[etype], - data, is_hetero, mode, "edge_feat", "edge", etype) + data, is_hetero, "edge_feat", "edge", etype) if e_out_labels: for etype in edges: if etype not in e_out_labels: continue if e_out_labels[etype]: add_attributes(e_out_labels[etype], e_attr_types[etype], edges[etype], - data, is_hetero, mode, "edge_label", "edge", etype) + data, is_hetero, "edge_label", "edge", etype) if e_extra_feats: for etype in edges: if etype not in e_extra_feats: continue if e_extra_feats[etype]: add_sep_attr(e_extra_feats[etype], e_attr_types[etype], edges[etype], - data, is_hetero, mode, "edge", etype) + data, is_hetero, "edge", etype) del edges # Deal with vertex attributes next if v_in_feats: @@ -1211,43 +1574,152 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, continue if v_in_feats[vtype]: add_attributes(v_in_feats[vtype], v_attr_types[vtype], vertices[vtype], - data, is_hetero, mode, "x", "vertex", vtype) + data, is_hetero, "x", "vertex", vtype) if v_out_labels: for vtype in vertices: if vtype not in v_out_labels: continue if v_out_labels[vtype]: add_attributes(v_out_labels[vtype], v_attr_types[vtype], vertices[vtype], - data, is_hetero, mode, "y", "vertex", vtype) + data, is_hetero, "y", "vertex", vtype) if v_extra_feats: for vtype in vertices: if vtype not in v_extra_feats: continue if v_extra_feats[vtype]: add_sep_attr(v_extra_feats[vtype], v_attr_types[vtype], vertices[vtype], - data, is_hetero, mode, "vertex", vtype) + data, is_hetero, "vertex", vtype) + del vertices + return data + + @staticmethod + def _parse_df_to_spektral( + raw: Union[Tuple[pd.DataFrame, pd.DataFrame], Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]], + v_in_feats: Union[list, dict] = [], + v_out_labels: Union[list, dict] = [], + v_extra_feats: Union[list, dict] = [], + v_attr_types: dict = {}, + e_in_feats: Union[list, dict] = [], + e_out_labels: Union[list, dict] = [], + e_extra_feats: Union[list, dict] = [], + e_attr_types: dict = {}, + add_self_loop: bool = False, + is_hetero: bool = False, + scipy = None, + spektral = None + ) -> "spektral.data.graph.Graph": + """Parse dataframes to Spektral graphs. + """ + def add_attributes(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, + graph, is_hetero: bool, feat_name: str, + target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: + """Add multiple attributes as a single feature to edges or vertices. + """ + data = graph + array = BaseLoader._attributes_to_np(attr_names, attr_types, attr_df) + try: + array = np.squeeze(array, axis=1) + except: + pass + data[feat_name] = array + + def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, + graph, is_hetero: bool, + target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: + """Add each attribute as a single feature to edges or vertices. + """ + data = graph + for col in attr_names: + dtype = attr_types[col].lower() + if dtype.startswith("str") or dtype.startswith("map"): + data[col] = attr_df[col].to_list() + elif dtype.startswith("list"): + dtype2 = dtype.split(":")[1] + if dtype2.startswith("str"): + data[col] = attr_df[col].str.split().to_list() + else: + data[col] = attr_df[col].str.split(expand=True).to_numpy().astype(dtype2) + elif dtype.startswith("set") or dtype.startswith("date"): + raise NotImplementedError( + "{} type not supported for extra features yet.".format(dtype)) + elif dtype == "bool": + data[col] = attr_df[col].astype("int8").astype(dtype) + else: + data[col] = attr_df[col].astype(dtype) + + # Reformat as a graph. + # Need to have a pair of tables for edges and vertices. + vertices, edges = raw + edgelist = BaseLoader._get_edgelist(vertices, edges, is_hetero, e_attr_types) + if not is_hetero: + # Deal with edgelist first + n_edges = len(edgelist) + n_vertices = len(vertices) + adjacency_data = [1 for i in range(n_edges)] #spektral adjacency format requires weights for each edge to initialize + adjacency = scipy.sparse.coo_matrix((adjacency_data, (edgelist["tmp_id_x"], edgelist["tmp_id_y"])), shape=(n_vertices, n_vertices)) + if add_self_loop: + adjacency = spektral.utils.add_self_loops(adjacency, value=1) + edge_index = np.stack((adjacency.row, adjacency.col), axis=-1) + data = spektral.data.graph.Graph(A=adjacency) + del edgelist + # Deal with edge attributes + if e_in_feats: + add_attributes(e_in_feats, e_attr_types, edges, + data, is_hetero, "edge_feat", "edge") + edge_data = data["edge_feat"] + edge_index, edge_data = spektral.utils.reorder(edge_index, edge_features=edge_data) + n_edges = len(edge_index) + data["e"] = np.array([[i] for i in edge_data]) #if something breaks when you add self-loops it's here + adjacency_data = [1 for i in range(n_edges)] + data["a"] = scipy.sparse.coo_matrix((adjacency_data, (edge_index[:, 0], edge_index[:, 1])), shape=(n_vertices, n_vertices)) + if e_out_labels: + add_attributes(e_out_labels, e_attr_types, edges, + data, is_hetero, "edge_label", "edge") + if e_extra_feats: + add_sep_attr(e_extra_feats, e_attr_types, edges, + data, is_hetero, "edge") + del edges + # Deal with vertex attributes next + if v_in_feats: + add_attributes(v_in_feats, v_attr_types, vertices, + data, is_hetero, "x", "vertex") + if v_out_labels: + add_attributes(v_out_labels, v_attr_types, vertices, + data, is_hetero, "y", "vertex") + if v_extra_feats: + add_sep_attr(v_extra_feats, v_attr_types, vertices, + data, is_hetero, "vertex") del vertices - if callback_fn: - return callback_fn(data) else: - return data + # Heterogeneous graph + # Deal with edgelist first + raise NotImplementedError + return data - def _start_request(self, out_tuple: bool, resp_type: str): + def _start_request(self, is_graph: bool): # If using kafka if self.kafka_address_consumer: # Generate topic self._set_kafka_topic() # Start consumer thread - self._downloader = Thread( - target=self._download_from_kafka, - args=( - self._exit_event, - self._read_task_q, - self.num_batches, - out_tuple, - self._kafka_consumer, - ), - ) + if is_graph: + self._downloader = Thread( + target=self._download_graph_kafka, + kwargs=dict( + exit_event = self._exit_event, + read_task_q = self._read_task_q, + kafka_consumer = self._kafka_consumer, + ), + ) + else: + self._downloader = Thread( + target=self._download_unimode_kafka, + kwargs=dict( + exit_event = self._exit_event, + read_task_q = self._read_task_q, + kafka_consumer = self._kafka_consumer + ), + ) self._downloader.start() # Start requester thread if not self.kafka_skip_produce: @@ -1264,23 +1736,33 @@ def _start_request(self, out_tuple: bool, resp_type: str): self._requester.start() else: # Otherwise, use rest api - self._requester = Thread( - target=self._request_rest, - args=( - self._graph, - self.query_name, - self._read_task_q, - self.timeout, - self._payload, - resp_type, - ), - ) + if is_graph: + self._requester = Thread( + target=self._request_graph_rest, + kwargs=dict( + tgraph = self._graph, + query_name = self.query_name, + read_task_q = self._read_task_q, + timeout = self.timeout, + payload = self._payload + ), + ) + else: + self._requester = Thread( + target=self._request_unimode_rest, + kwargs=dict( + tgraph = self._graph, + query_name = self.query_name, + read_task_q = self._read_task_q, + timeout = self.timeout, + payload = self._payload + ), + ) self._requester.start() def _start(self) -> None: # This is a template. Implement your own logics here. # Create task and result queues - self._request_task_q = Queue() self._read_task_q = Queue() self._data_q = Queue(self._buffer_size) self._exit_event = Event() @@ -1346,41 +1828,38 @@ def data(self) -> Any: return self def _reset(self, theend=False) -> None: - logging.debug("Resetting the loader") + logger.debug("Resetting data loader") if self._exit_event: self._exit_event.set() - if self._request_task_q: - self._request_task_q.put(None) - if self._download_task_q: - self._download_task_q.put(None) + logger.debug("Set exit event") if self._read_task_q: while True: try: - self._read_task_q.get(block=False) + self._read_task_q.get(timeout=1) except Empty: break - self._read_task_q.put(None) + logger.debug("Emptied read task queue") if self._data_q: while True: try: - self._data_q.get(block=False) + self._data_q.get(timeout=1) except Empty: break + logger.debug("Emptied data queue") if self._requester: self._requester.join() + logger.debug("Stopped requester thread") if self._downloader: self._downloader.join() + logger.debug("Stopped downloader thread") if self._reader: self._reader.join() - del self._request_task_q, self._download_task_q, self._read_task_q, self._data_q + logger.debug("Stopped reader thread") + del self._read_task_q, self._data_q self._exit_event = None self._requester, self._downloader, self._reader = None, None, None - self._request_task_q, self._download_task_q, self._read_task_q, self._data_q = ( - None, - None, - None, - None, - ) + self._read_task_q, self._data_q = None, None + logger.debug("Deleted all queues and threads") if theend: if self._kafka_topic and self._kafka_consumer: self._kafka_consumer.unsubscribe() @@ -1392,6 +1871,7 @@ def _reset(self, theend=False) -> None: raise TigerGraphException( "Failed to delete topic {}".format(del_res["topic"]) ) + logger.debug("Finished with Kafka. Reached the end.") else: if self.delete_epoch_topic and self._kafka_admin: if self._kafka_topic and self._kafka_consumer: @@ -1403,7 +1883,8 @@ def _reset(self, theend=False) -> None: "Failed to delete topic {}".format(del_res["topic"]) ) self._kafka_topic = None - logging.debug("Successfully reset the loader") + logger.debug("Finished with Kafka") + logger.debug("Reset data loader successfully") def _generate_attribute_string(self, schema_type, attr_names, attr_types) -> str: if schema_type.lower() == "vertex": @@ -1630,6 +2111,7 @@ def __init__( self._etypes = list(self._e_schema.keys()) self._vtypes = sorted(self._vtypes) self._etypes = sorted(self._etypes) + # Resolve seeds if v_seed_types: if isinstance(v_seed_types, list): self._seed_types = v_seed_types @@ -1641,10 +2123,15 @@ def __init__( self._seed_types = list(filter_by.keys()) else: self._seed_types = self._vtypes + if set(self._seed_types) - set(self._vtypes): + raise ValueError("Seed type has to be one of the vertex types to retrieve") - # Resolve seeds if batch_size: - # If batch_size is given, calculate the number of batches + # batch size takes precedence over number of batches + self.batch_size = batch_size + self.num_batches = None + else: + # If number of batches is given, calculate batch size if not filter_by: num_vertices = sum(self._graph.getVertexCount(self._seed_types).values()) elif isinstance(filter_by, str): @@ -1659,17 +2146,11 @@ def __init__( ) else: raise ValueError("filter_by should be None, attribute name, or dict of {type name: attribute name}.") - self.num_batches = math.ceil(num_vertices / batch_size) - else: - # Otherwise, take the number of batches as is. + self.batch_size = math.ceil(num_vertices / num_batches) self.num_batches = num_batches # Initialize parameters for the query - if batch_size: - self._payload["batch_size"] = batch_size - self._payload["num_batches"] = self.num_batches self._payload["num_neighbors"] = num_neighbors self._payload["num_hops"] = num_hops - self._payload["num_heap_inserts"] = self.num_heap_inserts if filter_by: if isinstance(filter_by, str): self._payload["filter_by"] = filter_by @@ -1715,19 +2196,23 @@ def _install_query(self, force: bool = False): + self.v_extra_feats.get(vtype, []) ) v_attr_types = self._v_schema[vtype] - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query_seed += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + delimiter + "1\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - print_query_other += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + delimiter + "0\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - else: - print_query_seed += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + "1\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query_other += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + "0\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query_seed += "END" - print_query_other += "END" + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query_seed += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) {} + delimiter + "1\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", + vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_other += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) {} + delimiter + "0\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", + vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_seed += """ + END""" + print_query_other += """ + END""" query_replace["{SEEDVERTEXATTRS}"] = print_query_seed query_replace["{OTHERVERTEXATTRS}"] = print_query_other # Multiple edge types @@ -1739,44 +2224,36 @@ def _install_query(self, force: bool = False): + self.e_extra_feats.get(etype, []) ) e_attr_types = self._e_schema[etype] - if e_attr_names: - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype, print_attr) - else: - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype) - print_query += "END" + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_query += """ + {} e.type == "{}" THEN + @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {} + "\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", + etype, + "+ delimiter + " + print_attr if e_attr_names else "") + print_query += """ + END""" query_replace["{EDGEATTRS}"] = print_query else: # Ignore vertex types v_attr_names = self.v_in_feats + self.v_out_labels + self.v_extra_feats v_attr_types = next(iter(self._v_schema.values())) - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + {} + delimiter + "1\\n")'.format( - print_attr - ) - query_replace["{SEEDVERTEXATTRS}"] = print_query - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + {} + delimiter + "0\\n")'.format( - print_attr - ) - query_replace["{OTHERVERTEXATTRS}"] = print_query - else: - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + "1\\n")' - query_replace["{SEEDVERTEXATTRS}"] = print_query - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + "0\\n")' - query_replace["{OTHERVERTEXATTRS}"] = print_query - # Ignore edge types - e_attr_names = self.e_in_feats + self.e_out_labels + self.e_extra_feats - e_attr_types = next(iter(self._e_schema.values())) - if e_attr_names: - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")'.format( - print_attr - ) - else: - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")' + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query = '@@v_batch += (stringify(getvid(s)) {} + delimiter + "1\\n")'.format( + "+ delimiter + " + print_attr if v_attr_names else "" + ) + query_replace["{SEEDVERTEXATTRS}"] = print_query + print_query = '@@v_batch += (stringify(getvid(s)) {} + delimiter + "0\\n")'.format( + "+ delimiter + " + print_attr if v_attr_names else "" + ) + query_replace["{OTHERVERTEXATTRS}"] = print_query + # Ignore edge types + e_attr_names = self.e_in_feats + self.e_out_labels + self.e_extra_feats + e_attr_types = next(iter(self._e_schema.values())) + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {} + "\\n")'.format( + " + delimiter + " + print_attr if e_attr_names else "" + ) query_replace["{EDGEATTRS}"] = print_query # Install query query_path = os.path.join( @@ -1785,15 +2262,21 @@ def _install_query(self, force: bool = False): "dataloaders", "neighbor_loader.gsql", ) - return install_query_file(self._graph, query_path, query_replace, force=force, distributed=self.distributed_query) + sub_query_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "gsql", + "dataloaders", + "neighbor_loader_sub.gsql", + ) + return install_query_files(self._graph, [sub_query_path, query_path], query_replace, force=force, distributed=[False, self.distributed_query]) def _start(self) -> None: # Create task and result queues - self._read_task_q = Queue(self.buffer_size * 2) + self._read_task_q = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() - self._start_request(True, "both") + self._start_request(True) # Start reading thread. if not self.is_hetero: @@ -1810,26 +2293,25 @@ def _start(self) -> None: v_attr_types[vtype]["is_seed"] = "bool" e_attr_types = self._e_schema self._reader = Thread( - target=self._read_data, - args=( - self._exit_event, - self._read_task_q, - self._data_q, - "graph", - self.output_format, - self.v_in_feats, - self.v_out_labels, - v_extra_feats, - v_attr_types, - self.e_in_feats, - self.e_out_labels, - self.e_extra_feats, - e_attr_types, - self.add_self_loop, - self.delimiter, - True, - self.is_hetero, - self.callback_fn + target=self._read_graph_data, + kwargs=dict( + exit_event = self._exit_event, + in_q = self._read_task_q, + out_q = self._data_q, + batch_size = self.batch_size, + out_format = self.output_format, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + delimiter = self.delimiter, + is_hetero = self.is_hetero, + callback_fn = self.callback_fn ), ) self._reader.start() @@ -1866,7 +2348,6 @@ def fetch(self, vertices: list) -> None: _payload = {} _payload["v_types"] = self._payload["v_types"] _payload["e_types"] = self._payload["e_types"] - _payload["num_batches"] = 1 _payload["num_neighbors"] = self._payload["num_neighbors"] _payload["num_hops"] = self._payload["num_hops"] _payload["delimiter"] = self._payload["delimiter"] @@ -1892,11 +2373,15 @@ def fetch(self, vertices: list) -> None: v_attr_types[vtype]["is_seed"] = "bool" v_attr_types[vtype]["primary_id"] = "str" e_attr_types = self._e_schema - i = resp[0] - data = self._parse_data( - raw = (i["vertex_batch"], i["edge_batch"]), - in_format = "graph", - out_format = self.output_format, + vertex_batch = set() + edge_batch = set() + for i in resp: + if "pids" in i: + break + vertex_batch.update(i["vertex_batch"].splitlines()) + edge_batch.update(i["edge_batch"].splitlines()) + data = self._parse_graph_data_to_df( + raw = (vertex_batch, edge_batch), v_in_feats = self.v_in_feats, v_out_labels = self.v_out_labels, v_extra_feats = v_extra_feats, @@ -1905,14 +2390,119 @@ def fetch(self, vertices: list) -> None: e_out_labels = self.e_out_labels, e_extra_feats = self.e_extra_feats, e_attr_types = e_attr_types, - add_self_loop = self.add_self_loop, delimiter = self.delimiter, - reindex = True, primary_id = i["pids"], is_hetero = self.is_hetero, - callback_fn = self.callback_fn ) - # Return data + if self.output_format == "dataframe" or self.output_format== "df": + vertices, edges = data + if not self.is_hetero: + for column in vertices.columns: + vertices[column] = pd.to_numeric(vertices[column], errors="ignore") + for column in edges.columns: + edges[column] = pd.to_numeric(edges[column], errors="ignore") + else: + for key in vertices: + for column in vertices[key].columns: + vertices[key][column] = pd.to_numeric(vertices[key][column], errors="ignore") + for key in edges: + for column in edges[key].columns: + edges[key][column] = pd.to_numeric(edges[key][column], errors="ignore") + data = (vertices, edges) + elif self.output_format == "pyg": + try: + import torch + except ImportError: + raise ImportError( + "PyTorch is not installed. Please install it to use PyG or DGL output." + ) + try: + import torch_geometric as pyg + except ImportError: + raise ImportError( + "PyG is not installed. Please install PyG to use PyG format." + ) + data = BaseLoader._parse_df_to_pyg( + raw = data, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + is_hetero = self.is_hetero, + torch = torch, + pyg = pyg + ) + elif self.output_format == "dgl": + try: + import torch + except ImportError: + raise ImportError( + "PyTorch is not installed. Please install it to use PyG or DGL output." + ) + try: + import dgl + except ImportError: + raise ImportError( + "DGL is not installed. Please install DGL to use DGL format." + ) + data = BaseLoader._parse_df_to_dgl( + raw = data, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + is_hetero = self.is_hetero, + torch = torch, + dgl= dgl + ) + elif self.output_format == "spektral" and self.is_hetero==False: + try: + import tensorflow as tf + except ImportError: + raise ImportError( + "Tensorflow is not installed. Please install it to use spektral output." + ) + try: + import scipy + except ImportError: + raise ImportError( + "scipy is not installed. Please install it to use spektral output." + ) + try: + import spektral + except ImportError: + raise ImportError( + "Spektral is not installed. Please install it to use spektral output." + ) + data = BaseLoader._parse_df_to_spektral( + raw = data, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + is_hetero = self.is_hetero, + scipy = scipy, + spektral = spektral + ) + else: + raise NotImplementedError + if self.callback_fn: + data = self.callback_fn(data) return data @@ -2099,19 +2689,32 @@ def __init__( self._etypes = sorted(self._etypes) # Initialize parameters for the query if batch_size: - # If batch_size is given, calculate the number of batches + # batch size takes precedence over number of batches + self.batch_size = batch_size + self.num_batches = None + else: + # If number of batches is given, calculate batch size if filter_by: - num_edges = sum(self._graph.getEdgeStats(e_type)[e_type][filter_by if isinstance(filter_by, str) else filter_by[e_type]]["TRUE"] for e_type in self._etypes) + num_edges = 0 + for e_type in self._etypes: + tmp = self._graph.getEdgeStats(e_type)[e_type][filter_by if isinstance(filter_by, str) else filter_by[e_type]]["TRUE"] + if self._e_schema[e_type]["IsDirected"]: + num_edges += tmp + else: + num_edges += 2*tmp else: - num_edges = sum(self._graph.getEdgeCount(i) for i in self._etypes) - self.num_batches = math.ceil(num_edges / batch_size) - else: - # Otherwise, take the number of batches as is. + num_edges = 0 + for e_type in self._etypes: + tmp = self._graph.getEdgeCount(e_type) + if self._e_schema[e_type]["IsDirected"]: + num_edges += tmp + else: + num_edges += 2*tmp + if num_edges==0: + raise ValueError("Cannot find any edge as seed. Please check your configuration and data. If they all look good, please use batch_size instead of num_batches or refresh metadata following https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints#_parameters_15") + self.batch_size = math.ceil(num_edges / num_batches) self.num_batches = num_batches # Initialize the exporter - if batch_size: - self._payload["batch_size"] = batch_size - self._payload["num_batches"] = self.num_batches if filter_by: self._payload["filter_by"] = filter_by self._payload["shuffle"] = shuffle @@ -2133,31 +2736,65 @@ def _install_query(self, force: bool = False): if isinstance(self.attributes, dict): # Multiple edge types - print_query = "" + print_query_kafka = "" + print_query_http = "" for idx, etype in enumerate(self._etypes): e_attr_names = self.attributes.get(etype, []) e_attr_types = self._e_schema[etype] if e_attr_names: print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype, print_attr) + print_query_http += """ + {} e.type == "{}" THEN + @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")\n"""\ + .format("IF" if idx==0 else "ELSE IF", etype, print_attr) + print_query_kafka += """ + {} e.type == "{}" THEN + STRING msg = (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_" + stringify(getvid(s)) + "_" + stringify(getvid(t)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for edge " + stringify(getvid(s)) + "_" + stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") + END\n""".format("IF" if idx==0 else "ELSE IF", etype, print_attr) else: - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype) - print_query += "END" - query_replace["{EDGEATTRS}"] = print_query + print_query_http += """ + {} e.type == "{}" THEN + @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")\n"""\ + .format("IF" if idx==0 else "ELSE IF", etype) + print_query_kafka += """ + {} e.type == "{}" THEN + STRING msg = (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_" + stringify(getvid(s)) + "_" + stringify(getvid(t)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for edge " + stringify(getvid(s)) + "_" + stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") + END\n""".format("IF" if idx==0 else "ELSE IF", etype) + print_query_http += "\ + END" + print_query_kafka += "\ + END" else: # Ignore edge types e_attr_names = self.attributes e_attr_types = next(iter(self._e_schema.values())) if e_attr_names: print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")'.format( + print_query_http = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")'.format( print_attr ) + print_query_kafka = """ + STRING msg = (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_" + stringify(getvid(s)) + "_" + stringify(getvid(t)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for edge " + stringify(getvid(s)) + "_" + stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") + END""".format(print_attr) else: - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")' - query_replace["{EDGEATTRS}"] = print_query + print_query_http = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")' + print_query_kafka = """ + STRING msg = (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_" + stringify(getvid(s)) + "_" + stringify(getvid(t)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for edge " + stringify(getvid(s)) + "_" + stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") + END""" + query_replace["{EDGEATTRSHTTP}"] = print_query_http + query_replace["{EDGEATTRSKAFKA}"] = print_query_kafka # Install query query_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -2169,11 +2806,11 @@ def _install_query(self, force: bool = False): def _start(self) -> None: # Create task and result queues - self._read_task_q = Queue(self.buffer_size * 2) + self._read_task_q = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() - self._start_request(False, "edge") + self._start_request(False) # Start reading thread. if not self.is_hetero: @@ -2181,27 +2818,20 @@ def _start(self) -> None: else: e_attr_types = self._e_schema self._reader = Thread( - target=self._read_data, - args=( - self._exit_event, - self._read_task_q, - self._data_q, - "edge", - self.output_format, - [], - [], - [], - {}, - self.attributes, - {} if self.is_hetero else [], - {} if self.is_hetero else [], - e_attr_types, - False, - self.delimiter, - False, - self.is_hetero, - self.callback_fn - ), + target=self._read_edge_data, + kwargs=dict( + exit_event = self._exit_event, + in_q = self._read_task_q, + out_q = self._data_q, + batch_size = self.batch_size, + e_in_feats = self.attributes, + e_out_labels = {} if self.is_hetero else [], + e_extra_feats = {} if self.is_hetero else [], + e_attr_types = e_attr_types, + delimiter = self.delimiter, + is_hetero = self.is_hetero, + callback_fn = self.callback_fn + ) ) self._reader.start() @@ -2246,7 +2876,7 @@ class VertexLoader(BaseLoader): print("----Batch {}: Shape {}----".format(i, batch.shape)) print(batch.head(1)) <1> ---- - <1> Since the example does not provide an output format, the output format defaults to panda frames, have access to the methods of panda frame instances. + <1> The output format is Pandas dataframe. -- Output:: + @@ -2399,7 +3029,11 @@ def __init__( self._vtypes = sorted(self._vtypes) # Initialize parameters for the query if batch_size: - # If batch_size is given, calculate the number of batches + # batch size takes precedence over number of batches + self.batch_size = batch_size + self.num_batches = None + else: + # If number of batches is given, calculate batch size num_vertices_by_type = self._graph.getVertexCount(self._vtypes) if filter_by: num_vertices = sum( @@ -2408,20 +3042,14 @@ def __init__( ) else: num_vertices = sum(num_vertices_by_type.values()) - self.num_batches = math.ceil(num_vertices / batch_size) - else: - # Otherwise, take the number of batches as is. + self.batch_size = math.ceil(num_vertices / num_batches) self.num_batches = num_batches - self._payload["num_batches"] = self.num_batches if filter_by: self._payload["filter_by"] = filter_by - if batch_size: - self._payload["batch_size"] = batch_size self._payload["shuffle"] = shuffle self._payload["delimiter"] = delimiter self._payload["v_types"] = self._vtypes self._payload["input_vertices"] = [] - self._payload["num_heap_inserts"] = self.num_heap_inserts # Install query self.query_name = self._install_query() @@ -2437,31 +3065,65 @@ def _install_query(self, force: bool = False) -> str: if isinstance(self.attributes, dict): # Multiple vertex types - print_query = "" + print_query_kafka = "" + print_query_http = "" for idx, vtype in enumerate(self._vtypes): v_attr_names = self.attributes.get(vtype, []) v_attr_types = self._v_schema[vtype] if v_attr_names: print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) + print_query_http += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + "\\n")\n"""\ + .format("IF" if idx==0 else "ELSE IF", vtype, print_attr) + print_query_kafka += """ + {} s.type == "{}" THEN + STRING msg = (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END\n""".format("IF" if idx==0 else "ELSE IF", vtype, print_attr) else: - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query += "END" - query_replace["{VERTEXATTRS}"] = print_query + print_query_http += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) + "\\n")\n"""\ + .format("IF" if idx==0 else "ELSE IF", vtype) + print_query_kafka += """ + {} s.type == "{}" THEN + STRING msg = (s.type + delimiter + stringify(getvid(s)) + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END\n""".format("IF" if idx==0 else "ELSE IF", vtype) + print_query_http += "\ + END" + print_query_kafka += "\ + END" else: # Ignore vertex types v_attr_names = self.attributes v_attr_types = next(iter(self._v_schema.values())) if v_attr_names: print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + {} + "\\n")'.format( + print_query_http = '@@v_batch += (stringify(getvid(s)) + delimiter + {} + "\\n")'.format( print_attr ) + print_query_kafka = """ + STRING msg = (stringify(getvid(s)) + delimiter + {} + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END""".format(print_attr) else: - print_query = '@@v_batch += (stringify(getvid(s)) + "\\n")' - query_replace["{VERTEXATTRS}"] = print_query + print_query_http = '@@v_batch += (stringify(getvid(s)) + "\\n")' + print_query_kafka = """ + STRING msg = (stringify(getvid(s)) + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END""" + query_replace["{VERTEXATTRSHTTP}"] = print_query_http + query_replace["{VERTEXATTRSKAFKA}"] = print_query_kafka # Install query query_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -2473,11 +3135,11 @@ def _install_query(self, force: bool = False) -> str: def _start(self) -> None: # Create task and result queues - self._read_task_q = Queue(self.buffer_size * 2) + self._read_task_q = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() - self._start_request(False, "vertex") + self._start_request(False) # Start reading thread. if not self.is_hetero: @@ -2485,27 +3147,20 @@ def _start(self) -> None: else: v_attr_types = self._v_schema self._reader = Thread( - target=self._read_data, - args=( - self._exit_event, - self._read_task_q, - self._data_q, - "vertex", - self.output_format, - self.attributes, - {} if self.is_hetero else [], - {} if self.is_hetero else [], - v_attr_types, - [], - [], - [], - {}, - False, - self.delimiter, - False, - self.is_hetero, - self.callback_fn - ), + target=self._read_vertex_data, + kwargs=dict( + exit_event = self._exit_event, + in_q = self._read_task_q, + out_q = self._data_q, + batch_size = self.batch_size, + v_in_feats = self.attributes, + v_out_labels = {} if self.is_hetero else [], + v_extra_feats = {} if self.is_hetero else [], + v_attr_types = v_attr_types, + delimiter = self.delimiter, + is_hetero = self.is_hetero, + callback_fn = self.callback_fn + ) ) self._reader.start() @@ -2729,24 +3384,41 @@ def __init__( self._etypes = sorted(self._etypes) # Initialize parameters for the query if batch_size: - # If batch_size is given, calculate the number of batches + # batch size takes precedence over number of batches + self.batch_size = batch_size + self.num_batches = None + else: + # If number of batches is given, calculate batch size if filter_by: - # TODO: get edge count with filter - raise NotImplementedError + num_edges = 0 + for e_type in self._etypes: + tmp = self._graph.getEdgeStats(e_type)[e_type][filter_by if isinstance(filter_by, str) else filter_by[e_type]]["TRUE"] + if self._e_schema[e_type]["IsDirected"]: + num_edges += tmp + else: + num_edges += 2*tmp else: - num_edges = sum(self._graph.getEdgeCount(i) for i in self._etypes) - self.num_batches = math.ceil(num_edges / batch_size) - else: - # Otherwise, take the number of batches as is. + num_edges = 0 + for e_type in self._etypes: + tmp = self._graph.getEdgeCount(e_type) + if self._e_schema[e_type]["IsDirected"]: + num_edges += tmp + else: + num_edges += 2*tmp + self.batch_size = math.ceil(num_edges / num_batches) self.num_batches = num_batches - self._payload["num_batches"] = self.num_batches if filter_by: - self._payload["filter_by"] = filter_by + if isinstance(filter_by, str): + self._payload["filter_by"] = filter_by + else: + attr = set(filter_by.values()) + if len(attr) != 1: + raise NotImplementedError("Filtering by different attributes for different edge types is not supported. Please use the same attribute for different types.") + self._payload["filter_by"] = attr.pop() self._payload["shuffle"] = shuffle self._payload["v_types"] = self._vtypes self._payload["e_types"] = self._etypes self._payload["delimiter"] = self.delimiter - self._payload["num_heap_inserts"] = self.num_heap_inserts # Output self.add_self_loop = add_self_loop # Install query @@ -2767,9 +3439,11 @@ def _install_query(self, force: bool = False) -> str: md5.update(json.dumps(query_suffix).encode()) query_replace = {"{QUERYSUFFIX}": md5.hexdigest()} + print_vertex_attr = "" + print_edge_http = "" + print_edge_kafka = "" if isinstance(self.v_in_feats, dict): # Multiple vertex types - print_query = "" for idx, vtype in enumerate(self._vtypes): v_attr_names = ( self.v_in_feats.get(vtype, []) @@ -2777,17 +3451,16 @@ def _install_query(self, force: bool = False) -> str: + self.v_extra_feats.get(vtype, []) ) v_attr_types = self._v_schema[vtype] - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - else: - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query += "END" - query_replace["{VERTEXATTRS}"] = print_query + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_vertex_attr += """ + {} s.type == "{}" THEN + ret = (s.type + delimiter + stringify(getvid(s)) {}+ "\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", + vtype, + "+ delimiter + {}".format(print_attr) if v_attr_names else "") + print_vertex_attr += """ + END""" # Multiple edge types - print_query = "" for idx, etype in enumerate(self._etypes): e_attr_names = ( self.e_in_feats.get(etype, []) @@ -2795,38 +3468,51 @@ def _install_query(self, force: bool = False) -> str: + self.e_extra_feats.get(etype, []) ) e_attr_types = self._e_schema[etype] - if e_attr_names: - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype, print_attr) - else: - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype) - print_query += "END" - query_replace["{EDGEATTRS}"] = print_query + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_edge_http += """ + {} e.type == "{}" THEN + STRING e_msg = (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ "\\n"), + @@e_batch += (stringify(getvid(s))+e.type+stringify(getvid(t)) -> e_msg)"""\ + .format("IF" if idx==0 else "ELSE IF", etype, + "+ delimiter + " + print_attr if e_attr_names else "") + print_edge_kafka += """ + {} e.type == "{}" THEN + STRING e_msg = (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_batch_" + stringify(getvid(s))+e.type+stringify(getvid(t)), e_msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending edge data for edge " + stringify(getvid(s))+e.type+stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") + END""".format("IF" if idx==0 else "ELSE IF", etype, + "+ delimiter + " + print_attr if e_attr_names else "") + print_edge_http += """ + END""" + print_edge_kafka += """ + END""" else: # Ignore vertex types v_attr_names = self.v_in_feats + self.v_out_labels + self.v_extra_feats v_attr_types = next(iter(self._v_schema.values())) - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + {} + "\\n")'.format( - print_attr - ) - else: - print_query = '@@v_batch += (stringify(getvid(s)) + "\\n")' - query_replace["{VERTEXATTRS}"] = print_query + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_vertex_attr += """ + ret = (stringify(getvid(s)) {}+ "\\n")"""\ + .format("+ delimiter + " + print_attr if v_attr_names else "") # Ignore edge types e_attr_names = self.e_in_feats + self.e_out_labels + self.e_extra_feats e_attr_types = next(iter(self._e_schema.values())) - if e_attr_names: - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")'.format( - print_attr - ) - else: - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")' - query_replace["{EDGEATTRS}"] = print_query + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_edge_http += """ + STRING e_msg = (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ "\\n"), + @@e_batch += (stringify(getvid(s))+e.type+stringify(getvid(t)) -> e_msg)"""\ + .format("+ delimiter + " + print_attr if e_attr_names else "") + print_edge_kafka += """ + STRING e_msg = (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ "\\n"), + INT kafka_errcode2 = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_batch_" + stringify(getvid(s))+e.type+stringify(getvid(t)), e_msg), + IF kafka_errcode2!=0 THEN + @@kafka_error += ("Error sending edge data for edge " + stringify(getvid(s))+e.type+stringify(getvid(t)) + ": "+ stringify(kafka_errcode2) + "\\n") + END""".format("+ delimiter + " + print_attr if e_attr_names else "") + query_replace["{VERTEXATTRS}"] = print_vertex_attr + query_replace["{EDGEATTRSKAFKA}"] = print_edge_kafka + query_replace["{EDGEATTRSHTTP}"] = print_edge_http + # Install query query_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -2834,15 +3520,21 @@ def _install_query(self, force: bool = False) -> str: "dataloaders", "graph_loader.gsql", ) - return install_query_file(self._graph, query_path, query_replace, force=force, distributed=self.distributed_query) + sub_query_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "gsql", + "dataloaders", + "graph_loader_sub.gsql", + ) + return install_query_files(self._graph, [sub_query_path, query_path], query_replace, force=force, distributed=[False, self.distributed_query]) def _start(self) -> None: # Create task and result queues - self._read_task_q = Queue(self.buffer_size * 2) + self._read_task_q = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() - self._start_request(True, "both") + self._start_request(True) # Start reading thread. if not self.is_hetero: @@ -2852,26 +3544,25 @@ def _start(self) -> None: v_attr_types = self._v_schema e_attr_types = self._e_schema self._reader = Thread( - target=self._read_data, - args=( - self._exit_event, - self._read_task_q, - self._data_q, - "graph", - self.output_format, - self.v_in_feats, - self.v_out_labels, - self.v_extra_feats, - v_attr_types, - self.e_in_feats, - self.e_out_labels, - self.e_extra_feats, - e_attr_types, - self.add_self_loop, - self.delimiter, - True, - self.is_hetero, - self.callback_fn + target=self._read_graph_data, + kwargs=dict( + exit_event = self._exit_event, + in_q = self._read_task_q, + out_q = self._data_q, + batch_size = self.batch_size, + out_format = self.output_format, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = self.v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + delimiter = self.delimiter, + is_hetero = self.is_hetero, + callback_fn = self.callback_fn ), ) self._reader.start() @@ -2972,7 +3663,9 @@ def __init__( kafka_add_topic_per_epoch: bool = False, callback_fn: Callable = None, kafka_group_id: str = None, - kafka_topic: str = None + kafka_topic: str = None, + num_machines: int = 1, + num_segments: int = 20 ) -> None: """NO DOC""" @@ -3017,7 +3710,9 @@ def __init__( kafka_add_topic_per_epoch, callback_fn, kafka_group_id, - kafka_topic + kafka_topic, + num_machines, + num_segments ) # Resolve attributes is_hetero = any(map(lambda x: isinstance(x, dict), @@ -3051,29 +3746,47 @@ def __init__( self._vtypes = sorted(self._vtypes) self._etypes = sorted(self._etypes) # Resolve seeds - self._seed_types = self._etypes if ((not filter_by) or isinstance(filter_by, str)) else list(filter_by.keys()) - if not(filter_by) and e_seed_types: - if isinstance(e_seed_types, str): - self._seed_types = [e_seed_types] - elif isinstance(e_seed_types, list): + if e_seed_types: + if isinstance(e_seed_types, list): self._seed_types = e_seed_types + elif isinstance(e_seed_types, str): + self._seed_types = [e_seed_types] else: - raise TigerGraphException("e_seed_types must be type list or string.") + raise TigerGraphException("e_seed_types must be either of type list or string.") + elif isinstance(filter_by, dict): + self._seed_types = list(filter_by.keys()) + else: + self._seed_types = self._etypes + if set(self._seed_types) - set(self._etypes): + raise ValueError("Seed type has to be one of the edge types to retrieve") # Resolve number of batches if batch_size: - # If batch_size is given, calculate the number of batches + # batch size takes precedence over number of batches + self.batch_size = batch_size + self.num_batches = None + else: + # If number of batches is given, calculate batch size if filter_by: - num_edges = sum(self._graph.getEdgeStats(e_type)[e_type][filter_by if isinstance(filter_by, str) else filter_by[e_type]]["TRUE"] for e_type in self._seed_types) + num_edges = 0 + for e_type in self._seed_types: + tmp = self._graph.getEdgeStats(e_type)[e_type][filter_by if isinstance(filter_by, str) else filter_by[e_type]]["TRUE"] + if self._e_schema[e_type]["IsDirected"]: + num_edges += tmp + else: + num_edges += 2*tmp else: - num_edges = sum(self._graph.getEdgeCount(i) for i in self._seed_types) - self.num_batches = math.ceil(num_edges / batch_size) - else: - # Otherwise, take the number of batches as is. - self.num_batches = num_batches + num_edges = 0 + for e_type in self._seed_types: + tmp = self._graph.getEdgeCount(e_type) + if self._e_schema[e_type]["IsDirected"]: + num_edges += tmp + else: + num_edges += 2*tmp + if num_edges==0: + raise ValueError("Cannot find any edge as seed. Please check the configuration and the data. If they all look right, please use batch_size instead of num_batches or refresh metadata following https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints#_parameters_15") + self.batch_size = math.ceil(num_edges / num_batches) + self.num_batches = num_batches # Initialize parameters for the query - if batch_size: - self._payload["batch_size"] = batch_size - self._payload["num_batches"] = self.num_batches self._payload["num_neighbors"] = num_neighbors self._payload["num_hops"] = num_hops self._payload["delimiter"] = delimiter @@ -3111,7 +3824,8 @@ def _install_query(self, force: bool = False): if self.is_hetero: # Multiple vertex types - print_query = "" + print_query_seed = "" + print_query_other = "" for idx, vtype in enumerate(self._vtypes): v_attr_names = ( self.v_in_feats.get(vtype, []) @@ -3119,18 +3833,26 @@ def _install_query(self, force: bool = False): + self.v_extra_feats.get(vtype, []) ) v_attr_types = self._v_schema[vtype] - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - else: - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query += "END" - query_replace["{VERTEXATTRS}"] = print_query + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query_seed += """ + {} s.type == "{}" THEN + @@v_batch += (s->(s.type + delimiter + stringify(getvid(s)) {}+ "\\n"))"""\ + .format("IF" if idx==0 else "ELSE IF", vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_other += """ + {} s.type == "{}" THEN + @@v_batch += (tmp_seed->(s.type + delimiter + stringify(getvid(s)) {}+ "\\n"))"""\ + .format("IF" if idx==0 else "ELSE IF", vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_seed += """ + END""" + print_query_other += """ + END""" + query_replace["{SEEDVERTEXATTRS}"] = print_query_seed + query_replace["{OTHERVERTEXATTRS}"] = print_query_other # Multiple edge types - print_query_seed = "" - print_query_other = "" + print_query = "" + print_query_kafka = "" for idx, etype in enumerate(self._etypes): e_attr_names = ( self.e_in_feats.get(etype, []) @@ -3138,52 +3860,50 @@ def _install_query(self, force: bool = False): + self.e_extra_feats.get(etype, []) ) e_attr_types = self._e_schema[etype] - if e_attr_names: - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query_seed += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + delimiter + "1\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype, print_attr) - print_query_other += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + delimiter + "0\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype, print_attr) - else: - print_query_seed += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + "1\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype) - print_query_other += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + "0\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype) - print_query_seed += "END" - print_query_other += "END" - query_replace["{SEEDEDGEATTRS}"] = print_query_seed - query_replace["{OTHEREDGEATTRS}"] = print_query_other + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_query += """ + {} e.type == "{}" THEN + @@e_batch += (tmp_seed->(e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ "\\n"))"""\ + .format("IF" if idx==0 else "ELSE IF", etype, + "+ delimiter + " + print_attr if e_attr_names else "") + print_query_kafka += """ + {} e.type == "{}" THEN + SET tmp_e = (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ "\\n", ""), + tmp_e_batch = tmp_e_batch UNION tmp_e"""\ + .format("IF" if idx==0 else "ELSE IF", etype, + "+ delimiter + " + print_attr if e_attr_names else "") + print_query += """ + END""" + print_query_kafka += """ + END""" + query_replace["{EDGEATTRS}"] = print_query + query_replace["{EDGEATTRSKAFKA}"] = print_query_kafka else: # Ignore vertex types v_attr_names = self.v_in_feats + self.v_out_labels + self.v_extra_feats v_attr_types = next(iter(self._v_schema.values())) - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + {} + "\\n")'.format( - print_attr - ) - query_replace["{VERTEXATTRS}"] = print_query - else: - print_query = '@@v_batch += (stringify(getvid(s)) + "\\n")' - query_replace["{VERTEXATTRS}"] = print_query + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query_seed = '@@v_batch += (s->(stringify(getvid(s)) {}+ "\\n"))'.format( + "+ delimiter + " + print_attr if v_attr_names else "" + ) + print_query_other = '@@v_batch += (tmp_seed->(stringify(getvid(s)) {}+ "\\n"))'.format( + "+ delimiter + " + print_attr if v_attr_names else "" + ) + query_replace["{SEEDVERTEXATTRS}"] = print_query_seed + query_replace["{OTHERVERTEXATTRS}"] = print_query_other # Ignore edge types e_attr_names = self.e_in_feats + self.e_out_labels + self.e_extra_feats e_attr_types = next(iter(self._e_schema.values())) - if e_attr_names: - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + delimiter + "1\\n")'.format( - print_attr - ) - query_replace["{SEEDEDGEATTRS}"] = print_query - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + delimiter + "0\\n")'.format( - print_attr - ) - query_replace["{OTHEREDGEATTRS}"] = print_query - else: - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + "1\\n")' - query_replace["{SEEDEDGEATTRS}"] = print_query - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + "0\\n")' - query_replace["{OTHEREDGEATTRS}"] = print_query + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_query = '@@e_batch += (tmp_seed->(stringify(getvid(s)) + delimiter + stringify(getvid(t)) {} + "\\n"))'.format( + "+ delimiter + " + print_attr if e_attr_names else "" + ) + query_replace["{EDGEATTRS}"] = print_query + print_query = """SET tmp_e = (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {} + "\\n", ""), + tmp_e_batch = tmp_e_batch UNION tmp_e""".format( + "+ delimiter + " + print_attr if e_attr_names else "" + ) + query_replace["{EDGEATTRSKAFKA}"] = print_query # Install query query_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -3195,11 +3915,11 @@ def _install_query(self, force: bool = False): def _start(self) -> None: # Create task and result queues - self._read_task_q = Queue(self.buffer_size * 2) + self._read_task_q = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() - self._start_request(True, "both") + self._start_request(True) # Start reading thread. if not self.is_hetero: @@ -3216,26 +3936,26 @@ def _start(self) -> None: e_attr_types[etype]["is_seed"] = "bool" v_attr_types = self._v_schema self._reader = Thread( - target=self._read_data, - args=( - self._exit_event, - self._read_task_q, - self._data_q, - "graph", - self.output_format, - self.v_in_feats, - self.v_out_labels, - self.v_extra_feats, - v_attr_types, - self.e_in_feats, - self.e_out_labels, - e_extra_feats, - e_attr_types, - self.add_self_loop, - self.delimiter, - True, - self.is_hetero, - self.callback_fn + target=self._read_graph_data, + kwargs=dict( + exit_event = self._exit_event, + in_q = self._read_task_q, + out_q = self._data_q, + batch_size = self.batch_size, + out_format = self.output_format, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = self.v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + delimiter = self.delimiter, + is_hetero = self.is_hetero, + callback_fn = self.callback_fn, + seed_type = "edge" ), ) self._reader.start() @@ -3408,7 +4128,12 @@ def __init__( else: self._seed_types = self._vtypes self._target_v_types = self._vtypes + if batch_size: + # batch size takes precedence over number of batches + self.batch_size = batch_size + self.num_batches = None + else: if not filter_by: num_vertices = sum(self._graph.getVertexCount(self._seed_types).values()) elif isinstance(filter_by, str): @@ -3424,12 +4149,9 @@ def __init__( ) else: raise ValueError("filter_by should be None, attribute name, or dict of {type name: attribute name}.") - self.num_batches = math.ceil(num_vertices / batch_size) - else: - # Otherwise, take the number of batches as is. + self.batch_size = math.ceil(num_vertices / num_batches) self.num_batches = num_batches self.filter_by = filter_by - self._payload["num_batches"] = self.num_batches if filter_by: if isinstance(filter_by, str): self._payload["filter_by"] = filter_by @@ -3437,8 +4159,7 @@ def __init__( attr = set(filter_by.values()) if len(attr) != 1: raise NotImplementedError("Filtering by different attributes for different vertex types is not supported. Please use the same attribute for different types.") - if batch_size: - self._payload["batch_size"] = batch_size + self._payload["filter_by"] = attr.pop() self._payload["shuffle"] = shuffle self._payload["v_types"] = self._vtypes self._payload["seed_types"] = self._seed_types @@ -3450,7 +4171,6 @@ def __init__( self._payload["clear_cache"] = clear_cache self._payload["delimiter"] = delimiter self._payload["input_vertices"] = [] - self._payload["num_heap_inserts"] = self.num_heap_inserts self._payload["num_edge_batches"] = self.num_edge_batches if e_types: self._payload["e_types"] = e_types @@ -3465,7 +4185,7 @@ def __init__( for v_type in self._vtypes: if anchor_attribute not in self._v_schema[v_type].keys(): to_change.append(v_type) - if to_change != []: + if to_change: print("Adding anchor attribute") ret = add_attribute(self._graph, "VERTEX", "BOOL", anchor_attribute, to_change, global_change=global_schema_change) print(ret) @@ -3476,7 +4196,7 @@ def __init__( if anchor_cache_attr not in self._v_schema[v_type].keys(): # add anchor cache attribute to_change.append(v_type) - if to_change != []: + if to_change: print("Adding anchor cache attribute") ret = add_attribute(self._graph, "VERTEX", "MAP", anchor_cache_attr, to_change, global_change=global_schema_change) print(ret) @@ -3558,34 +4278,49 @@ def _install_query(self, force: bool = False) -> str: if isinstance(self.attributes, dict): # Multiple vertex types + print_query_kafka = "" + print_query_http = "" print_query = "" for idx, vtype in enumerate(self._seed_types): v_attr_names = self.attributes.get(vtype, []) query_suffix.extend(v_attr_names) v_attr_types = self._v_schema[vtype] - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - else: - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query += "END" - query_replace["{VERTEXATTRS}"] = print_query + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query_http += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs {}+ "\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_kafka += """ + {} s.type == "{}" THEN + STRING msg = (s.type + delimiter + stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs {}+ "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END""".format("IF" if idx==0 else "ELSE IF", vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_http += """ + END""" + print_query_kafka += """ + END""" query_suffix = list(dict.fromkeys(query_suffix)) else: # Ignore vertex types v_attr_names = self.attributes query_suffix.extend(v_attr_names) v_attr_types = next(iter(self._v_schema.values())) - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs + delimiter + {} + "\\n")'.format( - print_attr - ) - else: - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs + "\\n")' - query_replace["{VERTEXATTRS}"] = print_query + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query_http = '@@v_batch += (stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs {}+ "\\n")'.format( + "+ delimiter + " + print_attr if v_attr_names else "" + ) + print_query_kafka = """ + STRING msg = (stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs {}+ "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END""".format("+ delimiter + " + print_attr if v_attr_names else "") + query_replace["{VERTEXATTRSHTTP}"] = print_query_http + query_replace["{VERTEXATTRSKAFKA}"] = print_query_kafka md5 = hashlib.md5() query_suffix.extend([self.distributed_query]) md5.update(json.dumps(query_suffix).encode()) @@ -3629,6 +4364,7 @@ def processRelContext(row): context = [self.idToIdx[str(x)] for x in context][:self._payload["max_rel_context"]] context = context + [self.idToIdx["PAD"] for x in range(len(context), self._payload["max_rel_context"])] return context + def processAnchors(row): try: ancs = row.split(" ")[:-1] @@ -3643,6 +4379,7 @@ def processAnchors(row): dists += [self.idToIdx["PAD"] for x in range(len(dists), self._payload["max_anchors"])] toks += [self.idToIdx["PAD"] for x in range(len(toks), self._payload["max_anchors"])] return {"ancs":toks, "dists": dists} + if self.is_hetero: for v_type in data.keys(): data[v_type]["relational_context"] = data[v_type]["relational_context"].apply(lambda x: processRelContext(x)) @@ -3665,11 +4402,11 @@ def processAnchors(row): def _start(self) -> None: # Create task and result queues - self._read_task_q = Queue(self.buffer_size * 2) + self._read_task_q = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() - self._start_request(False, "vertex") + self._start_request(False) # Start reading thread. if not self.is_hetero: @@ -3677,26 +4414,19 @@ def _start(self) -> None: else: v_attr_types = self._v_schema self._reader = Thread( - target=self._read_data, - args=( - self._exit_event, - self._read_task_q, - self._data_q, - "vertex", - self.output_format, - self.attributes, - {} if self.is_hetero else [], - {} if self.is_hetero else [], - v_attr_types, - [], - [], - [], - {}, - False, - self.delimiter, - False, - self.is_hetero, - self.nodepiece_process + target=self._read_vertex_data, + kwargs=dict( + exit_event = self._exit_event, + in_q = self._read_task_q, + out_q = self._data_q, + batch_size = self.batch_size, + v_in_feats = self.attributes, + v_out_labels = {} if self.is_hetero else [], + v_extra_feats = {} if self.is_hetero else [], + v_attr_types = v_attr_types, + delimiter = self.delimiter, + is_hetero = self.is_hetero, + callback_fn = self.nodepiece_process ), ) self._reader.start() @@ -3750,28 +4480,28 @@ def fetch(self, vertices: list) -> None: v_attr_types = next(iter(self._v_schema.values())) else: v_attr_types = self._v_schema - if self.is_hetero: - data = self._parse_data(resp[0]["vertex_batch"], - v_in_feats=attributes, - v_out_labels = {}, - v_extra_feats = {}, - v_attr_types=v_attr_types, - reindex=False, - delimiter = self.delimiter, - is_hetero=self.is_hetero, - primary_id=resp[0]["pids"], - callback_fn=self.nodepiece_process) + vertex_batch = set() + for i in resp: + if "pids" in i: + break + vertex_batch.add(i["data_batch"]) + data = BaseLoader._parse_vertex_data( + raw = vertex_batch, + v_in_feats = attributes, + v_out_labels = {} if self.is_hetero else [], + v_extra_feats = {} if self.is_hetero else [], + v_attr_types = v_attr_types, + delimiter = self.delimiter, + is_hetero = self.is_hetero + ) + if not self.is_hetero: + for column in data.columns: + data[column] = pd.to_numeric(data[column], errors="ignore") else: - data = self._parse_data(resp[0]["vertex_batch"], - v_in_feats=attributes, - v_out_labels = [], - v_extra_feats = [], - v_attr_types=v_attr_types, - reindex=False, - delimiter = self.delimiter, - is_hetero=self.is_hetero, - primary_id=resp[0]["pids"], - callback_fn=self.nodepiece_process) + for key in data: + for column in data[key].columns: + data[key][column] = pd.to_numeric(data[key][column], errors="ignore") + data = self.nodepiece_process(data) return data def precompute(self) -> None: @@ -3960,9 +4690,15 @@ def __init__( self._seed_types = list(filter_by.keys()) else: self._seed_types = self._vtypes + if set(self._seed_types) - set(self._vtypes): + raise ValueError("Seed type has to be one of the vertex types to retrieve") if batch_size: - # If batch_size is given, calculate the number of batches + # batch size takes precedence over number of batches + self.batch_size = batch_size + self.num_batches = None + else: + # If number of batches is given, calculate batch size if not filter_by: num_vertices = sum(self._graph.getVertexCount(self._seed_types).values()) elif isinstance(filter_by, str): @@ -3977,12 +4713,9 @@ def __init__( ) else: raise ValueError("filter_by should be None, attribute name, or dict of {type name: attribute name}.") - self.num_batches = math.ceil(num_vertices / batch_size) - else: - # Otherwise, take the number of batches as is. + self.batch_size = math.ceil(num_vertices / num_batches) self.num_batches = num_batches # Initialize parameters for the query - self._payload["num_batches"] = self.num_batches self._payload["num_hops"] = num_hops if filter_by: if isinstance(filter_by, str): @@ -3992,15 +4725,12 @@ def __init__( if len(attr) != 1: raise NotImplementedError("Filtering by different attributes for different vertex types is not supported. Please use the same attribute for different types.") self._payload["filter_by"] = attr.pop() - if batch_size: - self._payload["batch_size"] = batch_size self._payload["shuffle"] = shuffle self._payload["v_types"] = self._vtypes self._payload["e_types"] = self._etypes self._payload["seed_types"] = self._seed_types self._payload["delimiter"] = self.delimiter self._payload["input_vertices"] = [] - self._payload["num_heap_inserts"] = self.num_heap_inserts # Output self.add_self_loop = add_self_loop # Install query @@ -4033,19 +4763,23 @@ def _install_query(self, force: bool = False): + self.v_extra_feats.get(vtype, []) ) v_attr_types = self._v_schema[vtype] - if v_attr_names: - print_attr = print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query_seed += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + delimiter + "1\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - print_query_other += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + delimiter + "0\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - else: - print_query_seed += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + "1\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query_other += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + "0\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query_seed += "END" - print_query_other += "END" + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query_seed += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) {} + delimiter + "1\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", + vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_other += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) {} + delimiter + "0\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", + vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_seed += """ + END""" + print_query_other += """ + END""" query_replace["{SEEDVERTEXATTRS}"] = print_query_seed query_replace["{OTHERVERTEXATTRS}"] = print_query_other # Generate select for each type of neighbors @@ -4065,28 +4799,30 @@ def _install_query(self, force: bool = False): e_attr_types = self._e_schema[etype] if vtype!=e_attr_types["FromVertexTypeName"] and vtype!=e_attr_types["ToVertexTypeName"]: continue - if e_attr_names: - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")\n'.format( - "IF" if eidx==0 else "ELSE IF", etype, print_attr) - else: - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")\n'.format( - "IF" if eidx==0 else "ELSE IF", etype) + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_query += """ + {} e.type == "{}" THEN + @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {} + "\\n")"""\ + .format("IF" if eidx==0 else "ELSE IF", + etype, + "+ delimiter + " + print_attr if e_attr_names else "") eidx += 1 if print_query: - print_query += "END" - print_select += """seed{} = SELECT t - FROM seeds:s -(e_types:e)- {}:t - SAMPLE {} EDGE WHEN s.outdegree() >= 1 - ACCUM - IF NOT @@printed_edges.contains(e) THEN - @@printed_edges += e, - {} - END; - """.format(vidx, vtype, self.num_neighbors[vtype], print_query) + print_query += """ + END""" + print_select += """ + seed{} = SELECT t + FROM seeds:s -(e_types:e)- {}:t + SAMPLE {} EDGE WHEN s.outdegree() >= 1 + ACCUM + IF NOT @@printed_edges.contains(e) THEN + @@printed_edges += e, + {} + END;""".format(vidx, vtype, self.num_neighbors[vtype], print_query) seeds.append("seed{}".format(vidx)) vidx += 1 - print_select += "seeds = {};".format(" UNION ".join(seeds)) + print_select += """ + seeds = {};""".format(" UNION ".join(seeds)) query_replace["{SELECTNEIGHBORS}"] = print_select # Install query query_path = os.path.join( @@ -4095,15 +4831,21 @@ def _install_query(self, force: bool = False): "dataloaders", "hgt_loader.gsql", ) - return install_query_file(self._graph, query_path, query_replace, force=force, distributed=self.distributed_query) + sub_query_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "gsql", + "dataloaders", + "hgt_loader_sub.gsql", + ) + return install_query_files(self._graph, [sub_query_path, query_path], query_replace, force=force, distributed=[False, self.distributed_query]) def _start(self) -> None: # Create task and result queues - self._read_task_q = Queue(self.buffer_size * 2) + self._read_task_q = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() - self._start_request(True, "both") + self._start_request(True) # Start reading thread. if not self.is_hetero: @@ -4120,26 +4862,25 @@ def _start(self) -> None: v_attr_types[vtype]["is_seed"] = "bool" e_attr_types = self._e_schema self._reader = Thread( - target=self._read_data, - args=( - self._exit_event, - self._read_task_q, - self._data_q, - "graph", - self.output_format, - self.v_in_feats, - self.v_out_labels, - v_extra_feats, - v_attr_types, - self.e_in_feats, - self.e_out_labels, - self.e_extra_feats, - e_attr_types, - self.add_self_loop, - self.delimiter, - True, - self.is_hetero, - self.callback_fn + target=self._read_graph_data, + kwargs=dict( + exit_event = self._exit_event, + in_q = self._read_task_q, + out_q = self._data_q, + batch_size = self.batch_size, + out_format = self.output_format, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + delimiter = self.delimiter, + is_hetero = self.is_hetero, + callback_fn = self.callback_fn ), ) self._reader.start() @@ -4176,7 +4917,6 @@ def fetch(self, vertices: list) -> None: _payload = {} _payload["v_types"] = self._payload["v_types"] _payload["e_types"] = self._payload["e_types"] - _payload["num_batches"] = 1 _payload["num_hops"] = self._payload["num_hops"] _payload["delimiter"] = self._payload["delimiter"] _payload["input_vertices"] = [] @@ -4201,11 +4941,15 @@ def fetch(self, vertices: list) -> None: v_attr_types[vtype]["is_seed"] = "bool" v_attr_types[vtype]["primary_id"] = "str" e_attr_types = self._e_schema - i = resp[0] - data = self._parse_data( - raw = (i["vertex_batch"], i["edge_batch"]), - in_format = "graph", - out_format = self.output_format, + vertex_batch = set() + edge_batch = set() + for i in resp: + if "pids" in i: + break + vertex_batch.update(i["vertex_batch"].splitlines()) + edge_batch.update(i["edge_batch"].splitlines()) + data = self._parse_graph_data_to_df( + raw = (vertex_batch, edge_batch), v_in_feats = self.v_in_feats, v_out_labels = self.v_out_labels, v_extra_feats = v_extra_feats, @@ -4214,12 +4958,118 @@ def fetch(self, vertices: list) -> None: e_out_labels = self.e_out_labels, e_extra_feats = self.e_extra_feats, e_attr_types = e_attr_types, - add_self_loop = self.add_self_loop, delimiter = self.delimiter, - reindex = True, primary_id = i["pids"], is_hetero = self.is_hetero, - callback_fn = self.callback_fn ) + if self.output_format == "dataframe" or self.output_format== "df": + vertices, edges = data + if not self.is_hetero: + for column in vertices.columns: + vertices[column] = pd.to_numeric(vertices[column], errors="ignore") + for column in edges.columns: + edges[column] = pd.to_numeric(edges[column], errors="ignore") + else: + for key in vertices: + for column in vertices[key].columns: + vertices[key][column] = pd.to_numeric(vertices[key][column], errors="ignore") + for key in edges: + for column in edges[key].columns: + edges[key][column] = pd.to_numeric(edges[key][column], errors="ignore") + data = (vertices, edges) + elif self.output_format == "pyg": + try: + import torch + except ImportError: + raise ImportError( + "PyTorch is not installed. Please install it to use PyG or DGL output." + ) + try: + import torch_geometric as pyg + except ImportError: + raise ImportError( + "PyG is not installed. Please install PyG to use PyG format." + ) + data = BaseLoader._parse_df_to_pyg( + raw = data, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + is_hetero = self.is_hetero, + torch = torch, + pyg = pyg + ) + elif self.output_format == "dgl": + try: + import torch + except ImportError: + raise ImportError( + "PyTorch is not installed. Please install it to use PyG or DGL output." + ) + try: + import dgl + except ImportError: + raise ImportError( + "DGL is not installed. Please install DGL to use DGL format." + ) + data = BaseLoader._parse_df_to_dgl( + raw = data, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + is_hetero = self.is_hetero, + torch = torch, + dgl= dgl + ) + elif self.output_format == "spektral" and self.is_hetero==False: + try: + import tensorflow as tf + except ImportError: + raise ImportError( + "Tensorflow is not installed. Please install it to use spektral output." + ) + try: + import scipy + except ImportError: + raise ImportError( + "scipy is not installed. Please install it to use spektral output." + ) + try: + import spektral + except ImportError: + raise ImportError( + "Spektral is not installed. Please install it to use spektral output." + ) + data = BaseLoader._parse_df_to_spektral( + raw = data, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + is_hetero = self.is_hetero, + scipy = scipy, + spektral = spektral + ) + else: + raise NotImplementedError + if self.callback_fn: + data = self.callback_fn(data) # Return data return data diff --git a/pyTigerGraph/gds/gds.py b/pyTigerGraph/gds/gds.py index 55d597c1..6a67a123 100644 --- a/pyTigerGraph/gds/gds.py +++ b/pyTigerGraph/gds/gds.py @@ -938,7 +938,9 @@ def edgeNeighborLoader( timeout: int = 300000, callback_fn: Callable = None, reinstall_query: bool = False, - distributed_query: bool = False + distributed_query: bool = False, + num_machines: int = 1, + num_segments: int = 20 ) -> EdgeNeighborLoader: """Returns an `EdgeNeighborLoader` instance. An `EdgeNeighborLoader` instance performs neighbor sampling from all edges in the graph in batches in the following manner: @@ -1098,7 +1100,9 @@ def edgeNeighborLoader( "delimiter": delimiter, "timeout": timeout, "callback_fn": callback_fn, - "distributed_query": distributed_query + "distributed_query": distributed_query, + "num_machines": num_machines, + "num_segments": num_segments } if self.kafkaConfig: params.update(self.kafkaConfig) @@ -1130,7 +1134,9 @@ def edgeNeighborLoader( "delimiter": delimiter, "timeout": timeout, "callback_fn": callback_fn, - "distributed_query": distributed_query + "distributed_query": distributed_query, + "num_machines": num_machines, + "num_segments": num_segments } if self.kafkaConfig: params.update(self.kafkaConfig) diff --git a/pyTigerGraph/gds/gsql/dataloaders/edge_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/edge_loader.gsql index fc0085f8..6f0ce37e 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/edge_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/edge_loader.gsql @@ -1,12 +1,11 @@ CREATE QUERY edge_loader_{QUERYSUFFIX}( - INT batch_size, - INT num_batches=1, - BOOL shuffle=FALSE, STRING filter_by, SET e_types, STRING delimiter, + BOOL shuffle=FALSE, + INT num_chunks=2, STRING kafka_address="", - STRING kafka_topic, + STRING kafka_topic="", INT kafka_topic_partitions=1, STRING kafka_max_size="104857600", INT kafka_timeout=300000, @@ -41,198 +40,72 @@ CREATE QUERY edge_loader_{QUERYSUFFIX}( sasl_password : SASL password for Kafka. ssl_ca_location: Path to CA certificate for verifying the Kafka broker key */ - TYPEDEF TUPLE ID_Tuple; - INT num_vertices; - INT kafka_errcode; SumAccum @tmp_id; - SumAccum @@kafka_error; - UINT producer; - MapAccum @@edges_sampled; - SetAccum @valid_v_out; - SetAccum @valid_v_in; - - # Initialize Kafka producer - IF kafka_address != "" THEN - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); - END; - # Shuffle vertex ID if needed start = {ANY}; + # Filter seeds if needed + seeds = SELECT s + FROM start:s -(e_types:e)- :t + WHERE filter_by is NULL OR e.getAttr(filter_by, "BOOL") + POST-ACCUM s.@tmp_id = getvid(s) + POST-ACCUM t.@tmp_id = getvid(t); + # Shuffle vertex ID if needed IF shuffle THEN - num_vertices = start.size(); + INT num_vertices = seeds.size(); res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = floor(rand()*num_vertices); - ELSE - res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = getvid(s); - END; - - SumAccum @@num_edges; - IF filter_by IS NOT NULL THEN - res = SELECT s - FROM start:s -(e_types:e)- :t WHERE e.getAttr(filter_by, "BOOL") - ACCUM - IF e.isDirected() THEN # we divide by two later to correct for undirected edges being counted twice, need to count directed edges twice to get correct count - @@num_edges += 2 - ELSE - @@num_edges += 1 - END; - ELSE - res = SELECT s - FROM start:s -(e_types:e)- :t - ACCUM - IF e.isDirected() THEN # we divide by two later to correct for undirected edges being counted twice, need to count directed edges twice to get correct count - @@num_edges += 2 - ELSE - @@num_edges += 1 - END; - END; - INT batch_s; - IF batch_size IS NULL THEN - batch_s = ceil((@@num_edges/2)/num_batches); - ELSE - batch_s = batch_size; + FROM seeds:s + POST-ACCUM s.@tmp_id = floor(rand()*num_vertices) + LIMIT 1; END; # Generate batches - FOREACH batch_id IN RANGE[0, num_batches-1] DO - SumAccum @@e_batch; - SetAccum @@seeds; - SetAccum @@targets; - HeapAccum (1, tmp_id ASC) @@batch_heap; - @@batch_heap.resize(batch_s); - start = {ANY}; - IF filter_by IS NOT NULL THEN - res = - SELECT s - FROM start:s -(e_types:e)- :t - WHERE e.getAttr(filter_by, "BOOL") - AND - ((e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id)))) - OR - (NOT e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id))) - AND ((s.@tmp_id >= t.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id))))) - ACCUM - IF t.@tmp_id >= s.@tmp_id THEN - @@batch_heap += ID_Tuple(((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id), s, t) - ELSE - @@batch_heap += ID_Tuple(((s.@tmp_id*s.@tmp_id)+t.@tmp_id), s, t) - END; - - FOREACH elem IN @@batch_heap DO - SetAccum @@src; - @@seeds += elem.src; - @@targets += elem.tgt; - @@src += elem.src; - src = {@@src}; - res = SELECT s FROM src:s -(e_types:e)- :t - WHERE t == elem.tgt - ACCUM - s.@valid_v_out += elem.tgt, - t.@valid_v_in += elem.src; - END; - start = {@@seeds}; - res = - SELECT s - FROM start:s -(e_types:e)- :t - WHERE t in @@targets AND s IN t.@valid_v_in AND t IN s.@valid_v_out - ACCUM - {EDGEATTRS}, - IF t.@tmp_id >= s.@tmp_id THEN - @@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE), - IF NOT e.isDirected() THEN - @@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id) -> TRUE) - END - ELSE - @@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id) -> TRUE), - IF NOT e.isDirected() THEN - @@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE) - END - END - POST-ACCUM - s.@valid_v_in.clear(), s.@valid_v_out.clear() - POST-ACCUM - t.@valid_v_in.clear(), t.@valid_v_out.clear(); - ELSE - res = - SELECT s - FROM start:s -(e_types:e)- :t - WHERE ((e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id)))) - OR - (NOT e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id))) - AND ((s.@tmp_id >= t.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id))))) - ACCUM - IF t.@tmp_id >= s.@tmp_id THEN - @@batch_heap += ID_Tuple(((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id), s, t) - ELSE - @@batch_heap += ID_Tuple(((s.@tmp_id*s.@tmp_id)+t.@tmp_id), s, t) - END; - - FOREACH elem IN @@batch_heap DO - SetAccum @@src; - @@seeds += elem.src; - @@targets += elem.tgt; - @@src += elem.src; - src = {@@src}; - res = SELECT s FROM src:s -(e_types:e)- :t - WHERE t == elem.tgt - ACCUM - s.@valid_v_out += elem.tgt, - t.@valid_v_in += elem.src; - END; - start = {@@seeds}; - res = - SELECT s - FROM start:s -(e_types:e)- :t - WHERE t in @@targets AND s IN t.@valid_v_in AND t IN s.@valid_v_out - ACCUM - {EDGEATTRS}, - IF t.@tmp_id >= s.@tmp_id THEN - @@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE), - IF NOT e.isDirected() THEN - @@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id) -> TRUE) - END - ELSE - @@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id) -> TRUE), - IF NOT e.isDirected() THEN - @@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE) - END - END - POST-ACCUM - s.@valid_v_in.clear(), s.@valid_v_out.clear() - POST-ACCUM - t.@valid_v_in.clear(), t.@valid_v_out.clear(); + # If using kafka to export + IF kafka_address != "" THEN + SumAccum @@kafka_error; + + # Initialize Kafka producer + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal); + + FOREACH chunk IN RANGE[0, num_chunks-1] DO + res = SELECT s + FROM seeds:s -(e_types:e)- :t + WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ACCUM + {EDGEATTRSKAFKA} + LIMIT 1; END; - # Export batch - IF kafka_address != "" THEN - # Write to kafka - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "edge_batch_" + stringify(batch_id), @@e_batch); - IF kafka_errcode != 0 THEN - @@kafka_error += ("Error sending edge batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n"); END; - ELSE - # Add to response - PRINT @@e_batch AS edge_batch; END; - END; - IF kafka_address != "" THEN - kafka_errcode = close_kafka_producer(producer, kafka_timeout); - IF kafka_errcode != 0 THEN + + INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); + IF kafka_errcode!=0 THEN @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; PRINT @@kafka_error as kafkaError; + # Else return as http response + ELSE + FOREACH chunk IN RANGE[0, num_chunks-1] DO + ListAccum @@e_batch; + res = SELECT s + FROM seeds:s -(e_types:e)- :t + WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ACCUM + {EDGEATTRSHTTP} + LIMIT 1; + + FOREACH i IN @@e_batch DO + PRINT i as data_batch; + END; + END; END; } \ No newline at end of file diff --git a/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql index 2948ad81..b6b2fb08 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql @@ -1,6 +1,4 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}( - INT batch_size, - INT num_batches=1, INT num_neighbors=10, INT num_hops=2, BOOL shuffle=FALSE, @@ -9,8 +7,11 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}( SET e_types, SET seed_types, STRING delimiter, + INT num_chunks=2, + INT num_machines=1, + INT num_segments=20, STRING kafka_address="", - STRING kafka_topic, + STRING kafka_topic="", INT kafka_topic_partitions=1, STRING kafka_max_size="104857600", INT kafka_timeout=300000, @@ -49,227 +50,163 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}( sasl_password : SASL password for Kafka. ssl_ca_location: Path to CA certificate for verifying the Kafka broker key. */ - TYPEDEF TUPLE ID_Tuple; - INT num_vertices; - INT kafka_errcode; SumAccum @tmp_id; SumAccum @@kafka_error; - UINT producer; - MapAccum @@edges_sampled; - SetAccum @valid_v_out; - SetAccum @valid_v_in; + SetAccum @seeds; + MapAccum> @@mid_to_vid; # This tmp accumulator maps machine ID to the smallest vertex ID on the machine. + MapAccum @@mid_to_producer; + SumAccum @kafka_producer_id; - # Initialize Kafka producer - IF kafka_address != "" THEN - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); - END; - - # Shuffle vertex ID if needed start = {v_types}; + # Filter seeds if needed + start = SELECT s + FROM start:s -(seed_types:e)- v_types:t + WHERE filter_by is NULL OR e.getAttr(filter_by, "BOOL") + POST-ACCUM s.@tmp_id = getvid(s) + POST-ACCUM t.@tmp_id = getvid(t); + # Shuffle vertex ID if needed IF shuffle THEN - num_vertices = start.size(); + INT num_vertices = start.size(); res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = floor(rand()*num_vertices); - ELSE - res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = getvid(s); + FROM start:s + POST-ACCUM s.@tmp_id = floor(rand()*num_vertices) + LIMIT 1; END; - SumAccum @@num_edges; - IF filter_by IS NOT NULL THEN - res = SELECT s - FROM start:s -(seed_types:e)- v_types:t WHERE e.getAttr(filter_by, "BOOL") + # If using kafka to export + IF kafka_address != "" THEN + # We generate a vertex set that contains exactly one vertex per machine. + machine_set = + SELECT s + FROM start:s ACCUM - IF e.isDirected() THEN # we divide by two later to correct for undirected edges being counted twice, need to count directed edges twice to get correct count - @@num_edges += 2 - ELSE - @@num_edges += 1 - END; - ELSE - res = SELECT s - FROM start:s -(seed_types:e)- v_types:t + INT mid = (getvid(s) >> num_segments & 31) % num_machines, + @@mid_to_vid += (mid -> getvid(s)) + HAVING @@mid_to_vid.get((getvid(s) >> num_segments & 31) % num_machines) == getvid(s); + @@mid_to_vid.clear(); + # Initialize Kafka producer per machine + res = SELECT s + FROM machine_set:s ACCUM - IF e.isDirected() THEN # we divide by two later to correct for undirected edges being counted twice, need to count directed edges twice to get correct count - @@num_edges += 2 - ELSE - @@num_edges += 1 - END; - END; - INT batch_s; - IF batch_size IS NULL THEN - batch_s = ceil((@@num_edges/2)/num_batches); - ELSE - batch_s = batch_size; + INT mid = (getvid(s) >> num_segments & 31) % num_machines, + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal), + @@mid_to_producer += (mid -> producer); + res = SELECT s + FROM start:s + ACCUM + INT mid = (getvid(s) >> num_segments & 31) % num_machines, + s.@kafka_producer_id += @@mid_to_producer.get(mid); END; - - # Generate batches - FOREACH batch_id IN RANGE[0, num_batches-1] DO - SetAccum @@vertices; - SumAccum @@e_batch; - SumAccum @@v_batch; - SetAccum @@printed_edges; - SetAccum @@seeds; - SetAccum @@targets; - HeapAccum (1, tmp_id ASC) @@batch_heap; - @@batch_heap.resize(batch_s); - - start = {v_types}; - IF filter_by IS NOT NULL THEN - res = - SELECT s - FROM start:s -(seed_types:e)- v_types:t - WHERE e.getAttr(filter_by, "BOOL") - AND - ((e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id)))) - OR - (NOT e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id))) - AND ((s.@tmp_id >= t.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id))))) - - ACCUM - IF t.@tmp_id >= s.@tmp_id THEN - @@batch_heap += ID_Tuple(((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id), s, t) - ELSE - @@batch_heap += ID_Tuple(((s.@tmp_id*s.@tmp_id)+t.@tmp_id), s, t) - END; - - FOREACH elem IN @@batch_heap DO - SetAccum @@src; - @@seeds += elem.src; - @@targets += elem.tgt; - @@src += elem.src; - src = {@@src}; - res = SELECT s FROM src:s -(seed_types:e)- v_types:t - WHERE t == elem.tgt - ACCUM - s.@valid_v_out += elem.tgt, - t.@valid_v_in += elem.src; - END; - start = {@@seeds}; - res = - SELECT s - FROM start:s -(seed_types:e)- v_types:t - WHERE t in @@targets AND s IN t.@valid_v_in AND t IN s.@valid_v_out - ACCUM - {SEEDEDGEATTRS}, - @@printed_edges += e, - @@vertices += s, - @@vertices += t, - IF t.@tmp_id >= s.@tmp_id THEN - @@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE) - ELSE - @@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id) -> TRUE) - END; - ELSE - res = - SELECT s - FROM start:s -(seed_types:e)- v_types:t - WHERE ((e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id)))) - OR - (NOT e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id))) - AND ((s.@tmp_id >= t.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id))))) - - ACCUM - IF t.@tmp_id >= s.@tmp_id THEN - @@batch_heap += ID_Tuple(((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id), s, t) - ELSE - @@batch_heap += ID_Tuple(((s.@tmp_id*s.@tmp_id)+t.@tmp_id), s, t) - END; - - FOREACH elem IN @@batch_heap DO - SetAccum @@src; - @@seeds += elem.src; - @@targets += elem.tgt; - @@src += elem.src; - src = {@@src}; - res = SELECT s FROM src:s -(seed_types:e)- v_types:t - WHERE t == elem.tgt - ACCUM - s.@valid_v_out += elem.tgt, - t.@valid_v_in += elem.src; - END; - start = {@@seeds}; - res = - SELECT s - FROM start:s -(seed_types:e)- v_types:t - WHERE t in @@targets AND s IN t.@valid_v_in AND t IN s.@valid_v_out - ACCUM - {SEEDEDGEATTRS}, - @@printed_edges += e, - @@vertices += s, - @@vertices += t, - IF t.@tmp_id >= s.@tmp_id THEN - @@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE), - IF NOT e.isDirected() THEN - @@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id) -> TRUE) - END - ELSE - @@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id) -> TRUE), - IF NOT e.isDirected() THEN - @@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE) - END - END; - END; - # Get seed vertices - v_in_batch = @@vertices; - seeds = - SELECT s - FROM v_in_batch:s + FOREACH chunk IN RANGE[0, num_chunks-1] DO + MapAccum> @@v_batch; + MapAccum> @@e_batch; + + # Collect neighborhood data for each vertex + seed1 = SELECT s + FROM start:s -(seed_types:e)- v_types:t + WHERE (filter_by IS NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ; + seed2 = SELECT t + FROM start:s -(seed_types:e)- v_types:t + WHERE (filter_by IS NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ; + seeds = seed1 UNION seed2; + seeds = SELECT s + FROM seeds:s POST-ACCUM - s.@valid_v_in.clear(), s.@valid_v_out.clear(), - {VERTEXATTRS}; - # Get neighbors of seeeds - FOREACH i IN RANGE[1, num_hops] DO + s.@seeds += s, + {SEEDVERTEXATTRS}; + FOREACH hop IN RANGE[1, num_hops] DO seeds = SELECT t - FROM seeds:s -(e_types:e)- v_types:t - SAMPLE num_neighbors EDGE WHEN s.outdegree() >= 1 - ACCUM - IF NOT @@printed_edges.contains(e) THEN - {OTHEREDGEATTRS}, - @@printed_edges += e - END; - attr = - SELECT s - FROM seeds:s + FROM seeds:s -(e_types:e)- v_types:t + SAMPLE num_neighbors EDGE WHEN s.outdegree() >= 1 + ACCUM + t.@seeds += s.@seeds, + FOREACH tmp_seed in s.@seeds DO + {EDGEATTRS} + END; + seeds = SELECT s + FROM seeds:s POST-ACCUM - IF NOT @@vertices.contains(s) THEN - {VERTEXATTRS}, - @@vertices += s + FOREACH tmp_seed in s.@seeds DO + {OTHERVERTEXATTRS} END; - END; - IF kafka_address != "" THEN - # Write to kafka - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "vertex_batch_" + stringify(batch_id), @@v_batch); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending vertex batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); - END; - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "edge_batch_" + stringify(batch_id), @@e_batch); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending edge batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); - END; + END; + # Clear all accums + all_v = {v_types}; + res = SELECT s + FROM all_v:s + POST-ACCUM s.@seeds.clear() + LIMIT 1; + + # Generate output for each edge + # If use kafka to export + IF kafka_address != "" THEN + res = SELECT s + FROM seed1:s -(seed_types:e)- v_types:t + WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ACCUM + INT part_num = (getvid(s)+getvid(t))%kafka_topic_partitions, + STRING batch_id = stringify(getvid(s))+"_"+e.type+"_"+stringify(getvid(t)), + SET tmp_v_batch = @@v_batch.get(s) + @@v_batch.get(t), + INT kafka_errcode = write_to_kafka(s.@kafka_producer_id, kafka_topic, part_num, "vertex_batch_"+batch_id, stringify(tmp_v_batch)), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending vertex batch for "+batch_id+": "+stringify(kafka_errcode) + "\n") + END, + SET tmp_e_batch = @@e_batch.get(s) + @@e_batch.get(t), + {EDGEATTRSKAFKA}, + kafka_errcode = write_to_kafka(s.@kafka_producer_id, kafka_topic, part_num, "edge_batch_"+batch_id, stringify(tmp_e_batch)), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending edge batch for "+batch_id+ ": "+ stringify(kafka_errcode) + "\n") + END + LIMIT 1; + # Else return as http response ELSE - # Add to response - PRINT @@v_batch AS vertex_batch, @@e_batch AS edge_batch; + MapAccum @@v_data; + MapAccum @@e_data; + res = SELECT s + FROM seed1:s -(seed_types:e)- v_types:t + WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ACCUM + STRING batch_id = stringify(getvid(s))+"_"+e.type+"_"+stringify(getvid(t)), + SET tmp_v_batch = @@v_batch.get(s) + @@v_batch.get(t), + @@v_data += (batch_id -> stringify(tmp_v_batch)), + SET tmp_e_batch = @@e_batch.get(s) + @@e_batch.get(t), + {EDGEATTRSKAFKA}, + @@e_data += (batch_id -> stringify(tmp_e_batch)) + LIMIT 1; + + FOREACH (k,v) IN @@v_data DO + PRINT v as vertex_batch, @@e_data.get(k) as edge_batch, k AS seed; + END; END; END; - IF kafka_address != "" THEN - kafka_errcode = close_kafka_producer(producer, kafka_timeout); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); - END; + + IF kafka_address != "" THEN + res = SELECT s + FROM machine_set:s + WHERE (getvid(s) >> num_segments & 31) % num_machines == 0 + ACCUM + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(s.@kafka_producer_id, kafka_topic, i, "STOP", ""), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n") + END + END; + + res = SELECT s + FROM machine_set:s + ACCUM + INT kafka_errcode = close_kafka_producer(s.@kafka_producer_id, kafka_timeout), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n") + END; PRINT @@kafka_error as kafkaError; END; } \ No newline at end of file diff --git a/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader_sub.gsql b/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader_sub.gsql new file mode 100644 index 00000000..23f06ff5 --- /dev/null +++ b/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader_sub.gsql @@ -0,0 +1,46 @@ +CREATE QUERY edge_nei_loader_sub_{QUERYSUFFIX} (VERTEX u, VERTEX v, STRING delimiter, INT num_hops, INT num_neighbors, SET e_types, SET v_types, STRING seed_type) +RETURNS (ListAccum) +SYNTAX V1 +{ + SumAccum @@v_batch; + SumAccum @@e_batch; + SetAccum @@printed_vertices; + SetAccum @@printed_edges; + ListAccum @@ret; + + source = {u}; + res = SELECT s + FROM source:s -(seed_type:e)- v_types:t + WHERE t==v + ACCUM + @@printed_edges += e, + {SEEDEDGEATTRS}; + + start = {u,v}; + res = SELECT s + FROM start:s + POST-ACCUM + @@printed_vertices += s, + {VERTEXATTRS}; + + FOREACH i IN RANGE[1, num_hops] DO + start = SELECT t + FROM start:s -(e_types:e)- v_types:t + SAMPLE num_neighbors EDGE WHEN s.outdegree() >= 1 + ACCUM + IF NOT @@printed_edges.contains(e) THEN + @@printed_edges += e, + {OTHEREDGEATTRS} + END; + start = SELECT s + FROM start:s + POST-ACCUM + IF NOT @@printed_vertices.contains(s) THEN + @@printed_vertices += s, + {VERTEXATTRS} + END; + END; + @@ret += @@v_batch; + @@ret += @@e_batch; + RETURN @@ret; +} diff --git a/pyTigerGraph/gds/gsql/dataloaders/graph_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/graph_loader.gsql index b4035b90..fca2a83e 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/graph_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/graph_loader.gsql @@ -1,12 +1,12 @@ CREATE QUERY graph_loader_{QUERYSUFFIX}( - INT num_batches=1, - BOOL shuffle=FALSE, STRING filter_by, SET v_types, SET e_types, STRING delimiter, + BOOL shuffle=FALSE, + INT num_chunks=2, STRING kafka_address="", - STRING kafka_topic, + STRING kafka_topic="", INT kafka_topic_partitions=1, STRING kafka_max_size="104857600", INT kafka_timeout=300000, @@ -40,93 +40,83 @@ CREATE QUERY graph_loader_{QUERYSUFFIX}( sasl_password : SASL password for Kafka. ssl_ca_location: Path to CA certificate for verifying the Kafka broker key. */ - INT num_vertices; - INT kafka_errcode; SumAccum @tmp_id; - SumAccum @@kafka_error; - UINT producer; - # Initialize Kafka producer - IF kafka_address != "" THEN - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); - END; - - # Shuffle vertex ID if needed start = {v_types}; + # Filter seeds if needed + seeds = SELECT s + FROM start:s -(e_types:e)- v_types:t + WHERE filter_by is NULL OR e.getAttr(filter_by, "BOOL") + POST-ACCUM s.@tmp_id = getvid(s) + POST-ACCUM t.@tmp_id = getvid(t); + # Shuffle vertex ID if needed IF shuffle THEN - IF filter_by IS NOT NULL THEN - start = SELECT s FROM start:s WHERE s.getAttr(filter_by, "BOOL"); - END; - num_vertices = start.size(); + INT num_vertices = seeds.size(); res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = floor(rand()*num_vertices); - ELSE - res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = getvid(s); + FROM seeds:s + POST-ACCUM s.@tmp_id = floor(rand()*num_vertices) + LIMIT 1; END; # Generate batches - FOREACH batch_id IN RANGE[0, num_batches-1] DO - SetAccum @@vertices; - SumAccum @@e_batch; - SumAccum @@v_batch; - - start = {v_types}; - IF filter_by IS NOT NULL THEN - res = - SELECT s - FROM start:s -(e_types:e)- v_types:t - WHERE e.getAttr(filter_by, "BOOL") and ((s.@tmp_id+t.@tmp_id)*(s.@tmp_id+t.@tmp_id+1)/2+t.@tmp_id)%num_batches==batch_id - ACCUM - {EDGEATTRS}, - @@vertices += s, - @@vertices += t; - ELSE - res = - SELECT s - FROM start:s -(e_types:e)- v_types:t - WHERE ((s.@tmp_id+t.@tmp_id)*(s.@tmp_id+t.@tmp_id+1)/2+t.@tmp_id)%num_batches==batch_id - ACCUM - {EDGEATTRS}, - @@vertices += s, - @@vertices += t; + # If using kafka to export + IF kafka_address != "" THEN + SumAccum @@kafka_error; + + # Initialize Kafka producer + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal); + + FOREACH chunk IN RANGE[0, num_chunks-1] DO + res = SELECT s + FROM seeds:s -(e_types:e)- v_types:t + WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ACCUM + STRING s_msg = graph_loader_sub_{QUERYSUFFIX}(s, delimiter), + STRING t_msg = graph_loader_sub_{QUERYSUFFIX}(t, delimiter), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "vertex_batch_" + stringify(getvid(s))+e.type+stringify(getvid(t)), s_msg+t_msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending vertex data for edge " + stringify(getvid(s))+e.type+stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") + END, + {EDGEATTRSKAFKA} + LIMIT 1; END; - - # Get vertex attributes - v_in_batch = @@vertices; - attr = - SELECT s - FROM v_in_batch:s - POST-ACCUM - {VERTEXATTRS}; - IF kafka_address != "" THEN - # Write to kafka - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "vertex_batch_" + stringify(batch_id), @@v_batch); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending vertex batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); - END; - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "edge_batch_" + stringify(batch_id), @@e_batch); + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending edge batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n"); END; - ELSE - # Add to response - PRINT @@v_batch AS vertex_batch, @@e_batch AS edge_batch; END; - END; - IF kafka_address != "" THEN - kafka_errcode = close_kafka_producer(producer, kafka_timeout); + + INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); IF kafka_errcode!=0 THEN @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; PRINT @@kafka_error as kafkaError; + # Else return as http response + ELSE + FOREACH chunk IN RANGE[0, num_chunks-1] DO + MapAccum @@v_batch; + MapAccum @@e_batch; + + res = SELECT s + FROM seeds:s -(e_types:e)- v_types:t + WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ACCUM + STRING s_msg = graph_loader_sub_{QUERYSUFFIX}(s, delimiter), + STRING t_msg = graph_loader_sub_{QUERYSUFFIX}(t, delimiter), + @@v_batch += (stringify(getvid(s))+e.type+stringify(getvid(t)) -> s_msg+t_msg), + {EDGEATTRSHTTP} + LIMIT 1; + + FOREACH (k,v) IN @@v_batch DO + PRINT v as vertex_batch, @@e_batch.get(k) as edge_batch; + END; + END; END; } \ No newline at end of file diff --git a/pyTigerGraph/gds/gsql/dataloaders/graph_loader_sub.gsql b/pyTigerGraph/gds/gsql/dataloaders/graph_loader_sub.gsql new file mode 100644 index 00000000..11eb6993 --- /dev/null +++ b/pyTigerGraph/gds/gsql/dataloaders/graph_loader_sub.gsql @@ -0,0 +1,11 @@ +CREATE QUERY graph_loader_sub_{QUERYSUFFIX} (VERTEX v, STRING delimiter) +RETURNS (STRING) +{ + STRING ret; + start = {v}; + res = SELECT s + FROM start:s + POST-ACCUM + {VERTEXATTRS}; + RETURN ret; +} diff --git a/pyTigerGraph/gds/gsql/dataloaders/hgt_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/hgt_loader.gsql index 4e43e28f..401a2a6b 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/hgt_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/hgt_loader.gsql @@ -1,7 +1,5 @@ CREATE QUERY hgt_loader_{QUERYSUFFIX}( SET input_vertices, - INT batch_size, - INT num_batches=1, INT num_hops=2, BOOL shuffle=FALSE, STRING filter_by, @@ -9,8 +7,9 @@ CREATE QUERY hgt_loader_{QUERYSUFFIX}( SET e_types, SET seed_types, STRING delimiter, + INT num_chunks=2, STRING kafka_address="", - STRING kafka_topic, + STRING kafka_topic="", INT kafka_topic_partitions=1, STRING kafka_max_size="104857600", INT kafka_timeout=300000, @@ -25,8 +24,7 @@ CREATE QUERY hgt_loader_{QUERYSUFFIX}( STRING ssl_endpoint_identification_algorithm="", STRING sasl_kerberos_service_name="", STRING sasl_kerberos_keytab="", - STRING sasl_kerberos_principal="", - INT num_heap_inserts = 10 + STRING sasl_kerberos_principal="" ) SYNTAX V1 { /* This query generates the neighborhood subgraphs of given seed vertices (i.e., `input_vertices`). @@ -51,142 +49,138 @@ CREATE QUERY hgt_loader_{QUERYSUFFIX}( sasl_password : SASL password for Kafka. ssl_ca_location: Path to CA certificate for verifying the Kafka broker key. */ - TYPEDEF TUPLE ID_Tuple; - INT num_vertices; - INT kafka_errcode; SumAccum @tmp_id; - SumAccum @@kafka_error; - UINT producer; - INT batch_s; - OrAccum @prev_sampled; - # Initialize Kafka producer - IF kafka_address != "" THEN - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); - END; - - # Shuffle vertex ID if needed + # If getting all vertices of given types IF input_vertices.size()==0 THEN start = {seed_types}; - IF filter_by IS NOT NULL THEN - start = SELECT s FROM start:s WHERE s.getAttr(filter_by, "BOOL"); - END; + # Filter seeds if needed + seeds = SELECT s + FROM start:s + WHERE filter_by is NULL OR s.getAttr(filter_by, "BOOL"); + # Shuffle vertex ID if needed IF shuffle THEN - num_vertices = start.size(); + INT num_vertices = seeds.size(); res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = floor(rand()*num_vertices); + FROM seeds:s + POST-ACCUM s.@tmp_id = floor(rand()*num_vertices) + LIMIT 1; ELSE res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = getvid(s); + FROM seeds:s + POST-ACCUM s.@tmp_id = getvid(s) + LIMIT 1; END; - END; - IF batch_size IS NULL THEN - batch_s = ceil(res.size()/num_batches); - ELSE - batch_s = batch_size; - END; - # Generate subgraphs - FOREACH batch_id IN RANGE[0, num_batches-1] DO - SumAccum @@v_batch; - SumAccum @@e_batch; - SetAccum @@printed_vertices; - SetAccum @@printed_edges; - SetAccum @@seeds; - # Get seeds - IF input_vertices.size()==0 THEN - start = {seed_types}; - HeapAccum (1, tmp_id ASC) @@batch_heap; - @@batch_heap.resize(batch_s); - IF filter_by IS NOT NULL THEN - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled AND s.getAttr(filter_by, "BOOL") - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM seeds:s - POST-ACCUM - s.@prev_sampled += TRUE, - {SEEDVERTEXATTRS}, - @@printed_vertices += s; - ELSE - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM start:s - POST-ACCUM - s.@prev_sampled += TRUE, - {SEEDVERTEXATTRS}, - @@printed_vertices += s; - END; - ELSE - start = input_vertices; - seeds = SELECT s - FROM start:s - POST-ACCUM - @@printed_vertices += s, - {SEEDVERTEXATTRS}; - END; - # Get neighbors of seeeds - FOREACH i IN RANGE[1, num_hops] DO - {SELECTNEIGHBORS} - attr = SELECT s - FROM seeds:s - POST-ACCUM - IF NOT @@printed_vertices.contains(s) THEN - @@printed_vertices += s, - {OTHERVERTEXATTRS} - END; - END; + # Export data + # If using kafka to export IF kafka_address != "" THEN - # Write to kafka - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "vertex_batch_" + stringify(batch_id), @@v_batch); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending vertex batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + SumAccum @@kafka_error; + + # Initialize Kafka producer + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal); + + FOREACH chunk IN RANGE[0, num_chunks-1] DO + res = SELECT s + FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + LIST msg = hgt_loader_sub_{QUERYSUFFIX}(s, delimiter, num_hops, e_types, v_types), + BOOL is_first=True, + FOREACH i in msg DO + IF is_first THEN + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_batch_" + stringify(getvid(s)), i), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending vertex batch for " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END, + is_first = False + ELSE + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "edge_batch_" + stringify(getvid(s)), i), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending edge batch for " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END + END + END + LIMIT 1; + END; + + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n"); + END; END; - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "edge_batch_" + stringify(batch_id), @@e_batch); + + INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending edge batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; + PRINT @@kafka_error as kafkaError; + # Else return as http response ELSE - # Add to response - IF input_vertices.size()==0 THEN - PRINT @@v_batch AS vertex_batch, @@e_batch AS edge_batch; - ELSE - MapAccum @@id_map; - source = @@printed_vertices; - res = - SELECT s - FROM source:s - POST-ACCUM @@id_map += (getvid(s) -> s); - PRINT @@v_batch AS vertex_batch, @@e_batch AS edge_batch, @@id_map AS pids; + FOREACH chunk IN RANGE[0, num_chunks-1] DO + MapAccum @@v_batch; + MapAccum @@e_batch; + + res = SELECT s + FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + LIST msg = hgt_loader_sub_{QUERYSUFFIX}(s, delimiter, num_hops, e_types, v_types), + BOOL is_first=True, + FOREACH i in msg DO + IF is_first THEN + @@v_batch += (getvid(s) -> i), + is_first = False + ELSE + @@e_batch += (getvid(s) -> i) + END + END + LIMIT 1; + + FOREACH (k,v) IN @@v_batch DO + PRINT v as vertex_batch, @@e_batch.get(k) as edge_batch; + END; END; - END; - END; - IF kafka_address != "" THEN - kafka_errcode = close_kafka_producer(producer, kafka_timeout); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; - PRINT @@kafka_error as kafkaError; + # Else get given vertices. + ELSE + MapAccum @@v_batch; + MapAccum @@e_batch; + MapAccum @@id_map; + + seeds = input_vertices; + res = SELECT s + FROM seeds:s + POST-ACCUM + LIST msg = hgt_loader_sub_{QUERYSUFFIX}(s, delimiter, num_hops, e_types, v_types), + BOOL is_first=True, + FOREACH i in msg DO + IF is_first THEN + @@v_batch += (getvid(s) -> i), + is_first = False + ELSE + @@e_batch += (getvid(s) -> i) + END + END, + @@id_map += (getvid(s) -> s) + LIMIT 1; + + FOREACH (k,v) IN @@v_batch DO + PRINT v as vertex_batch, @@e_batch.get(k) as edge_batch; + END; + + FOREACH hop IN RANGE[1, num_hops] DO + seeds = SELECT t + FROM seeds:s -(e_types:e)- v_types:t + POST-ACCUM + @@id_map += (getvid(t) -> t); + END; + PRINT @@id_map AS pids; END; } \ No newline at end of file diff --git a/pyTigerGraph/gds/gsql/dataloaders/hgt_loader_sub.gsql b/pyTigerGraph/gds/gsql/dataloaders/hgt_loader_sub.gsql new file mode 100644 index 00000000..88fa5558 --- /dev/null +++ b/pyTigerGraph/gds/gsql/dataloaders/hgt_loader_sub.gsql @@ -0,0 +1,32 @@ +CREATE QUERY hgt_loader_sub_{QUERYSUFFIX} (VERTEX v, STRING delimiter, INT num_hops, SET e_types, SET v_types) +RETURNS (ListAccum) +SYNTAX V1 +{ + SumAccum @@v_batch; + SumAccum @@e_batch; + SetAccum @@printed_vertices; + SetAccum @@printed_edges; + ListAccum @@ret; + + seeds = {v}; + res = SELECT s + FROM seeds:s + POST-ACCUM + @@printed_vertices += s, + {SEEDVERTEXATTRS}; + + FOREACH i IN RANGE[1, num_hops] DO + {SELECTNEIGHBORS} + + seeds = SELECT s + FROM seeds:s + POST-ACCUM + IF NOT @@printed_vertices.contains(s) THEN + @@printed_vertices += s, + {OTHERVERTEXATTRS} + END; + END; + @@ret += @@v_batch; + @@ret += @@e_batch; + RETURN @@ret; +} diff --git a/pyTigerGraph/gds/gsql/dataloaders/neighbor_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/neighbor_loader.gsql index 8be04db6..d1cd3add 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/neighbor_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/neighbor_loader.gsql @@ -1,7 +1,5 @@ CREATE QUERY neighbor_loader_{QUERYSUFFIX}( SET input_vertices, - INT batch_size, - INT num_batches=1, INT num_neighbors=10, INT num_hops=2, BOOL shuffle=FALSE, @@ -10,8 +8,9 @@ CREATE QUERY neighbor_loader_{QUERYSUFFIX}( SET e_types, SET seed_types, STRING delimiter, + INT num_chunks=2, STRING kafka_address="", - STRING kafka_topic, + STRING kafka_topic="", INT kafka_topic_partitions=1, STRING kafka_max_size="104857600", INT kafka_timeout=300000, @@ -26,8 +25,7 @@ CREATE QUERY neighbor_loader_{QUERYSUFFIX}( STRING ssl_endpoint_identification_algorithm="", STRING sasl_kerberos_service_name="", STRING sasl_kerberos_keytab="", - STRING sasl_kerberos_principal="", - INT num_heap_inserts = 10 + STRING sasl_kerberos_principal="" ) SYNTAX V1 { /* This query generates the neighborhood subgraphs of given seed vertices (i.e., `input_vertices`). @@ -55,148 +53,138 @@ CREATE QUERY neighbor_loader_{QUERYSUFFIX}( sasl_password : SASL password for Kafka. ssl_ca_location: Path to CA certificate for verifying the Kafka broker key. */ - TYPEDEF TUPLE ID_Tuple; - INT num_vertices; - INT kafka_errcode; SumAccum @tmp_id; - SumAccum @@kafka_error; - UINT producer; - INT batch_s; - OrAccum @prev_sampled; - # Initialize Kafka producer - IF kafka_address != "" THEN - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); - END; - - # Shuffle vertex ID if needed + # If getting all vertices of given types IF input_vertices.size()==0 THEN start = {seed_types}; - IF filter_by IS NOT NULL THEN - start = SELECT s FROM start:s WHERE s.getAttr(filter_by, "BOOL"); - END; + # Filter seeds if needed + seeds = SELECT s + FROM start:s + WHERE filter_by is NULL OR s.getAttr(filter_by, "BOOL"); + # Shuffle vertex ID if needed IF shuffle THEN - num_vertices = start.size(); + INT num_vertices = seeds.size(); res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = floor(rand()*num_vertices); + FROM seeds:s + POST-ACCUM s.@tmp_id = floor(rand()*num_vertices) + LIMIT 1; ELSE res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = getvid(s); - END; - END; - IF batch_size IS NULL THEN - batch_s = ceil(res.size()/num_batches); - ELSE - batch_s = batch_size; - END; - # Generate subgraphs - FOREACH batch_id IN RANGE[0, num_batches-1] DO - SumAccum @@v_batch; - SumAccum @@e_batch; - SetAccum @@printed_vertices; - SetAccum @@printed_edges; - SetAccum @@seeds; - # Get seeds - IF input_vertices.size()==0 THEN - start = {seed_types}; - HeapAccum (1, tmp_id ASC) @@batch_heap; - @@batch_heap.resize(batch_s); - IF filter_by IS NOT NULL THEN - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled AND s.getAttr(filter_by, "BOOL") - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM seeds:s - POST-ACCUM - s.@prev_sampled += TRUE, - {SEEDVERTEXATTRS}, - @@printed_vertices += s; - ELSE - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM start:s - POST-ACCUM - s.@prev_sampled += TRUE, - {SEEDVERTEXATTRS}, - @@printed_vertices += s; - END; - ELSE - start = input_vertices; - seeds = SELECT s - FROM start:s - POST-ACCUM - {SEEDVERTEXATTRS}, - @@printed_vertices += s; - END; - # Get neighbors of seeeds - FOREACH i IN RANGE[1, num_hops] DO - seeds = SELECT t - FROM seeds:s -(e_types:e)- v_types:t - SAMPLE num_neighbors EDGE WHEN s.outdegree() >= 1 - ACCUM - IF NOT @@printed_edges.contains(e) THEN - {EDGEATTRS}, - @@printed_edges += e - END; - attr = SELECT s - FROM seeds:s - POST-ACCUM - IF NOT @@printed_vertices.contains(s) THEN - {OTHERVERTEXATTRS}, - @@printed_vertices += s - END; + FROM seeds:s + POST-ACCUM s.@tmp_id = getvid(s) + LIMIT 1; END; + + # Export data + # If using kafka to export IF kafka_address != "" THEN - # Write to kafka - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "vertex_batch_" + stringify(batch_id), @@v_batch); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending vertex batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + SumAccum @@kafka_error; + + # Initialize Kafka producer + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal); + + FOREACH chunk IN RANGE[0, num_chunks-1] DO + res = SELECT s + FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + LIST msg = neighbor_loader_sub_{QUERYSUFFIX}(s, delimiter, num_hops, num_neighbors, e_types, v_types), + BOOL is_first=True, + FOREACH i in msg DO + IF is_first THEN + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_batch_" + stringify(getvid(s)), i), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending vertex batch for " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END, + is_first = False + ELSE + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "edge_batch_" + stringify(getvid(s)), i), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending edge batch for " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END + END + END + LIMIT 1; + END; + + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n"); + END; END; - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "edge_batch_" + stringify(batch_id), @@e_batch); + + INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending edge batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; + PRINT @@kafka_error as kafkaError; + # Else return as http response ELSE - # Add to response - IF input_vertices.size()==0 THEN - PRINT @@v_batch AS vertex_batch, @@e_batch AS edge_batch; - ELSE - MapAccum @@id_map; - source = @@printed_vertices; - res = - SELECT s - FROM source:s - POST-ACCUM @@id_map += (getvid(s) -> s); - PRINT @@v_batch AS vertex_batch, @@e_batch AS edge_batch, @@id_map AS pids; + FOREACH chunk IN RANGE[0, num_chunks-1] DO + MapAccum @@v_batch; + MapAccum @@e_batch; + + res = SELECT s + FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + LIST msg = neighbor_loader_sub_{QUERYSUFFIX}(s, delimiter, num_hops, num_neighbors, e_types, v_types), + BOOL is_first=True, + FOREACH i in msg DO + IF is_first THEN + @@v_batch += (getvid(s) -> i), + is_first = False + ELSE + @@e_batch += (getvid(s) -> i) + END + END + LIMIT 1; + + FOREACH (k,v) IN @@v_batch DO + PRINT v as vertex_batch, @@e_batch.get(k) as edge_batch; + END; END; - END; - END; - IF kafka_address != "" THEN - kafka_errcode = close_kafka_producer(producer, kafka_timeout); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; - PRINT @@kafka_error as kafkaError; + # Else get given vertices. + ELSE + MapAccum @@v_batch; + MapAccum @@e_batch; + MapAccum @@id_map; + + seeds = input_vertices; + res = SELECT s + FROM seeds:s + POST-ACCUM + LIST msg = neighbor_loader_sub_{QUERYSUFFIX}(s, delimiter, num_hops, num_neighbors, e_types, v_types), + BOOL is_first=True, + FOREACH i in msg DO + IF is_first THEN + @@v_batch += (getvid(s) -> i), + is_first = False + ELSE + @@e_batch += (getvid(s) -> i) + END + END, + @@id_map += (getvid(s) -> s) + LIMIT 1; + + FOREACH (k,v) IN @@v_batch DO + PRINT v as vertex_batch, @@e_batch.get(k) as edge_batch; + END; + + FOREACH hop IN RANGE[1, num_hops] DO + seeds = SELECT t + FROM seeds:s -(e_types:e)- v_types:t + POST-ACCUM + @@id_map += (getvid(t) -> t); + END; + PRINT @@id_map AS pids; END; -} \ No newline at end of file +} diff --git a/pyTigerGraph/gds/gsql/dataloaders/neighbor_loader_sub.gsql b/pyTigerGraph/gds/gsql/dataloaders/neighbor_loader_sub.gsql new file mode 100644 index 00000000..c944aca5 --- /dev/null +++ b/pyTigerGraph/gds/gsql/dataloaders/neighbor_loader_sub.gsql @@ -0,0 +1,38 @@ +CREATE QUERY neighbor_loader_sub_{QUERYSUFFIX} (VERTEX v, STRING delimiter, INT num_hops, INT num_neighbors, SET e_types, SET v_types) +RETURNS (ListAccum) +SYNTAX V1 +{ + SumAccum @@v_batch; + SumAccum @@e_batch; + SetAccum @@printed_vertices; + SetAccum @@printed_edges; + ListAccum @@ret; + + start = {v}; + res = SELECT s + FROM start:s + POST-ACCUM + @@printed_vertices += s, + {SEEDVERTEXATTRS}; + + FOREACH i IN RANGE[1, num_hops] DO + start = SELECT t + FROM start:s -(e_types:e)- v_types:t + SAMPLE num_neighbors EDGE WHEN s.outdegree() >= 1 + ACCUM + IF NOT @@printed_edges.contains(e) THEN + @@printed_edges += e, + {EDGEATTRS} + END; + start = SELECT s + FROM start:s + POST-ACCUM + IF NOT @@printed_vertices.contains(s) THEN + @@printed_vertices += s, + {OTHERVERTEXATTRS} + END; + END; + @@ret += @@v_batch; + @@ret += @@e_batch; + RETURN @@ret; +} diff --git a/pyTigerGraph/gds/gsql/dataloaders/nodepiece_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/nodepiece_loader.gsql index 36f0ee1f..1c99da63 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/nodepiece_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/nodepiece_loader.gsql @@ -12,12 +12,11 @@ CREATE QUERY nodepiece_loader_{QUERYSUFFIX}( INT max_distance, INT max_anchors, INT max_rel_context, - INT batch_size, - INT num_batches=1, BOOL shuffle=FALSE, STRING delimiter, + INT num_chunks=2, STRING kafka_address="", - STRING kafka_topic, + STRING kafka_topic="", INT kafka_topic_partitions=1, STRING kafka_max_size="104857600", INT kafka_timeout=300000, @@ -33,50 +32,31 @@ CREATE QUERY nodepiece_loader_{QUERYSUFFIX}( STRING sasl_kerberos_service_name="", STRING sasl_kerberos_keytab="", STRING sasl_kerberos_principal="", - INT num_heap_inserts=10, INT num_edge_batches=10 ) SYNTAX v1{ TYPEDEF TUPLE Distance_Tuple; - TYPEDEF TUPLE ID_Tuple; INT num_vertices; - INT kafka_errcode; - INT batch_s; SumAccum @tmp_id; - SumAccum @@kafka_error; - SetAccum @next_pass, @to_pass, @received; HeapAccum (max_anchors, distance ASC) @token_heap; SumAccum @rel_context_set; SumAccum @ancs; - OrAccum @prev_sampled; OrAccum @heapFull; - MapAccum> @@token_count; MapAccum> @conv_map; BOOL cache_empty = FALSE; INT distance; - UINT producer; - - # Initialize Kafka producer - IF kafka_address != "" THEN - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); - END; start = {v_types}; # Perform fetch operation if desired IF clear_cache THEN - res = SELECT s FROM start:s POST-ACCUM s.{ANCHOR_CACHE_ATTRIBUTE} = s.@conv_map; + res = SELECT s FROM start:s POST-ACCUM s.{ANCHOR_CACHE_ATTRIBUTE} = s.@conv_map; END; IF input_vertices.size() != 0 AND NOT compute_all THEN - seeds = {input_vertices}; - res = SELECT s FROM seeds:s -(e_types)- v_types:t + seeds = {input_vertices}; + res = SELECT s FROM seeds:s -(e_types)- v_types:t ACCUM IF s.{ANCHOR_CACHE_ATTRIBUTE}.size() != 0 THEN - FOREACH (key, val) IN s.{ANCHOR_CACHE_ATTRIBUTE} DO # s.{ANCHOR_CACHE_ATTRIBUTE} should be changed to getAttr() when supported + FOREACH (key, val) IN s.{ANCHOR_CACHE_ATTRIBUTE} DO # s.ANCHOR_CACHE_ATTRIBUTE should be changed to getAttr() when supported s.@token_heap += Distance_Tuple(key, val) END ELSE @@ -91,6 +71,7 @@ CREATE QUERY nodepiece_loader_{QUERYSUFFIX}( ELSE cache_empty = TRUE; END; + IF cache_empty THEN # computing all, shuffle vertices if needed ancs = SELECT s FROM start:s @@ -127,107 +108,105 @@ CREATE QUERY nodepiece_loader_{QUERYSUFFIX}( END; END; END; - IF batch_size IS NULL THEN - batch_s = ceil(res.size()/num_batches); - ELSE - batch_s = batch_size; - END; - FOREACH batch_id IN RANGE[0, num_batches-1] DO - SumAccum @@v_batch; - SetAccum @@printed_vertices; - SetAccum @@seeds; - # Get batch seeds - IF input_vertices.size()==0 THEN + + # Get batch seeds + IF input_vertices.size()==0 THEN start = {seed_types}; - HeapAccum (1, tmp_id ASC) @@batch_heap; - @@batch_heap.resize(batch_s); - IF filter_by IS NOT NULL THEN - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled AND s.getAttr(filter_by, "BOOL") - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM seeds:s - POST-ACCUM - s.@prev_sampled += TRUE, - @@printed_vertices += s; - ELSE - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM seeds:s - POST-ACCUM - s.@prev_sampled += TRUE, - @@printed_vertices += s; - END; - ELSE + # Filter seeds if needed + seeds = SELECT s + FROM start:s + WHERE filter_by is NULL OR s.getAttr(filter_by, "BOOL"); + ELSE start = input_vertices; seeds = SELECT s - FROM start:s - ACCUM @@printed_vertices += s; - END; - # Get relational context - - IF max_rel_context > 0 THEN + FROM start:s; + END; + + # Get relational context + IF max_rel_context > 0 THEN seeds = SELECT s FROM seeds:s -(e_types:e)- v_types:t - SAMPLE max_rel_context EDGE WHEN s.outdegree() >= max_rel_context - ACCUM s.@rel_context_set += e.type +" "; - END; + SAMPLE max_rel_context EDGE WHEN s.outdegree() >= max_rel_context + ACCUM s.@rel_context_set += e.type +" "; + END; - res = SELECT s FROM seeds:s - POST-ACCUM - FOREACH tup IN s.@token_heap DO + res = SELECT s + FROM seeds:s + POST-ACCUM + FOREACH tup IN s.@token_heap DO s.@ancs += stringify(tup.v_id)+":"+stringify(tup.distance)+" ", IF use_cache AND cache_empty THEN - s.@conv_map += (tup.v_id -> tup.distance) + s.@conv_map += (tup.v_id -> tup.distance) END - END, - IF (use_cache AND cache_empty) OR precompute THEN + END, + IF (use_cache AND cache_empty) OR precompute THEN s.{ANCHOR_CACHE_ATTRIBUTE} = s.@conv_map - END, - {VERTEXATTRS}; - IF NOT precompute THEN # No Output if precomputing - IF kafka_address != "" THEN - # Write to kafka - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "vertex_batch_" + stringify(batch_id), @@v_batch); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending vertex batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); - END; - ELSE # HTTP mode - # Add to response - IF input_vertices.size()==0 THEN - PRINT @@v_batch AS vertex_batch; - ELSE + END; + + IF NOT precompute THEN # No Output if precomputing + # If getting all vertices of given types + IF input_vertices.size()==0 THEN + IF kafka_address != "" THEN + SumAccum @@kafka_error; + # Initialize Kafka producer + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal); + + FOREACH chunk IN RANGE[0, num_chunks-1] DO + res = SELECT s FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + {VERTEXATTRSKAFKA} + LIMIT 1; + END; + + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n"); + END; + END; + + INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); + END; + + PRINT @@kafka_error as kafkaError; + ELSE # HTTP mode + FOREACH chunk IN RANGE[0, num_chunks-1] DO + ListAccum @@v_batch; + res = SELECT s + FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + {VERTEXATTRSHTTP} + LIMIT 1; + + FOREACH i IN @@v_batch DO + PRINT i as data_batch; + END; + END; + END; + ELSE # Else get given vertices + ListAccum @@v_batch; MapAccum @@id_map; MapAccum @@type_map; - source = @@printed_vertices; - res = - SELECT s - FROM source:s - POST-ACCUM @@id_map += (getvid(s) -> s), @@type_map += (getvid(s) -> s.type); - PRINT @@v_batch AS vertex_batch, @@id_map AS pids, @@type_map AS types; - END; - END; - END; - END; - - IF kafka_address != "" THEN - kafka_errcode = close_kafka_producer(producer, kafka_timeout); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); + + res = SELECT s + FROM seeds:s + POST-ACCUM + {VERTEXATTRSHTTP}, + @@id_map += (getvid(s) -> s), + @@type_map += (getvid(s) -> s.type); + + FOREACH i IN @@v_batch DO + PRINT i as data_batch; + END; + PRINT @@id_map AS pids, @@type_map AS types; END; - PRINT @@kafka_error as kafkaError; END; } \ No newline at end of file diff --git a/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql index 76ffaddd..25b8df47 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql @@ -1,13 +1,12 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( SET input_vertices, - INT batch_size, - INT num_batches=1, - BOOL shuffle=FALSE, STRING filter_by, SET v_types, STRING delimiter, + BOOL shuffle=FALSE, + INT num_chunks=2, STRING kafka_address="", - STRING kafka_topic, + STRING kafka_topic="", INT kafka_topic_partitions=1, STRING kafka_max_size="104857600", INT kafka_timeout=300000, @@ -22,8 +21,7 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( STRING ssl_endpoint_identification_algorithm="", STRING sasl_kerberos_service_name="", STRING sasl_kerberos_keytab="", - STRING sasl_kerberos_principal="", - INT num_heap_inserts = 10 + STRING sasl_kerberos_principal="" ) SYNTAX V2 { /* This query generates batches of vertices. If `input_vertices` is given, it will generate @@ -45,108 +43,90 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( sasl_password : SASL password for Kafka. ssl_ca_location: Path to CA certificate for verifying the Kafka broker key. */ - TYPEDEF TUPLE ID_Tuple; - INT num_vertices; - INT kafka_errcode; SumAccum @tmp_id; - SumAccum @@kafka_error; - UINT producer; - INT batch_s; - OrAccum @prev_sampled; - - # Initialize Kafka producer - IF kafka_address != "" THEN - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); - END; - # Shuffle vertex ID if needed - start = {v_types}; - IF filter_by IS NOT NULL THEN - start = SELECT s FROM start:s WHERE s.getAttr(filter_by, "BOOL"); - END; - IF shuffle THEN - num_vertices = start.size(); - res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = floor(rand()*num_vertices); - ELSE - res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = getvid(s); - END; - IF batch_size IS NULL THEN - batch_s = ceil(res.size()/num_batches); - ELSE - batch_s = batch_size; - END; - # Generate batches - FOREACH batch_id IN RANGE[0, num_batches-1] DO - SumAccum @@v_batch; - SetAccum @@seeds; - IF input_vertices.size()==0 THEN - start = {v_types}; - HeapAccum (1, tmp_id ASC) @@batch_heap; - @@batch_heap.resize(batch_s); - IF filter_by IS NOT NULL THEN - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled AND s.getAttr(filter_by, "BOOL") - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM seeds:s - POST-ACCUM - s.@prev_sampled += TRUE, - {VERTEXATTRS}; - ELSE - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM start:s - POST-ACCUM - s.@prev_sampled += TRUE, - {VERTEXATTRS}; - END; + # If getting all vertices of given types + IF input_vertices.size()==0 THEN + start = {v_types}; + # Filter seeds if needed + seeds = SELECT s + FROM start:s + WHERE filter_by is NULL OR s.getAttr(filter_by, "BOOL"); + # Shuffle vertex ID if needed + IF shuffle THEN + INT num_vertices = seeds.size(); + res = SELECT s + FROM seeds:s + POST-ACCUM s.@tmp_id = floor(rand()*num_vertices) + LIMIT 1; ELSE - start = input_vertices; - seeds = SELECT s - FROM start:s - POST-ACCUM - {VERTEXATTRS}; + res = SELECT s + FROM seeds:s + POST-ACCUM s.@tmp_id = getvid(s) + LIMIT 1; END; + # Export data + # If using kafka to export IF kafka_address != "" THEN - # Write to kafka - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "vertex_batch_" + stringify(batch_id), @@v_batch); + SumAccum @@kafka_error; + + # Initialize Kafka producer + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal); + + FOREACH chunk IN RANGE[0, num_chunks-1] DO + res = SELECT s + FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + {VERTEXATTRSKAFKA} + LIMIT 1; + END; + + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n"); + END; + END; + + INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending vertex batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; + PRINT @@kafka_error as kafkaError; + # Else return as http response ELSE - # Add to response - PRINT @@v_batch AS vertex_batch; + FOREACH chunk IN RANGE[0, num_chunks-1] DO + ListAccum @@v_batch; + res = SELECT s + FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + {VERTEXATTRSHTTP} + LIMIT 1; + + FOREACH i IN @@v_batch DO + PRINT i as data_batch; + END; + END; END; - END; - IF kafka_address != "" THEN - kafka_errcode = close_kafka_producer(producer, kafka_timeout); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); + # Else get given vertices. + ELSE + ListAccum @@v_batch; + start = input_vertices; + res = SELECT s + FROM start:s + POST-ACCUM + {VERTEXATTRSHTTP} + LIMIT 1; + FOREACH i IN @@v_batch DO + PRINT i as data_batch; END; - PRINT @@kafka_error as kafkaError; END; } \ No newline at end of file diff --git a/pyTigerGraph/gds/trainer.py b/pyTigerGraph/gds/trainer.py index f44ca992..03e650c0 100644 --- a/pyTigerGraph/gds/trainer.py +++ b/pyTigerGraph/gds/trainer.py @@ -16,6 +16,7 @@ import time import os import warnings +import math class BaseCallback(): """Base class for training callbacks. @@ -145,7 +146,7 @@ def on_train_step_end(self, trainer): trainer.update_train_step_metrics(metric.get_metrics()) metric.reset_metrics() trainer.update_train_step_metrics({"global_step": trainer.cur_step}) - trainer.update_train_step_metrics({"epoch": int(trainer.cur_step/trainer.train_loader.num_batches)}) + trainer.update_train_step_metrics({"epoch": trainer.cur_epoch}) def on_eval_start(self, trainer): """NO DOC""" @@ -209,7 +210,7 @@ def on_epoch_start(self, trainer): self.epoch_bar = self.tqdm(desc="Epochs", total=trainer.num_epochs) else: self.epoch_bar = self.tqdm(desc="Training Steps", total=trainer.max_num_steps) - if not(self.batch_bar): + if self.batch_bar is None: self.batch_bar = self.tqdm(desc="Training Batches", total=trainer.train_loader.num_batches) def on_train_step_end(self, trainer): @@ -217,20 +218,20 @@ def on_train_step_end(self, trainer): logger = logging.getLogger(__name__) logger.info("train_step:"+str(trainer.get_train_step_metrics())) if self.tqdm: - if self.batch_bar: + if self.batch_bar is not None: self.batch_bar.update(1) def on_eval_start(self, trainer): """NO DOC""" trainer.reset_eval_metrics() if self.tqdm: - if not(self.valid_bar): + if self.valid_bar is None: self.valid_bar = self.tqdm(desc="Eval Batches", total=trainer.eval_loader.num_batches) def on_eval_step_end(self, trainer): """NO DOC""" if self.tqdm: - if self.valid_bar: + if self.valid_bar is not None: self.valid_bar.update(1) def on_eval_end(self, trainer): @@ -239,7 +240,7 @@ def on_eval_end(self, trainer): logger.info("evaluation:"+str(trainer.get_eval_metrics())) trainer.model.train() if self.tqdm: - if self.valid_bar: + if self.valid_bar is not None: self.valid_bar.close() self.valid_bar = None @@ -248,7 +249,7 @@ def on_epoch_end(self, trainer): if self.tqdm: if self.epoch_bar: self.epoch_bar.update(1) - if self.batch_bar: + if self.batch_bar is not None: self.batch_bar.close() self.batch_bar = None trainer.eval() @@ -407,12 +408,17 @@ def train(self, num_epochs=None, max_num_steps=None): Defaults to the length of the `training_dataloader` """ if num_epochs: - self.max_num_steps = self.train_loader.num_batches * num_epochs - else: + self.max_num_steps = math.inf + self.num_epochs = num_epochs + elif max_num_steps: self.max_num_steps = max_num_steps - self.num_epochs = num_epochs + self.num_epochs = math.inf + else: + self.max_num_steps = math.inf + self.num_epochs = 1 self.cur_step = 0 - while self.cur_step < self.max_num_steps: + self.cur_epoch = 0 + while self.cur_step < self.max_num_steps and self.cur_epoch < self.num_epochs: for callback in self.callbacks: callback.on_epoch_start(trainer=self) for batch in self.train_loader: @@ -432,7 +438,7 @@ def train(self, num_epochs=None, max_num_steps=None): self.cur_step += 1 for callback in self.callbacks: callback.on_train_step_end(trainer=self) - + self.cur_epoch += 1 for callback in self.callbacks: callback.on_epoch_end(trainer=self) diff --git a/pyTigerGraph/gds/utilities.py b/pyTigerGraph/gds/utilities.py index 5089f7ec..66d49c28 100644 --- a/pyTigerGraph/gds/utilities.py +++ b/pyTigerGraph/gds/utilities.py @@ -7,7 +7,7 @@ import re import string from os.path import join as pjoin -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Union, List from urllib.parse import urlparse if TYPE_CHECKING: @@ -166,6 +166,75 @@ def install_query_file( return query_name +def install_query_files( + conn: "TigerGraphConnection", + file_paths: List[str], + replace: dict = None, + distributed: List[bool] = [], + force: bool = False, +) -> str: + queries_to_install = [] + last_query = "" + for idx, file_path in enumerate(file_paths): + # Read the first line of the file to get query name. The first line should be + # something like CREATE QUERY query_name (... + with open(file_path) as infile: + firstline = infile.readline() + try: + query_name = re.search(r"QUERY (.+?)\(", firstline).group(1).strip() + except: + raise ValueError( + "Cannot parse the query file. It should start with CREATE QUERY ... " + ) + # If a suffix is to be added to query name + if replace and ("{QUERYSUFFIX}" in replace): + query_name = query_name.replace("{QUERYSUFFIX}", replace["{QUERYSUFFIX}"]) + last_query = query_name + # If query is already installed, skip unless force install. + is_installed, is_enabled = is_query_installed(conn, query_name, return_status=True) + if is_installed: + if force or (not is_enabled): + query = "USE GRAPH {}\nDROP QUERY {}\n".format(conn.graphname, query_name) + resp = conn.gsql(query) + if "Successfully dropped queries" not in resp: + raise ConnectionError(resp) + else: + continue + # Otherwise, install the query from file + with open(file_path) as infile: + query = infile.read() + # Replace placeholders with actual content if given + if replace: + for placeholder in replace: + query = query.replace(placeholder, replace[placeholder]) + if distributed and distributed[idx]: + query = query.replace("CREATE QUERY", "CREATE DISTRIBUTED QUERY") + logger.debug(query) + query = ( + "USE GRAPH {}\n".format(conn.graphname) + + query + + "\n" + ) + resp = conn.gsql(query) + if "Successfully created queries" not in resp: + raise ConnectionError(resp) + queries_to_install.append(query_name) + if queries_to_install: + query = ( + "USE GRAPH {}\n".format(conn.graphname) + + "Install Query {}\n".format(",".join(queries_to_install)) + ) + print( + "Installing and optimizing queries. It might take a minute or two." + ) + resp = conn.gsql(query) + if "Query installation finished" not in resp: + raise ConnectionError(resp) + else: + print("Query installation finished.") + return last_query + + def add_attribute(conn: "TigerGraphConnection", schema_type:str, attr_type:str = None, attr_name:Union[str, dict] = None, schema_name:list = None, global_change:bool = False): ''' If the current attribute is not already added to the schema, it will create the schema job to do that. diff --git a/tests/test_gds_BaseLoader.py b/tests/test_gds_BaseLoader.py index d2987c8d..68ad7bed 100644 --- a/tests/test_gds_BaseLoader.py +++ b/tests/test_gds_BaseLoader.py @@ -1,7 +1,7 @@ import io import unittest -from queue import Queue -from threading import Event +from queue import Queue, Empty +from threading import Event, Thread import pandas as pd import torch @@ -143,150 +143,168 @@ def test_read_vertex(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() - raw = "99|1 0 0 1 |1|0|1\n8|1 0 0 1 |1|1|1\n" - read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "vertex", - "dataframe", - ["x"], - ["y"], - ["train_mask", "is_seed"], - {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, - delimiter="|" - ) + raw = ["99|1 0 0 1 |1|0|1\n", + "8|1 0 0 1 |1|1|1\n"] + for i in raw: + read_task_q.put(i) + thread = Thread( + target=self.loader._read_vertex_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 2, + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "is_seed"], + v_attr_types = {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, + delimiter = "|" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() truth = pd.read_csv( - io.StringIO(raw), + io.StringIO("".join(raw)), header=None, names=["vid", "x", "y", "train_mask", "is_seed"], sep=self.loader.delimiter ) assert_frame_equal(data, truth) - data = data_q.get() - self.assertIsNone(data) def test_read_vertex_callback(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() - raw = "99|1 0 0 1 |1|0|1\n8|1 0 0 1 |1|1|1\n" - read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "vertex", - "dataframe", - ["x"], - ["y"], - ["train_mask", "is_seed"], - {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, - callback_fn=lambda x: 1, - delimiter="|" - ) + raw = ["99|1 0 0 1 |1|0|1\n", + "8|1 0 0 1 |1|1|1\n"] + for i in raw: + read_task_q.put(i) + thread = Thread( + target=self.loader._read_vertex_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 2, + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "is_seed"], + v_attr_types = {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, + delimiter = "|", + callback_fn = lambda x: 1 + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertEqual(1, data) def test_read_edge(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() - raw = "1|2|0.1|2021|1|0\n2|1|1.5|2020|0|1\n" - read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "edge", - "dataframe", - [], - [], - [], - {}, - ["x", "time"], - ["y"], - ["is_train"], - {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, - delimiter="|" - ) + raw = ["1|2|0.1|2021|1|0\n", + "2|1|1.5|2020|0|1\n"] + for i in raw: + read_task_q.put(i) + thread = Thread( + target=self.loader._read_edge_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 2, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train"], + e_attr_types = {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, + delimiter = "|" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() truth = pd.read_csv( - io.StringIO(raw), + io.StringIO("".join(raw)), header=None, names=["source", "target", "x", "time", "y", "is_train"], sep=self.loader.delimiter, ) assert_frame_equal(data, truth) - data = data_q.get() - self.assertIsNone(data) def test_read_edge_callback(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() - raw = "1|2|0.1|2021|1|0\n2|1|1.5|2020|0|1\n" - read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "edge", - "dataframe", - [], - [], - [], - {}, - ["x", "time"], - ["y"], - ["is_train"], - {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, - callback_fn=lambda x: 1, - delimiter="|" - ) + raw = ["1|2|0.1|2021|1|0\n", + "2|1|1.5|2020|0|1\n"] + for i in raw: + read_task_q.put(i) + thread = Thread( + target=self.loader._read_edge_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 2, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train"], + e_attr_types = {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, + delimiter = "|", + callback_fn=lambda x: 1 + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertEqual(data, 1) - def test_read_graph_out_df(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() raw = ( - "99|1 0 0 1 |1|0|1\n8|1 0 0 1 |1|1|1\n", - "1|2|0.1|2021|1|0\n2|1|1.5|2020|0|1\n", + "99|1 0 0 1 |1|0\n 8|1 0 0 1 |1|1\n ", + "1|2|0.1|2021|1|0\n 2|1|1.5|2020|0|1\n ", + "99" ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "graph", - "dataframe", - ["x"], - ["y"], - ["train_mask", "is_seed"], - {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, - ["x", "time"], - ["y"], - ["is_train"], - {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, - delimiter="|" - ) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "is_seed"], + v_attr_types = {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train"], + e_attr_types = {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, + delimiter = "|", + seed_type = "vertex" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() vertices = pd.read_csv( io.StringIO(raw[0]), header=None, - names=["vid", "x", "y", "train_mask", "is_seed"], + names=["vid", "x", "y", "train_mask"], sep=self.loader.delimiter ) + vertices["is_seed"] = [True, False] edges = pd.read_csv( io.StringIO(raw[1]), header=None, @@ -295,75 +313,84 @@ def test_read_graph_out_df(self): ) assert_frame_equal(data[0], vertices) assert_frame_equal(data[1], edges) - data = data_q.get() - self.assertIsNone(data) - def test_read_graph_out_df_callback(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() raw = ( - "99|1 0 0 1 |1|0|1\n8|1 0 0 1 |1|1|1\n", - "1|2|0.1|2021|1|0\n2|1|1.5|2020|0|1\n", + "99|1 0 0 1 |1|0|1\n 8|1 0 0 1 |1|1|1\n ", + "1|2|0.1|2021|1|0\n 2|1|1.5|2020|0|1\n ", + "" ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "graph", - "dataframe", - ["x"], - ["y"], - ["train_mask", "is_seed"], - {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, - ["x", "time"], - ["y"], - ["is_train"], - {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, - callback_fn=lambda x: (1, 2), - delimiter="|" - ) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "is_seed"], + v_attr_types = {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train"], + e_attr_types = {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, + delimiter = "|", + callback_fn = lambda x: (1, 2), + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertEqual(data[0], 1) self.assertEqual(data[1], 2) - def test_read_graph_out_pyg(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() raw = ( - "99|1 0 0 1 |1|0|Alex|1\n8|1 0 0 1 |1|1|Bill|0\n", - "99|8|0.1|2021|1|0|a b \n8|99|1.5|2020|0|1|c d \n", + "99|1 0 0 1 |1|0|Alex\n 8|1 0 0 1 |1|1|Bill\n ", + "99|8|0.1|2021|1|0|a b \n 8|99|1.5|2020|0|1|c d \n ", + "99" ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "graph", - "pyg", - ["x"], - ["y"], - ["train_mask", "name", "is_seed"], - { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - ["x", "time"], - ["y"], - ["is_train", "category"], - {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, - delimiter="|" - ) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "pyg", + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "name", "is_seed"], + v_attr_types = + { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train", "category"], + e_attr_types = {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, + delimiter = "|", + seed_type = "vertex" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, pygData) assert_close_torch(data["edge_index"], torch.tensor([[0, 1], [1, 0]])) assert_close_torch( @@ -378,42 +405,48 @@ def test_read_graph_out_pyg(self): assert_close_torch(data["is_seed"], torch.tensor([True, False])) self.assertListEqual(data["name"], ["Alex", "Bill"]) self.assertListEqual(data["category"], [['a', 'b'], ['c', 'd']]) - data = data_q.get() - self.assertIsNone(data) def test_read_graph_out_dgl(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() raw = ( - "99|1 0 0 1 |1|0|Alex|1\n8|1 0 0 1 |1|1|Bill|0\n", - "99|8|0.1|2021|1|0|a b \n8|99|1.5|2020|0|1|c d \n", + "99|1 0 0 1 |1|0|Alex\n 8|1 0 0 1 |1|1|Bill\n ", + "99|8|0.1|2021|1|0|a b \n 8|99|1.5|2020|0|1|c d \n ", + "99" ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "graph", - "dgl", - ["x"], - ["y"], - ["train_mask", "name", "is_seed"], - { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - ["x", "time"], - ["y"], - ["is_train", "category"], - {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, - delimiter="|" - ) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "dgl", + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "name", "is_seed"], + v_attr_types = + { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train", "category"], + e_attr_types = {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, + delimiter = "|", + seed_type = "vertex" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, DGLGraph) assert_close_torch(data.edges(), (torch.tensor([0, 1]), torch.tensor([1, 0]))) assert_close_torch( @@ -428,8 +461,6 @@ def test_read_graph_out_dgl(self): assert_close_torch(data.ndata["is_seed"], torch.tensor([True, False])) self.assertListEqual(data.extra_data["name"], ["Alex", "Bill"]) self.assertListEqual(data.extra_data["category"], [['a', 'b'], ['c', 'd']]) - data = data_q.get() - self.assertIsNone(data) def test_read_graph_parse_error(self): read_task_q = Queue() @@ -440,101 +471,114 @@ def test_read_graph_parse_error(self): "99|8|0.1|2021|1|0|a b \n8|99|1.5|2020|0|1|c d \n", ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "graph", - "dgl", - ["x"], - ["y"], - ["train_mask", "name", "is_seed"], - { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - ["x", "time"], - ["y"], - ["is_train", "category"], - {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, - delimiter="|" - ) - data = data_q.get() - self.assertIsNone(data) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "dgl", + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "name", "is_seed"], + v_attr_types = + { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train", "category"], + e_attr_types = {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, + delimiter = "|" + ) + ) + thread.start() + with self.assertRaises(Empty): + data = data_q.get(timeout=1) + exit_event.set() + thread.join() def test_read_graph_no_attr(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() - raw = ("99|1\n8|0\n", "99|8\n8|99\n") + raw = ("99\n 8\n ", "99|8\n 8|99\n ", "99") read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "graph", - "pyg", - [], - [], - ["is_seed"], - { - "x": "INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - [], - [], - [], - {}, - delimiter="|" - ) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "pyg", + v_extra_feats = ["is_seed"], + v_attr_types = + { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + delimiter = "|", + seed_type = "vertex" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, pygData) assert_close_torch(data["edge_index"], torch.tensor([[0, 1], [1, 0]])) assert_close_torch(data["is_seed"], torch.tensor([True, False])) - data = data_q.get() - self.assertIsNone(data) def test_read_graph_no_edge(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() raw = ( - "99|1 0 0 1 |1|0|Alex|1\n8|1 0 0 1 |1|1|Bill|0\n", + "99|1 0 0 1 |1|0|Alex\n 8|1 0 0 1 |1|1|Bill\n ", "", + "99" ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "graph", - "pyg", - ["x"], - ["y"], - ["train_mask", "name", "is_seed"], - { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - ["x", "time"], - ["y"], - ["is_train"], - {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL"}, - delimiter="|" - ) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "pyg", + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "name", "is_seed"], + v_attr_types = + { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train"], + e_attr_types = {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL"}, + delimiter = "|", + seed_type = "vertex" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, pygData) self.assertListEqual(list(data["edge_index"].shape), [2,0]) self.assertListEqual(list(data["edge_feat"].shape), [0,2]) @@ -545,63 +589,65 @@ def test_read_graph_no_edge(self): assert_close_torch(data["train_mask"], torch.tensor([False, True])) assert_close_torch(data["is_seed"], torch.tensor([True, False])) self.assertListEqual(data["name"], ["Alex", "Bill"]) - data = data_q.get() - self.assertIsNone(data) def test_read_hetero_graph_out_pyg(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() raw = ( - "People|99|1 0 0 1 |1|0|Alex|1\nPeople|8|1 0 0 1 |1|1|Bill|0\nCompany|2|0.3|0\n", - "Colleague|99|8|0.1|2021|1|0\nColleague|8|99|1.5|2020|0|1\nWork|99|2\nWork|2|8\n", + "People|99|1 0 0 1 |1|0|Alex\n People|8|1 0 0 1 |1|1|Bill\n Company|2|0.3\n ", + "Colleague|99|8|0.1|2021|1|0\n Colleague|8|99|1.5|2020|0|1\n Work|99|2\n Work|2|8\n ", + "99" ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "graph", - "pyg", - {"People": ["x"], "Company": ["x"]}, - {"People": ["y"]}, - {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, - { - "People": { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "pyg", + v_in_feats = {"People": ["x"], "Company": ["x"]}, + v_out_labels = {"People": ["y"]}, + v_extra_feats = {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, + v_attr_types = + { + "People": { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + "Company": {"x": "FLOAT", "is_seed": "BOOL"}, + }, + e_in_feats = {"Colleague": ["x", "time"]}, + e_out_labels = {"Colleague": ["y"]}, + e_extra_feats = {"Colleague": ["is_train"]}, + e_attr_types = { + "Colleague": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "People", + "IsDirected": False, + "x": "DOUBLE", + "time": "INT", + "y": "INT", + "is_train": "BOOL"}, + "Work": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "Company", + "IsDirected": False} }, - "Company": {"x": "FLOAT", "is_seed": "BOOL"}, - }, - {"Colleague": ["x", "time"]}, - {"Colleague": ["y"]}, - {"Colleague": ["is_train"]}, - { - "Colleague": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "People", - "IsDirected": False, - "x": "DOUBLE", - "time": "INT", - "y": "INT", - "is_train": "BOOL", - }, - "Work": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "Company", - "IsDirected": False, - } - }, - False, - "|", - True, - True, + delimiter = "|", + is_hetero = True, + seed_type = "vertex" + ) ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, pygHeteroData) assert_close_torch( data["Colleague"]["edge_index"], torch.tensor([[0, 1], [1, 0]]) @@ -626,63 +672,65 @@ def test_read_hetero_graph_out_pyg(self): assert_close_torch( data["Work"]["edge_index"], torch.tensor([[0, 1], [0, 0]]) ) - data = data_q.get() - self.assertIsNone(data) def test_read_hetero_graph_no_attr(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() raw = ( - "People|99|1\nPeople|8|0\nCompany|2|0\n", - "Colleague|99|8\nColleague|8|99\nWork|99|2\nWork|2|8\n", + "People|99\n People|8\n Company|2\n ", + "Colleague|99|8\n Colleague|8|99\n Work|99|2\n Work|2|8\n ", + "99" ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "graph", - "pyg", - {"People": [], "Company": []}, - {"People": [], "Company": []}, - {"People": ["is_seed"], "Company": ["is_seed"]}, - { - "People": { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - "Company": {"x": "FLOAT", "is_seed": "BOOL"}, - }, - {"Colleague": [], "Work": []}, - {"Colleague": [], "Work": []}, - {"Colleague": [], "Work": []}, - { - "Colleague": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "People", - "IsDirected": False, - "x": "DOUBLE", - "time": "INT", - "y": "INT", - "is_train": "BOOL", + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "pyg", + v_in_feats = {"People": [], "Company": []}, + v_out_labels = {"People": [], "Company": []}, + v_extra_feats = {"People": ["is_seed"], "Company": ["is_seed"]}, + v_attr_types = + { + "People": { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + "Company": {"x": "FLOAT", "is_seed": "BOOL"}, + }, + e_in_feats = {"Colleague": [], "Work": []}, + e_out_labels = {"Colleague": [], "Work": []}, + e_extra_feats = {"Colleague": [], "Work": []}, + e_attr_types = { + "Colleague": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "People", + "IsDirected": False, + "x": "DOUBLE", + "time": "INT", + "y": "INT", + "is_train": "BOOL"}, + "Work": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "Company", + "IsDirected": False} }, - "Work": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "Company", - "IsDirected": False, - } - }, - False, - "|", - True, - True, + delimiter = "|", + is_hetero = True, + seed_type = "vertex" + ) ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, pygHeteroData) assert_close_torch( data["Colleague"]["edge_index"], torch.tensor([[0, 1], [1, 0]]) @@ -692,63 +740,65 @@ def test_read_hetero_graph_no_attr(self): ) assert_close_torch(data["People"]["is_seed"], torch.tensor([True, False])) assert_close_torch(data["Company"]["is_seed"], torch.tensor([False])) - data = data_q.get() - self.assertIsNone(data) def test_read_hetero_graph_no_edge(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() raw = ( - "People|99|1 0 0 1 |1|0|Alex|1\nPeople|8|1 0 0 1 |1|1|Bill|0\nCompany|2|0.3|0\n", + "People|99|1 0 0 1 |1|0|Alex\n People|8|1 0 0 1 |1|1|Bill\n Company|2|0.3\n ", "", + "99" ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "graph", - "pyg", - {"People": ["x"], "Company": ["x"]}, - {"People": ["y"]}, - {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, - { - "People": { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "pyg", + v_in_feats = {"People": ["x"], "Company": ["x"]}, + v_out_labels = {"People": ["y"]}, + v_extra_feats = {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, + v_attr_types = + { + "People": { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + "Company": {"x": "FLOAT", "is_seed": "BOOL"}, + }, + e_in_feats = {"Colleague": ["x", "time"]}, + e_out_labels = {"Colleague": ["y"]}, + e_extra_feats = {"Colleague": ["is_train"]}, + e_attr_types = { + "Colleague": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "People", + "IsDirected": False, + "x": "DOUBLE", + "time": "INT", + "y": "INT", + "is_train": "BOOL"}, + "Work": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "Company", + "IsDirected": False} }, - "Company": {"x": "FLOAT", "is_seed": "BOOL"}, - }, - {"Colleague": ["x", "time"]}, - {"Colleague": ["y"]}, - {"Colleague": ["is_train"]}, - { - "Colleague": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "People", - "IsDirected": False, - "x": "DOUBLE", - "time": "INT", - "y": "INT", - "is_train": "BOOL", - }, - "Work": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "Company", - "IsDirected": False, - } - }, - False, - "|", - True, - True, + delimiter = "|", + is_hetero = True, + seed_type = "vertex" + ) ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, pygHeteroData) self.assertNotIn("Colleague", data) assert_close_torch( @@ -763,64 +813,66 @@ def test_read_hetero_graph_no_edge(self): ) assert_close_torch(data["Company"]["is_seed"], torch.tensor([False])) self.assertNotIn("Work", data) - data = data_q.get() - self.assertIsNone(data) def test_read_hetero_graph_out_dgl(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() raw = ( - "People|99|1 0 0 1 |1|0|Alex|1\nPeople|8|1 0 0 1 |1|1|Bill|0\nCompany|2|0.3|0\n", - "Colleague|99|8|0.1|2021|1|0\nColleague|8|99|1.5|2020|0|1\nWork|99|2|a b \nWork|2|8|c d \n", + "People|99|1 0 0 1 |1|0|Alex\n People|8|1 0 0 1 |1|1|Bill\n Company|2|0.3\n ", + "Colleague|99|8|0.1|2021|1|0\n Colleague|8|99|1.5|2020|0|1\n Work|99|2|a b \n Work|2|8|c d \n ", + "99" ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "graph", - "dgl", - {"People": ["x"], "Company": ["x"]}, - {"People": ["y"]}, - {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, - { - "People": { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - "Company": {"x": "FLOAT", "is_seed": "BOOL"}, - }, - {"Colleague": ["x", "time"]}, - {"Colleague": ["y"]}, - {"Colleague": ["is_train"], "Work": ["category"]}, - { - "Colleague": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "People", - "IsDirected": False, - "x": "DOUBLE", - "time": "INT", - "y": "INT", - "is_train": "BOOL", + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "dgl", + v_in_feats = {"People": ["x"], "Company": ["x"]}, + v_out_labels = {"People": ["y"]}, + v_extra_feats = {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, + v_attr_types = + { + "People": { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + "Company": {"x": "FLOAT", "is_seed": "BOOL"}, + }, + e_in_feats = {"Colleague": ["x", "time"]}, + e_out_labels = {"Colleague": ["y"]}, + e_extra_feats = {"Colleague": ["is_train"], "Work": ["category"]}, + e_attr_types = { + "Colleague": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "People", + "IsDirected": False, + "x": "DOUBLE", + "time": "INT", + "y": "INT", + "is_train": "BOOL"}, + "Work": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "Company", + "IsDirected": False, + "category": "LIST:STRING"} }, - "Work": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "Company", - "IsDirected": False, - "category": "LIST:STRING" - } - }, - False, - "|", - True, - True, + delimiter = "|", + is_hetero = True, + seed_type = "vertex" + ) ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, DGLGraph) assert_close_torch( data.edges(etype="Colleague"), (torch.tensor([0, 1]), torch.tensor([1, 0])) @@ -846,42 +898,48 @@ def test_read_hetero_graph_out_dgl(self): data.edges(etype="Work"), (torch.tensor([0, 1]), torch.tensor([0, 0])) ) self.assertListEqual(data.extra_data["Work"]["category"], [['a', 'b'], ['c', 'd']]) - data = data_q.get() - self.assertIsNone(data) def test_read_bool_label(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() raw = ( - "99|1 0 0 1 |1|0|Alex|1\n8|1 0 0 1 |1|1|Bill|0\n", - "99|8|0.1|2021|1|0\n8|99|1.5|2020|0|1\n", + "99|1 0 0 1 |1|0|Alex\n 8|1 0 0 1 |1|1|Bill\n ", + "99|8|0.1|2021|1|0\n 8|99|1.5|2020|0|1\n ", + "99" ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_data( - exit_event, - read_task_q, - data_q, - "graph", - "pyg", - ["x"], - ["y"], - ["train_mask", "name", "is_seed"], - { - "x": "LIST:INT", - "y": "BOOL", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - ["x", "time"], - ["y"], - ["is_train"], - {"x": "DOUBLE", "time": "INT", "y": "BOOL", "is_train": "BOOL"}, - delimiter="|" - ) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "pyg", + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "name", "is_seed"], + v_attr_types = + { + "x": "LIST:INT", + "y": "BOOL", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train"], + e_attr_types = {"x": "DOUBLE", "time": "INT", "y": "BOOL", "is_train": "BOOL"}, + delimiter = "|", + seed_type = "vertex" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, pygData) assert_close_torch(data["edge_index"], torch.tensor([[0, 1], [1, 0]])) assert_close_torch( @@ -895,8 +953,6 @@ def test_read_bool_label(self): assert_close_torch(data["train_mask"], torch.tensor([False, True])) assert_close_torch(data["is_seed"], torch.tensor([True, False])) self.assertListEqual(data["name"], ["Alex", "Bill"]) - data = data_q.get() - self.assertIsNone(data) if __name__ == "__main__": diff --git a/tests/test_gds_EdgeLoader.py b/tests/test_gds_EdgeLoader.py index 5e22671c..eac8075d 100644 --- a/tests/test_gds_EdgeLoader.py +++ b/tests/test_gds_EdgeLoader.py @@ -7,7 +7,7 @@ from pyTigerGraph.gds.utilities import is_query_installed -class TestGDSEdgeLoader(unittest.TestCase): +class TestGDSEdgeLoaderKafka(unittest.TestCase): @classmethod def setUpClass(cls): cls.conn = make_connection(graphname="Cora") @@ -17,44 +17,47 @@ def test_init(self): graph=self.conn, batch_size=1024, shuffle=False, - filter_by=None, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = EdgeLoader( graph=self.conn, batch_size=1024, shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data.head()) self.assertIsInstance(data, DataFrame) + self.assertIn("source", data) + self.assertIn("target", data) num_batches += 1 + self.assertEqual(data.shape[1], 2) + batch_sizes.append(data.shape[0]) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_whole_edgelist(self): loader = EdgeLoader( graph=self.conn, num_batches=1, - shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, + shuffle=False, kafka_address="kafka:9092", ) data = loader.data # print(data) self.assertIsInstance(data, DataFrame) + self.assertIn("source", data) + self.assertIn("target", data) + self.assertEqual(data.shape[0], 10556) + self.assertEqual(data.shape[1], 2) def test_iterate_attr(self): loader = EdgeLoader( @@ -62,19 +65,46 @@ def test_iterate_attr(self): attributes=["time", "is_train"], batch_size=1024, shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data.head()) + self.assertIsInstance(data, DataFrame) + self.assertIn("time", data) + self.assertIn("is_train", data) + num_batches += 1 + self.assertEqual(data.shape[1], 4) + batch_sizes.append(data.shape[0]) + self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) + + def test_iterate_attr_multichar_delimiter(self): + loader = EdgeLoader( + graph=self.conn, + attributes=["time", "is_train"], + batch_size=1024, + shuffle=True, + kafka_address="kafka:9092", + delimiter="|$" + ) + num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data.head()) self.assertIsInstance(data, DataFrame) self.assertIn("time", data) self.assertIn("is_train", data) num_batches += 1 + self.assertEqual(data.shape[1], 4) + batch_sizes.append(data.shape[0]) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_sasl_plaintext(self): loader = EdgeLoader( @@ -130,6 +160,111 @@ def test_sasl_ssl(self): # TODO: test filter_by +class TestGDSHeteroEdgeLoaderKafka(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = EdgeLoader( + graph=self.conn, + batch_size=1024, + shuffle=False, + kafka_address="kafka:9092" + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_iterate_as_homo(self): + loader = EdgeLoader( + graph=self.conn, + batch_size=1024, + shuffle=False, + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data.head()) + self.assertIsInstance(data, DataFrame) + self.assertIn("source", data) + self.assertIn("target", data) + num_batches += 1 + self.assertEqual(data.shape[1], 2) + batch_sizes.append(data.shape[0]) + self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) + + def test_iterate_hetero(self): + loader = EdgeLoader( + graph=self.conn, + attributes={"v0v0": ["is_train", "is_val"], "v2v0": ["is_train", "is_val"]}, + batch_size=200, + shuffle=True, + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + batchsize = 0 + if "v0v0" in data: + self.assertIsInstance(data["v0v0"], DataFrame) + self.assertIn("is_val", data["v0v0"]) + self.assertIn("is_train", data["v0v0"]) + batchsize += data["v0v0"].shape[0] + self.assertEqual(data["v0v0"].shape[1], 4) + if "v2v0" in data: + self.assertIsInstance(data["v2v0"], DataFrame) + self.assertIn("is_val", data["v2v0"]) + self.assertIn("is_train", data["v2v0"]) + batchsize += data["v2v0"].shape[0] + self.assertEqual(data["v2v0"].shape[1], 4) + self.assertGreater(len(data), 0) + num_batches += 1 + batch_sizes.append(batchsize) + self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 200) + self.assertLessEqual(batch_sizes[-1], 200) + + def test_iterate_hetero_multichar_delimiter(self): + loader = EdgeLoader( + graph=self.conn, + attributes={"v0v0": ["is_train", "is_val"], "v2v0": ["is_train", "is_val"]}, + batch_size=200, + shuffle=True, + delimiter="|$", + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + batchsize = 0 + if "v0v0" in data: + self.assertIsInstance(data["v0v0"], DataFrame) + self.assertIn("is_val", data["v0v0"]) + self.assertIn("is_train", data["v0v0"]) + batchsize += data["v0v0"].shape[0] + self.assertEqual(data["v0v0"].shape[1], 4) + if "v2v0" in data: + self.assertIsInstance(data["v2v0"], DataFrame) + self.assertIn("is_val", data["v2v0"]) + self.assertIn("is_train", data["v2v0"]) + batchsize += data["v2v0"].shape[0] + self.assertEqual(data["v2v0"].shape[1], 4) + self.assertGreater(len(data), 0) + num_batches += 1 + batch_sizes.append(batchsize) + self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 200) + self.assertLessEqual(batch_sizes[-1], 200) + + class TestGDSEdgeLoaderREST(unittest.TestCase): @classmethod def setUpClass(cls): @@ -140,41 +275,45 @@ def test_init(self): graph=self.conn, batch_size=1024, shuffle=False, - filter_by=None, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = EdgeLoader( graph=self.conn, batch_size=1024, shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data.head()) self.assertIsInstance(data, DataFrame) + self.assertIn("source", data) + self.assertIn("target", data) num_batches += 1 + self.assertEqual(data.shape[1], 2) + batch_sizes.append(data.shape[0]) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_whole_edgelist(self): loader = EdgeLoader( graph=self.conn, num_batches=1, - shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, + shuffle=False, ) data = loader.data # print(data) self.assertIsInstance(data, DataFrame) + self.assertIn("source", data) + self.assertIn("target", data) + self.assertEqual(data.shape[0], 10556) + self.assertEqual(data.shape[1], 2) + def test_iterate_attr(self): loader = EdgeLoader( @@ -182,18 +321,21 @@ def test_iterate_attr(self): attributes=["time", "is_train"], batch_size=1024, shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data.head()) self.assertIsInstance(data, DataFrame) self.assertIn("time", data) self.assertIn("is_train", data) num_batches += 1 + self.assertEqual(data.shape[1], 4) + batch_sizes.append(data.shape[0]) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_iterate_attr_multichar_delimiter(self): loader = EdgeLoader( @@ -201,19 +343,22 @@ def test_iterate_attr_multichar_delimiter(self): attributes=["time", "is_train"], batch_size=1024, shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, delimiter="|$" ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data.head()) self.assertIsInstance(data, DataFrame) self.assertIn("time", data) self.assertIn("is_train", data) num_batches += 1 + self.assertEqual(data.shape[1], 4) + batch_sizes.append(data.shape[0]) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) # TODO: test filter_by @@ -228,87 +373,110 @@ def test_init(self): graph=self.conn, batch_size=1024, shuffle=False, - filter_by=None, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 6) + self.assertIsNone(loader.num_batches) def test_iterate_as_homo(self): loader = EdgeLoader( graph=self.conn, batch_size=1024, shuffle=False, - filter_by=None, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data.head()) self.assertIsInstance(data, DataFrame) + self.assertIn("source", data) + self.assertIn("target", data) num_batches += 1 + self.assertEqual(data.shape[1], 2) + batch_sizes.append(data.shape[0]) self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_iterate_hetero(self): loader = EdgeLoader( graph=self.conn, attributes={"v0v0": ["is_train", "is_val"], "v2v0": ["is_train", "is_val"]}, batch_size=200, - shuffle=True, # Needed to get around VID distribution issues - filter_by=None, - loader_id=None, - buffer_size=4, + shuffle=True, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) - self.assertEqual(len(data), 2) - self.assertIsInstance(data["v0v0"], DataFrame) - self.assertIn("is_val", data["v0v0"]) - self.assertIn("is_train", data["v0v0"]) - self.assertIsInstance(data["v2v0"], DataFrame) - self.assertIn("is_val", data["v2v0"]) - self.assertIn("is_train", data["v2v0"]) + batchsize = 0 + if "v0v0" in data: + self.assertIsInstance(data["v0v0"], DataFrame) + self.assertIn("is_val", data["v0v0"]) + self.assertIn("is_train", data["v0v0"]) + batchsize += data["v0v0"].shape[0] + self.assertEqual(data["v0v0"].shape[1], 4) + if "v2v0" in data: + self.assertIsInstance(data["v2v0"], DataFrame) + self.assertIn("is_val", data["v2v0"]) + self.assertIn("is_train", data["v2v0"]) + batchsize += data["v2v0"].shape[0] + self.assertEqual(data["v2v0"].shape[1], 4) + self.assertGreater(len(data), 0) num_batches += 1 + batch_sizes.append(batchsize) self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 200) + self.assertLessEqual(batch_sizes[-1], 200) def test_iterate_hetero_multichar_delimiter(self): loader = EdgeLoader( graph=self.conn, attributes={"v0v0": ["is_train", "is_val"], "v2v0": ["is_train", "is_val"]}, batch_size=200, - shuffle=True, # Needed to get around VID distribution issues - filter_by=None, - loader_id=None, - buffer_size=4, + shuffle=True, delimiter="|$" ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) - if num_batches == 0: - self.assertEqual(data["v0v0"].shape[0]+data["v2v0"].shape[0], 200) - self.assertEqual(len(data), 2) - self.assertIsInstance(data["v0v0"], DataFrame) - self.assertIn("is_val", data["v0v0"]) - self.assertIn("is_train", data["v0v0"]) - self.assertIsInstance(data["v2v0"], DataFrame) - self.assertIn("is_val", data["v2v0"]) - self.assertIn("is_train", data["v2v0"]) + batchsize = 0 + if "v0v0" in data: + self.assertIsInstance(data["v0v0"], DataFrame) + self.assertIn("is_val", data["v0v0"]) + self.assertIn("is_train", data["v0v0"]) + batchsize += data["v0v0"].shape[0] + self.assertEqual(data["v0v0"].shape[1], 4) + if "v2v0" in data: + self.assertIsInstance(data["v2v0"], DataFrame) + self.assertIn("is_val", data["v2v0"]) + self.assertIn("is_train", data["v2v0"]) + batchsize += data["v2v0"].shape[0] + self.assertEqual(data["v2v0"].shape[1], 4) + self.assertGreater(len(data), 0) num_batches += 1 + batch_sizes.append(batchsize) self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 200) + self.assertLessEqual(batch_sizes[-1], 200) if __name__ == "__main__": suite = unittest.TestSuite() - suite.addTest(TestGDSEdgeLoader("test_init")) - suite.addTest(TestGDSEdgeLoader("test_iterate")) - suite.addTest(TestGDSEdgeLoader("test_whole_edgelist")) - suite.addTest(TestGDSEdgeLoader("test_iterate_attr")) + suite.addTest(TestGDSEdgeLoaderKafka("test_init")) + suite.addTest(TestGDSEdgeLoaderKafka("test_iterate")) + suite.addTest(TestGDSEdgeLoaderKafka("test_whole_edgelist")) + suite.addTest(TestGDSEdgeLoaderKafka("test_iterate_attr")) + suite.addTest(TestGDSEdgeLoaderKafka("test_iterate_attr_multichar_delimiter")) # suite.addTest(TestGDSEdgeLoader("test_sasl_plaintext")) # suite.addTest(TestGDSEdgeLoader("test_sasl_ssl")) + suite.addTest(TestGDSHeteroEdgeLoaderKafka("test_init")) + suite.addTest(TestGDSHeteroEdgeLoaderKafka("test_iterate_as_homo")) + suite.addTest(TestGDSHeteroEdgeLoaderKafka("test_iterate_hetero")) + suite.addTest(TestGDSHeteroEdgeLoaderKafka("test_iterate_hetero_multichar_delimiter")) suite.addTest(TestGDSEdgeLoaderREST("test_init")) suite.addTest(TestGDSEdgeLoaderREST("test_iterate")) suite.addTest(TestGDSEdgeLoaderREST("test_whole_edgelist")) diff --git a/tests/test_gds_EdgeNeighborLoader.py b/tests/test_gds_EdgeNeighborLoader.py index a284689a..96c9e936 100644 --- a/tests/test_gds_EdgeNeighborLoader.py +++ b/tests/test_gds_EdgeNeighborLoader.py @@ -2,6 +2,7 @@ from pyTigerGraphUnitTest import make_connection from torch_geometric.data import Data as pygData +from torch_geometric.data import HeteroData as pygHeteroData from pyTigerGraph.gds.dataloaders import EdgeNeighborLoader from pyTigerGraph.gds.utilities import is_query_installed @@ -12,7 +13,7 @@ class TestGDSEdgeNeighborLoaderKafka(unittest.TestCase): def setUpClass(cls): cls.conn = make_connection(graphname="Cora") - def test_iterate_pyg(self): + def test_init(self): loader = EdgeNeighborLoader( graph=self.conn, v_in_feats=["x"], @@ -23,12 +24,57 @@ def test_iterate_pyg(self): shuffle=False, filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_iterate_pyg(self): + loader = EdgeNeighborLoader( + graph=self.conn, + v_in_feats=["x"], + e_extra_feats=["is_train"], + batch_size=1024, + num_neighbors=10, + num_hops=2, + shuffle=True, + filter_by=None, + output_format="PyG", + kafka_address="kafka:9092", + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygData) + self.assertIn("x", data) + self.assertIn("is_seed", data) + self.assertIn("is_train", data) + self.assertGreater(data["x"].shape[0], 0) + self.assertGreater(data["edge_index"].shape[1], 0) + num_batches += 1 + batch_sizes.append(int(data["is_seed"].sum())) + self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) + + def test_iterate_pyg_distributed(self): + loader = EdgeNeighborLoader( + graph=self.conn, + v_in_feats=["x"], + e_extra_feats=["is_train"], + batch_size=1024, + num_neighbors=10, + num_hops=2, + shuffle=True, + filter_by=None, + output_format="PyG", + kafka_address="kafka:9092", + distributed_query=True + ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygData) @@ -38,7 +84,11 @@ def test_iterate_pyg(self): self.assertGreater(data["x"].shape[0], 0) self.assertGreater(data["edge_index"].shape[1], 0) num_batches += 1 + batch_sizes.append(int(data["is_seed"].sum())) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_sasl_ssl(self): loader = EdgeNeighborLoader( @@ -92,12 +142,9 @@ def test_init(self): shuffle=False, filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertIsNone(loader.num_batches) def test_iterate_pyg(self): loader = EdgeNeighborLoader( @@ -107,14 +154,12 @@ def test_iterate_pyg(self): batch_size=1024, num_neighbors=10, num_hops=2, - shuffle=False, + shuffle=True, filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygData) @@ -124,7 +169,11 @@ def test_iterate_pyg(self): self.assertGreater(data["x"].shape[0], 0) self.assertGreater(data["edge_index"].shape[1], 0) num_batches += 1 + batch_sizes.append(int(data["is_seed"].sum())) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_iterate_spektral(self): loader = EdgeNeighborLoader( @@ -154,13 +203,202 @@ def test_iterate_spektral(self): self.assertEqual(num_batches, 11) +class TestGDSHeteroEdgeNeighborLoaderREST(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = EdgeNeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x", "y"], "v2": ["x"]}, + e_extra_feats={"v2v0":["is_train"], "v0v0":[], "v2v2":[]}, + e_seed_types=["v2v0"], + batch_size=100, + num_neighbors=5, + num_hops=2, + shuffle=True, + filter_by=None, + output_format="PyG", + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_iterate_pyg(self): + loader = EdgeNeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x", "y"], "v2": ["x"]}, + e_extra_feats={"v2v0":["is_train"], "v0v0":[], "v2v2":[]}, + e_seed_types=["v2v0"], + batch_size=100, + num_neighbors=5, + num_hops=2, + shuffle=True, + filter_by=None, + output_format="PyG", + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertGreater(data["v2"]["x"].shape[0], 0) + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + self.assertEqual( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1], + data['v2', 'v2v0', 'v0']["is_train"].shape[0] + ) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v2', 'v2v2', 'v2') in data.edge_types: + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) + num_batches += 1 + batch_sizes.append(int(data['v2', 'v2v0', 'v0']["is_seed"].sum())) + self.assertEqual(num_batches, 10) + for i in batch_sizes[:-1]: + self.assertEqual(i, 100) + self.assertLessEqual(batch_sizes[-1], 100) + + +class TestGDSHeteroEdgeNeighborLoaderKafka(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = EdgeNeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x", "y"], "v2": ["x"]}, + e_extra_feats={"v2v0":["is_train"], "v0v0":[], "v2v2":[]}, + e_seed_types=["v2v0"], + batch_size=100, + num_neighbors=5, + num_hops=2, + shuffle=True, + filter_by=None, + output_format="PyG", + kafka_address="kafka:9092" + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_iterate_pyg(self): + loader = EdgeNeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x", "y"], "v2": ["x"]}, + e_extra_feats={"v2v0":["is_train"], "v0v0":[], "v2v2":[]}, + e_seed_types=["v2v0"], + batch_size=100, + num_neighbors=5, + num_hops=2, + shuffle=True, + filter_by=None, + output_format="PyG", + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertGreater(data["v2"]["x"].shape[0], 0) + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + self.assertEqual( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1], + data['v2', 'v2v0', 'v0']["is_train"].shape[0] + ) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v2', 'v2v2', 'v2') in data.edge_types: + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) + num_batches += 1 + batch_sizes.append(int(data['v2', 'v2v0', 'v0']["is_seed"].sum())) + self.assertEqual(num_batches, 10) + for i in batch_sizes[:-1]: + self.assertEqual(i, 100) + self.assertLessEqual(batch_sizes[-1], 100) + + def test_iterate_pyg_distributed(self): + loader = EdgeNeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x", "y"], "v2": ["x"]}, + e_extra_feats={"v2v0":["is_train"], "v0v0":[], "v2v2":[]}, + e_seed_types=["v2v0"], + batch_size=100, + num_neighbors=5, + num_hops=2, + shuffle=True, + filter_by=None, + output_format="PyG", + kafka_address="kafka:9092", + distributed_query=True + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertGreater(data["v2"]["x"].shape[0], 0) + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + self.assertEqual( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1], + data['v2', 'v2v0', 'v0']["is_train"].shape[0] + ) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v2', 'v2v2', 'v2') in data.edge_types: + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) + num_batches += 1 + batch_sizes.append(int(data['v2', 'v2v0', 'v0']["is_seed"].sum())) + self.assertEqual(num_batches, 10) + for i in batch_sizes[:-1]: + self.assertEqual(i, 100) + self.assertLessEqual(batch_sizes[-1], 100) + + if __name__ == "__main__": suite = unittest.TestSuite() + suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_init")) suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_iterate_pyg")) + suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_iterate_pyg_distributed")) # suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_sasl_ssl")) suite.addTest(TestGDSEdgeNeighborLoaderREST("test_init")) suite.addTest(TestGDSEdgeNeighborLoaderREST("test_iterate_pyg")) # suite.addTest(TestGDSEdgeNeighborLoaderREST("test_iterate_spektral")) - + suite.addTest(TestGDSHeteroEdgeNeighborLoaderREST("test_init")) + suite.addTest(TestGDSHeteroEdgeNeighborLoaderREST("test_iterate_pyg")) + suite.addTest(TestGDSHeteroEdgeNeighborLoaderKafka("test_init")) + suite.addTest(TestGDSHeteroEdgeNeighborLoaderKafka("test_iterate_pyg")) + suite.addTest(TestGDSHeteroEdgeNeighborLoaderKafka("test_iterate_pyg_distributed")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) diff --git a/tests/test_gds_GDS.py b/tests/test_gds_GDS.py index b247ac05..f6364f88 100644 --- a/tests/test_gds_GDS.py +++ b/tests/test_gds_GDS.py @@ -26,7 +26,7 @@ def test_neighborLoader(self): buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertEqual(loader.batch_size, 16) def test_neighborLoader_multiple_filters(self): loaders = self.conn.gds.neighborLoader( @@ -60,7 +60,7 @@ def test_graphLoader(self): buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertEqual(loader.batch_size, 1024) def test_vertexLoader(self): loader = self.conn.gds.vertexLoader( @@ -72,7 +72,7 @@ def test_vertexLoader(self): buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertEqual(loader.batch_size, 16) def test_edgeLoader(self): loader = self.conn.gds.edgeLoader( @@ -83,7 +83,7 @@ def test_edgeLoader(self): buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertEqual(loader.batch_size, 1024) def test_edgeNeighborLoader(self): loader = self.conn.gds.edgeNeighborLoader( @@ -100,7 +100,7 @@ def test_edgeNeighborLoader(self): buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertEqual(loader.batch_size, 1024) def test_configureKafka(self): self.conn.gds.configureKafka(kafka_address="kafka:9092") diff --git a/tests/test_gds_GraphLoader.py b/tests/test_gds_GraphLoader.py index b06d94aa..ef3c2c07 100644 --- a/tests/test_gds_GraphLoader.py +++ b/tests/test_gds_GraphLoader.py @@ -9,7 +9,7 @@ from pyTigerGraph.gds.utilities import is_query_installed -class TestGDSGraphLoader(unittest.TestCase): +class TestGDSGraphLoaderKafka(unittest.TestCase): @classmethod def setUpClass(cls): cls.conn = make_connection(graphname="Cora") @@ -22,15 +22,11 @@ def test_init(self): v_extra_feats=["train_mask", "val_mask", "test_mask"], batch_size=1024, shuffle=True, - filter_by=None, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertIsNone(loader.num_batches) def test_iterate_pyg(self): loader = GraphLoader( @@ -40,14 +36,11 @@ def test_iterate_pyg(self): v_extra_feats=["train_mask", "val_mask", "test_mask"], batch_size=1024, shuffle=True, - filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygData) @@ -56,8 +49,12 @@ def test_iterate_pyg(self): self.assertIn("train_mask", data) self.assertIn("val_mask", data) self.assertIn("test_mask", data) + batch_sizes.append(data["edge_index"].shape[1]) num_batches += 1 self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_iterate_df(self): loader = GraphLoader( @@ -67,25 +64,28 @@ def test_iterate_df(self): v_extra_feats=["train_mask", "val_mask", "test_mask"], batch_size=1024, shuffle=True, - filter_by=None, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] for data in loader: - # print(num_batches, data) + # print(num_batches, data, flush=True) self.assertIsInstance(data[0], DataFrame) - self.assertIsInstance(data[1], DataFrame) self.assertIn("x", data[0].columns) self.assertIn("y", data[0].columns) self.assertIn("train_mask", data[0].columns) self.assertIn("val_mask", data[0].columns) self.assertIn("test_mask", data[0].columns) + self.assertIsInstance(data[1], DataFrame) + self.assertIn("source", data[1]) + self.assertIn("target", data[1]) + batch_sizes.append(data[1].shape[0]) num_batches += 1 self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_edge_attr(self): loader = GraphLoader( @@ -97,11 +97,7 @@ def test_edge_attr(self): e_extra_feats=["is_train"], batch_size=1024, shuffle=True, - filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) num_batches = 0 @@ -207,12 +203,9 @@ def test_init(self): shuffle=True, filter_by=None, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertIsNone(loader.num_batches) def test_iterate_pyg(self): loader = GraphLoader( @@ -222,13 +215,10 @@ def test_iterate_pyg(self): v_extra_feats=["train_mask", "val_mask", "test_mask"], batch_size=1024, shuffle=True, - filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygData) @@ -238,7 +228,11 @@ def test_iterate_pyg(self): self.assertIn("val_mask", data) self.assertIn("test_mask", data) num_batches += 1 + batch_sizes.append(data["edge_index"].shape[1]) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_iterate_df(self): loader = GraphLoader( @@ -248,13 +242,10 @@ def test_iterate_df(self): v_extra_feats=["train_mask", "val_mask", "test_mask"], batch_size=1024, shuffle=True, - filter_by=None, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data[0], DataFrame) @@ -264,8 +255,14 @@ def test_iterate_df(self): self.assertIn("train_mask", data[0].columns) self.assertIn("val_mask", data[0].columns) self.assertIn("test_mask", data[0].columns) + self.assertIn("source", data[1]) + self.assertIn("target", data[1]) + batch_sizes.append(data[1].shape[0]) num_batches += 1 self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_edge_attr(self): loader = GraphLoader( @@ -277,11 +274,7 @@ def test_edge_attr(self): e_extra_feats=["is_train"], batch_size=1024, shuffle=True, - filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4 ) num_batches = 0 for data in loader: @@ -327,11 +320,7 @@ def test_iterate_spektral(self): v_extra_feats=["train_mask", "val_mask", "test_mask"], batch_size=1024, shuffle=True, - filter_by=None, output_format="spektral", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 for data in loader: @@ -360,14 +349,10 @@ def test_init(self): v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, batch_size=1024, shuffle=False, - filter_by=None, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 6) + self.assertIsNone(loader.num_batches) def test_iterate_pyg(self): loader = GraphLoader( @@ -376,26 +361,36 @@ def test_iterate_pyg(self): "v1": ["x"]}, v_out_labels={"v0": ["y"]}, v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, - batch_size=1024, + batch_size=300, shuffle=True, - filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygHeteroData) - self.assertIn("x", data["v0"]) - self.assertIn("y", data["v0"]) - self.assertIn("train_mask", data["v0"]) - self.assertIn("val_mask", data["v0"]) - self.assertIn("test_mask", data["v0"]) - self.assertIn("x", data["v1"]) + self.assertTrue("v0" in data.node_types or "v1" in data.node_types) + if "v0" in data.node_types: + self.assertIn("x", data["v0"]) + self.assertIn("y", data["v0"]) + self.assertIn("train_mask", data["v0"]) + self.assertIn("val_mask", data["v0"]) + self.assertIn("test_mask", data["v0"]) + if "v1" in data.node_types: + self.assertIn("x", data["v1"]) + self.assertTrue(('v0', 'v0v0', 'v0') in data.edge_types or ('v1', 'v1v1', 'v1') in data.edge_types) + batchsize = 0 + if ('v0', 'v0v0', 'v0') in data.edge_types: + batchsize += data["v0", "v0v0", "v0"].edge_index.shape[1] + if ('v1', 'v1v1', 'v1') in data.edge_types: + batchsize += data["v1", "v1v1", "v1"].edge_index.shape[1] + batch_sizes.append(batchsize) num_batches += 1 self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 300) + self.assertLessEqual(batch_sizes[-1], 300) def test_iterate_df(self): loader = GraphLoader( @@ -404,29 +399,41 @@ def test_iterate_df(self): "v1": ["x"]}, v_out_labels={"v0": ["y"]}, v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, - batch_size=1024, + batch_size=300, shuffle=False, - filter_by=None, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) - self.assertIsInstance(data[0]["v0"], DataFrame) - self.assertIsInstance(data[0]["v1"], DataFrame) - self.assertIsInstance(data[1]["v0v0"], DataFrame) - self.assertIsInstance(data[1]["v1v1"], DataFrame) - self.assertIn("x", data[0]["v0"].columns) - self.assertIn("y", data[0]["v0"].columns) - self.assertIn("train_mask", data[0]["v0"].columns) - self.assertIn("val_mask", data[0]["v0"].columns) - self.assertIn("test_mask", data[0]["v0"].columns) - self.assertIn("x", data[0]["v1"].columns) + self.assertTrue("v0" in data[0] or "v1" in data[0]) + if "v0" in data[0]: + self.assertIsInstance(data[0]["v0"], DataFrame) + self.assertIn("x", data[0]["v0"].columns) + self.assertIn("y", data[0]["v0"].columns) + self.assertIn("train_mask", data[0]["v0"].columns) + self.assertIn("val_mask", data[0]["v0"].columns) + self.assertIn("test_mask", data[0]["v0"].columns) + if "v1" in data[0]: + self.assertIsInstance(data[0]["v1"], DataFrame) + self.assertIn("x", data[0]["v1"].columns) + self.assertTrue("v0v0" in data[1] or "v1v1" in data[1]) + batchsize = 0 + if "v0v0" in data[1]: + self.assertIsInstance(data[1]["v0v0"], DataFrame) + batchsize += data[1]["v0v0"].shape[0] + self.assertEqual(data[1]["v0v0"].shape[1], 2) + if "v1v1" in data[1]: + self.assertIsInstance(data[1]["v1v1"], DataFrame) + batchsize += data[1]["v1v1"].shape[0] + self.assertEqual(data[1]["v1v1"].shape[1], 2) + batch_sizes.append(batchsize) num_batches += 1 self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 300) + self.assertLessEqual(batch_sizes[-1], 300) def test_edge_attr(self): loader = GraphLoader( @@ -437,38 +444,197 @@ def test_edge_attr(self): v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, e_extra_feats={"v0v0": ["is_train", "is_val"], "v1v1": ["is_train", "is_val"]}, + batch_size=300, + shuffle=False, + output_format="PyG", + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertTrue("v0" in data.node_types or "v1" in data.node_types) + if "v0" in data.node_types: + self.assertIn("x", data["v0"]) + self.assertIn("y", data["v0"]) + self.assertIn("train_mask", data["v0"]) + self.assertIn("val_mask", data["v0"]) + self.assertIn("test_mask", data["v0"]) + if "v1" in data.node_types: + self.assertIn("x", data["v1"]) + self.assertTrue(('v0', 'v0v0', 'v0') in data.edge_types or ('v1', 'v1v1', 'v1') in data.edge_types) + batchsize = 0 + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertIn("is_train", data["v0", "v0v0", "v0"]) + self.assertIn("is_val", data["v0", "v0v0", "v0"]) + batchsize += data["v0", "v0v0", "v0"].edge_index.shape[1] + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertIn("is_train", data["v1", "v1v1", "v1"]) + self.assertIn("is_val", data["v1", "v1v1", "v1"]) + batchsize += data["v1", "v1v1", "v1"].edge_index.shape[1] + batch_sizes.append(batchsize) + num_batches += 1 + self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 300) + self.assertLessEqual(batch_sizes[-1], 300) + + +class TestGDSHeteroGraphLoaderKafka(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = GraphLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], + "v1": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, batch_size=1024, shuffle=False, - filter_by=None, + output_format="dataframe", + kafka_address="kafka:9092", + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_iterate_pyg(self): + loader = GraphLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], + "v1": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + batch_size=300, + shuffle=True, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4 + kafka_address="kafka:9092", + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertTrue("v0" in data.node_types or "v1" in data.node_types) + if "v0" in data.node_types: + self.assertIn("x", data["v0"]) + self.assertIn("y", data["v0"]) + self.assertIn("train_mask", data["v0"]) + self.assertIn("val_mask", data["v0"]) + self.assertIn("test_mask", data["v0"]) + if "v1" in data.node_types: + self.assertIn("x", data["v1"]) + self.assertTrue(('v0', 'v0v0', 'v0') in data.edge_types or ('v1', 'v1v1', 'v1') in data.edge_types) + batchsize = 0 + if ('v0', 'v0v0', 'v0') in data.edge_types: + batchsize += data["v0", "v0v0", "v0"].edge_index.shape[1] + if ('v1', 'v1v1', 'v1') in data.edge_types: + batchsize += data["v1", "v1v1", "v1"].edge_index.shape[1] + batch_sizes.append(batchsize) + num_batches += 1 + self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 300) + self.assertLessEqual(batch_sizes[-1], 300) + + def test_iterate_df(self): + loader = GraphLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], + "v1": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + batch_size=300, + shuffle=False, + output_format="dataframe", + kafka_address="kafka:9092", + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertTrue("v0" in data[0] or "v1" in data[0]) + if "v0" in data[0]: + self.assertIsInstance(data[0]["v0"], DataFrame) + self.assertIn("x", data[0]["v0"].columns) + self.assertIn("y", data[0]["v0"].columns) + self.assertIn("train_mask", data[0]["v0"].columns) + self.assertIn("val_mask", data[0]["v0"].columns) + self.assertIn("test_mask", data[0]["v0"].columns) + if "v1" in data[0]: + self.assertIsInstance(data[0]["v1"], DataFrame) + self.assertIn("x", data[0]["v1"].columns) + self.assertTrue("v0v0" in data[1] or "v1v1" in data[1]) + batchsize = 0 + if "v0v0" in data[1]: + self.assertIsInstance(data[1]["v0v0"], DataFrame) + batchsize += data[1]["v0v0"].shape[0] + self.assertEqual(data[1]["v0v0"].shape[1], 2) + if "v1v1" in data[1]: + self.assertIsInstance(data[1]["v1v1"], DataFrame) + batchsize += data[1]["v1v1"].shape[0] + self.assertEqual(data[1]["v1v1"].shape[1], 2) + batch_sizes.append(batchsize) + num_batches += 1 + self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 300) + self.assertLessEqual(batch_sizes[-1], 300) + + def test_edge_attr(self): + loader = GraphLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], + "v1": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + e_extra_feats={"v0v0": ["is_train", "is_val"], + "v1v1": ["is_train", "is_val"]}, + batch_size=300, + shuffle=False, + output_format="PyG", + kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygHeteroData) - self.assertIn("x", data["v0"]) - self.assertIn("y", data["v0"]) - self.assertIn("train_mask", data["v0"]) - self.assertIn("val_mask", data["v0"]) - self.assertIn("test_mask", data["v0"]) - self.assertIn("x", data["v1"]) - self.assertIn("is_train", data["v0v0"]) - self.assertIn("is_train", data["v1v1"]) - self.assertIn("is_val", data["v0v0"]) - self.assertIn("is_val", data["v1v1"]) + self.assertTrue("v0" in data.node_types or "v1" in data.node_types) + if "v0" in data.node_types: + self.assertIn("x", data["v0"]) + self.assertIn("y", data["v0"]) + self.assertIn("train_mask", data["v0"]) + self.assertIn("val_mask", data["v0"]) + self.assertIn("test_mask", data["v0"]) + if "v1" in data.node_types: + self.assertIn("x", data["v1"]) + self.assertTrue(('v0', 'v0v0', 'v0') in data.edge_types or ('v1', 'v1v1', 'v1') in data.edge_types) + batchsize = 0 + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertIn("is_train", data["v0", "v0v0", "v0"]) + self.assertIn("is_val", data["v0", "v0v0", "v0"]) + batchsize += data["v0", "v0v0", "v0"].edge_index.shape[1] + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertIn("is_train", data["v1", "v1v1", "v1"]) + self.assertIn("is_val", data["v1", "v1v1", "v1"]) + batchsize += data["v1", "v1v1", "v1"].edge_index.shape[1] + batch_sizes.append(batchsize) num_batches += 1 - self.assertEqual(num_batches, 2) + self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 300) + self.assertLessEqual(batch_sizes[-1], 300) if __name__ == "__main__": suite = unittest.TestSuite() - suite.addTest(TestGDSGraphLoader("test_init")) - suite.addTest(TestGDSGraphLoader("test_iterate_pyg")) - suite.addTest(TestGDSGraphLoader("test_iterate_df")) - suite.addTest(TestGDSGraphLoader("test_edge_attr")) + suite.addTest(TestGDSGraphLoaderKafka("test_init")) + suite.addTest(TestGDSGraphLoaderKafka("test_iterate_pyg")) + suite.addTest(TestGDSGraphLoaderKafka("test_iterate_df")) + suite.addTest(TestGDSGraphLoaderKafka("test_edge_attr")) # suite.addTest(TestGDSGraphLoader("test_sasl_plaintext")) # suite.addTest(TestGDSGraphLoader("test_sasl_ssl")) suite.addTest(TestGDSGraphLoaderREST("test_init")) @@ -480,6 +646,10 @@ def test_edge_attr(self): suite.addTest(TestGDSHeteroGraphLoaderREST("test_iterate_pyg")) suite.addTest(TestGDSHeteroGraphLoaderREST("test_iterate_df")) suite.addTest(TestGDSHeteroGraphLoaderREST("test_edge_attr")) + suite.addTest(TestGDSHeteroGraphLoaderKafka("test_init")) + suite.addTest(TestGDSHeteroGraphLoaderKafka("test_iterate_pyg")) + suite.addTest(TestGDSHeteroGraphLoaderKafka("test_iterate_df")) + suite.addTest(TestGDSHeteroGraphLoaderKafka("test_edge_attr")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) diff --git a/tests/test_gds_HGTLoader.py b/tests/test_gds_HGTLoader.py index 15c51820..6278f8a8 100644 --- a/tests/test_gds_HGTLoader.py +++ b/tests/test_gds_HGTLoader.py @@ -12,8 +12,6 @@ class TestGDSHGTLoaderREST(unittest.TestCase): @classmethod def setUpClass(cls): cls.conn = make_connection(graphname="hetero") - splitter = cls.conn.gds.vertexSplitter(v_types=["v2"], train_mask=0.3) - splitter.run() def test_init(self): loader = HGTLoader( @@ -26,12 +24,9 @@ def test_init(self): num_hops=2, shuffle=True, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 18) + self.assertIsNone(loader.num_batches) def test_whole_graph_df(self): loader = HGTLoader( @@ -44,14 +39,11 @@ def test_whole_graph_df(self): num_hops=2, shuffle=False, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.data - self.assertTupleEqual(data[0]["v0"].shape, (76, 7)) - self.assertTupleEqual(data[0]["v1"].shape, (110, 3)) - self.assertTupleEqual(data[0]["v2"].shape, (100, 3)) + self.assertTupleEqual(data[0]["v0"].shape, (152, 7)) + self.assertTupleEqual(data[0]["v1"].shape, (220, 3)) + self.assertTupleEqual(data[0]["v2"].shape, (200, 3)) self.assertTrue( data[1]["v0v0"].shape[0] > 0 and data[1]["v0v0"].shape[0] <= 710 ) @@ -82,9 +74,6 @@ def test_whole_graph_pyg(self): num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.data self.assertTupleEqual(data["v0"]["x"].shape, (76, 77)) @@ -129,23 +118,62 @@ def test_iterate_pyg(self): v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, v_out_labels={"v0": ["y"]}, v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, - num_batches=6, + v_seed_types=["v2"], + batch_size=16, num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, - filter_by= {"v2": "train_mask"} ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygHeteroData) self.assertGreater(data["v2"]["x"].shape[0], 0) self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + batch_sizes.append(int(data["v2"]["is_seed"].sum())) + self.assertGreater(data["v1"]["x"].shape[0], 0) + self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) + self.assertEqual( + data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] + ) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + self.assertTrue( + data['v2', 'v2v1', 'v1']["edge_index"].shape[1] > 0 + and data['v2', 'v2v1', 'v1']["edge_index"].shape[1] <= 959 + ) + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertTrue( + data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 + and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 + ) + if ('v1', 'v1v2', 'v2') in data.edge_types: + self.assertTrue( + data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 + and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 + ) num_batches += 1 - self.assertEqual(num_batches, 6) + self.assertEqual(num_batches, 7) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_fetch(self): loader = HGTLoader( @@ -155,29 +183,195 @@ def test_fetch(self): v_out_labels={"v0": ["y"]}, v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, batch_size=16, - num_hops=1, + num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.fetch( - [{"primary_id": "13", "type": "v2"}, {"primary_id": "28", "type": "v2"}] + [{"primary_id": "10", "type": "v0"}, {"primary_id": "55", "type": "v0"}] ) - self.assertIn("13", data["v2"]["primary_id"]) - self.assertIn("28", data["v2"]["primary_id"]) - for i, d in enumerate(data["v2"]["primary_id"]): - if d == "13" or d == "28": - self.assertTrue(data["v2"]["is_seed"][i].item()) + self.assertIn("primary_id", data["v0"]) + self.assertGreater(data["v0"]["x"].shape[0], 2) + self.assertGreater(data["v0v0"]["edge_index"].shape[1], 0) + self.assertIn("10", data["v0"]["primary_id"]) + self.assertIn("55", data["v0"]["primary_id"]) + for i, d in enumerate(data["v0"]["primary_id"]): + if d == "10" or d == "55": + self.assertTrue(data["v0"]["is_seed"][i].item()) else: - self.assertFalse(data["v2"]["is_seed"][i].item()) - # self.assertGreaterEqual(len(data["v0"]["primary_id"]), 2) - # self.assertGreaterEqual(len(data["v1"]["primary_id"]), 2) - # print("v0", data["v0"]["primary_id"]) - # print("v1", data["v1"]["primary_id"]) - # print("v2", data["v2"]["primary_id"]) - # print(data) + self.assertFalse(data["v0"]["is_seed"][i].item()) + + +class TestGDSHGTLoaderKafka(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = HGTLoader( + graph=self.conn, + num_neighbors={"v0": 3, "v1": 5, "v2": 10}, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + batch_size=16, + num_hops=2, + shuffle=True, + output_format="PyG", + kafka_address="kafka:9092" + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_whole_graph_df(self): + loader = HGTLoader( + graph=self.conn, + num_neighbors={"v0": 3, "v1": 5, "v2": 10}, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + num_batches=1, + num_hops=2, + shuffle=False, + output_format="dataframe", + kafka_address="kafka:9092" + ) + data = loader.data + self.assertTupleEqual(data[0]["v0"].shape, (152, 7)) + self.assertTupleEqual(data[0]["v1"].shape, (220, 3)) + self.assertTupleEqual(data[0]["v2"].shape, (200, 3)) + self.assertTrue( + data[1]["v0v0"].shape[0] > 0 and data[1]["v0v0"].shape[0] <= 710 + ) + self.assertTrue( + data[1]["v1v1"].shape[0] > 0 and data[1]["v1v1"].shape[0] <= 1044 + ) + self.assertTrue( + data[1]["v1v2"].shape[0] > 0 and data[1]["v1v2"].shape[0] <= 1038 + ) + self.assertTrue( + data[1]["v2v0"].shape[0] > 0 and data[1]["v2v0"].shape[0] <= 943 + ) + self.assertTrue( + data[1]["v2v1"].shape[0] > 0 and data[1]["v2v1"].shape[0] <= 959 + ) + self.assertTrue( + data[1]["v2v2"].shape[0] > 0 and data[1]["v2v2"].shape[0] <= 966 + ) + + def test_whole_graph_pyg(self): + loader = HGTLoader( + graph=self.conn, + num_neighbors={"v0": 3, "v1": 5, "v2": 10}, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + num_batches=1, + num_hops=2, + shuffle=False, + output_format="PyG", + kafka_address="kafka:9092" + ) + data = loader.data + self.assertTupleEqual(data["v0"]["x"].shape, (76, 77)) + self.assertEqual(data["v0"]["y"].shape[0], 76) + self.assertEqual(data["v0"]["train_mask"].shape[0], 76) + self.assertEqual(data["v0"]["test_mask"].shape[0], 76) + self.assertEqual(data["v0"]["val_mask"].shape[0], 76) + self.assertEqual(data["v0"]["is_seed"].shape[0], 76) + self.assertTupleEqual(data["v1"]["x"].shape, (110, 57)) + self.assertEqual(data["v1"]["is_seed"].shape[0], 110) + self.assertTupleEqual(data["v2"]["x"].shape, (100, 48)) + self.assertEqual(data["v2"]["is_seed"].shape[0], 100) + self.assertTrue( + data["v0v0"]["edge_index"].shape[1] > 0 + and data["v0v0"]["edge_index"].shape[1] <= 710 + ) + self.assertTrue( + data["v1v1"]["edge_index"].shape[1] > 0 + and data["v1v1"]["edge_index"].shape[1] <= 1044 + ) + self.assertTrue( + data["v1v2"]["edge_index"].shape[1] > 0 + and data["v1v2"]["edge_index"].shape[1] <= 1038 + ) + self.assertTrue( + data["v2v0"]["edge_index"].shape[1] > 0 + and data["v2v0"]["edge_index"].shape[1] <= 943 + ) + self.assertTrue( + data["v2v1"]["edge_index"].shape[1] > 0 + and data["v2v1"]["edge_index"].shape[1] <= 959 + ) + self.assertTrue( + data["v2v2"]["edge_index"].shape[1] > 0 + and data["v2v2"]["edge_index"].shape[1] <= 966 + ) + + def test_iterate_pyg(self): + loader = HGTLoader( + graph=self.conn, + num_neighbors={"v0": 2, "v1": 2, "v2": 2}, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + v_seed_types=["v2"], + batch_size=16, + num_hops=2, + shuffle=False, + output_format="PyG", + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertGreater(data["v2"]["x"].shape[0], 0) + self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + batch_sizes.append(int(data["v2"]["is_seed"].sum())) + self.assertGreater(data["v1"]["x"].shape[0], 0) + self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) + self.assertEqual( + data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] + ) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + self.assertTrue( + data['v2', 'v2v1', 'v1']["edge_index"].shape[1] > 0 + and data['v2', 'v2v1', 'v1']["edge_index"].shape[1] <= 959 + ) + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertTrue( + data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 + and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 + ) + if ('v1', 'v1v2', 'v2') in data.edge_types: + self.assertTrue( + data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 + and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 + ) + num_batches += 1 + self.assertEqual(num_batches, 7) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) if __name__ == "__main__": @@ -187,6 +381,10 @@ def test_fetch(self): suite.addTest(TestGDSHGTLoaderREST("test_whole_graph_pyg")) suite.addTest(TestGDSHGTLoaderREST("test_iterate_pyg")) suite.addTest(TestGDSHGTLoaderREST("test_fetch")) + suite.addTest(TestGDSHGTLoaderKafka("test_init")) + suite.addTest(TestGDSHGTLoaderKafka("test_whole_graph_df")) + suite.addTest(TestGDSHGTLoaderKafka("test_whole_graph_pyg")) + suite.addTest(TestGDSHGTLoaderKafka("test_iterate_pyg")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) diff --git a/tests/test_gds_NeighborLoader.py b/tests/test_gds_NeighborLoader.py index 2bb9a7c0..2a3f4ac8 100644 --- a/tests/test_gds_NeighborLoader.py +++ b/tests/test_gds_NeighborLoader.py @@ -26,13 +26,10 @@ def test_init(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, - kafka_address="kafka:9092", + kafka_address="kafka:9092" ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertIsNone(loader.num_batches) def test_iterate_pyg(self): loader = NeighborLoader( @@ -46,27 +43,27 @@ def test_iterate_pyg(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, - kafka_address="kafka:9092", + kafka_address="kafka:9092" ) - for epoch in range(2): - with self.subTest(i=epoch): - num_batches = 0 - for data in loader: - # print(num_batches, data) - self.assertIsInstance(data, pygData) - self.assertIn("x", data) - self.assertIn("y", data) - self.assertIn("train_mask", data) - self.assertIn("val_mask", data) - self.assertIn("test_mask", data) - self.assertIn("is_seed", data) - self.assertGreater(data["x"].shape[0], 0) - self.assertGreater(data["edge_index"].shape[1], 0) - num_batches += 1 - self.assertEqual(num_batches, 9) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygData) + self.assertIn("x", data) + self.assertIn("y", data) + self.assertIn("train_mask", data) + self.assertIn("val_mask", data) + self.assertIn("test_mask", data) + self.assertIn("is_seed", data) + self.assertGreater(data["x"].shape[0], 0) + self.assertGreater(data["edge_index"].shape[1], 0) + num_batches += 1 + batch_sizes.append(int(data["is_seed"].sum())) + self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_iterate_stop_pyg(self): loader = NeighborLoader( @@ -80,9 +77,6 @@ def test_iterate_stop_pyg(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) for epoch in range(2): @@ -109,33 +103,6 @@ def test_iterate_stop_pyg(self): rq_id = self.conn.getRunningQueries()["results"] self.assertEqual(len(rq_id), 0) - def test_whole_graph_pyg(self): - loader = NeighborLoader( - graph=self.conn, - v_in_feats=["x"], - v_out_labels=["y"], - v_extra_feats=["train_mask", "val_mask", "test_mask"], - num_batches=1, - num_neighbors=10, - num_hops=2, - shuffle=False, - filter_by="train_mask", - output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, - kafka_address="kafka:9092", - ) - data = loader.data - # print(data) - self.assertIsInstance(data, pygData) - self.assertIn("x", data) - self.assertIn("y", data) - self.assertIn("train_mask", data) - self.assertIn("val_mask", data) - self.assertIn("test_mask", data) - self.assertIn("is_seed", data) - def test_edge_attr(self): loader = NeighborLoader( graph=self.conn, @@ -150,14 +117,12 @@ def test_edge_attr(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, - kafka_address="kafka:9092", + kafka_address="kafka:9092" ) for epoch in range(2): with self.subTest(i=epoch): num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygData) @@ -170,7 +135,11 @@ def test_edge_attr(self): self.assertIn("edge_feat", data) self.assertIn("is_train", data) num_batches += 1 + batch_sizes.append(int(data["is_seed"].sum())) self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_sasl_plaintext(self): loader = NeighborLoader( @@ -307,12 +276,9 @@ def test_init(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertIsNone(loader.num_batches) def test_iterate_pyg(self): loader = NeighborLoader( @@ -326,11 +292,9 @@ def test_iterate_pyg(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygData) @@ -343,33 +307,11 @@ def test_iterate_pyg(self): self.assertGreater(data["x"].shape[0], 0) self.assertGreater(data["edge_index"].shape[1], 0) num_batches += 1 + batch_sizes.append(int(data["is_seed"].sum())) self.assertEqual(num_batches, 9) - - def test_whole_graph_pyg(self): - loader = NeighborLoader( - graph=self.conn, - v_in_feats=["x"], - v_out_labels=["y"], - v_extra_feats=["train_mask", "val_mask", "test_mask"], - num_batches=1, - num_neighbors=10, - num_hops=2, - shuffle=False, - filter_by="train_mask", - output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, - ) - data = loader.data - # print(data) - self.assertIsInstance(data, pygData) - self.assertIn("x", data) - self.assertIn("y", data) - self.assertIn("train_mask", data) - self.assertIn("val_mask", data) - self.assertIn("test_mask", data) - self.assertIn("is_seed", data) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_edge_attr(self): loader = NeighborLoader( @@ -385,13 +327,11 @@ def test_edge_attr(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) for epoch in range(2): with self.subTest(i=epoch): num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygData) @@ -404,7 +344,11 @@ def test_edge_attr(self): self.assertIn("edge_feat", data) self.assertIn("is_train", data) num_batches += 1 + batch_sizes.append(int(data["is_seed"].sum())) self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_fetch(self): loader = NeighborLoader( @@ -415,26 +359,26 @@ def test_fetch(self): batch_size=16, num_neighbors=10, num_hops=2, - shuffle=True, + shuffle=False, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.fetch( [ - {"primary_id": "100", "type": "Paper"}, + {"primary_id": "60", "type": "Paper"}, {"primary_id": "55", "type": "Paper"}, ] ) + # print(data) + # print(data["primary_id"]) + # print(data["is_seed"]) self.assertIn("primary_id", data) self.assertGreater(data["x"].shape[0], 2) self.assertGreater(data["edge_index"].shape[1], 0) - self.assertIn("100", data["primary_id"]) + self.assertIn("60", data["primary_id"]) self.assertIn("55", data["primary_id"]) for i, d in enumerate(data["primary_id"]): - if d == "100" or d == "55": + if d == "60" or d == "55": self.assertTrue(data["is_seed"][i].item()) else: self.assertFalse(data["is_seed"][i].item()) @@ -452,23 +396,20 @@ def test_fetch_delimiter(self): delimiter="$|", filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.fetch( [ - {"primary_id": "100", "type": "Paper"}, + {"primary_id": "60", "type": "Paper"}, {"primary_id": "55", "type": "Paper"}, ] ) self.assertIn("primary_id", data) self.assertGreater(data["x"].shape[0], 2) self.assertGreater(data["edge_index"].shape[1], 0) - self.assertIn("100", data["primary_id"]) + self.assertIn("60", data["primary_id"]) self.assertIn("55", data["primary_id"]) for i, d in enumerate(data["primary_id"]): - if d == "100" or d == "55": + if d == "60" or d == "55": self.assertTrue(data["is_seed"][i].item()) else: self.assertFalse(data["is_seed"][i].item()) @@ -485,9 +426,6 @@ def test_iterate_spektral(self): shuffle=True, filter_by="train_mask", output_format="spektral", - add_self_loop=False, - loader_id=None, - buffer_size=4 ) num_batches = 0 for data in loader: @@ -502,32 +440,6 @@ def test_iterate_spektral(self): num_batches += 1 self.assertEqual(num_batches, 9) - def test_whole_graph_spektral(self): - loader = NeighborLoader( - graph=self.conn, - v_in_feats=["x"], - v_out_labels=["y"], - v_extra_feats=["train_mask", "val_mask", "test_mask"], - num_batches=1, - num_neighbors=10, - num_hops=2, - shuffle=False, - filter_by="train_mask", - output_format="spektral", - add_self_loop=False, - loader_id=None, - buffer_size=4, - ) - data = loader.data - # print(data) - # self.assertIsInstance(data, spData) - self.assertIn("x", data) - self.assertIn("y", data) - self.assertIn("train_mask", data) - self.assertIn("val_mask", data) - self.assertIn("test_mask", data) - self.assertIn("is_seed", data) - def test_reinstall_query(self): loader = NeighborLoader( graph=self.conn, @@ -540,9 +452,6 @@ def test_reinstall_query(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) query_name = loader.query_name @@ -565,12 +474,9 @@ def test_init(self): num_hops=2, shuffle=True, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 18) + self.assertIsNone(loader.num_batches) def test_whole_graph_df(self): loader = NeighborLoader( @@ -583,14 +489,11 @@ def test_whole_graph_df(self): num_hops=2, shuffle=False, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.data - self.assertTupleEqual(data[0]["v0"].shape, (76, 7)) - self.assertTupleEqual(data[0]["v1"].shape, (110, 3)) - self.assertTupleEqual(data[0]["v2"].shape, (100, 3)) + self.assertTupleEqual(data[0]["v0"].shape, (152, 7)) + self.assertTupleEqual(data[0]["v1"].shape, (220, 3)) + self.assertTupleEqual(data[0]["v2"].shape, (200, 3)) self.assertTrue( data[1]["v0v0"].shape[0] > 0 and data[1]["v0v0"].shape[0] <= 710 ) @@ -621,9 +524,6 @@ def test_whole_graph_pyg(self): num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.data # print(data) @@ -668,57 +568,75 @@ def test_iterate_pyg(self): v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, v_out_labels={"v0": ["y"]}, v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + v_seed_types=["v2"], batch_size=16, num_neighbors=10, num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygHeteroData) - self.assertGreater(data["v0"]["x"].shape[0], 0) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) - self.assertEqual( - data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] - ) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) - self.assertGreater(data["v1"]["x"].shape[0], 0) - self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) self.assertGreater(data["v2"]["x"].shape[0], 0) self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + batch_sizes.append(int(data["v2"]["is_seed"].sum())) + if "v1" in data.node_types: + self.assertGreater(data["v1"]["x"].shape[0], 0) + self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) + if "v0" in data.node_types: + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) + self.assertEqual( + data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] + ) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) self.assertTrue( - data["v0v0"]["edge_index"].shape[1] > 0 - and data["v0v0"]["edge_index"].shape[1] <= 710 - ) - self.assertTrue( - data["v1v1"]["edge_index"].shape[1] > 0 - and data["v1v1"]["edge_index"].shape[1] <= 1044 - ) - self.assertTrue( - data["v1v2"]["edge_index"].shape[1] > 0 - and data["v1v2"]["edge_index"].shape[1] <= 1038 - ) - self.assertTrue( - data["v2v0"]["edge_index"].shape[1] > 0 - and data["v2v0"]["edge_index"].shape[1] <= 943 - ) - self.assertTrue( - data["v2v1"]["edge_index"].shape[1] > 0 - and data["v2v1"]["edge_index"].shape[1] <= 959 - ) - self.assertTrue( - data["v2v2"]["edge_index"].shape[1] > 0 - and data["v2v2"]["edge_index"].shape[1] <= 966 - ) + ('v0', 'v0v0', 'v0') in data.edge_types or + ('v1', 'v1v1', 'v1') in data.edge_types or + ('v1', 'v1v2', 'v2') in data.edge_types or + ('v2', 'v2v0', 'v0') in data.edge_types or + ('v2', 'v2v1', 'v1') in data.edge_types or + ('v2', 'v2v2', 'v2') in data.edge_types) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertTrue( + data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 + and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 + ) + if ('v1', 'v1v2', 'v2') in data.edge_types: + self.assertTrue( + data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 + and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 + ) + if ('v2', 'v2v0', 'v0') in data.edge_types: + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + if ('v2', 'v2v1', 'v1') in data.edge_types: + self.assertTrue( + data['v2', 'v2v1', 'v1']["edge_index"].shape[1] > 0 + and data['v2', 'v2v1', 'v1']["edge_index"].shape[1] <= 959 + ) + if ('v2', 'v2v2', 'v2') in data.edge_types: + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) num_batches += 1 - self.assertEqual(num_batches, 18) + self.assertEqual(num_batches, 7) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_iterate_pyg_multichar_delimiter(self): loader = NeighborLoader( @@ -726,58 +644,76 @@ def test_iterate_pyg_multichar_delimiter(self): v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, v_out_labels={"v0": ["y"]}, v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + v_seed_types=["v2"], batch_size=16, num_neighbors=10, num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, delimiter="|$" ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygHeteroData) - self.assertGreater(data["v0"]["x"].shape[0], 0) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) - self.assertEqual( - data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] - ) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) - self.assertGreater(data["v1"]["x"].shape[0], 0) - self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) self.assertGreater(data["v2"]["x"].shape[0], 0) self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + batch_sizes.append(int(data["v2"]["is_seed"].sum())) + if "v1" in data.node_types: + self.assertGreater(data["v1"]["x"].shape[0], 0) + self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) + if "v2" in data.node_types: + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) + self.assertEqual( + data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] + ) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) self.assertTrue( - data["v0v0"]["edge_index"].shape[1] > 0 - and data["v0v0"]["edge_index"].shape[1] <= 710 - ) - self.assertTrue( - data["v1v1"]["edge_index"].shape[1] > 0 - and data["v1v1"]["edge_index"].shape[1] <= 1044 - ) - self.assertTrue( - data["v1v2"]["edge_index"].shape[1] > 0 - and data["v1v2"]["edge_index"].shape[1] <= 1038 - ) - self.assertTrue( - data["v2v0"]["edge_index"].shape[1] > 0 - and data["v2v0"]["edge_index"].shape[1] <= 943 - ) - self.assertTrue( - data["v2v1"]["edge_index"].shape[1] > 0 - and data["v2v1"]["edge_index"].shape[1] <= 959 - ) - self.assertTrue( - data["v2v2"]["edge_index"].shape[1] > 0 - and data["v2v2"]["edge_index"].shape[1] <= 966 - ) + ('v0', 'v0v0', 'v0') in data.edge_types or + ('v1', 'v1v1', 'v1') in data.edge_types or + ('v1', 'v1v2', 'v2') in data.edge_types or + ('v2', 'v2v0', 'v0') in data.edge_types or + ('v2', 'v2v1', 'v1') in data.edge_types or + ('v2', 'v2v2', 'v2') in data.edge_types) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertTrue( + data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 + and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 + ) + if ('v1', 'v1v2', 'v2') in data.edge_types: + self.assertTrue( + data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 + and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 + ) + if ('v2', 'v2v0', 'v0') in data.edge_types: + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + if ('v2', 'v2v1', 'v1') in data.edge_types: + self.assertTrue( + data['v2', 'v2v1', 'v1']["edge_index"].shape[1] > 0 + and data['v2', 'v2v1', 'v1']["edge_index"].shape[1] <= 959 + ) + if ('v2', 'v2v2', 'v2') in data.edge_types: + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) num_batches += 1 - self.assertEqual(num_batches, 18) + self.assertEqual(num_batches, 7) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_fetch(self): loader = NeighborLoader( @@ -790,9 +726,6 @@ def test_fetch(self): num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.fetch( [{"primary_id": "10", "type": "v0"}, {"primary_id": "55", "type": "v0"}] @@ -820,9 +753,6 @@ def test_fetch_delimiter(self): shuffle=False, output_format="PyG", delimiter="$|", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.fetch( [{"primary_id": "10", "type": "v0"}, {"primary_id": "55", "type": "v0"}] @@ -849,9 +779,6 @@ def test_metadata(self): num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) test = (["v0", "v1", "v2"], @@ -865,19 +792,281 @@ def test_metadata(self): metadata = loader.metadata() self.assertEqual(test, metadata) + +class TestGDSHeteroNeighborLoaderKafka(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = NeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + batch_size=16, + num_neighbors=10, + num_hops=2, + shuffle=True, + output_format="PyG", + kafka_address="kafka:9092" + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_whole_graph_df(self): + loader = NeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + num_batches=1, + num_neighbors=10, + num_hops=2, + shuffle=False, + output_format="dataframe", + kafka_address="kafka:9092" + ) + data = loader.data + self.assertTupleEqual(data[0]["v0"].shape, (152, 7)) + self.assertTupleEqual(data[0]["v1"].shape, (220, 3)) + self.assertTupleEqual(data[0]["v2"].shape, (200, 3)) + self.assertTrue( + data[1]["v0v0"].shape[0] > 0 and data[1]["v0v0"].shape[0] <= 710 + ) + self.assertTrue( + data[1]["v1v1"].shape[0] > 0 and data[1]["v1v1"].shape[0] <= 1044 + ) + self.assertTrue( + data[1]["v1v2"].shape[0] > 0 and data[1]["v1v2"].shape[0] <= 1038 + ) + self.assertTrue( + data[1]["v2v0"].shape[0] > 0 and data[1]["v2v0"].shape[0] <= 943 + ) + self.assertTrue( + data[1]["v2v1"].shape[0] > 0 and data[1]["v2v1"].shape[0] <= 959 + ) + self.assertTrue( + data[1]["v2v2"].shape[0] > 0 and data[1]["v2v2"].shape[0] <= 966 + ) + + def test_whole_graph_pyg(self): + loader = NeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + num_batches=1, + num_neighbors=10, + num_hops=2, + shuffle=False, + output_format="PyG", + kafka_address="kafka:9092" + ) + data = loader.data + # print(data) + self.assertTupleEqual(data["v0"]["x"].shape, (76, 77)) + self.assertEqual(data["v0"]["y"].shape[0], 76) + self.assertEqual(data["v0"]["train_mask"].shape[0], 76) + self.assertEqual(data["v0"]["test_mask"].shape[0], 76) + self.assertEqual(data["v0"]["val_mask"].shape[0], 76) + self.assertEqual(data["v0"]["is_seed"].shape[0], 76) + self.assertTupleEqual(data["v1"]["x"].shape, (110, 57)) + self.assertEqual(data["v1"]["is_seed"].shape[0], 110) + self.assertTupleEqual(data["v2"]["x"].shape, (100, 48)) + self.assertEqual(data["v2"]["is_seed"].shape[0], 100) + self.assertTrue( + data["v0v0"]["edge_index"].shape[1] > 0 + and data["v0v0"]["edge_index"].shape[1] <= 710 + ) + self.assertTrue( + data["v1v1"]["edge_index"].shape[1] > 0 + and data["v1v1"]["edge_index"].shape[1] <= 1044 + ) + self.assertTrue( + data["v1v2"]["edge_index"].shape[1] > 0 + and data["v1v2"]["edge_index"].shape[1] <= 1038 + ) + self.assertTrue( + data["v2v0"]["edge_index"].shape[1] > 0 + and data["v2v0"]["edge_index"].shape[1] <= 943 + ) + self.assertTrue( + data["v2v1"]["edge_index"].shape[1] > 0 + and data["v2v1"]["edge_index"].shape[1] <= 959 + ) + self.assertTrue( + data["v2v2"]["edge_index"].shape[1] > 0 + and data["v2v2"]["edge_index"].shape[1] <= 966 + ) + + def test_iterate_pyg(self): + loader = NeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + v_seed_types=["v2"], + batch_size=16, + num_neighbors=10, + num_hops=2, + shuffle=False, + output_format="PyG", + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertGreater(data["v2"]["x"].shape[0], 0) + self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + batch_sizes.append(int(data["v2"]["is_seed"].sum())) + if "v1" in data.node_types: + self.assertGreater(data["v1"]["x"].shape[0], 0) + self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) + if "v2" in data.node_types: + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) + self.assertEqual( + data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] + ) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) + self.assertTrue( + ('v0', 'v0v0', 'v0') in data.edge_types or + ('v1', 'v1v1', 'v1') in data.edge_types or + ('v1', 'v1v2', 'v2') in data.edge_types or + ('v2', 'v2v0', 'v0') in data.edge_types or + ('v2', 'v2v1', 'v1') in data.edge_types or + ('v2', 'v2v2', 'v2') in data.edge_types) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertTrue( + data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 + and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 + ) + if ('v1', 'v1v2', 'v2') in data.edge_types: + self.assertTrue( + data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 + and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 + ) + if ('v2', 'v2v0', 'v0') in data.edge_types: + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + if ('v2', 'v2v1', 'v1') in data.edge_types: + self.assertTrue( + data['v2', 'v2v1', 'v1']["edge_index"].shape[1] > 0 + and data['v2', 'v2v1', 'v1']["edge_index"].shape[1] <= 959 + ) + if ('v2', 'v2v2', 'v2') in data.edge_types: + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) + num_batches += 1 + self.assertEqual(num_batches, 7) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) + + def test_iterate_pyg_multichar_delimiter(self): + loader = NeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + v_seed_types=["v2"], + batch_size=16, + num_neighbors=10, + num_hops=2, + shuffle=False, + output_format="PyG", + delimiter="|$", + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertGreater(data["v2"]["x"].shape[0], 0) + self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + batch_sizes.append(int(data["v2"]["is_seed"].sum())) + if "v1" in data.node_types: + self.assertGreater(data["v1"]["x"].shape[0], 0) + self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) + if "v2" in data.node_types: + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) + self.assertEqual( + data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] + ) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) + self.assertTrue( + ('v0', 'v0v0', 'v0') in data.edge_types or + ('v1', 'v1v1', 'v1') in data.edge_types or + ('v1', 'v1v2', 'v2') in data.edge_types or + ('v2', 'v2v0', 'v0') in data.edge_types or + ('v2', 'v2v1', 'v1') in data.edge_types or + ('v2', 'v2v2', 'v2') in data.edge_types) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertTrue( + data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 + and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 + ) + if ('v1', 'v1v2', 'v2') in data.edge_types: + self.assertTrue( + data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 + and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 + ) + if ('v2', 'v2v0', 'v0') in data.edge_types: + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + if ('v2', 'v2v1', 'v1') in data.edge_types: + self.assertTrue( + data['v2', 'v2v1', 'v1']["edge_index"].shape[1] > 0 + and data['v2', 'v2v1', 'v1']["edge_index"].shape[1] <= 959 + ) + if ('v2', 'v2v2', 'v2') in data.edge_types: + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) + num_batches += 1 + self.assertEqual(num_batches, 7) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) + + if __name__ == "__main__": suite = unittest.TestSuite() suite.addTest(TestGDSNeighborLoaderKafka("test_init")) suite.addTest(TestGDSNeighborLoaderKafka("test_iterate_pyg")) suite.addTest(TestGDSNeighborLoaderKafka("test_iterate_stop_pyg")) - suite.addTest(TestGDSNeighborLoaderKafka("test_whole_graph_pyg")) suite.addTest(TestGDSNeighborLoaderKafka("test_edge_attr")) suite.addTest(TestGDSNeighborLoaderKafka("test_distributed_loaders")) # suite.addTest(TestGDSNeighborLoaderKafka("test_sasl_plaintext")) # suite.addTest(TestGDSNeighborLoaderKafka("test_sasl_ssl")) suite.addTest(TestGDSNeighborLoaderREST("test_init")) suite.addTest(TestGDSNeighborLoaderREST("test_iterate_pyg")) - suite.addTest(TestGDSNeighborLoaderREST("test_whole_graph_pyg")) suite.addTest(TestGDSNeighborLoaderREST("test_edge_attr")) suite.addTest(TestGDSNeighborLoaderREST("test_fetch")) suite.addTest(TestGDSNeighborLoaderREST("test_fetch_delimiter")) @@ -890,6 +1079,11 @@ def test_metadata(self): suite.addTest(TestGDSHeteroNeighborLoaderREST("test_fetch")) suite.addTest(TestGDSHeteroNeighborLoaderREST("test_fetch_delimiter")) suite.addTest(TestGDSHeteroNeighborLoaderREST("test_metadata")) + suite.addTest(TestGDSHeteroNeighborLoaderKafka("test_init")) + suite.addTest(TestGDSHeteroNeighborLoaderKafka("test_whole_graph_df")) + suite.addTest(TestGDSHeteroNeighborLoaderKafka("test_whole_graph_pyg")) + suite.addTest(TestGDSHeteroNeighborLoaderKafka("test_iterate_pyg")) + suite.addTest(TestGDSHeteroNeighborLoaderKafka("test_iterate_pyg_multichar_delimiter")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) diff --git a/tests/test_gds_NodePieceLoader.py b/tests/test_gds_NodePieceLoader.py index e34e7a84..8074f990 100644 --- a/tests/test_gds_NodePieceLoader.py +++ b/tests/test_gds_NodePieceLoader.py @@ -8,7 +8,7 @@ from pyTigerGraph.gds.utilities import is_query_installed -class TestGDSNodePieceLoader(unittest.TestCase): +class TestGDSNodePieceLoaderKafka(unittest.TestCase): @classmethod def setUpClass(cls): cls.conn = make_connection(graphname="Cora") @@ -18,16 +18,14 @@ def test_init(self): graph=self.conn, v_feats=["x", "y", "train_mask", "val_mask", "test_mask"], compute_anchors=True, - anchor_percentage=0.5, batch_size=16, shuffle=True, filter_by="train_mask", - loader_id=None, - buffer_size=4, - kafka_address="kafka:9092", + anchor_percentage=0.5, + kafka_address="kafka:9092" ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = NodePieceLoader( @@ -38,13 +36,12 @@ def test_iterate(self): shuffle=True, filter_by="train_mask", anchor_percentage=0.5, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] for data in loader: - # print(num_batches, data.head()) + # print(num_batches, data.shape, data.head()) self.assertIsInstance(data, DataFrame) self.assertIn("x", data.columns) self.assertIn("y", data.columns) @@ -53,8 +50,12 @@ def test_iterate(self): self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + batch_sizes.append(data.shape[0]) num_batches += 1 self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_all_vertices(self): loader = NodePieceLoader( @@ -64,8 +65,6 @@ def test_all_vertices(self): shuffle=True, filter_by="train_mask", anchor_percentage=0.5, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) data = loader.data @@ -78,6 +77,7 @@ def test_all_vertices(self): self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[0], 140) def test_sasl_plaintext(self): loader = NodePieceLoader( @@ -158,11 +158,9 @@ def test_init(self): shuffle=True, filter_by="train_mask", anchor_percentage=0.5, - loader_id=None, - buffer_size=4 ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = NodePieceLoader( @@ -173,12 +171,11 @@ def test_iterate(self): shuffle=True, filter_by="train_mask", anchor_percentage=0.5, - loader_id=None, - buffer_size=4 ) num_batches = 0 + batch_sizes = [] for data in loader: - # print(num_batches, data.head()) + # print(num_batches, data.shape, data.head()) self.assertIsInstance(data, DataFrame) self.assertIn("x", data.columns) self.assertIn("y", data.columns) @@ -187,8 +184,12 @@ def test_iterate(self): self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + batch_sizes.append(data.shape[0]) num_batches += 1 self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_all_vertices(self): loader = NodePieceLoader( @@ -198,8 +199,6 @@ def test_all_vertices(self): shuffle=True, filter_by="train_mask", anchor_percentage=0.5, - loader_id=None, - buffer_size=4 ) data = loader.data # print(data) @@ -211,6 +210,7 @@ def test_all_vertices(self): self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[0], 140) class TestGDSHeteroNodePieceLoaderREST(unittest.TestCase): @@ -228,11 +228,9 @@ def test_init(self): batch_size=20, shuffle=True, filter_by=None, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 10) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = NodePieceLoader( @@ -244,23 +242,32 @@ def test_iterate(self): batch_size=20, shuffle=True, filter_by=None, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) - self.assertIsInstance(data["v0"], DataFrame) - self.assertIsInstance(data["v1"], DataFrame) - self.assertIn("x", data["v0"].columns) - self.assertIn("relational_context", data["v0"].columns) - self.assertIn("anchors", data["v0"].columns) - self.assertIn("y", data["v0"].columns) - self.assertIn("x", data["v1"].columns) - self.assertIn("relational_context", data["v1"].columns) - self.assertIn("anchors", data["v1"].columns) + batchsize = 0 + self.assertTrue(("v0" in data) or ("v1" in data)) + if "v0" in data: + self.assertIsInstance(data["v0"], DataFrame) + self.assertIn("x", data["v0"].columns) + self.assertIn("relational_context", data["v0"].columns) + self.assertIn("anchors", data["v0"].columns) + self.assertIn("y", data["v0"].columns) + batchsize += data["v0"].shape[0] + if "v1" in data: + self.assertIsInstance(data["v1"], DataFrame) + self.assertIn("x", data["v1"].columns) + self.assertIn("relational_context", data["v1"].columns) + self.assertIn("anchors", data["v1"].columns) + batchsize += data["v1"].shape[0] num_batches += 1 + batch_sizes.append(batchsize) self.assertEqual(num_batches, 10) + for i in batch_sizes[:-1]: + self.assertEqual(i, 20) + self.assertLessEqual(batch_sizes[-1], 20) def test_all_vertices(self): loader = NodePieceLoader( @@ -272,8 +279,89 @@ def test_all_vertices(self): num_batches=1, shuffle=False, filter_by=None, - loader_id=None, - buffer_size=4, + ) + data = loader.data + # print(data) + self.assertIsInstance(data["v0"], DataFrame) + self.assertTupleEqual(data["v0"].shape, (76, 6)) + self.assertIsInstance(data["v1"], DataFrame) + self.assertIn("x", data["v0"].columns) + self.assertIn("y", data["v0"].columns) + self.assertIn("x", data["v1"].columns) + self.assertIn("anchors", data["v0"].columns) + self.assertIn("relational_context", data["v0"].columns) + self.assertIn("anchors", data["v1"].columns) + + +class TestGDSHeteroNodePieceLoaderKafka(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = NodePieceLoader( + graph=self.conn, + compute_anchors=True, + anchor_percentage=0.5, + v_feats={"v0": ["x", "y"], + "v1": ["x"]}, + batch_size=20, + shuffle=True, + filter_by=None, + kafka_address="kafka:9092", + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_iterate(self): + loader = NodePieceLoader( + compute_anchors=True, + anchor_percentage=0.5, + graph=self.conn, + v_feats={"v0": ["x", "y"], + "v1": ["x"]}, + batch_size=20, + shuffle=True, + filter_by=None, + kafka_address="kafka:9092", + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + batchsize = 0 + self.assertTrue(("v0" in data) or ("v1" in data)) + if "v0" in data: + self.assertIsInstance(data["v0"], DataFrame) + self.assertIn("x", data["v0"].columns) + self.assertIn("relational_context", data["v0"].columns) + self.assertIn("anchors", data["v0"].columns) + self.assertIn("y", data["v0"].columns) + batchsize += data["v0"].shape[0] + if "v1" in data: + self.assertIsInstance(data["v1"], DataFrame) + self.assertIn("x", data["v1"].columns) + self.assertIn("relational_context", data["v1"].columns) + self.assertIn("anchors", data["v1"].columns) + batchsize += data["v1"].shape[0] + num_batches += 1 + batch_sizes.append(batchsize) + self.assertEqual(num_batches, 10) + for i in batch_sizes[:-1]: + self.assertEqual(i, 20) + self.assertLessEqual(batch_sizes[-1], 20) + + def test_all_vertices(self): + loader = NodePieceLoader( + graph=self.conn, + compute_anchors=True, + anchor_percentage=0.5, + v_feats={"v0": ["x", "y"], + "v1": ["x"]}, + num_batches=1, + shuffle=False, + filter_by=None, + kafka_address="kafka:9092", ) data = loader.data # print(data) @@ -290,17 +378,20 @@ def test_all_vertices(self): if __name__ == "__main__": suite = unittest.TestSuite() - suite.addTest(TestGDSNodePieceLoader("test_init")) - suite.addTest(TestGDSNodePieceLoader("test_iterate")) - suite.addTest(TestGDSNodePieceLoader("test_all_vertices")) - #suite.addTest(TestGDSNodePieceLoader("test_sasl_plaintext")) - # suite.addTest(TestGDSNodePieceLoader("test_sasl_ssl")) + suite.addTest(TestGDSNodePieceLoaderKafka("test_init")) + suite.addTest(TestGDSNodePieceLoaderKafka("test_iterate")) + suite.addTest(TestGDSNodePieceLoaderKafka("test_all_vertices")) + #suite.addTest(TestGDSNodePieceLoaderKafka("test_sasl_plaintext")) + #suite.addTest(TestGDSNodePieceLoaderKafka("test_sasl_ssl")) suite.addTest(TestGDSNodePieceLoaderREST("test_init")) suite.addTest(TestGDSNodePieceLoaderREST("test_iterate")) suite.addTest(TestGDSNodePieceLoaderREST("test_all_vertices")) suite.addTest(TestGDSHeteroNodePieceLoaderREST("test_init")) suite.addTest(TestGDSHeteroNodePieceLoaderREST("test_iterate")) suite.addTest(TestGDSHeteroNodePieceLoaderREST("test_all_vertices")) + suite.addTest(TestGDSHeteroNodePieceLoaderKafka("test_init")) + suite.addTest(TestGDSHeteroNodePieceLoaderKafka("test_iterate")) + suite.addTest(TestGDSHeteroNodePieceLoaderKafka("test_all_vertices")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) diff --git a/tests/test_gds_VertexLoader.py b/tests/test_gds_VertexLoader.py index 3b76890d..e92a0d77 100644 --- a/tests/test_gds_VertexLoader.py +++ b/tests/test_gds_VertexLoader.py @@ -7,7 +7,7 @@ from pyTigerGraph.gds.utilities import is_query_installed -class TestGDSVertexLoader(unittest.TestCase): +class TestGDSVertexLoaderKafka(unittest.TestCase): @classmethod def setUpClass(cls): cls.conn = make_connection(graphname="Cora") @@ -19,12 +19,10 @@ def test_init(self): batch_size=16, shuffle=True, filter_by="train_mask", - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = VertexLoader( @@ -33,21 +31,25 @@ def test_iterate(self): batch_size=16, shuffle=True, filter_by="train_mask", - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] for data in loader: - # print(num_batches, data.head()) + # print(num_batches, data.shape, data.head()) self.assertIsInstance(data, DataFrame) self.assertIn("x", data.columns) self.assertIn("y", data.columns) self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[1], 6) + batch_sizes.append(data.shape[0]) num_batches += 1 self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_all_vertices(self): loader = VertexLoader( @@ -56,8 +58,6 @@ def test_all_vertices(self): num_batches=1, shuffle=False, filter_by="train_mask", - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) data = loader.data @@ -68,6 +68,29 @@ def test_all_vertices(self): self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[0], 140) + self.assertEqual(data.shape[1], 6) + + def test_all_vertices_multichar_delimiter(self): + loader = VertexLoader( + graph=self.conn, + attributes=["x", "y", "train_mask", "val_mask", "test_mask"], + num_batches=1, + shuffle=False, + filter_by="train_mask", + delimiter="$|", + kafka_address="kafka:9092", + ) + data = loader.data + # print(data) + self.assertIsInstance(data, DataFrame) + self.assertIn("x", data.columns) + self.assertIn("y", data.columns) + self.assertIn("train_mask", data.columns) + self.assertIn("val_mask", data.columns) + self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[0], 140) + self.assertEqual(data.shape[1], 6) def test_sasl_plaintext(self): loader = VertexLoader( @@ -126,6 +149,97 @@ def test_sasl_ssl(self): num_batches += 1 self.assertEqual(num_batches, 9) + +class TestGDSHeteroVertexLoaderKafka(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = VertexLoader( + graph=self.conn, + attributes={"v0": ["x", "y"], + "v1": ["x"]}, + batch_size=20, + shuffle=True, + kafka_address="kafka:9092" + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_iterate(self): + loader = VertexLoader( + graph=self.conn, + attributes={"v0": ["x", "y"], + "v1": ["x"]}, + batch_size=20, + shuffle=True, + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + batchsize = 0 + if "v0" in data: + self.assertIsInstance(data["v0"], DataFrame) + self.assertIn("x", data["v0"].columns) + self.assertIn("y", data["v0"].columns) + batchsize += data["v0"].shape[0] + self.assertEqual(data["v0"].shape[1], 3) + if "v1" in data: + self.assertIsInstance(data["v1"], DataFrame) + self.assertIn("x", data["v1"].columns) + batchsize += data["v1"].shape[0] + self.assertEqual(data["v1"].shape[1], 2) + self.assertGreater(len(data), 0) + num_batches += 1 + batch_sizes.append(batchsize) + self.assertEqual(num_batches, 10) + for i in batch_sizes[:-1]: + self.assertEqual(i, 20) + self.assertLessEqual(batch_sizes[-1], 20) + + def test_all_vertices(self): + loader = VertexLoader( + graph=self.conn, + attributes={"v0": ["x", "y"], + "v1": ["x"]}, + num_batches=1, + shuffle=False, + kafka_address="kafka:9092" + ) + data = loader.data + # print(data) + self.assertIsInstance(data["v0"], DataFrame) + self.assertTupleEqual(data["v0"].shape, (76, 3)) + self.assertIsInstance(data["v1"], DataFrame) + self.assertTupleEqual(data["v1"].shape, (110, 2)) + self.assertIn("x", data["v0"].columns) + self.assertIn("y", data["v0"].columns) + self.assertIn("x", data["v1"].columns) + + def test_all_vertices_multichar_delimiter(self): + loader = VertexLoader( + graph=self.conn, + attributes={"v0": ["x", "y"], + "v1": ["x"]}, + num_batches=1, + shuffle=False, + delimiter="|$", + kafka_address="kafka:9092" + ) + data = loader.data + # print(data) + self.assertIsInstance(data["v0"], DataFrame) + self.assertTupleEqual(data["v0"].shape, (76, 3)) + self.assertIsInstance(data["v1"], DataFrame) + self.assertTupleEqual(data["v1"].shape, (110, 2)) + self.assertIn("x", data["v0"].columns) + self.assertIn("y", data["v0"].columns) + self.assertIn("x", data["v1"].columns) + + class TestGDSVertexLoaderREST(unittest.TestCase): @classmethod def setUpClass(cls): @@ -137,12 +251,10 @@ def test_init(self): attributes=["x", "y", "train_mask", "val_mask", "test_mask"], batch_size=16, shuffle=True, - filter_by="train_mask", - loader_id=None, - buffer_size=4, + filter_by="train_mask" ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = VertexLoader( @@ -150,21 +262,25 @@ def test_iterate(self): attributes=["x", "y", "train_mask", "val_mask", "test_mask"], batch_size=16, shuffle=True, - filter_by="train_mask", - loader_id=None, - buffer_size=4, + filter_by="train_mask" ) num_batches = 0 + batch_sizes = [] for data in loader: - # print(num_batches, data.head()) + # print(num_batches, data.shape, data.head()) self.assertIsInstance(data, DataFrame) self.assertIn("x", data.columns) self.assertIn("y", data.columns) self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[1], 6) + batch_sizes.append(data.shape[0]) num_batches += 1 self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_all_vertices(self): loader = VertexLoader( @@ -172,9 +288,7 @@ def test_all_vertices(self): attributes=["x", "y", "train_mask", "val_mask", "test_mask"], num_batches=1, shuffle=False, - filter_by="train_mask", - loader_id=None, - buffer_size=4, + filter_by="train_mask" ) data = loader.data # print(data) @@ -184,6 +298,8 @@ def test_all_vertices(self): self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[0], 140) + self.assertEqual(data.shape[1], 6) def test_all_vertices_multichar_delimiter(self): loader = VertexLoader( @@ -192,8 +308,6 @@ def test_all_vertices_multichar_delimiter(self): num_batches=1, shuffle=False, filter_by="train_mask", - loader_id=None, - buffer_size=4, delimiter="$|" ) data = loader.data @@ -204,6 +318,8 @@ def test_all_vertices_multichar_delimiter(self): self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[0], 140) + self.assertEqual(data.shape[1], 6) def test_string_attr(self): conn = make_connection(graphname="Social") @@ -212,14 +328,13 @@ def test_string_attr(self): graph=conn, attributes=["age", "state"], num_batches=1, - shuffle=False, - loader_id=None, - buffer_size=4, + shuffle=False ) data = loader.data # print(data) self.assertIsInstance(data, DataFrame) self.assertEqual(data.shape[0], 7) + self.assertEqual(data.shape[1], 3) self.assertIn("age", data.columns) self.assertIn("state", data.columns) @@ -235,13 +350,10 @@ def test_init(self): attributes={"v0": ["x", "y"], "v1": ["x"]}, batch_size=20, - shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, + shuffle=True ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 10) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = VertexLoader( @@ -249,21 +361,31 @@ def test_iterate(self): attributes={"v0": ["x", "y"], "v1": ["x"]}, batch_size=20, - shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, + shuffle=True ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) - self.assertIsInstance(data["v0"], DataFrame) - self.assertIsInstance(data["v1"], DataFrame) - self.assertIn("x", data["v0"].columns) - self.assertIn("y", data["v0"].columns) - self.assertIn("x", data["v1"].columns) + batchsize = 0 + if "v0" in data: + self.assertIsInstance(data["v0"], DataFrame) + self.assertIn("x", data["v0"].columns) + self.assertIn("y", data["v0"].columns) + batchsize += data["v0"].shape[0] + self.assertEqual(data["v0"].shape[1], 3) + if "v1" in data: + self.assertIsInstance(data["v1"], DataFrame) + self.assertIn("x", data["v1"].columns) + batchsize += data["v1"].shape[0] + self.assertEqual(data["v1"].shape[1], 2) + self.assertGreater(len(data), 0) num_batches += 1 + batch_sizes.append(batchsize) self.assertEqual(num_batches, 10) + for i in batch_sizes[:-1]: + self.assertEqual(i, 20) + self.assertLessEqual(batch_sizes[-1], 20) def test_all_vertices(self): loader = VertexLoader( @@ -271,10 +393,7 @@ def test_all_vertices(self): attributes={"v0": ["x", "y"], "v1": ["x"]}, num_batches=1, - shuffle=False, - filter_by=None, - loader_id=None, - buffer_size=4, + shuffle=False ) data = loader.data # print(data) @@ -293,9 +412,6 @@ def test_all_vertices_multichar_delimiter(self): "v1": ["x"]}, num_batches=1, shuffle=False, - filter_by=None, - loader_id=None, - buffer_size=4, delimiter="|$" ) data = loader.data @@ -311,11 +427,6 @@ def test_all_vertices_multichar_delimiter(self): if __name__ == "__main__": suite = unittest.TestSuite() - suite.addTest(TestGDSVertexLoader("test_init")) - suite.addTest(TestGDSVertexLoader("test_iterate")) - suite.addTest(TestGDSVertexLoader("test_all_vertices")) - # suite.addTest(TestGDSVertexLoader("test_sasl_plaintext")) - # suite.addTest(TestGDSVertexLoader("test_sasl_ssl")) suite.addTest(TestGDSVertexLoaderREST("test_init")) suite.addTest(TestGDSVertexLoaderREST("test_iterate")) suite.addTest(TestGDSVertexLoaderREST("test_all_vertices")) @@ -325,6 +436,15 @@ def test_all_vertices_multichar_delimiter(self): suite.addTest(TestGDSHeteroVertexLoaderREST("test_iterate")) suite.addTest(TestGDSHeteroVertexLoaderREST("test_all_vertices")) suite.addTest(TestGDSHeteroVertexLoaderREST("test_all_vertices_multichar_delimiter")) - + suite.addTest(TestGDSVertexLoaderKafka("test_init")) + suite.addTest(TestGDSVertexLoaderKafka("test_iterate")) + suite.addTest(TestGDSVertexLoaderKafka("test_all_vertices")) + suite.addTest(TestGDSVertexLoaderKafka("test_all_vertices_multichar_delimiter")) + # suite.addTest(TestGDSVertexLoaderKafka("test_sasl_plaintext")) + # suite.addTest(TestGDSVertexLoaderKafka("test_sasl_ssl")) + suite.addTest(TestGDSHeteroVertexLoaderKafka("test_init")) + suite.addTest(TestGDSHeteroVertexLoaderKafka("test_iterate")) + suite.addTest(TestGDSHeteroVertexLoaderKafka("test_all_vertices")) + suite.addTest(TestGDSHeteroVertexLoaderKafka("test_all_vertices_multichar_delimiter")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) diff --git a/tests/test_gds_featurizer.py b/tests/test_gds_featurizer.py index b873d404..d6266c93 100644 --- a/tests/test_gds_featurizer.py +++ b/tests/test_gds_featurizer.py @@ -23,7 +23,11 @@ def test_get_db_version(self): major_ver, minor_ver, patch_ver = self.featurizer._get_db_version() self.assertIsNotNone(int(major_ver)) self.assertIsNotNone(int(minor_ver)) - self.assertIsNotNone(int(patch_ver)) + try: + patch_ver = int(patch_ver) + except: + pass + self.assertIsNotNone(patch_ver) self.assertIsInstance(self.featurizer.algo_ver, str) def test_get_algo_dict(self):