Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce negative sampling ratio #279

Closed
wants to merge 2 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
introduce negative sampling ratio
Yiwen Yuan committed Nov 24, 2024
commit 826a69e35856af66c3428f64785993110d7de0d8
6 changes: 3 additions & 3 deletions examples/gnn_link.py
Original file line number Diff line number Diff line change
@@ -26,15 +26,15 @@
from relbench.tasks import get_task

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="rel-hm")
parser.add_argument("--task", type=str, default="user-item-purchase")
parser.add_argument("--dataset", type=str, default="rel-trial")
parser.add_argument("--task", type=str, default="site-sponsor-run")
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--eval_epochs_interval", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--channels", type=int, default=128)
parser.add_argument("--aggr", type=str, default="sum")
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--num_layers", type=int, default=4)
parser.add_argument("--num_neighbors", type=int, default=128)
parser.add_argument("--temporal_strategy", type=str, default="uniform")
# Use the same seed time across the mini-batch and share the negatives
5 changes: 4 additions & 1 deletion examples/model.py
Original file line number Diff line number Diff line change
@@ -85,9 +85,12 @@ def forward(
) -> Tensor:
seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)
seed_time_dict = {}
for node_type in batch:
seed_time_dict[node_type] = batch[node_type].seed_time

rel_time_dict = self.temporal_encoder(
seed_time, batch.time_dict, batch.batch_dict
seed_time_dict, batch.time_dict, batch.batch_dict
)

for node_type, rel_time in rel_time_dict.items():
7 changes: 6 additions & 1 deletion relbench/modeling/loader.py
Original file line number Diff line number Diff line change
@@ -181,6 +181,8 @@ class LinkNeighborLoader(DataLoader):
(default: :obj:`None`)
share_same_time (bool): Whether to share the seed time within mini-batch
or not (default: :obj:`False`)
neg_ratio (float): How many negs per pos does the model sample.
(default: :obj:`1.0`)
"""

def __init__(
@@ -195,6 +197,7 @@ def __init__(
subgraph_type: Union[SubgraphType, str] = "directional",
temporal_strategy: str = "uniform",
time_attr: Optional[str] = None,
neg_ratio: Optional[int] = 1,
**kwargs,
):
node_sampler = NeighborSampler(
@@ -213,9 +216,11 @@ def __init__(
self.num_dst_nodes = num_dst_nodes
self.src_time = src_time
self.share_same_time = share_same_time
self.neg_ratio = neg_ratio

kwargs.pop("dataset", None)
kwargs.pop("collate_fn", None)
kwargs.pop("neg_ratio", None)
if share_same_time:
kwargs.pop("sampler", None)
kwargs["batch_sampler"] = TimestampSampler(
@@ -255,7 +260,7 @@ def collate_fn(
src_indices = index[:, 0].contiguous()
pos_dst_indices = index[:, 1].contiguous()
time = index[:, 2].contiguous()
neg_dst_indices = torch.randint(0, self.num_dst_nodes, size=(len(src_indices),))
neg_dst_indices = torch.randint(0, self.num_dst_nodes, size=(len(src_indices) * self.neg_ratio,))
src_out = self.src_loader.get_neighbors(
NodeSamplerInput(
input_id=src_indices,
3 changes: 2 additions & 1 deletion relbench/modeling/nn.py
Original file line number Diff line number Diff line change
@@ -105,13 +105,14 @@ def reset_parameters(self):

def forward(
self,
seed_time: Tensor,
seed_time_dict: Dict[NodeType, Tensor],
time_dict: Dict[NodeType, Tensor],
batch_dict: Dict[NodeType, Tensor],
) -> Dict[NodeType, Tensor]:
out_dict: Dict[NodeType, Tensor] = {}

for node_type, time in time_dict.items():
seed_time = seed_time_dict[node_type]
rel_time = seed_time[batch_dict[node_type]] - time
rel_time = rel_time / (60 * 60 * 24) # Convert seconds to days.