Skip to content

Commit

Permalink
Add in edge weights to preprocessor (#148)
Browse files Browse the repository at this point in the history
* Added in colab data loader

* Support both type and weight

* Not working code

* Works on the custom dataset

* Ran tox

* Fixed flake8 issues

* Hopefully fixed linter issues

* Wrote torch tests

* Added in pandas converter tests

* Performed linter checks

* Run check lint outside of the container

* working with pandas example

* Work on OGBL loader

* Passed all tests

* Ran autoformat with updated tox

* Fixed test_generate issue

---------

Co-authored-by: Devesh Sarda <[email protected]>
Co-authored-by: Jason Mohoney <[email protected]>
  • Loading branch information
3 people authored Nov 12, 2023
1 parent 841145c commit 483277a
Show file tree
Hide file tree
Showing 36 changed files with 2,232 additions and 196 deletions.
3 changes: 2 additions & 1 deletion docs/examples/config/nc_custom.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ Let's borrow the provided ``examples/python/custom_nc_graphsage.py`` and modify
output_dir=self.output_directory,
train_edges=self.input_edge_list_file,
num_partitions=num_partitions,
columns=[0, 1],
src_column = 0,
dst_column = 1,
remap_ids=remap_ids,
sequential_train_nodes=sequential_train_nodes,
delim=",",
Expand Down
3 changes: 2 additions & 1 deletion docs/examples/python/lp_custom.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ Making a new dataset class requires writing two methods:
converter = converter(
output_dir=self.output_directory,
train_edges=self.input_train_edges_file,
columns = [0,1], # col 0 is src and col 1 dst node in input csv
src_column = 0, # col 0 is src and col 1 dst node in input csv
dst_column = 1,
delim=",", # CSV delimitor is ","
splits = splits, # Splitting the data in train, valid and test
remap_ids=remap_ids # Remapping the raw entity ids into random integers
Expand Down
2 changes: 1 addition & 1 deletion examples/docker/cpu_ubuntu/dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ RUN sh cmake-3.20.0-linux-x86_64.sh --skip-license --prefix=/opt/cmake/
RUN ln -s /opt/cmake/bin/cmake /usr/local/bin/cmake

# install pytorch
RUN python3 -m pip install torch==2.0.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
RUN python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

RUN mkdir /working_dir
WORKDIR /working_dir
3 changes: 2 additions & 1 deletion examples/python/custom_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def preprocess(self, remap_ids=True, splits=None):
converter = converter(
output_dir=self.output_directory,
train_edges=self.input_train_edges_file,
columns=[0, 1], # col 0 is src and col 1 dst node in input csv
src_column=0, # col 0 is src and col 1 dst node in input csv
dst_column=1,
delim=",", # CSV delimitor is ","
splits=splits, # Splitting the data in train, valid and test
remap_ids=remap_ids, # Remapping the raw entity ids into random integers
Expand Down
3 changes: 2 additions & 1 deletion examples/python/custom_nc_graphsage.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def preprocess(
output_dir=self.output_directory,
train_edges=self.input_edge_list_file,
num_partitions=num_partitions,
columns=[0, 1],
src_column=0,
dst_column=1,
remap_ids=remap_ids,
sequential_train_nodes=sequential_train_nodes,
delim=",",
Expand Down
1 change: 1 addition & 0 deletions src/cpp/src/common/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <unistd.h>

#include <fstream>
#include <iostream>

#include "reporting/logger.h"
Expand Down
3 changes: 2 additions & 1 deletion src/cpp/src/reporting/reporting.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//
// Created by Jason Mohoney on 8/24/21.
//

#include "reporting/reporting.h"

#include <fstream>

#include "configuration/constants.h"
#include "reporting/logger.h"

Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/storage/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
#include <fcntl.h>
#include <unistd.h>

#include <fstream>
#include <functional>
#include <future>
#include <iostream>
#include <shared_mutex>

#include "configuration/constants.h"
Expand Down
7 changes: 7 additions & 0 deletions src/python/tools/configuration/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class PathConstants:
node_mapping_file: str = "node_mapping.txt"
relation_mapping_file: str = "relation_mapping.txt"
edge_file_name: str = "edges"
edge_weight_file_name: str = "edges_weights"
node_file_name: str = "nodes"
features_file_name: str = "features"
labels_file_name: str = "labels"
Expand All @@ -23,8 +24,14 @@ class PathConstants:
file_ext: str = ".bin"

train_edges_path: str = edges_directory + training_file_prefix + edge_file_name + file_ext
train_edges_weights_path: str = edges_directory + training_file_prefix + edge_weight_file_name + file_ext

valid_edges_path: str = edges_directory + validation_file_prefix + edge_file_name + file_ext
valid_edges_weights_path: str = edges_directory + validation_file_prefix + edge_weight_file_name + file_ext

test_edges_path: str = edges_directory + test_file_prefix + edge_file_name + file_ext
test_edges_weights_path: str = edges_directory + test_file_prefix + edge_weight_file_name + file_ext

train_edge_buckets_path: str = edges_directory + training_file_prefix + partition_offsets_file
valid_edge_buckets_path: str = edges_directory + validation_file_prefix + partition_offsets_file
test_edge_buckets_path: str = edges_directory + test_file_prefix + partition_offsets_file
Expand Down
47 changes: 41 additions & 6 deletions src/python/tools/marius_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ogb_mag240m,
ogb_wikikg90mv2,
ogbl_citation2,
ogbl_collab,
ogbl_ppa,
ogbl_wikikg2,
ogbn_arxiv,
Expand Down Expand Up @@ -91,13 +92,39 @@ def set_args():
)

parser.add_argument(
"--columns",
metavar="columns",
nargs="*",
"--src_column",
metavar="src_column",
required=False,
type=int,
default=[0, 1, 2],
help="List of column ids of input delimited files which denote the src node, edge-type, and dst node of edges.",
default=None,
help="The column id of the src column",
)

parser.add_argument(
"--dst_column",
metavar="dst_column",
required=False,
type=int,
default=None,
help="The column id of the dst column",
)

parser.add_argument(
"--edge_type_column",
metavar="edge_type_column",
required=False,
type=int,
default=None,
help="The column id which denotes the edge weight column",
)

parser.add_argument(
"--edge_weight_column",
metavar="edge_weight_column",
required=False,
type=int,
default=None,
help="The column id which denotes the edge weight column",
)

return parser
Expand All @@ -106,6 +133,8 @@ def set_args():
def main():
parser = set_args()
args = parser.parse_args()
if args.dataset == "custom" and (args.src_column is None or args.dst_column is None):
parser.error("When using a custom dataset, src column and dst column must be specified")

if args.output_directory == "":
args.output_directory = args.dataset
Expand All @@ -127,10 +156,12 @@ def main():
"OGBN_PAPERS100M": ogbn_papers100m.OGBNPapers100M,
"OGB_WIKIKG90MV2": ogb_wikikg90mv2.OGBWikiKG90Mv2,
"OGB_MAG240M": ogb_mag240m.OGBMag240M,
"OGBL_COLLAB": ogbl_collab.OGBLCollab,
}

dataset = dataset_dict.get(args.dataset.upper())
if dataset is not None:
print("Using existing dataset of", args.dataset.upper())
dataset = dataset(args.output_directory, spark=args.spark)
dataset.download(args.overwrite)
dataset.preprocess(
Expand All @@ -140,6 +171,7 @@ def main():
sequential_train_nodes=args.sequential_train_nodes,
partitioned_eval=args.partitioned_eval,
)

else:
print("Preprocess custom dataset")

Expand All @@ -157,7 +189,10 @@ def main():
splits=args.dataset_split,
partitioned_eval=args.partitioned_eval,
sequential_train_nodes=args.sequential_train_nodes,
columns=args.columns,
src_column=args.src_column,
dst_column=args.dst_column,
edge_type_column=args.edge_type_column,
edge_weight_column=args.edge_weight_column,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,26 @@
import torch # isort:skip


def dataframe_to_tensor(input_dataframe):
np_array = input_dataframe.to_dask_array().compute()
return torch.from_numpy(np_array)
def dataframe_to_tensor(df):
return torch.tensor(df.to_numpy())


def partition_edges(edges, num_nodes, num_partitions):
def partition_edges(edges, num_nodes, num_partitions, edge_weights=None):
partition_size = int(np.ceil(num_nodes / num_partitions))

src_partitions = torch.div(edges[:, 0], partition_size, rounding_mode="trunc")
dst_partitions = torch.div(edges[:, -1], partition_size, rounding_mode="trunc")

_, dst_args = torch.sort(dst_partitions, stable=True)
_, src_args = torch.sort(src_partitions[dst_args], stable=True)
sort_order = dst_args[src_args]

edges = edges[dst_args[src_args]]
edge_bucket_ids = torch.div(edges, partition_size, rounding_mode="trunc")
edges = edges[sort_order]
if edge_weights is not None:
edge_weights = edge_weights[sort_order]

edge_bucket_ids = torch.div(edges, partition_size, rounding_mode="trunc")
offsets = np.zeros([num_partitions, num_partitions], dtype=int)

unique_src, num_source = torch.unique_consecutive(edge_bucket_ids[:, 0], return_counts=True)

num_source_offsets = torch.cumsum(num_source, 0) - num_source
Expand All @@ -42,7 +43,7 @@ def partition_edges(edges, num_nodes, num_partitions):

offsets = list(offsets.flatten())

return edges, offsets
return edges, offsets, edge_weights


class TorchPartitioner(Partitioner):
Expand All @@ -51,19 +52,42 @@ def __init__(self, partitioned_evaluation):

self.partitioned_evaluation = partitioned_evaluation

def partition_edges(self, train_edges_tens, valid_edges_tens, test_edges_tens, num_nodes, num_partitions):
""" """

train_edges_tens, train_offsets = partition_edges(train_edges_tens, num_nodes, num_partitions)
def partition_edges(
self, train_edges_tens, valid_edges_tens, test_edges_tens, num_nodes, num_partitions, edge_weights=None
):
# Extract the edge weights
train_edge_weights, valid_edge_weights, test_edge_weights = None, None, None
if edge_weights is not None:
train_edge_weights, valid_edge_weights, test_edge_weights = (
edge_weights[0],
edge_weights[1],
edge_weights[2],
)

train_edges_tens, train_offsets, train_edge_weights = partition_edges(
train_edges_tens, num_nodes, num_partitions, edge_weights=train_edge_weights
)

valid_offsets = None
test_offsets = None

if self.partitioned_evaluation:
if valid_edges_tens is not None:
valid_edges_tens, valid_offsets = partition_edges(valid_edges_tens, num_nodes, num_partitions)
valid_edges_tens, valid_offsets, valid_edge_weights = partition_edges(
valid_edges_tens, num_nodes, num_partitions, edge_weights=valid_edge_weights
)

if test_edges_tens is not None:
test_edges_tens, test_offsets = partition_edges(test_edges_tens, num_nodes, num_partitions)

return train_edges_tens, train_offsets, valid_edges_tens, valid_offsets, test_edges_tens, test_offsets
test_edges_tens, test_offsets, test_edge_weights = partition_edges(
test_edges_tens, num_nodes, num_partitions, edge_weights=test_edge_weights
)

return (
train_edges_tens,
train_offsets,
valid_edges_tens,
valid_offsets,
test_edges_tens,
test_offsets,
[train_edge_weights, valid_edge_weights, test_edge_weights],
)
Loading

0 comments on commit 483277a

Please sign in to comment.