diff --git a/topobenchmark/nn/wrappers/graph/gnn_wrapper.py b/topobenchmark/nn/wrappers/graph/gnn_wrapper.py index fd479281..c0576823 100644 --- a/topobenchmark/nn/wrappers/graph/gnn_wrapper.py +++ b/topobenchmark/nn/wrappers/graph/gnn_wrapper.py @@ -27,7 +27,7 @@ def forward(self, batch): x_0 = self.backbone( batch.x_0, batch.edge_index, - batch.get("edge_weight", None), + edge_weight=batch.get("edge_weight", None), ) model_out = {"labels": batch.y, "batch_0": batch.batch_0}