diff --git a/besskge/dataset.py b/besskge/dataset.py index 8a8d1c6..5a7a098 100644 --- a/besskge/dataset.py +++ b/besskge/dataset.py @@ -123,8 +123,9 @@ class should be instantiated manually. id_shuffle, (num_train, num_train + num_valid), axis=0 ) triples = dict() - for split in ["train", "valid", "test"]: - triples[split] = data[triple_ids[split]] + triples["train"] = data[triple_ids["train"]] + triples["valid"] = data[triple_ids["valid"]] + triples["test"] = data[triple_ids["test"]] ds = cls( n_entity=data[:, [0, 2]].max() + 1, @@ -134,7 +135,7 @@ class should be instantiated manually. type_offsets=type_offsets, triples=triples, ) - ds.original_triple_ids = triple_ids + ds.original_triple_ids = triple_ids # type: ignore return ds