diff --git a/train.py b/train.py index 7c12912..85b57d0 100644 --- a/train.py +++ b/train.py @@ -55,7 +55,7 @@ def prepare_data(graphs, args, test_graphs=None, max_nodes=0): else: train_idx = int(len(graphs) * args.train_ratio) train_graphs = graphs[:train_idx] - val_graphs = graph[train_idx:] + val_graphs = graphs[train_idx:] print( "Num training graphs: ", len(train_graphs),