Skip to content

Commit

Permalink
Merge pull request #15 from kumo-ai/yyuan/include-shallowrhs-to-bench…
Browse files Browse the repository at this point in the history
…mark

include shallowrhs to benchmark script
  • Loading branch information
andyhuang-kumo authored Aug 2, 2024
2 parents 7594074 + 47285f4 commit 78e11cd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions benchmark/relbench_link_prediction_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@
from torch_geometric.utils.cross_entropy import sparse_cross_entropy
from tqdm import tqdm

from hybridgnn.nn.models import IDGNN, HybridGNN
from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN
from hybridgnn.utils import GloveTextEmbedding

TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"]
LINK_PREDICTION_METRIC = "link_prediction_map"

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="rel-trial")
parser.add_argument("--task", type=str, default="condition-sponsor-run")
parser.add_argument("--dataset", type=str, default="rel-stack")
parser.add_argument("--task", type=str, default="user-post-comment")
parser.add_argument(
"--model",
type=str,
default="hybridgnn",
choices=["hybridgnn", "idgnn"],
default="shallowrhsgnn",
choices=["hybridgnn", "idgnn", "shallowrhsgnn"],
)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--num_trials", type=int, default=10,
Expand Down Expand Up @@ -102,7 +102,7 @@
int(args.num_neighbors // 2**i) for i in range(args.num_layers)
]

model_cls: Type[Union[IDGNN, HybridGNN]]
model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN]]

if args.model == "idgnn":
model_search_space = {
Expand All @@ -115,7 +115,7 @@
"gamma_rate": [0.9, 0.95, 1.],
}
model_cls = IDGNN
elif args.model == "hybridgnn":
elif args.model in ["hybridgnn", "shallowrhsgnn"]:
model_search_space = {
"channels": [64, 128, 256],
"embedding_dim": [64, 128, 256],
Expand All @@ -126,7 +126,7 @@
"base_lr": [0.001, 0.01],
"gamma_rate": [0.9, 0.95, 1.],
}
model_cls = HybridGNN
model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN)


def train(
Expand Down Expand Up @@ -164,7 +164,7 @@ def train(

loss = F.binary_cross_entropy_with_logits(out, target)
numel = out.numel()
else:
elif args.model in ["hybridgnn", "shallowrhsgnn"]:
logits = model(batch, task.src_entity_table, task.dst_entity_table)
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)
Expand Down Expand Up @@ -205,7 +205,7 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float:
device=out.device)
scores[batch[task.dst_entity_table].batch,
batch[task.dst_entity_table].n_id] = torch.sigmoid(out)
elif args.model == "hybridgnn":
elif args.model in ["hybridgnn", "shallowrhsgnn"]:
# Get ground-truth
out = model(batch, task.src_entity_table,
task.dst_entity_table).detach()
Expand Down Expand Up @@ -248,7 +248,7 @@ def train_and_eval_with_cfg(
persistent_workers=args.num_workers > 0,
)

if args.model == "hybridgnn":
if args.model in ["hybridgnn", "shallowrhsgnn"]:
model_cfg["num_nodes"] = num_dst_nodes_dict["train"]
elif args.model == "idgnn":
model_cfg["out_channels"] = 1
Expand Down
Binary file added result/rel-stack_user-post-comment_hybridgnn
Binary file not shown.

0 comments on commit 78e11cd

Please sign in to comment.