diff --git a/besskge/dataset.py b/besskge/dataset.py index 9423ea4..f4e827a 100644 --- a/besskge/dataset.py +++ b/besskge/dataset.py @@ -36,6 +36,11 @@ class KGDataset: #: {part: int32[n_triple, {h,r,t}]} triples: Dict[str, NDArray[np.int32]] + #: IDs of the triples in KGDataset.triples wrt + #: the ordering in the original array/dataframe + #: from where the triples originate. + original_triple_ids: Dict[str, NDArray[np.int32]] + #: Entity labels by ID; str[n_entity] entity_dict: Optional[List[str]] = None @@ -90,6 +95,9 @@ def from_triples( and relations have already been assigned. Note that, if entities have types, entities of the same type need to have contiguous IDs. Triples are randomly split in train/validation/test sets. + The attribute `KGDataset.original_triple_ids` stores the IDs + of the triples in each split wrt the original ordering in `data`. + If a pre-defined train/validation/test split is wanted, the KGDataset class should be instantiated manually. @@ -114,22 +122,28 @@ class should be instantiated manually. num_valid = int(num_triples * split[1]) rng = np.random.default_rng(seed=seed) - rng.shuffle(data, axis=0) - - triples = dict() - triples["train"], triples["valid"], triples["test"] = np.split( - data, (num_train, num_train + num_valid), axis=0 + id_shuffle = rng.permutation(np.arange(num_triples)) + triple_ids = dict() + triple_ids["train"], triple_ids["valid"], triple_ids["test"] = np.split( + id_shuffle, (num_train, num_train + num_valid), axis=0 ) + triples = dict() + triples["train"] = data[triple_ids["train"]] + triples["valid"] = data[triple_ids["valid"]] + triples["test"] = data[triple_ids["test"]] - return cls( + ds = cls( n_entity=data[:, [0, 2]].max() + 1, n_relation_type=data[:, 1].max() + 1, entity_dict=entity_dict, relation_dict=relation_dict, type_offsets=type_offsets, triples=triples, + original_triple_ids=triple_ids, ) + return ds + @classmethod def from_dataframe( cls, @@ -219,6 +233,9 @@ def from_dataframe( relation_dict=relation_dict, type_offsets=type_offsets, triples=triples, + original_triple_ids={ + k: np.arange(v.shape[0]) for k, v in triples.items() + }, ) @classmethod @@ -280,6 +297,7 @@ def build_ogbl_biokg(cls, root: Path) -> "KGDataset": relation_dict=rel_dict, type_offsets=type_offsets, triples=triples, + original_triple_ids={k: np.arange(v.shape[0]) for k, v in triples.items()}, neg_heads=neg_heads, neg_tails=neg_tails, ) @@ -329,6 +347,7 @@ def build_ogbl_wikikg2(cls, root: Path) -> "KGDataset": relation_dict=rel_dict, type_offsets=None, triples=triples, + original_triple_ids={k: np.arange(v.shape[0]) for k, v in triples.items()}, neg_heads=neg_heads, neg_tails=neg_tails, ) diff --git a/tests/test_batch_sampler.py b/tests/test_batch_sampler.py index cfee463..91bd6ca 100644 --- a/tests/test_batch_sampler.py +++ b/tests/test_batch_sampler.py @@ -36,6 +36,7 @@ relation_dict=None, type_offsets=None, triples=triples, + original_triple_ids={k: np.arange(v.shape[0]) for k, v in triples.items()}, neg_heads=None, neg_tails=None, ) diff --git a/tests/test_bess.py b/tests/test_bess.py index f0521ea..ebf7869 100644 --- a/tests/test_bess.py +++ b/tests/test_bess.py @@ -79,6 +79,7 @@ def test_bess_inference( relation_dict=None, type_offsets=None, triples=triples, + original_triple_ids={k: np.arange(v.shape[0]) for k, v in triples.items()}, neg_heads=neg_heads, neg_tails=neg_tails, ) @@ -306,6 +307,7 @@ def test_bess_topk_prediction( relation_dict=None, type_offsets=None, triples=triples, + original_triple_ids={k: np.arange(v.shape[0]) for k, v in triples.items()}, neg_heads=neg_heads, neg_tails=neg_tails, ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 4f42f23..cd54962 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -51,6 +51,7 @@ def test_all_scores_pipeline( relation_dict=None, type_offsets=None, triples=triples, + original_triple_ids={k: np.arange(v.shape[0]) for k, v in triples.items()}, ) partition_mode = "h_shard" if corruption_scheme == "t" else "t_shard" diff --git a/tests/test_sharding.py b/tests/test_sharding.py index 557d725..495e735 100644 --- a/tests/test_sharding.py +++ b/tests/test_sharding.py @@ -34,6 +34,7 @@ relation_dict=None, type_offsets={str(i): o for i, o in enumerate(type_offsets)}, triples=triples, + original_triple_ids={k: np.arange(v.shape[0]) for k, v in triples.items()}, neg_heads=neg_heads, neg_tails=neg_tails, )