diff --git a/examples/gnn_link.py b/examples/gnn_link.py index 96805e73..c16740e3 100644 --- a/examples/gnn_link.py +++ b/examples/gnn_link.py @@ -26,15 +26,15 @@ from relbench.tasks import get_task parser = argparse.ArgumentParser() -parser.add_argument("--dataset", type=str, default="rel-hm") -parser.add_argument("--task", type=str, default="user-item-purchase") +parser.add_argument("--dataset", type=str, default="rel-trial") +parser.add_argument("--task", type=str, default="site-sponsor-run") parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--eval_epochs_interval", type=int, default=1) parser.add_argument("--batch_size", type=int, default=512) parser.add_argument("--channels", type=int, default=128) parser.add_argument("--aggr", type=str, default="sum") -parser.add_argument("--num_layers", type=int, default=2) +parser.add_argument("--num_layers", type=int, default=4) parser.add_argument("--num_neighbors", type=int, default=128) parser.add_argument("--temporal_strategy", type=str, default="uniform") # Use the same seed time across the mini-batch and share the negatives @@ -42,6 +42,9 @@ parser.add_argument( "--no-share_same_time", dest="share_same_time", action="store_false" ) +parser.add_argument( + "--neg-ratio", type=int, default=2 +) # Whether to use shallow embedding on dst nodes or not. parser.add_argument("--use_shallow", action="store_true", default=True) parser.add_argument("--no-use_shallow", dest="use_shallow", action="store_false") @@ -104,6 +107,7 @@ # if share_same_time is True, we use sampler, so shuffle must be set False shuffle=not args.share_same_time, num_workers=args.num_workers, + neg_ratio=args.neg_ratio, ) eval_loaders_dict: Dict[str, Tuple[NeighborLoader, NeighborLoader]] = {} @@ -118,7 +122,7 @@ time_attr="time", input_nodes=(task.src_entity_table, src_node_indices), input_time=torch.full( - size=(len(src_node_indices),), fill_value=seed_time, dtype=torch.long + size=(len(src_node_indices)* args.neg_ratio,), fill_value=seed_time, dtype=torch.long ), batch_size=args.batch_size, shuffle=False, diff --git a/relbench/modeling/loader.py b/relbench/modeling/loader.py index 941d5722..bb97083a 100644 --- a/relbench/modeling/loader.py +++ b/relbench/modeling/loader.py @@ -181,6 +181,8 @@ class LinkNeighborLoader(DataLoader): (default: :obj:`None`) share_same_time (bool): Whether to share the seed time within mini-batch or not (default: :obj:`False`) + neg_ratio (float): How many negs per pos does the model sample. + (default: :obj:`1.0`) """ def __init__( @@ -195,6 +197,7 @@ def __init__( subgraph_type: Union[SubgraphType, str] = "directional", temporal_strategy: str = "uniform", time_attr: Optional[str] = None, + neg_ratio: Optional[int] = 1, **kwargs, ): node_sampler = NeighborSampler( @@ -213,9 +216,11 @@ def __init__( self.num_dst_nodes = num_dst_nodes self.src_time = src_time self.share_same_time = share_same_time + self.neg_ratio = neg_ratio kwargs.pop("dataset", None) kwargs.pop("collate_fn", None) + kwargs.pop("neg_ratio", None) if share_same_time: kwargs.pop("sampler", None) kwargs["batch_sampler"] = TimestampSampler( @@ -255,7 +260,7 @@ def collate_fn( src_indices = index[:, 0].contiguous() pos_dst_indices = index[:, 1].contiguous() time = index[:, 2].contiguous() - neg_dst_indices = torch.randint(0, self.num_dst_nodes, size=(len(src_indices),)) + neg_dst_indices = torch.randint(0, self.num_dst_nodes, size=(len(src_indices) * self.neg_ratio,)) src_out = self.src_loader.get_neighbors( NodeSamplerInput( input_id=src_indices, diff --git a/relbench/modeling/nn.py b/relbench/modeling/nn.py index 40d26987..e68b6f72 100644 --- a/relbench/modeling/nn.py +++ b/relbench/modeling/nn.py @@ -112,6 +112,8 @@ def forward( out_dict: Dict[NodeType, Tensor] = {} for node_type, time in time_dict.items(): + if batch_dict[node_type].max()> len(seed_time): + seed_time = torch.full((batch_dict[node_type].max().item()+1, ),seed_time[0] ).to(seed_time.device) rel_time = seed_time[batch_dict[node_type]] - time rel_time = rel_time / (60 * 60 * 24) # Convert seconds to days.