diff --git a/examples/relbench_example.py b/examples/relbench_example.py index 3727147..54c2b40 100644 --- a/examples/relbench_example.py +++ b/examples/relbench_example.py @@ -1,7 +1,7 @@ -""" -example command +"""Command to run the script: -python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --epochs 10 +python relbench_example.py --dataset rel-trial --task site-sponsor-run + --model hybridgnn --epochs 10 """ import argparse @@ -128,6 +128,10 @@ out_channels=1, aggr=args.aggr, norm="layer_norm", + torch_frame_model_kwargs={ + "channels": 128, + "num_layers": 4, + }, ).to(device) elif args.model == "hybridgnn": model = HybridGNN( @@ -139,6 +143,10 @@ aggr="sum", norm="layer_norm", embedding_dim=64, + torch_frame_model_kwargs={ + "channels": 128, + "num_layers": 4, + }, ).to(device) elif args.model == 'shallowrhsgnn': model = ShallowRHSGNN( @@ -150,6 +158,10 @@ aggr="sum", norm="layer_norm", embedding_dim=64, + torch_frame_model_kwargs={ + "channels": 128, + "num_layers": 4, + }, ).to(device) else: raise ValueError(f"Unsupported model type {args.model}.") diff --git a/hybridgnn/nn/encoder.py b/hybridgnn/nn/encoder.py index 022b30a..9844956 100644 --- a/hybridgnn/nn/encoder.py +++ b/hybridgnn/nn/encoder.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import torch import torch_frame @@ -54,10 +54,7 @@ def __init__( node_to_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]], stype_encoder_cls_kwargs: Dict[torch_frame.stype, Any], torch_frame_model_cls=ResNet, - torch_frame_model_kwargs: Dict[str, Any] = { - "channels": 128, - "num_layers": 4, - }, + torch_frame_model_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() diff --git a/hybridgnn/nn/models/hybridgnn.py b/hybridgnn/nn/models/hybridgnn.py index dad8815..a9164af 100644 --- a/hybridgnn/nn/models/hybridgnn.py +++ b/hybridgnn/nn/models/hybridgnn.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional import torch from torch import Tensor @@ -29,10 +29,7 @@ def __init__( aggr: str = 'sum', norm: str = 'layer_norm', torch_frame_model_cls=ResNet, - torch_frame_model_kwargs: Dict[str, Any] = { - "channels": 128, - "num_layers": 4, - }, + torch_frame_model_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() diff --git a/hybridgnn/nn/models/idgnn.py b/hybridgnn/nn/models/idgnn.py index 9e25623..3891e9a 100644 --- a/hybridgnn/nn/models/idgnn.py +++ b/hybridgnn/nn/models/idgnn.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional import torch from torch import Tensor @@ -27,11 +27,8 @@ def __init__( out_channels: int, aggr: str = 'sum', norm: str = 'layer_norm', - torch_frame_model_cls=ResNet, - torch_frame_model_kwargs: Dict[str, Any] = { - "channels": 128, - "num_layers": 4, - }, + torch_frame_model_cls: torch.nn.Module = ResNet, + torch_frame_model_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() diff --git a/hybridgnn/nn/models/shallowrhsgnn.py b/hybridgnn/nn/models/shallowrhsgnn.py index 56e8261..d36bd92 100644 --- a/hybridgnn/nn/models/shallowrhsgnn.py +++ b/hybridgnn/nn/models/shallowrhsgnn.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional import torch from torch import Tensor @@ -29,10 +29,7 @@ def __init__( aggr: str = 'sum', norm: str = 'layer_norm', torch_frame_model_cls=ResNet, - torch_frame_model_kwargs: Dict[str, Any] = { - "channels": 128, - "num_layers": 4, - }, + torch_frame_model_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() self.encoder = HeteroEncoder( diff --git a/test/nn/test_encoder.py b/test/nn/test_encoder.py index 28609d8..b726c1e 100644 --- a/test/nn/test_encoder.py +++ b/test/nn/test_encoder.py @@ -27,8 +27,15 @@ def test_encoder(tmp_path): # Ensure that full-batch model works as expected ########################## encoder = HeteroEncoder( - 64, node_to_col_names_dict, col_stats_dict, - stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT) + 64, + node_to_col_names_dict, + col_stats_dict, + stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, + torch_frame_model_kwargs={ + "channels": 128, + "num_layers": 4, + }, + ) x_dict = encoder(data.tf_dict) assert 'product' in x_dict.keys() diff --git a/test/nn/test_model.py b/test/nn/test_model.py index 92e797d..0bd98f9 100644 --- a/test/nn/test_model.py +++ b/test/nn/test_model.py @@ -45,8 +45,19 @@ def test_idgnn(tmp_path): batch = next(iter(train_loader)) assert len(batch[task.dst_entity_table].batch) > 0 - model = IDGNN(data=data, col_stats_dict=col_stats_dict, num_layers=2, - channels=64, out_channels=1, aggr="sum", norm="layer_norm") + model = IDGNN( + data=data, + col_stats_dict=col_stats_dict, + num_layers=2, + channels=64, + out_channels=1, + aggr="sum", + norm="layer_norm", + torch_frame_model_kwargs={ + "channels": 128, + "num_layers": 4, + }, + ) model.train() out = model(batch, task.src_entity_table, task.dst_entity_table).flatten() @@ -88,10 +99,20 @@ def test_hybridgnn(tmp_path): channels = 16 embedding_dim = 8 - model = HybridGNN(data=data, col_stats_dict=col_stats_dict, - num_nodes=train_table_input.num_dst_nodes, num_layers=2, - channels=channels, aggr="sum", norm="layer_norm", - embedding_dim=embedding_dim) + model = HybridGNN( + data=data, + col_stats_dict=col_stats_dict, + num_nodes=train_table_input.num_dst_nodes, + num_layers=2, + channels=channels, + aggr="sum", + norm="layer_norm", + embedding_dim=embedding_dim, + torch_frame_model_kwargs={ + "channels": 128, + "num_layers": 4, + }, + ) model.train() logits = model(batch, task.src_entity_table, task.dst_entity_table) @@ -135,10 +156,20 @@ def test_shallowrhsgnn(tmp_path): channels = 16 embedding_dim = 8 - model = ShallowRHSGNN(data=data, col_stats_dict=col_stats_dict, - num_nodes=train_table_input.num_dst_nodes, - num_layers=2, channels=channels, aggr="sum", - norm="layer_norm", embedding_dim=embedding_dim) + model = ShallowRHSGNN( + data=data, + col_stats_dict=col_stats_dict, + num_nodes=train_table_input.num_dst_nodes, + num_layers=2, + channels=channels, + aggr="sum", + norm="layer_norm", + embedding_dim=embedding_dim, + torch_frame_model_kwargs={ + "channels": 128, + "num_layers": 4, + }, + ) model.train() logits = model(batch, task.src_entity_table, task.dst_entity_table)