diff --git a/graphein/ml/conversion.py b/graphein/ml/conversion.py index 010f418e..f9ad2ca4 100644 --- a/graphein/ml/conversion.py +++ b/graphein/ml/conversion.py @@ -278,10 +278,14 @@ def convert_nx_to_pyg(self, G: nx.Graph) -> Data: for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): for key, value in feat_dict.items(): if str(key) in self.columns: - data[str(key)] = ( - list(value) if i == 0 else data[str(key)] + list(value) - ) - + #data[str(key)] = ( + # list(value) if i == 0 else data[str(key)] + list(value) + #) + if i == 0: + data[str(key)] = [] + data[str(key)].append(list(value)) + + # Add graph-level features for feat_name in G.graph: if str(feat_name) in self.columns: