From 5d8ffb18911160bc7a05061e0db9239cb1231af7 Mon Sep 17 00:00:00 2001 From: yiweny Date: Fri, 2 Aug 2024 22:16:23 +0000 Subject: [PATCH 1/3] include shallowrhs to benchmark track --- .../relbench_link_prediction_benchmark.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 8fa42d8..7cb322a 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -30,26 +30,26 @@ from torch_geometric.utils.cross_entropy import sparse_cross_entropy from tqdm import tqdm -from hybridgnn.nn.models import IDGNN, HybridGNN +from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN from hybridgnn.utils import GloveTextEmbedding TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] LINK_PREDICTION_METRIC = "link_prediction_map" parser = argparse.ArgumentParser() -parser.add_argument("--dataset", type=str, default="rel-trial") -parser.add_argument("--task", type=str, default="condition-sponsor-run") +parser.add_argument("--dataset", type=str, default="rel-stack") +parser.add_argument("--task", type=str, default="user-post-comment") parser.add_argument( "--model", type=str, default="hybridgnn", - choices=["hybridgnn", "idgnn"], + choices=["hybridgnn", "idgnn", "shallowrhs"], ) -parser.add_argument("--epochs", type=int, default=20) -parser.add_argument("--num_trials", type=int, default=10, +parser.add_argument("--epochs", type=int, default=1) +parser.add_argument("--num_trials", type=int, default=2, help="Number of Optuna-based hyper-parameter tuning.") parser.add_argument( - "--num_repeats", type=int, default=5, + "--num_repeats", type=int, default=2, help="Number of repeated training and eval on the best config.") parser.add_argument("--eval_epochs_interval", type=int, default=1) parser.add_argument("--num_layers", type=int, default=2) @@ -102,7 +102,7 @@ int(args.num_neighbors // 2**i) for i in range(args.num_layers) ] -model_cls: Type[Union[IDGNN, HybridGNN]] +model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN]] if args.model == "idgnn": model_search_space = { @@ -115,7 +115,7 @@ "gamma_rate": [0.9, 0.95, 1.], } model_cls = IDGNN -elif args.model == "hybridgnn": +elif args.model in ["hybridgnn", "shallowrhsgnn"]: model_search_space = { "channels": [64, 128, 256], "embedding_dim": [64, 128, 256], @@ -126,7 +126,7 @@ "base_lr": [0.001, 0.01], "gamma_rate": [0.9, 0.95, 1.], } - model_cls = HybridGNN + model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN) def train( @@ -164,7 +164,7 @@ def train( loss = F.binary_cross_entropy_with_logits(out, target) numel = out.numel() - else: + elif args.model in ["hybridgnn", "shallowrhsgnn"]: logits = model(batch, task.src_entity_table, task.dst_entity_table) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(logits, edge_label_index) @@ -205,7 +205,7 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: device=out.device) scores[batch[task.dst_entity_table].batch, batch[task.dst_entity_table].n_id] = torch.sigmoid(out) - elif args.model == "hybridgnn": + elif args.model in ["hybridgnn", "shallowrhsgnn"]: # Get ground-truth out = model(batch, task.src_entity_table, task.dst_entity_table).detach() @@ -248,7 +248,7 @@ def train_and_eval_with_cfg( persistent_workers=args.num_workers > 0, ) - if args.model == "hybridgnn": + if args.model in ["hybridgnn", "shallowrhsgnn"]: model_cfg["num_nodes"] = num_dst_nodes_dict["train"] elif args.model == "idgnn": model_cfg["out_channels"] = 1 From a6d081a01dd37bacfffcb40890f462d6f5645d6d Mon Sep 17 00:00:00 2001 From: yiweny Date: Fri, 2 Aug 2024 22:17:08 +0000 Subject: [PATCH 2/3] fix numbers --- benchmark/relbench_link_prediction_benchmark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 7cb322a..579357f 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -45,11 +45,11 @@ default="hybridgnn", choices=["hybridgnn", "idgnn", "shallowrhs"], ) -parser.add_argument("--epochs", type=int, default=1) -parser.add_argument("--num_trials", type=int, default=2, +parser.add_argument("--epochs", type=int, default=20) +parser.add_argument("--num_trials", type=int, default=10, help="Number of Optuna-based hyper-parameter tuning.") parser.add_argument( - "--num_repeats", type=int, default=2, + "--num_repeats", type=int, default=5, help="Number of repeated training and eval on the best config.") parser.add_argument("--eval_epochs_interval", type=int, default=1) parser.add_argument("--num_layers", type=int, default=2) From 47285f4a380c14ca676722ce462f9e02774e60bb Mon Sep 17 00:00:00 2001 From: yiweny Date: Fri, 2 Aug 2024 22:30:07 +0000 Subject: [PATCH 3/3] fix script --- benchmark/relbench_link_prediction_benchmark.py | 4 ++-- result/rel-stack_user-post-comment_hybridgnn | Bin 0 -> 2264 bytes 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 result/rel-stack_user-post-comment_hybridgnn diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 579357f..5b40c23 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -42,8 +42,8 @@ parser.add_argument( "--model", type=str, - default="hybridgnn", - choices=["hybridgnn", "idgnn", "shallowrhs"], + default="shallowrhsgnn", + choices=["hybridgnn", "idgnn", "shallowrhsgnn"], ) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--num_trials", type=int, default=10, diff --git a/result/rel-stack_user-post-comment_hybridgnn b/result/rel-stack_user-post-comment_hybridgnn new file mode 100644 index 0000000000000000000000000000000000000000..9ea5d4a56ae7419ad16a984add801299664f6a99 GIT binary patch literal 2264 zcmb7GO>7%g5Z?T!?q5nm(~#2g6B1MEzs89ZNO1@xhAc_VN|jbcuv)KQysPYw=k0S6 z2L%N|QmWpNh(Jh4h!Yo1NFdP)wB~>WNO0l64RPfF5(jQD@7Zyj1Xp3CwY_iNeDlq` znR(92ZS9iO-7OuGsFai_u_Hdyjb*Lm6B;RcK8qNh;}Dl=W;sW#e8F{Nd7bIe;<7zA z6PIrMPuXnU04b~@2T7*|A6hn`HLP~wD@XMSgEsXTzSWsPyV#7nW#~}Ppz07c=u|s! z-SP6ohAv@VCkx$bC+bMiGfW?j$t`LKU)_?UF>2|y4l2cPUip#;*3b=;XnBi5pE`tS zG1GHMtZc0kHx?Dj7^X8vT*K7Jn(h>B;=}y{n-afdGp(pI6V8fERac9?r-O>a+ zcI)k%Fs#5xSZ)(oQxINUT%7+0MSRE-gU55X%=791S_C_lEj9R{@@eJW;LG&cU}uf0 z!F!cA(s0Q{w@)Z=nL#8cj!lHlc>KGuN6i6K1QIimU0i`gjolSJwqZtr zXVp`?%&=B1S2LCh@Z1K>hDF=tbcQA5Tk8bo7Nlw`6?DhZHD1_|$)tb(p34uSp>x>I znoZ&POxpeP)wh2A=DX4!yl7Sd|H1~$hgCk13{!Vqf;}l?=kP#@lOy@Obx0*q=04FWGG~@UAe)&6`S$FzB#BflSuQVc zK0B)Vmon3@Ony9!euM|+G1Lo7B{RMD{$HoQ#t{_F%3a;-zjWXJ776hu+jOqRq`|#u zcF(z*D>LHJJfU+l{mlleZ82sO?tSa(Zv%X^@u#`>h;;FY-dAwGSe`pKGuX^GAN8Xl z?)#??3)guE@~64)DXH&>zN38j$MDL*k!-Ew81P)p%FoU89HJ5*A01DQrBb7*$>ex4 zo`{bnM&m-R#3#naQxmC)L^6>|ah7cB`X6Cr$e*m-!pYWezQ(hQa>YKW-#+&|uibm_ zAa>))#P;Fxd+KP|^;$Rx-kAExKJcH93|>3xgV2rnn%W1gV38e-_kW*cR_^S)TtmFj rwnv+HGnd`(*ZFmT@;5JP6tv@i4Wt!59#FVCD|ZNOJ6iGm)xG}$bY_dn literal 0 HcmV?d00001