Skip to content

Commit

Permalink
limit batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Aug 11, 2024
1 parent acb5ff2 commit f785e10
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions benchmark/relbench_link_prediction_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
"norm": ["layer_norm", "batch_norm"]
}
train_search_space = {
"batch_size": [256, 512, 1024],
"batch_size": [256, 512],
"base_lr": [0.001, 0.01],
"gamma_rate": [0.8, 1.],
}
Expand Down Expand Up @@ -258,12 +258,11 @@ def train_and_eval_with_cfg(
model_cfg["out_channels"] = 1
encoder_model_kwargs = {
"channels": model_cfg["encoder_channels"],
"num_layers": model_cfg["num_layers"]
"num_layers": model_cfg["encoder_layers"]
}
model_cfg.pop("channels")
model_cfg.pop("num_layers")
model_kwargs = {k: v for k,v in model_cfg.items() if k not in ["encoder_channels", "encoder_layers"]}
# Use model_cfg to set up training procedure
model = model_cls(**model_cfg, data=data, col_stats_dict=col_stats_dict,
model = model_cls(**model_kwargs, data=data, col_stats_dict=col_stats_dict,
num_layers=args.num_layers,
torch_frame_model_kwargs=encoder_model_kwargs).to(device)
model.reset_parameters()
Expand Down

0 comments on commit f785e10

Please sign in to comment.