diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 345118f..d3f3a96 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -14,7 +14,7 @@ jobs: ci: runs-on: ubuntu-latest container: graphcore/pytorch:3.3.0-ubuntu-20.04-20230703 - timeout-minutes: 15 + timeout-minutes: 25 steps: - uses: actions/checkout@v3 - name: Install dev-requirements diff --git a/besskge/__init__.py b/besskge/__init__.py index d22e654..8627096 100644 --- a/besskge/__init__.py +++ b/besskge/__init__.py @@ -45,6 +45,7 @@ def load_custom_ops_so() -> None: loss, metric, negative_sampler, + pipeline, scoring, sharding, utils, diff --git a/besskge/bess.py b/besskge/bess.py index 866dca9..798feed 100644 --- a/besskge/bess.py +++ b/besskge/bess.py @@ -707,22 +707,22 @@ def forward( and respective scores are kept to be used in the next iteration, while the rest are discarded. - :param relation: shape: (shard_bs,) + :param relation: shape: (1, shard_bs,) Relation indices. - :param head: shape: (shard_bs,) + :param head: shape: (1, shard_bs,) Head indices, if known. Default: None. - :param tail: shape: (shard_bs,) + :param tail: shape: (1, shard_bs,) Tail indices, if known. Default: None. - :param negative: shape: (n_shard, B, padded_negative) + :param negative: shape: (1, n_shard, B, padded_negative) Candidates to score against the queries. It can be the same set for all queries (B=1), or specific for each query in the batch (B=shard_bs). If None, score each query against all entities in the knowledge graph. Default: None. - :param triple_mask: shape: (shard_bs,) + :param triple_mask: shape: (1, shard_bs,) Mask to filter the triples in the micro-batch before computing metrics. Default: None. - :param negative_mask: shape: (n_shard, B, padded_negative) + :param negative_mask: shape: (1, n_shard, B, padded_negative) If candidates are provided, mask to discard padding negatives when computing best completions. Requires the use of :code:`mask_on_gather=True` in the candidate @@ -919,3 +919,144 @@ def loop_body( ) return out_dict + + +class AllScoresBESS(torch.nn.Module): + """ + Distributed scoring of (h, r, ?) or (?, r, t) queries against + the entities in the knowledge graph, returning all scores to + host in blocks, based on the BESS :cite:p:`BESS` + inference scheme. + To be used in combination with a batch sampler based on a + "h_shard"/"t_shard"-partitioned triple set. + Since each iteration on IPU computes only part of the scores + (based on the size of the sliding window), metrics should be + computed on host after aggregating data (see + :class:`besskge.pipeline.AllScoresPipeline`). + + Only to be used for inference. + """ + + def __init__( + self, + candidate_sampler: PlaceholderNegativeSampler, + score_fn: BaseScoreFunction, + window_size: int = 1000, + ) -> None: + """ + Initialize AllScores BESS-KGE module. + + :param candidate_sampler: + :class:`besskge.negative_sampler.PlaceholderNegativeSampler` class, + specifying corruption scheme. + :param score_fn: + Scoring function. + :param window_size: + Size of the sliding window, namely the number of negative entities + scored against each query at each step on IPU and returned to host. + Should be decreased with large batch sizes, to avoid an OOM error. + Default: 1000. + """ + super().__init__() + self.sharding = score_fn.sharding + self.score_fn = score_fn + self.negative_sampler = candidate_sampler + self.window_size = window_size + + if not score_fn.negative_sample_sharing: + raise ValueError("AllScoresBESS requires using negative sample sharing") + + if self.negative_sampler.corruption_scheme not in ["h", "t"]: + raise ValueError("AllScoresBESS only support 'h', 't' corruption scheme") + + if not isinstance(self.negative_sampler, PlaceholderNegativeSampler): + raise ValueError( + "AllScoresBESS requires a `PlaceholderNegativeSampler`" + " candidate_sampler" + ) + + self.entity_embedding = self.score_fn.entity_embedding + self.entity_embedding_size: int = self.entity_embedding.shape[-1] + + self.candidate = torch.arange(self.window_size, dtype=torch.int32) + self.n_step = int( + np.ceil(self.sharding.max_entity_per_shard / self.window_size) + ) + + def forward( + self, + step: torch.Tensor, + relation: torch.Tensor, + head: Optional[torch.Tensor] = None, + tail: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward step. + + Similarly to :class:`ScoreMovingBessKGE`, candidates are scored on the + device where they are gathered, then scores for the same query against + candidates in different shards are collected together via an AllToAll. + + :param step: + The index of the block (of size self.window_size) of entities + on each IPU to score against queries. + :param relation: shape: (1, shard_bs,) + Relation indices. + :param head: shape: (1, shard_bs,) + Head indices, if known. Default: None. + :param tail: shape: (1, shard_bs,) + Tail indices, if known. Default: None. + + :return: + The scores for the completions. + """ + + relation = relation.squeeze(0) + if head is not None: + head = head.squeeze(0) + if tail is not None: + tail = tail.squeeze(0) + + n_shard = self.sharding.n_shard + shard_bs = relation.shape[0] + + relation_all = all_gather(relation, n_shard) + if self.negative_sampler.corruption_scheme == "h": + tail_embedding = self.entity_embedding[tail] + tail_embedding_all = all_gather(tail_embedding, n_shard) + elif self.negative_sampler.corruption_scheme == "t": + head_embedding = self.entity_embedding[head] + head_embedding_all = all_gather(head_embedding, n_shard) + + # Local indices of the entities to score against queries + ent_slice = torch.minimum( + step * self.window_size + + torch.arange(self.window_size, device=relation.device), + torch.tensor(self.sharding.max_entity_per_shard - 1), + ) + negative_embedding = self.entity_embedding[ent_slice] + + if self.negative_sampler.corruption_scheme == "h": + scores = self.score_fn.score_heads( + negative_embedding, + relation_all.flatten(end_dim=1), + tail_embedding_all.flatten(end_dim=1), + ) + elif self.negative_sampler.corruption_scheme == "t": + scores = self.score_fn.score_tails( + head_embedding_all.flatten(end_dim=1), + relation_all.flatten(end_dim=1), + negative_embedding, + ) + + # Send back queries to original shard + scores = ( + all_to_all( + scores.reshape(n_shard, shard_bs, self.window_size), + n_shard, + ) + .transpose(0, 1) + .flatten(start_dim=1) + ) # shape (bs, n_shard * ws) + + return scores diff --git a/besskge/pipeline.py b/besskge/pipeline.py new file mode 100644 index 0000000..49e8507 --- /dev/null +++ b/besskge/pipeline.py @@ -0,0 +1,298 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. + +""" +High-level APIs for training/inference with BESS. +""" + +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import poptorch +import torch +from numpy.typing import NDArray +from tqdm import tqdm + +from besskge.batch_sampler import ShardedBatchSampler +from besskge.bess import AllScoresBESS +from besskge.metric import Evaluation +from besskge.negative_sampler import PlaceholderNegativeSampler +from besskge.scoring import BaseScoreFunction +from besskge.utils import get_entity_filter + + +class AllScoresPipeline(torch.nn.Module): + """ + Pipeline to compute scores of (h, r, ?) / (?, r, t) queries against all entities + in the KG, and related prediction metrics. + It supports filtering out the scores of specific completions that appear in a given + set of triples. + + To be used in combination with a batch sampler based on a + "h_shard"/"t_shard"-partitioned triple set. + """ + + def __init__( + self, + batch_sampler: ShardedBatchSampler, + corruption_scheme: str, + score_fn: BaseScoreFunction, + evaluation: Optional[Evaluation] = None, + filter_triples: Optional[List[Union[torch.Tensor, NDArray[np.int32]]]] = None, + return_scores: bool = False, + return_topk: bool = False, + k: int = 10, + window_size: int = 1000, + use_ipu_model: bool = False, + ) -> None: + """ + Initialize pipeline. + + :param batch_sampler: + Batch sampler, based on a + "h_shard"/"t_shard"-partitioned triple set. + :param corruption_scheme: + Set to "t" to score (h, r, ?) completions, or to + "h" to score (?, r, t) completions. + :param score_fn: + The trained scoring function. + :param evaluation: + Evaluation module, for computing metrics. + Default: None. + :param filter_triples: + The set of all triples whose scores need to be filtered. + The triples passed here must have GLOBAL IDs for head/tail + entities. Default: None. + :param return_scores: + If True, store and return scores of all queries' completions + (with filters applied, if specified). + For large number of queries/entities, this can cause the host + to go OOM. + Default: False. + :param return_topk: + If True, return for each query the global IDs of the most likely + completions, after filtering out the scores of `filter_triples`. + Default: False. + :param k: + If `return_topk` is set to True, for each query return the + top-k most likely predictions (after filtering). Default: 10. + :param window_size: + Size of the sliding window, namely the number of negative entities + scored against each query at each step on IPU and returned to host. + Should be decreased with large batch sizes, to avoid an OOM error. + Default: 1000. + :param use_ipu_model: + Run pipeline on IPU Model instead of actual hardware. Default: False. + """ + super().__init__() + self.batch_sampler = batch_sampler + if not (evaluation or return_scores): + raise ValueError( + "Nothing to return. Provide `evaluation` or set `return_scores=True`" + ) + if corruption_scheme not in ["h", "t"]: + raise ValueError("corruption_scheme needs to be either 'h' or 't'") + if ( + corruption_scheme == "h" + and self.batch_sampler.triple_partition_mode != "t_shard" + ): + raise ValueError( + "Corruption scheme 'h' requires 't-shard'-partitioned triples" + ) + elif ( + corruption_scheme == "t" + and self.batch_sampler.triple_partition_mode != "h_shard" + ): + raise ValueError( + "Corruption scheme 't' requires 'h-shard'-partitioned triples" + ) + self.candidate_sampler = PlaceholderNegativeSampler( + corruption_scheme=corruption_scheme + ) + self.score_fn = score_fn + self.evaluation = evaluation + self.return_scores = return_scores + self.return_topk = return_topk + self.k = k + self.window_size = window_size + self.corruption_scheme = corruption_scheme + self.bess_module = AllScoresBESS( + self.candidate_sampler, self.score_fn, self.window_size + ) + + inf_options = poptorch.Options() + inf_options.replication_factor = self.bess_module.sharding.n_shard + inf_options.deviceIterations(self.batch_sampler.batches_per_step) + inf_options.outputMode(poptorch.OutputMode.All) + if use_ipu_model: + inf_options.useIpuModel(True) + self.dl = self.batch_sampler.get_dataloader(options=inf_options, shuffle=False) + + self.poptorch_module = poptorch.inferenceModel( + self.bess_module, options=inf_options + ) + self.poptorch_module.entity_embedding.replicaGrouping( + poptorch.CommGroupType.NoGrouping, + 0, + poptorch.VariableRetrievalMode.OnePerGroup, + ) + self.filter_triples: Optional[torch.Tensor] = None + if filter_triples: + # Reconstruct global IDs for all entities in triples + local_id_col = ( + 0 if self.batch_sampler.triple_partition_mode == "h_shard" else 2 + ) + triple_shard_offset = np.concatenate( + [np.array([0]), np.cumsum(batch_sampler.triple_counts)] + ) + global_id_triples = [] + for i in range(len(triple_shard_offset) - 1): + shard_triples = np.copy( + batch_sampler.triples[ + triple_shard_offset[i] : triple_shard_offset[i + 1] + ] + ) + shard_triples[ + :, local_id_col + ] = self.bess_module.sharding.shard_and_idx_to_entity[i][ + shard_triples[:, local_id_col] + ] + global_id_triples.append(shard_triples) + self.triples = torch.from_numpy(np.concatenate(global_id_triples, axis=0)) + self.filter_triples = torch.concat( + [ + tr if isinstance(tr, torch.Tensor) else torch.from_numpy(tr) + for tr in filter_triples + ], + dim=0, + ) + + def forward(self) -> Dict[str, Any]: + """ + Compute scores of all completions and (possibly) metrics. + + :return: + Scores, metrics, and (if provided in batch sampler) IDs + of inference triples (wrt partitioned_triple_set.triples) + to order results. + """ + scores = [] + ids = [] + metrics = [] + ranks = [] + topk_ids = [] + n_triple = 0 + for batch in tqdm(iter(self.dl)): + triple_mask = batch.pop("triple_mask") + if ( + self.candidate_sampler.corruption_scheme == "h" + and "head" in batch.keys() + ): + ground_truth = batch.pop("head") + elif ( + self.candidate_sampler.corruption_scheme == "t" + and "tail" in batch.keys() + ): + ground_truth = batch.pop("tail") + if self.batch_sampler.return_triple_idx: + triple_id = batch.pop("triple_idx") + ids.append(triple_id[triple_mask]) + n_triple += triple_mask.sum() + + batch_res = [] + batch_idx = [] + for i in range(self.bess_module.n_step): + step = ( + torch.tensor([i], dtype=torch.int32) + .broadcast_to( + ( + self.bess_module.sharding.n_shard + * self.batch_sampler.batches_per_step, + 1, + ) + ) + .contiguous() + ) + ent_slice = torch.minimum( + i * self.bess_module.window_size + + torch.arange(self.bess_module.window_size), + torch.tensor(self.bess_module.sharding.max_entity_per_shard - 1), + ) + # Global indices of entities scored in the step + batch_idx.append( + self.bess_module.sharding.shard_and_idx_to_entity[ + :, ent_slice + ].flatten() + ) + inp = {k: v.flatten(end_dim=1) for k, v in batch.items()} + inp.update(dict(step=step)) + batch_res.append(self.poptorch_module(**inp)) + batch_scores = torch.concat(batch_res, dim=-1) + # Filter out padding scores + batch_scores_filt = batch_scores[triple_mask.flatten()][ + :, np.unique(np.concatenate(batch_idx), return_index=True)[1] + ][:, : self.bess_module.sharding.n_entity] + if ground_truth is not None: + # Scores of positive triples + true_scores = batch_scores_filt[ + torch.arange(batch_scores_filt.shape[0]), + ground_truth[triple_mask], + ] + if self.filter_triples is not None: + # Filter for triples in batch + batch_filter = get_entity_filter( + self.triples[triple_id[triple_mask]], + self.filter_triples, + filter_mode=self.corruption_scheme, + ) + batch_scores_filt[batch_filter[:, 0], batch_filter[:, 1]] = -torch.inf + + if self.evaluation: + assert ( + ground_truth is not None + ), "Evaluation requires providing ground truth entities" + # If not already masked, mask scores of true triples + # to compute metrics + batch_scores_filt[ + torch.arange(batch_scores_filt.shape[0]), + ground_truth[triple_mask], + ] = -torch.inf + batch_ranks = self.evaluation.ranks_from_scores( + true_scores, batch_scores_filt + ) + metrics.append(self.evaluation.dict_metrics_from_ranks(batch_ranks)) + if self.evaluation.return_ranks: + ranks.append(batch_ranks) + if ground_truth is not None: + # Restore positive scores in the returned scores + batch_scores_filt[ + torch.arange(batch_scores_filt.shape[0]), + ground_truth[triple_mask], + ] = true_scores + if self.return_scores: + scores.append(batch_scores_filt) + if self.return_topk: + topk_ids.append(torch.topk(batch_scores_filt, k=self.k, dim=-1).indices) + + out = dict() + if scores: + out["scores"] = torch.concat(scores, dim=0) + if topk_ids: + out["topk_global_id"] = torch.concat(topk_ids, dim=0) + if ids: + out["triple_idx"] = torch.concat(ids, dim=0) + if self.evaluation: + final_metrics = dict() + for m in metrics[0].keys(): + # Reduce metrics over all batches + final_metrics[m] = self.evaluation.reduction( + torch.concat([met[m].reshape(-1) for met in metrics]) + ) + out["metrics"] = final_metrics # type: ignore + # Average metrics over all triples + out["metrics_avg"] = { + m: v.sum() / n_triple for m, v in final_metrics.items() + } # type: ignore + if ranks: + out["ranks"] = torch.concat(ranks, dim=0) + + return out diff --git a/besskge/utils.py b/besskge/utils.py index 31bd6a9..50f68c0 100644 --- a/besskge/utils.py +++ b/besskge/utils.py @@ -33,6 +33,42 @@ def gather_indices(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor: return x.view(-1, mask_size) +def get_entity_filter( + triples: torch.Tensor, filter_triples: torch.Tensor, filter_mode: str +) -> torch.Tensor: + """ + Compare two sets of triples: for each triple (h,r,t) in the first set, find + the entities `e` such that (e,r,t) (or (h,r,e), depending on `filter_mode`) + appears in the second set of triples. + + :param triples: shape (x, 3) + The set of triples to construct filters for. + :param filter_triples: shape (y, 3) + The set of triples determining the head/tail entities to filter. + :param filter_mode: + Set to "h" to look for entities appearing as heads of the same (r,t) pair, + or to "t" to look for entities appearing as tails of the same (h,r) pair. + + :return: shape (z, 2) + The sparse filters. Each row is given by a tuple (i, j), with i the index + of the triple in `triples` to which the filter applies to and j the global + ID of the entity to filter. + """ + if filter_mode == "t": + ent_col = 0 + elif filter_mode == "h": + ent_col = 2 + else: + raise ValueError("`filter_mode` needs to be either 'h' or 't'") + relation_filter = (filter_triples[:, 1]) == triples[:, 1].view(-1, 1) + entity_filter = (filter_triples[:, ent_col]) == triples[:, ent_col].view(-1, 1) + + filter = (entity_filter & relation_filter).nonzero(as_tuple=False) + filter[:, 1] = filter_triples[:, 2 - ent_col].view(1, -1)[:, filter[:, 1]] + + return filter + + def complex_multiplication(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor: """ Batched complex multiplication. diff --git a/docs/source/API_reference.rst b/docs/source/API_reference.rst index 563839f..5179b0a 100644 --- a/docs/source/API_reference.rst +++ b/docs/source/API_reference.rst @@ -6,6 +6,7 @@ BESS-KGE API Reference :template: module.rst :recursive: + besskge.pipeline besskge.dataset besskge.sharding besskge.batch_sampler diff --git a/requirements-dev.txt b/requirements-dev.txt index b66e578..1b22695 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,6 +4,7 @@ flake8==6.0.0 isort==5.12.0 mypy==1.0.1 pandas-stubs==2.0.1.230501 +tqdm-stubs==0.2.1 types-requests==2.28.11.17 pytest==7.2.1 pytest-cov==4.0.0 diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..4f42f23 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,178 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. + +import numpy as np +import pytest +import torch +from torch.testing import assert_close + +from besskge.batch_sampler import RigidShardedBatchSampler +from besskge.dataset import KGDataset +from besskge.metric import Evaluation +from besskge.negative_sampler import PlaceholderNegativeSampler +from besskge.pipeline import AllScoresPipeline +from besskge.scoring import ComplEx +from besskge.sharding import PartitionedTripleSet, Sharding +from besskge.utils import get_entity_filter + +seed = 1234 +n_entity = 5000 +n_relation_type = 50 +n_shard = 4 +n_test_triple = 1000 +batches_per_step = 2 +shard_bs = 400 +embedding_size = 128 + +np.random.seed(seed) +torch.manual_seed(seed) + +sharding = Sharding.create(n_entity, n_shard, seed=seed) + +unsharded_entity_table = torch.randn(size=(n_entity, 2 * embedding_size)) +relation_table = torch.randn(size=(n_relation_type, 2 * embedding_size)) + +test_triples_h = np.random.choice(n_entity - 1, size=n_test_triple, replace=False) +test_triples_t = np.random.choice(n_entity - 1, size=n_test_triple, replace=False) +test_triples_r = np.random.randint(n_relation_type, size=n_test_triple) +triples = {"test": np.stack([test_triples_h, test_triples_r, test_triples_t], axis=1)} + + +@pytest.mark.parametrize("corruption_scheme", ["h", "t"]) +@pytest.mark.parametrize( + "filter_scores, extra_only", [(True, True), (True, False), (False, False)] +) +def test_all_scores_pipeline( + corruption_scheme: str, filter_scores: bool, extra_only: bool +) -> None: + ds = KGDataset( + n_entity=n_entity, + n_relation_type=n_relation_type, + entity_dict=None, + relation_dict=None, + type_offsets=None, + triples=triples, + ) + + partition_mode = "h_shard" if corruption_scheme == "t" else "t_shard" + partitioned_triple_set = PartitionedTripleSet.create_from_dataset( + ds, "test", sharding, partition_mode=partition_mode + ) + + score_fn = ComplEx( + negative_sample_sharing=True, + sharding=sharding, + n_relation_type=ds.n_relation_type, + embedding_size=embedding_size, + entity_initializer=unsharded_entity_table, + relation_initializer=relation_table, + ) + placeholder_ns = PlaceholderNegativeSampler( + corruption_scheme=corruption_scheme, seed=seed + ) + + test_bs = RigidShardedBatchSampler( + partitioned_triple_set=partitioned_triple_set, + negative_sampler=placeholder_ns, + shard_bs=shard_bs, + batches_per_step=batches_per_step, + seed=seed, + hrt_freq_weighting=False, + duplicate_batch=False, + return_triple_idx=True, + ) + + evaluation = Evaluation( + ["mrr", "hits@10"], mode="average", reduction="sum", return_ranks=True + ) + + ground_truth_col = 0 if corruption_scheme == "h" else 2 + if filter_scores: + extra_filter_triples = np.copy(triples["test"]) + extra_filter_triples[:, ground_truth_col] += 1 + triples_to_filter = ( + [extra_filter_triples] + if extra_only + else [triples["test"], extra_filter_triples] + ) + else: + triples_to_filter = None + + pipeline = AllScoresPipeline( + test_bs, + corruption_scheme, + score_fn, + evaluation, + filter_triples=triples_to_filter, # type: ignore + return_scores=True, + return_topk=True, + k=10, + window_size=1000, + use_ipu_model=True, + ) + out = pipeline() + + # Shuffle triples in same order of out["scores"] + triple_reordered = torch.from_numpy( + ds.triples["test"][partitioned_triple_set.triple_sort_idx[out["triple_idx"]]] + ) + + # All scores, computed on CPU + if corruption_scheme == "t": + cpu_scores = score_fn.score_tails( + unsharded_entity_table[triple_reordered[:, 0]], + triple_reordered[:, 1], + unsharded_entity_table.unsqueeze(0), + ) + else: + cpu_scores = score_fn.score_heads( + unsharded_entity_table.unsqueeze(0), + triple_reordered[:, 1], + unsharded_entity_table[triple_reordered[:, 2]], + ) + + pos_scores = score_fn.score_triple( + unsharded_entity_table[triple_reordered[:, 0]], + triple_reordered[:, 1], + unsharded_entity_table[triple_reordered[:, 2]], + ).flatten() + # mask positive scores to compute metrics + cpu_scores[ + torch.arange(cpu_scores.shape[0]), triple_reordered[:, ground_truth_col] + ] = -torch.inf + if filter_scores: + filter_triples = torch.from_numpy( + np.concatenate(triples_to_filter, axis=0) # type: ignore + ) + tr_filter = get_entity_filter( + triple_reordered, filter_triples, corruption_scheme + ) + cpu_scores[tr_filter[:, 0], tr_filter[:, 1]] = -torch.inf + # check filters (for convenience, here h/t entities are non-repeating) + if extra_only: + assert torch.all(tr_filter[:, 0] == torch.arange(n_test_triple)) + assert torch.all( + tr_filter[:, 1] == triple_reordered[:, ground_truth_col] + 1 + ) + else: + assert torch.all(tr_filter[::2, 0] == tr_filter[1::2, 0]) + assert torch.all(tr_filter[::2, 0] == torch.arange(n_test_triple)) + assert torch.all(tr_filter[::2, 1] == triple_reordered[:, ground_truth_col]) + assert torch.all( + tr_filter[1::2, 1] == triple_reordered[:, ground_truth_col] + 1 + ) + + cpu_ranks = evaluation.ranks_from_scores(pos_scores, cpu_scores) + # we allow for a off-by-one rank difference on at most 1% of triples, + # due to rounding differences in CPU vs IPU score computations + assert torch.all(torch.abs(cpu_ranks - out["ranks"]) <= 1) + assert (cpu_ranks != out["ranks"]).sum() < n_test_triple / 100 + + # restore positive scores + cpu_scores[ + torch.arange(cpu_scores.shape[0]), triple_reordered[:, ground_truth_col] + ] = pos_scores + + assert_close(cpu_scores, out["scores"], atol=1e-3, rtol=1e-4) + assert torch.all( + torch.topk(cpu_scores, k=pipeline.k, dim=-1).indices == out["topk_global_id"] + )