diff --git a/torch_geometric/utils/rag/feature_store.py b/torch_geometric/utils/rag/feature_store.py index bbbc2bf23b0e..9e4f1d33c0bb 100644 --- a/torch_geometric/utils/rag/feature_store.py +++ b/torch_geometric/utils/rag/feature_store.py @@ -88,9 +88,11 @@ def load_subgraph( node_id = sample.node print("sampled node id in load subgraph=", node_id) print("sampled src ids in load subgraph=", sample.row) - print("num src ids not in node ids inside `load_subgraph`=",sum([src_id not in node_id for src_id in sample.row])) + print("num src ids not in node ids inside `load_subgraph`=", + sum([src_id not in node_id for src_id in sample.row])) print("sampled dst ids in load subgraph=", sample.col) - print("num dst ids not in node ids inside `load_subgraph`=",sum([dst_id not in node_id for dst_id in sample.col])) + print("num dst ids not in node ids inside `load_subgraph`=", + sum([dst_id not in node_id for dst_id in sample.col])) edge_id = sample.edge edge_index = torch.stack((sample.row, sample.col), dim=0) x = self.x[node_id]