Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Aug 18, 2024
1 parent f785e10 commit 66ba68b
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 36 deletions.
18 changes: 15 additions & 3 deletions examples/relbench_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
example command
"""Command to run the script:
python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --epochs 10
python relbench_example.py --dataset rel-trial --task site-sponsor-run
--model hybridgnn --epochs 10
"""

import argparse
Expand Down Expand Up @@ -128,6 +128,10 @@
out_channels=1,
aggr=args.aggr,
norm="layer_norm",
torch_frame_model_kwargs={
"channels": 128,
"num_layers": 4,
},
).to(device)
elif args.model == "hybridgnn":
model = HybridGNN(
Expand All @@ -139,6 +143,10 @@
aggr="sum",
norm="layer_norm",
embedding_dim=64,
torch_frame_model_kwargs={
"channels": 128,
"num_layers": 4,
},
).to(device)
elif args.model == 'shallowrhsgnn':
model = ShallowRHSGNN(
Expand All @@ -150,6 +158,10 @@
aggr="sum",
norm="layer_norm",
embedding_dim=64,
torch_frame_model_kwargs={
"channels": 128,
"num_layers": 4,
},
).to(device)
else:
raise ValueError(f"Unsupported model type {args.model}.")
Expand Down
7 changes: 2 additions & 5 deletions hybridgnn/nn/encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import torch
import torch_frame
Expand Down Expand Up @@ -54,10 +54,7 @@ def __init__(
node_to_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]],
stype_encoder_cls_kwargs: Dict[torch_frame.stype, Any],
torch_frame_model_cls=ResNet,
torch_frame_model_kwargs: Dict[str, Any] = {
"channels": 128,
"num_layers": 4,
},
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()

Expand Down
7 changes: 2 additions & 5 deletions hybridgnn/nn/models/hybridgnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -29,10 +29,7 @@ def __init__(
aggr: str = 'sum',
norm: str = 'layer_norm',
torch_frame_model_cls=ResNet,
torch_frame_model_kwargs: Dict[str, Any] = {
"channels": 128,
"num_layers": 4,
},
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()

Expand Down
9 changes: 3 additions & 6 deletions hybridgnn/nn/models/idgnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -27,11 +27,8 @@ def __init__(
out_channels: int,
aggr: str = 'sum',
norm: str = 'layer_norm',
torch_frame_model_cls=ResNet,
torch_frame_model_kwargs: Dict[str, Any] = {
"channels": 128,
"num_layers": 4,
},
torch_frame_model_cls: torch.nn.Module = ResNet,
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()

Expand Down
7 changes: 2 additions & 5 deletions hybridgnn/nn/models/shallowrhsgnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -29,10 +29,7 @@ def __init__(
aggr: str = 'sum',
norm: str = 'layer_norm',
torch_frame_model_cls=ResNet,
torch_frame_model_kwargs: Dict[str, Any] = {
"channels": 128,
"num_layers": 4,
},
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self.encoder = HeteroEncoder(
Expand Down
11 changes: 9 additions & 2 deletions test/nn/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,15 @@ def test_encoder(tmp_path):
# Ensure that full-batch model works as expected ##########################

encoder = HeteroEncoder(
64, node_to_col_names_dict, col_stats_dict,
stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT)
64,
node_to_col_names_dict,
col_stats_dict,
stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT,
torch_frame_model_kwargs={
"channels": 128,
"num_layers": 4,
},
)

x_dict = encoder(data.tf_dict)
assert 'product' in x_dict.keys()
Expand Down
51 changes: 41 additions & 10 deletions test/nn/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,19 @@ def test_idgnn(tmp_path):
batch = next(iter(train_loader))

assert len(batch[task.dst_entity_table].batch) > 0
model = IDGNN(data=data, col_stats_dict=col_stats_dict, num_layers=2,
channels=64, out_channels=1, aggr="sum", norm="layer_norm")
model = IDGNN(
data=data,
col_stats_dict=col_stats_dict,
num_layers=2,
channels=64,
out_channels=1,
aggr="sum",
norm="layer_norm",
torch_frame_model_kwargs={
"channels": 128,
"num_layers": 4,
},
)
model.train()

out = model(batch, task.src_entity_table, task.dst_entity_table).flatten()
Expand Down Expand Up @@ -88,10 +99,20 @@ def test_hybridgnn(tmp_path):

channels = 16
embedding_dim = 8
model = HybridGNN(data=data, col_stats_dict=col_stats_dict,
num_nodes=train_table_input.num_dst_nodes, num_layers=2,
channels=channels, aggr="sum", norm="layer_norm",
embedding_dim=embedding_dim)
model = HybridGNN(
data=data,
col_stats_dict=col_stats_dict,
num_nodes=train_table_input.num_dst_nodes,
num_layers=2,
channels=channels,
aggr="sum",
norm="layer_norm",
embedding_dim=embedding_dim,
torch_frame_model_kwargs={
"channels": 128,
"num_layers": 4,
},
)
model.train()

logits = model(batch, task.src_entity_table, task.dst_entity_table)
Expand Down Expand Up @@ -135,10 +156,20 @@ def test_shallowrhsgnn(tmp_path):

channels = 16
embedding_dim = 8
model = ShallowRHSGNN(data=data, col_stats_dict=col_stats_dict,
num_nodes=train_table_input.num_dst_nodes,
num_layers=2, channels=channels, aggr="sum",
norm="layer_norm", embedding_dim=embedding_dim)
model = ShallowRHSGNN(
data=data,
col_stats_dict=col_stats_dict,
num_nodes=train_table_input.num_dst_nodes,
num_layers=2,
channels=channels,
aggr="sum",
norm="layer_norm",
embedding_dim=embedding_dim,
torch_frame_model_kwargs={
"channels": 128,
"num_layers": 4,
},
)
model.train()

logits = model(batch, task.src_entity_table, task.dst_entity_table)
Expand Down

0 comments on commit 66ba68b

Please sign in to comment.