From 4c3a92e50c73bf65cd6a2072c7b19cdf75e909bf Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo Date: Tue, 24 Oct 2023 13:27:11 +0000 Subject: [PATCH 01/11] implement AllScoresBESS and high-level pipeline --- besskge/__init__.py | 1 + besskge/bess.py | 153 ++++++++++++++++++++++++- besskge/pipeline.py | 206 ++++++++++++++++++++++++++++++++++ docs/source/API_reference.rst | 1 + requirements-dev.txt | 1 + tests/test_pipeline.py | 127 +++++++++++++++++++++ 6 files changed, 483 insertions(+), 6 deletions(-) create mode 100644 besskge/pipeline.py create mode 100644 tests/test_pipeline.py diff --git a/besskge/__init__.py b/besskge/__init__.py index d22e654..8627096 100644 --- a/besskge/__init__.py +++ b/besskge/__init__.py @@ -45,6 +45,7 @@ def load_custom_ops_so() -> None: loss, metric, negative_sampler, + pipeline, scoring, sharding, utils, diff --git a/besskge/bess.py b/besskge/bess.py index 866dca9..798feed 100644 --- a/besskge/bess.py +++ b/besskge/bess.py @@ -707,22 +707,22 @@ def forward( and respective scores are kept to be used in the next iteration, while the rest are discarded. - :param relation: shape: (shard_bs,) + :param relation: shape: (1, shard_bs,) Relation indices. - :param head: shape: (shard_bs,) + :param head: shape: (1, shard_bs,) Head indices, if known. Default: None. - :param tail: shape: (shard_bs,) + :param tail: shape: (1, shard_bs,) Tail indices, if known. Default: None. - :param negative: shape: (n_shard, B, padded_negative) + :param negative: shape: (1, n_shard, B, padded_negative) Candidates to score against the queries. It can be the same set for all queries (B=1), or specific for each query in the batch (B=shard_bs). If None, score each query against all entities in the knowledge graph. Default: None. - :param triple_mask: shape: (shard_bs,) + :param triple_mask: shape: (1, shard_bs,) Mask to filter the triples in the micro-batch before computing metrics. Default: None. - :param negative_mask: shape: (n_shard, B, padded_negative) + :param negative_mask: shape: (1, n_shard, B, padded_negative) If candidates are provided, mask to discard padding negatives when computing best completions. Requires the use of :code:`mask_on_gather=True` in the candidate @@ -919,3 +919,144 @@ def loop_body( ) return out_dict + + +class AllScoresBESS(torch.nn.Module): + """ + Distributed scoring of (h, r, ?) or (?, r, t) queries against + the entities in the knowledge graph, returning all scores to + host in blocks, based on the BESS :cite:p:`BESS` + inference scheme. + To be used in combination with a batch sampler based on a + "h_shard"/"t_shard"-partitioned triple set. + Since each iteration on IPU computes only part of the scores + (based on the size of the sliding window), metrics should be + computed on host after aggregating data (see + :class:`besskge.pipeline.AllScoresPipeline`). + + Only to be used for inference. + """ + + def __init__( + self, + candidate_sampler: PlaceholderNegativeSampler, + score_fn: BaseScoreFunction, + window_size: int = 1000, + ) -> None: + """ + Initialize AllScores BESS-KGE module. + + :param candidate_sampler: + :class:`besskge.negative_sampler.PlaceholderNegativeSampler` class, + specifying corruption scheme. + :param score_fn: + Scoring function. + :param window_size: + Size of the sliding window, namely the number of negative entities + scored against each query at each step on IPU and returned to host. + Should be decreased with large batch sizes, to avoid an OOM error. + Default: 1000. + """ + super().__init__() + self.sharding = score_fn.sharding + self.score_fn = score_fn + self.negative_sampler = candidate_sampler + self.window_size = window_size + + if not score_fn.negative_sample_sharing: + raise ValueError("AllScoresBESS requires using negative sample sharing") + + if self.negative_sampler.corruption_scheme not in ["h", "t"]: + raise ValueError("AllScoresBESS only support 'h', 't' corruption scheme") + + if not isinstance(self.negative_sampler, PlaceholderNegativeSampler): + raise ValueError( + "AllScoresBESS requires a `PlaceholderNegativeSampler`" + " candidate_sampler" + ) + + self.entity_embedding = self.score_fn.entity_embedding + self.entity_embedding_size: int = self.entity_embedding.shape[-1] + + self.candidate = torch.arange(self.window_size, dtype=torch.int32) + self.n_step = int( + np.ceil(self.sharding.max_entity_per_shard / self.window_size) + ) + + def forward( + self, + step: torch.Tensor, + relation: torch.Tensor, + head: Optional[torch.Tensor] = None, + tail: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward step. + + Similarly to :class:`ScoreMovingBessKGE`, candidates are scored on the + device where they are gathered, then scores for the same query against + candidates in different shards are collected together via an AllToAll. + + :param step: + The index of the block (of size self.window_size) of entities + on each IPU to score against queries. + :param relation: shape: (1, shard_bs,) + Relation indices. + :param head: shape: (1, shard_bs,) + Head indices, if known. Default: None. + :param tail: shape: (1, shard_bs,) + Tail indices, if known. Default: None. + + :return: + The scores for the completions. + """ + + relation = relation.squeeze(0) + if head is not None: + head = head.squeeze(0) + if tail is not None: + tail = tail.squeeze(0) + + n_shard = self.sharding.n_shard + shard_bs = relation.shape[0] + + relation_all = all_gather(relation, n_shard) + if self.negative_sampler.corruption_scheme == "h": + tail_embedding = self.entity_embedding[tail] + tail_embedding_all = all_gather(tail_embedding, n_shard) + elif self.negative_sampler.corruption_scheme == "t": + head_embedding = self.entity_embedding[head] + head_embedding_all = all_gather(head_embedding, n_shard) + + # Local indices of the entities to score against queries + ent_slice = torch.minimum( + step * self.window_size + + torch.arange(self.window_size, device=relation.device), + torch.tensor(self.sharding.max_entity_per_shard - 1), + ) + negative_embedding = self.entity_embedding[ent_slice] + + if self.negative_sampler.corruption_scheme == "h": + scores = self.score_fn.score_heads( + negative_embedding, + relation_all.flatten(end_dim=1), + tail_embedding_all.flatten(end_dim=1), + ) + elif self.negative_sampler.corruption_scheme == "t": + scores = self.score_fn.score_tails( + head_embedding_all.flatten(end_dim=1), + relation_all.flatten(end_dim=1), + negative_embedding, + ) + + # Send back queries to original shard + scores = ( + all_to_all( + scores.reshape(n_shard, shard_bs, self.window_size), + n_shard, + ) + .transpose(0, 1) + .flatten(start_dim=1) + ) # shape (bs, n_shard * ws) + + return scores diff --git a/besskge/pipeline.py b/besskge/pipeline.py new file mode 100644 index 0000000..3af955f --- /dev/null +++ b/besskge/pipeline.py @@ -0,0 +1,206 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. + +""" +High-level APIs for training/inference with BESS. +""" + +from typing import Any, Dict, Optional + +import numpy as np +import poptorch +import torch +from tqdm import tqdm + +from besskge.batch_sampler import ShardedBatchSampler +from besskge.bess import AllScoresBESS +from besskge.metric import Evaluation +from besskge.negative_sampler import PlaceholderNegativeSampler +from besskge.scoring import BaseScoreFunction + + +class AllScoresPipeline(torch.nn.Module): + """ + Pipeline to compute scores of (h, r, ?) / (?, r, t) + queries against all entities in the KG, and related + prediction metrics. + + To be used in combination with a batch sampler based on a + "h_shard"/"t_shard"-partitioned triple set. + """ + + def __init__( + self, + batch_sampler: ShardedBatchSampler, + corruption_scheme: str, + score_fn: BaseScoreFunction, + evaluation: Optional[Evaluation] = None, + windows_size: int = 1000, + ) -> None: + """ + Initialize pipeline. + + :param batch_sampler: + Batch sampler, based on a + "h_shard"/"t_shard"-partitioned triple set. + :param corruption_scheme: + Set to "t" to score (h, r, ?) completions, or to + "h" to score (?, r, t) completions. + :param score_fn: + The trained scoring function. + :param evaluation: + Evaluation module, for computing metrics. + Default: None. + :param windows_size: + Size of the sliding window, namely the number of negative entities + scored against each query at each step on IPU and returned to host. + Should be decreased with large batch sizes, to avoid an OOM error. + Default: 1000. + """ + super().__init__() + self.batch_sampler = batch_sampler + if corruption_scheme not in ["h", "t"]: + raise ValueError("corruption_scheme needs to be either 'h' or 't'") + if ( + corruption_scheme == "h" + and self.batch_sampler.triple_partition_mode != "t_shard" + ): + raise ValueError( + "Corruption scheme 'h' requires 't-shard'-partitioned triples" + ) + elif ( + corruption_scheme == "t" + and self.batch_sampler.triple_partition_mode != "h_shard" + ): + raise ValueError( + "Corruption scheme 't' requires 'h-shard'-partitioned triples" + ) + self.candidate_sampler = PlaceholderNegativeSampler( + corruption_scheme=corruption_scheme + ) + self.score_fn = score_fn + self.evaluation = evaluation + self.window_size = windows_size + self.bess_module = AllScoresBESS( + self.candidate_sampler, self.score_fn, self.window_size + ) + + inf_options = poptorch.Options() + inf_options.replication_factor = self.bess_module.sharding.n_shard + inf_options.deviceIterations(self.batch_sampler.batches_per_step) + inf_options.outputMode(poptorch.OutputMode.All) + self.dl = self.batch_sampler.get_dataloader(options=inf_options, shuffle=False) + + self.poptorch_module = poptorch.inferenceModel( + self.bess_module, options=inf_options + ) + self.poptorch_module.entity_embedding.replicaGrouping( + poptorch.CommGroupType.NoGrouping, + 0, + poptorch.VariableRetrievalMode.OnePerGroup, + ) + + def forward(self) -> Dict[str, Any]: + """ + Compute scores of all completions and (possibly) metrics. + + :return: + Scores, metrics, and (if provided in batch sampler) IDs + of inference triples (wrt partitioned_triple_set.triples) + to order results. + """ + scores = [] + ids = [] + metrics = [] + ranks = [] + n_triple = 0 + for batch in tqdm(iter(self.dl)): + triple_mask = batch.pop("triple_mask") + if ( + self.candidate_sampler.corruption_scheme == "h" + and "head" in batch.keys() + ): + ground_truth = batch.pop("head") + elif ( + self.candidate_sampler.corruption_scheme == "t" + and "tail" in batch.keys() + ): + ground_truth = batch.pop("tail") + if self.batch_sampler.return_triple_idx: + triple_id = batch.pop("triple_idx") + ids.append(triple_id[triple_mask]) + n_triple += triple_mask.sum() + + batch_res = [] + batch_idx = [] + for i in range(self.bess_module.n_step): + step = ( + torch.tensor([i], dtype=torch.int32) + .broadcast_to( + ( + self.bess_module.sharding.n_shard + * self.batch_sampler.batches_per_step, + 1, + ) + ) + .contiguous() + ) + ent_slice = torch.minimum( + i * self.bess_module.window_size + + torch.arange(self.bess_module.window_size), + torch.tensor(self.bess_module.sharding.max_entity_per_shard - 1), + ) + # Global indices of entities scored in the step + batch_idx.append( + self.bess_module.sharding.shard_and_idx_to_entity[ + :, ent_slice + ].flatten() + ) + inp = {k: v.flatten(end_dim=1) for k, v in batch.items()} + inp.update(dict(step=step)) + batch_res.append(self.poptorch_module(**inp)) + batch_scores = torch.concat(batch_res, dim=-1) + # Filter out padding scores + batch_scores_filt = batch_scores[triple_mask.flatten()][ + :, np.unique(np.concatenate(batch_idx), return_index=True)[1] + ][:, : self.bess_module.sharding.n_entity] + scores.append(torch.clone(batch_scores_filt)) + + if self.evaluation: + assert ( + ground_truth is not None + ), "Evaluation requires providing ground truth entities" + # Scores of true triples + true_scores = batch_scores_filt[ + torch.arange(batch_scores_filt.shape[0]), + ground_truth[triple_mask], + ] + # Mask scores of true triples to compute ranks + batch_scores_filt[ + torch.arange(batch_scores_filt.shape[0]), + ground_truth[triple_mask], + ] = -torch.inf + batch_ranks = self.evaluation.ranks_from_scores( + true_scores, batch_scores_filt + ) + metrics.append(self.evaluation.dict_metrics_from_ranks(batch_ranks)) + if self.evaluation.return_ranks: + ranks.append(batch_ranks) + + out = dict(scores=torch.concat(scores, dim=0)) + if ids: + out["triple_ids"] = torch.concat(ids, dim=0) + if self.evaluation: + concat_metrics = dict() + for m in metrics[0].keys(): + concat_metrics[m] = torch.concat( + [met[m].reshape(-1) for met in metrics] + ) + out["metrics"] = concat_metrics # type: ignore + # Average metrics over all triples + out["metrics_avg"] = { + m: v.sum() / n_triple for m, v in concat_metrics.items() + } # type: ignore + if ranks: + out["ranks"] = torch.concat(ranks, dim=0) + + return out diff --git a/docs/source/API_reference.rst b/docs/source/API_reference.rst index 563839f..5179b0a 100644 --- a/docs/source/API_reference.rst +++ b/docs/source/API_reference.rst @@ -6,6 +6,7 @@ BESS-KGE API Reference :template: module.rst :recursive: + besskge.pipeline besskge.dataset besskge.sharding besskge.batch_sampler diff --git a/requirements-dev.txt b/requirements-dev.txt index b66e578..1b22695 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,6 +4,7 @@ flake8==6.0.0 isort==5.12.0 mypy==1.0.1 pandas-stubs==2.0.1.230501 +tqdm-stubs==0.2.1 types-requests==2.28.11.17 pytest==7.2.1 pytest-cov==4.0.0 diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..0f2c447 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,127 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. + +import numpy as np +import pytest +import torch +from torch.testing import assert_close + +from besskge.batch_sampler import RigidShardedBatchSampler +from besskge.dataset import KGDataset +from besskge.metric import Evaluation +from besskge.negative_sampler import PlaceholderNegativeSampler +from besskge.pipeline import AllScoresPipeline +from besskge.scoring import ComplEx +from besskge.sharding import PartitionedTripleSet, Sharding + +seed = 1234 +n_entity = 20000 +n_relation_type = 50 +n_shard = 4 +n_test_triple = 5000 +batches_per_step = 2 +shard_bs = 400 +embedding_size = 128 + +np.random.seed(seed) +torch.manual_seed(seed) + +sharding = Sharding.create(n_entity, n_shard, seed=seed) + +unsharded_entity_table = torch.randn(size=(n_entity, 2 * embedding_size)) +relation_table = torch.randn(size=(n_relation_type, 2 * embedding_size)) + +test_triples_h = np.random.randint(n_entity, size=(n_test_triple,)) +test_triples_t = np.random.randint(n_entity, size=(n_test_triple,)) +test_triples_r = np.random.randint(n_relation_type, size=(n_test_triple,)) +triples = {"test": np.stack([test_triples_h, test_triples_r, test_triples_t], axis=1)} + + +@pytest.mark.parametrize("corruption_scheme", ["h", "t"]) +@pytest.mark.parametrize("compute_metrics", [True, False]) +def test_all_scores_pipeline(corruption_scheme: str, compute_metrics: bool) -> None: + ds = KGDataset( + n_entity=n_entity, + n_relation_type=n_relation_type, + entity_dict=None, + relation_dict=None, + type_offsets=None, + triples=triples, + ) + + partition_mode = "h_shard" if corruption_scheme == "t" else "t_shard" + partitioned_triple_set = PartitionedTripleSet.create_from_dataset( + ds, "test", sharding, partition_mode=partition_mode + ) + + score_fn = ComplEx( + negative_sample_sharing=True, + sharding=sharding, + n_relation_type=ds.n_relation_type, + embedding_size=embedding_size, + entity_initializer=unsharded_entity_table, + relation_initializer=relation_table, + ) + placeholder_ns = PlaceholderNegativeSampler( + corruption_scheme=corruption_scheme, seed=seed + ) + + test_bs = RigidShardedBatchSampler( + partitioned_triple_set=partitioned_triple_set, + negative_sampler=placeholder_ns, + shard_bs=shard_bs, + batches_per_step=batches_per_step, + seed=seed, + hrt_freq_weighting=False, + duplicate_batch=False, + return_triple_idx=True, + ) + + if compute_metrics: + evaluation = Evaluation( + ["mrr", "hits@10"], mode="average", reduction="sum", return_ranks=True + ) + else: + evaluation = None + + pipeline = AllScoresPipeline( + test_bs, corruption_scheme, score_fn, evaluation, windows_size=1000 + ) + out = pipeline() + + # Shuffle triples in same order of out["scores"] + triple_reordered = torch.from_numpy( + ds.triples["test"][partitioned_triple_set.triple_sort_idx[out["triple_ids"]]] + ) + + # Real scores + if corruption_scheme == "t": + real_scores = score_fn.score_tails( + unsharded_entity_table[triple_reordered[:, 0]], + triple_reordered[:, 1], + unsharded_entity_table.unsqueeze(0), + ) + else: + real_scores = score_fn.score_heads( + unsharded_entity_table.unsqueeze(0), + triple_reordered[:, 1], + unsharded_entity_table[triple_reordered[:, 2]], + ) + + assert_close(real_scores, out["scores"], atol=1e-3, rtol=1e-4) + + if evaluation: + ground_truth_col = 0 if corruption_scheme == "h" else 2 + real_scores_masked = torch.clone(real_scores) + pos_scores = real_scores_masked[ + torch.arange(real_scores.shape[0]), triple_reordered[:, ground_truth_col] + ] + real_scores_masked[ + torch.arange(real_scores.shape[0]), triple_reordered[:, ground_truth_col] + ] = -torch.inf + + real_ranks = evaluation.ranks_from_scores(pos_scores, real_scores_masked) + + # we allow for a off-by-one rank difference on at most 1% of triples, + # due to rounding differences in CPU vs IPU score computations + assert torch.all(torch.abs(real_ranks - out["ranks"]) <= 1) + assert (real_ranks != out["ranks"]).sum() < n_test_triple / 100 From 282a0eaa4d834cfa923a8ca155ef8d5df4186993 Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo Date: Tue, 24 Oct 2023 13:33:07 +0000 Subject: [PATCH 02/11] notation consistency --- besskge/pipeline.py | 2 +- tests/test_pipeline.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/besskge/pipeline.py b/besskge/pipeline.py index 3af955f..89a9680 100644 --- a/besskge/pipeline.py +++ b/besskge/pipeline.py @@ -188,7 +188,7 @@ def forward(self) -> Dict[str, Any]: out = dict(scores=torch.concat(scores, dim=0)) if ids: - out["triple_ids"] = torch.concat(ids, dim=0) + out["triple_idx"] = torch.concat(ids, dim=0) if self.evaluation: concat_metrics = dict() for m in metrics[0].keys(): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 0f2c447..bb20071 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -90,7 +90,7 @@ def test_all_scores_pipeline(corruption_scheme: str, compute_metrics: bool) -> N # Shuffle triples in same order of out["scores"] triple_reordered = torch.from_numpy( - ds.triples["test"][partitioned_triple_set.triple_sort_idx[out["triple_ids"]]] + ds.triples["test"][partitioned_triple_set.triple_sort_idx[out["triple_idx"]]] ) # Real scores From 674fd996ab9e76dbde3cf4c2fd04eba27d3b68da Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo Date: Tue, 24 Oct 2023 13:52:06 +0000 Subject: [PATCH 03/11] IpuModel for CI --- besskge/pipeline.py | 4 ++++ tests/test_pipeline.py | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/besskge/pipeline.py b/besskge/pipeline.py index 89a9680..e839d11 100644 --- a/besskge/pipeline.py +++ b/besskge/pipeline.py @@ -35,6 +35,7 @@ def __init__( score_fn: BaseScoreFunction, evaluation: Optional[Evaluation] = None, windows_size: int = 1000, + use_ipu_model: bool = False, ) -> None: """ Initialize pipeline. @@ -55,6 +56,8 @@ def __init__( scored against each query at each step on IPU and returned to host. Should be decreased with large batch sizes, to avoid an OOM error. Default: 1000. + :param use_ipu_model: + Run pipeline on IPU Model instead of actual hardware. Default: False. """ super().__init__() self.batch_sampler = batch_sampler @@ -88,6 +91,7 @@ def __init__( inf_options.replication_factor = self.bess_module.sharding.n_shard inf_options.deviceIterations(self.batch_sampler.batches_per_step) inf_options.outputMode(poptorch.OutputMode.All) + inf_options.useIpuModel(use_ipu_model) self.dl = self.batch_sampler.get_dataloader(options=inf_options, shuffle=False) self.poptorch_module = poptorch.inferenceModel( diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index bb20071..675c75a 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -84,7 +84,12 @@ def test_all_scores_pipeline(corruption_scheme: str, compute_metrics: bool) -> N evaluation = None pipeline = AllScoresPipeline( - test_bs, corruption_scheme, score_fn, evaluation, windows_size=1000 + test_bs, + corruption_scheme, + score_fn, + evaluation, + windows_size=1000, + use_ipu_model=True, ) out = pipeline() From da0ca26503d068182f9770ac353468c214ba1d12 Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo Date: Tue, 24 Oct 2023 14:10:16 +0000 Subject: [PATCH 04/11] increase CI timeout --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 345118f..d3f3a96 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -14,7 +14,7 @@ jobs: ci: runs-on: ubuntu-latest container: graphcore/pytorch:3.3.0-ubuntu-20.04-20230703 - timeout-minutes: 15 + timeout-minutes: 25 steps: - uses: actions/checkout@v3 - name: Install dev-requirements From fbe91e9f8f8be2452171d1cebf2713138969b4d7 Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo Date: Tue, 24 Oct 2023 14:26:52 +0000 Subject: [PATCH 05/11] reduce CI memory requirements --- tests/test_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 675c75a..33cd705 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -14,10 +14,10 @@ from besskge.sharding import PartitionedTripleSet, Sharding seed = 1234 -n_entity = 20000 +n_entity = 5000 n_relation_type = 50 n_shard = 4 -n_test_triple = 5000 +n_test_triple = 1000 batches_per_step = 2 shard_bs = 400 embedding_size = 128 From c6325c49be358f253d81831b37683b86be8bbd5a Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo Date: Wed, 25 Oct 2023 10:27:27 +0000 Subject: [PATCH 06/11] make return all scores optional, to avoid host OOM --- besskge/pipeline.py | 18 ++++++++++++++++-- tests/test_pipeline.py | 1 + 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/besskge/pipeline.py b/besskge/pipeline.py index e839d11..d27c240 100644 --- a/besskge/pipeline.py +++ b/besskge/pipeline.py @@ -34,6 +34,7 @@ def __init__( corruption_scheme: str, score_fn: BaseScoreFunction, evaluation: Optional[Evaluation] = None, + return_scores: bool = False, windows_size: int = 1000, use_ipu_model: bool = False, ) -> None: @@ -51,6 +52,11 @@ def __init__( :param evaluation: Evaluation module, for computing metrics. Default: None. + :param return_scores: + If True, store and return scores of all queries' completions. + For large number of queries/entities, this can cause the host + to go OOM. + Default: False. :param windows_size: Size of the sliding window, namely the number of negative entities scored against each query at each step on IPU and returned to host. @@ -61,6 +67,10 @@ def __init__( """ super().__init__() self.batch_sampler = batch_sampler + if not (evaluation or return_scores): + raise ValueError( + "Nothing to return. Provide `evaluation` or set" " `return_scores=True`" + ) if corruption_scheme not in ["h", "t"]: raise ValueError("corruption_scheme needs to be either 'h' or 't'") if ( @@ -82,6 +92,7 @@ def __init__( ) self.score_fn = score_fn self.evaluation = evaluation + self.return_scores = return_scores self.window_size = windows_size self.bess_module = AllScoresBESS( self.candidate_sampler, self.score_fn, self.window_size @@ -167,7 +178,8 @@ def forward(self) -> Dict[str, Any]: batch_scores_filt = batch_scores[triple_mask.flatten()][ :, np.unique(np.concatenate(batch_idx), return_index=True)[1] ][:, : self.bess_module.sharding.n_entity] - scores.append(torch.clone(batch_scores_filt)) + if self.return_scores: + scores.append(torch.clone(batch_scores_filt)) if self.evaluation: assert ( @@ -190,7 +202,9 @@ def forward(self) -> Dict[str, Any]: if self.evaluation.return_ranks: ranks.append(batch_ranks) - out = dict(scores=torch.concat(scores, dim=0)) + out = dict() + if scores: + out["scores"] = torch.concat(scores, dim=0) if ids: out["triple_idx"] = torch.concat(ids, dim=0) if self.evaluation: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 33cd705..7de30c5 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -88,6 +88,7 @@ def test_all_scores_pipeline(corruption_scheme: str, compute_metrics: bool) -> N corruption_scheme, score_fn, evaluation, + return_scores=True, windows_size=1000, use_ipu_model=True, ) From 2cb2b17cddc8938335a39b103d74b7dddfd9c3a0 Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo Date: Wed, 25 Oct 2023 10:28:27 +0000 Subject: [PATCH 07/11] fix typo --- besskge/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/besskge/pipeline.py b/besskge/pipeline.py index d27c240..d77b330 100644 --- a/besskge/pipeline.py +++ b/besskge/pipeline.py @@ -69,7 +69,7 @@ def __init__( self.batch_sampler = batch_sampler if not (evaluation or return_scores): raise ValueError( - "Nothing to return. Provide `evaluation` or set" " `return_scores=True`" + "Nothing to return. Provide `evaluation` or set `return_scores=True`" ) if corruption_scheme not in ["h", "t"]: raise ValueError("corruption_scheme needs to be either 'h' or 't'") From 9dee8b153de1888e5794c9688ccdc442e46b0d8d Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo Date: Wed, 25 Oct 2023 10:33:59 +0000 Subject: [PATCH 08/11] fix typo --- besskge/pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/besskge/pipeline.py b/besskge/pipeline.py index d77b330..73e6832 100644 --- a/besskge/pipeline.py +++ b/besskge/pipeline.py @@ -102,7 +102,8 @@ def __init__( inf_options.replication_factor = self.bess_module.sharding.n_shard inf_options.deviceIterations(self.batch_sampler.batches_per_step) inf_options.outputMode(poptorch.OutputMode.All) - inf_options.useIpuModel(use_ipu_model) + if use_ipu_model: + inf_options.useIpuModel(True) self.dl = self.batch_sampler.get_dataloader(options=inf_options, shuffle=False) self.poptorch_module = poptorch.inferenceModel( From 57264829852a117b07dc92a3e32a63ff7abe4936 Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo Date: Thu, 26 Oct 2023 10:30:09 +0000 Subject: [PATCH 09/11] add triple filtering to AllScoresPipeline --- besskge/pipeline.py | 80 ++++++++++++++++++++++++++------ besskge/utils.py | 36 +++++++++++++++ tests/test_pipeline.py | 101 +++++++++++++++++++++++++++++------------ 3 files changed, 174 insertions(+), 43 deletions(-) diff --git a/besskge/pipeline.py b/besskge/pipeline.py index 73e6832..3396df7 100644 --- a/besskge/pipeline.py +++ b/besskge/pipeline.py @@ -4,11 +4,12 @@ High-level APIs for training/inference with BESS. """ -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Union import numpy as np import poptorch import torch +from numpy.typing import NDArray from tqdm import tqdm from besskge.batch_sampler import ShardedBatchSampler @@ -16,13 +17,15 @@ from besskge.metric import Evaluation from besskge.negative_sampler import PlaceholderNegativeSampler from besskge.scoring import BaseScoreFunction +from besskge.utils import get_entity_filter class AllScoresPipeline(torch.nn.Module): """ - Pipeline to compute scores of (h, r, ?) / (?, r, t) - queries against all entities in the KG, and related - prediction metrics. + Pipeline to compute scores of (h, r, ?) / (?, r, t) queries against all entities + in the KG, and related prediction metrics. + It supports filtering out the scores of specific completions that appear in a given + set of triples. To be used in combination with a batch sampler based on a "h_shard"/"t_shard"-partitioned triple set. @@ -34,6 +37,7 @@ def __init__( corruption_scheme: str, score_fn: BaseScoreFunction, evaluation: Optional[Evaluation] = None, + filter_triples: Optional[List[Union[torch.Tensor, NDArray[np.int32]]]] = None, return_scores: bool = False, windows_size: int = 1000, use_ipu_model: bool = False, @@ -52,8 +56,13 @@ def __init__( :param evaluation: Evaluation module, for computing metrics. Default: None. + :param filter_triples: + The set of all triples whose scores need to be filtered. + The triples passed here must have GLOBAL IDs for head/tail + entities. Default: None. :param return_scores: - If True, store and return scores of all queries' completions. + If True, store and return scores of all queries' completions + (with filters applied, if specified). For large number of queries/entities, this can cause the host to go OOM. Default: False. @@ -94,6 +103,7 @@ def __init__( self.evaluation = evaluation self.return_scores = return_scores self.window_size = windows_size + self.corruption_scheme = corruption_scheme self.bess_module = AllScoresBESS( self.candidate_sampler, self.score_fn, self.window_size ) @@ -114,6 +124,36 @@ def __init__( 0, poptorch.VariableRetrievalMode.OnePerGroup, ) + self.filter_triples: Optional[torch.Tensor] = None + if filter_triples: + # Reconstruct global IDs for all entities in triples + local_id_col = ( + 0 if self.batch_sampler.triple_partition_mode == "h_shard" else 2 + ) + triple_shard_offset = np.concatenate( + [np.array([0]), np.cumsum(batch_sampler.triple_counts)] + ) + global_id_triples = [] + for i in range(len(triple_shard_offset) - 1): + shard_triples = np.copy( + batch_sampler.triples[ + triple_shard_offset[i] : triple_shard_offset[i + 1] + ] + ) + shard_triples[ + :, local_id_col + ] = self.bess_module.sharding.shard_and_idx_to_entity[i][ + shard_triples[:, local_id_col] + ] + global_id_triples.append(shard_triples) + self.triples = torch.from_numpy(np.concatenate(global_id_triples, axis=0)) + self.filter_triples = torch.concat( + [ + tr if isinstance(tr, torch.Tensor) else torch.from_numpy(tr) + for tr in filter_triples + ], + dim=0, + ) def forward(self) -> Dict[str, Any]: """ @@ -179,19 +219,26 @@ def forward(self) -> Dict[str, Any]: batch_scores_filt = batch_scores[triple_mask.flatten()][ :, np.unique(np.concatenate(batch_idx), return_index=True)[1] ][:, : self.bess_module.sharding.n_entity] - if self.return_scores: - scores.append(torch.clone(batch_scores_filt)) + # Scores of positive triples + true_scores = batch_scores_filt[ + torch.arange(batch_scores_filt.shape[0]), + ground_truth[triple_mask], + ] + if self.filter_triples is not None: + # Filter for triples in batch + batch_filter = get_entity_filter( + self.triples[triple_id[triple_mask]], + self.filter_triples, + filter_mode=self.corruption_scheme, + ) + batch_scores_filt[batch_filter[:, 0], batch_filter[:, 1]] = -torch.inf if self.evaluation: assert ( ground_truth is not None ), "Evaluation requires providing ground truth entities" - # Scores of true triples - true_scores = batch_scores_filt[ - torch.arange(batch_scores_filt.shape[0]), - ground_truth[triple_mask], - ] - # Mask scores of true triples to compute ranks + # If not already masked, mask scores of true triples + # to compute metrics batch_scores_filt[ torch.arange(batch_scores_filt.shape[0]), ground_truth[triple_mask], @@ -202,6 +249,13 @@ def forward(self) -> Dict[str, Any]: metrics.append(self.evaluation.dict_metrics_from_ranks(batch_ranks)) if self.evaluation.return_ranks: ranks.append(batch_ranks) + if self.return_scores: + # Restore positive scores in the returned scores + batch_scores_filt[ + torch.arange(batch_scores_filt.shape[0]), + ground_truth[triple_mask], + ] = true_scores + scores.append(batch_scores_filt) out = dict() if scores: diff --git a/besskge/utils.py b/besskge/utils.py index 31bd6a9..50f68c0 100644 --- a/besskge/utils.py +++ b/besskge/utils.py @@ -33,6 +33,42 @@ def gather_indices(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor: return x.view(-1, mask_size) +def get_entity_filter( + triples: torch.Tensor, filter_triples: torch.Tensor, filter_mode: str +) -> torch.Tensor: + """ + Compare two sets of triples: for each triple (h,r,t) in the first set, find + the entities `e` such that (e,r,t) (or (h,r,e), depending on `filter_mode`) + appears in the second set of triples. + + :param triples: shape (x, 3) + The set of triples to construct filters for. + :param filter_triples: shape (y, 3) + The set of triples determining the head/tail entities to filter. + :param filter_mode: + Set to "h" to look for entities appearing as heads of the same (r,t) pair, + or to "t" to look for entities appearing as tails of the same (h,r) pair. + + :return: shape (z, 2) + The sparse filters. Each row is given by a tuple (i, j), with i the index + of the triple in `triples` to which the filter applies to and j the global + ID of the entity to filter. + """ + if filter_mode == "t": + ent_col = 0 + elif filter_mode == "h": + ent_col = 2 + else: + raise ValueError("`filter_mode` needs to be either 'h' or 't'") + relation_filter = (filter_triples[:, 1]) == triples[:, 1].view(-1, 1) + entity_filter = (filter_triples[:, ent_col]) == triples[:, ent_col].view(-1, 1) + + filter = (entity_filter & relation_filter).nonzero(as_tuple=False) + filter[:, 1] = filter_triples[:, 2 - ent_col].view(1, -1)[:, filter[:, 1]] + + return filter + + def complex_multiplication(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor: """ Batched complex multiplication. diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 7de30c5..e38434f 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -12,6 +12,7 @@ from besskge.pipeline import AllScoresPipeline from besskge.scoring import ComplEx from besskge.sharding import PartitionedTripleSet, Sharding +from besskge.utils import get_entity_filter seed = 1234 n_entity = 5000 @@ -30,15 +31,19 @@ unsharded_entity_table = torch.randn(size=(n_entity, 2 * embedding_size)) relation_table = torch.randn(size=(n_relation_type, 2 * embedding_size)) -test_triples_h = np.random.randint(n_entity, size=(n_test_triple,)) -test_triples_t = np.random.randint(n_entity, size=(n_test_triple,)) -test_triples_r = np.random.randint(n_relation_type, size=(n_test_triple,)) +test_triples_h = np.random.choice(n_entity - 1, size=n_test_triple, replace=False) +test_triples_t = np.random.choice(n_entity - 1, size=n_test_triple, replace=False) +test_triples_r = np.random.randint(n_relation_type, size=n_test_triple) triples = {"test": np.stack([test_triples_h, test_triples_r, test_triples_t], axis=1)} @pytest.mark.parametrize("corruption_scheme", ["h", "t"]) -@pytest.mark.parametrize("compute_metrics", [True, False]) -def test_all_scores_pipeline(corruption_scheme: str, compute_metrics: bool) -> None: +@pytest.mark.parametrize( + "filter_scores, extra_only", [(True, True), (True, False), (False, False)] +) +def test_all_scores_pipeline( + corruption_scheme: str, filter_scores: bool, extra_only: bool +) -> None: ds = KGDataset( n_entity=n_entity, n_relation_type=n_relation_type, @@ -76,18 +81,28 @@ def test_all_scores_pipeline(corruption_scheme: str, compute_metrics: bool) -> N return_triple_idx=True, ) - if compute_metrics: - evaluation = Evaluation( - ["mrr", "hits@10"], mode="average", reduction="sum", return_ranks=True + evaluation = Evaluation( + ["mrr", "hits@10"], mode="average", reduction="sum", return_ranks=True + ) + + ground_truth_col = 0 if corruption_scheme == "h" else 2 + if filter_scores: + extra_filter_triples = np.copy(triples["test"]) + extra_filter_triples[:, ground_truth_col] += 1 + triples_to_filter = ( + [extra_filter_triples] + if extra_only + else [triples["test"], extra_filter_triples] ) else: - evaluation = None + triples_to_filter = None pipeline = AllScoresPipeline( test_bs, corruption_scheme, score_fn, evaluation, + filter_triples=triples_to_filter, # type: ignore return_scores=True, windows_size=1000, use_ipu_model=True, @@ -99,35 +114,61 @@ def test_all_scores_pipeline(corruption_scheme: str, compute_metrics: bool) -> N ds.triples["test"][partitioned_triple_set.triple_sort_idx[out["triple_idx"]]] ) - # Real scores + # All scores, computed on CPU if corruption_scheme == "t": - real_scores = score_fn.score_tails( + cpu_scores = score_fn.score_tails( unsharded_entity_table[triple_reordered[:, 0]], triple_reordered[:, 1], unsharded_entity_table.unsqueeze(0), ) else: - real_scores = score_fn.score_heads( + cpu_scores = score_fn.score_heads( unsharded_entity_table.unsqueeze(0), triple_reordered[:, 1], unsharded_entity_table[triple_reordered[:, 2]], ) - assert_close(real_scores, out["scores"], atol=1e-3, rtol=1e-4) - - if evaluation: - ground_truth_col = 0 if corruption_scheme == "h" else 2 - real_scores_masked = torch.clone(real_scores) - pos_scores = real_scores_masked[ - torch.arange(real_scores.shape[0]), triple_reordered[:, ground_truth_col] - ] - real_scores_masked[ - torch.arange(real_scores.shape[0]), triple_reordered[:, ground_truth_col] - ] = -torch.inf - - real_ranks = evaluation.ranks_from_scores(pos_scores, real_scores_masked) - - # we allow for a off-by-one rank difference on at most 1% of triples, - # due to rounding differences in CPU vs IPU score computations - assert torch.all(torch.abs(real_ranks - out["ranks"]) <= 1) - assert (real_ranks != out["ranks"]).sum() < n_test_triple / 100 + pos_scores = score_fn.score_triple( + unsharded_entity_table[triple_reordered[:, 0]], + triple_reordered[:, 1], + unsharded_entity_table[triple_reordered[:, 2]], + ).flatten() + # mask positive scores to compute metrics + cpu_scores[ + torch.arange(cpu_scores.shape[0]), triple_reordered[:, ground_truth_col] + ] = -torch.inf + if filter_scores: + filter_triples = torch.from_numpy( + np.concatenate(triples_to_filter, axis=0) # type: ignore + ) + tr_filter = get_entity_filter( + triple_reordered, filter_triples, corruption_scheme + ) + cpu_scores[tr_filter[:, 0], tr_filter[:, 1]] = -torch.inf + # check filters (for convenience, here h/t entities are non-repeating) + if extra_only: + assert torch.all(tr_filter[:, 0] == torch.arange(n_test_triple)) + assert torch.all( + tr_filter[:, 1] == triple_reordered[:, ground_truth_col] + 1 + ) + else: + assert torch.all(tr_filter[::2, 0] == tr_filter[1::2, 0]) + assert torch.all(tr_filter[::2, 0] == torch.arange(n_test_triple)) + assert torch.all(tr_filter[::2, 1] == triple_reordered[:, ground_truth_col]) + assert torch.all( + tr_filter[1::2, 1] == triple_reordered[:, ground_truth_col] + 1 + ) + + cpu_ranks = evaluation.ranks_from_scores(pos_scores, cpu_scores) + + # we allow for a off-by-one rank difference on at most 1% of triples, + # due to rounding differences in CPU vs IPU score computations + assert torch.all(torch.abs(cpu_ranks - out["ranks"]) <= 1) + assert (cpu_ranks != out["ranks"]).sum() < n_test_triple / 100 + + # restore positive scores + cpu_scores[ + torch.arange(cpu_scores.shape[0]), triple_reordered[:, ground_truth_col] + ] = pos_scores + + assert_close(cpu_scores, out["scores"], atol=1e-3, rtol=1e-4) From 247d88b4c3a38d816eb66f409e6887d111671d5f Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo Date: Fri, 27 Oct 2023 10:36:01 +0000 Subject: [PATCH 10/11] add return_topk option to AllScoresPipeline --- besskge/pipeline.py | 41 ++++++++++++++++++++++++++++++----------- tests/test_pipeline.py | 6 +++++- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/besskge/pipeline.py b/besskge/pipeline.py index 3396df7..040ea14 100644 --- a/besskge/pipeline.py +++ b/besskge/pipeline.py @@ -39,6 +39,8 @@ def __init__( evaluation: Optional[Evaluation] = None, filter_triples: Optional[List[Union[torch.Tensor, NDArray[np.int32]]]] = None, return_scores: bool = False, + return_topk: bool = False, + k: int = 10, windows_size: int = 1000, use_ipu_model: bool = False, ) -> None: @@ -66,6 +68,13 @@ def __init__( For large number of queries/entities, this can cause the host to go OOM. Default: False. + :param return_topk: + If True, return for each query the global IDs of the most likely + completions, after filtering out the scores of `filter_triples`. + Default: False. + :param k: + If `return_topk` is set to True, for each query return the + top-k most likely predictions (after filtering). Default: 10. :param windows_size: Size of the sliding window, namely the number of negative entities scored against each query at each step on IPU and returned to host. @@ -102,6 +111,8 @@ def __init__( self.score_fn = score_fn self.evaluation = evaluation self.return_scores = return_scores + self.return_topk = return_topk + self.k = k self.window_size = windows_size self.corruption_scheme = corruption_scheme self.bess_module = AllScoresBESS( @@ -168,6 +179,7 @@ def forward(self) -> Dict[str, Any]: ids = [] metrics = [] ranks = [] + topk_ids = [] n_triple = 0 for batch in tqdm(iter(self.dl)): triple_mask = batch.pop("triple_mask") @@ -219,11 +231,12 @@ def forward(self) -> Dict[str, Any]: batch_scores_filt = batch_scores[triple_mask.flatten()][ :, np.unique(np.concatenate(batch_idx), return_index=True)[1] ][:, : self.bess_module.sharding.n_entity] - # Scores of positive triples - true_scores = batch_scores_filt[ - torch.arange(batch_scores_filt.shape[0]), - ground_truth[triple_mask], - ] + if ground_truth is not None: + # Scores of positive triples + true_scores = batch_scores_filt[ + torch.arange(batch_scores_filt.shape[0]), + ground_truth[triple_mask], + ] if self.filter_triples is not None: # Filter for triples in batch batch_filter = get_entity_filter( @@ -249,29 +262,35 @@ def forward(self) -> Dict[str, Any]: metrics.append(self.evaluation.dict_metrics_from_ranks(batch_ranks)) if self.evaluation.return_ranks: ranks.append(batch_ranks) - if self.return_scores: + if ground_truth is not None: # Restore positive scores in the returned scores batch_scores_filt[ torch.arange(batch_scores_filt.shape[0]), ground_truth[triple_mask], ] = true_scores + if self.return_scores: scores.append(batch_scores_filt) + if self.return_topk: + topk_ids.append(torch.topk(batch_scores_filt, k=self.k, dim=-1).indices) out = dict() if scores: out["scores"] = torch.concat(scores, dim=0) + if topk_ids: + out["topk_global_id"] = torch.concat(topk_ids, dim=0) if ids: out["triple_idx"] = torch.concat(ids, dim=0) if self.evaluation: - concat_metrics = dict() + final_metrics = dict() for m in metrics[0].keys(): - concat_metrics[m] = torch.concat( - [met[m].reshape(-1) for met in metrics] + # Reduce metrics over all batches + final_metrics[m] = self.evaluation.reduction( + torch.concat([met[m].reshape(-1) for met in metrics]) ) - out["metrics"] = concat_metrics # type: ignore + out["metrics"] = final_metrics # type: ignore # Average metrics over all triples out["metrics_avg"] = { - m: v.sum() / n_triple for m, v in concat_metrics.items() + m: v.sum() / n_triple for m, v in final_metrics.items() } # type: ignore if ranks: out["ranks"] = torch.concat(ranks, dim=0) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index e38434f..cd336df 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -104,6 +104,8 @@ def test_all_scores_pipeline( evaluation, filter_triples=triples_to_filter, # type: ignore return_scores=True, + return_topk=True, + k=10, windows_size=1000, use_ipu_model=True, ) @@ -160,7 +162,6 @@ def test_all_scores_pipeline( ) cpu_ranks = evaluation.ranks_from_scores(pos_scores, cpu_scores) - # we allow for a off-by-one rank difference on at most 1% of triples, # due to rounding differences in CPU vs IPU score computations assert torch.all(torch.abs(cpu_ranks - out["ranks"]) <= 1) @@ -172,3 +173,6 @@ def test_all_scores_pipeline( ] = pos_scores assert_close(cpu_scores, out["scores"], atol=1e-3, rtol=1e-4) + assert torch.all( + torch.topk(cpu_scores, k=pipeline.k, dim=-1).indices == out["topk_global_id"] + ) From d30fff80be7b2188b676db4f470637b0d5fadd43 Mon Sep 17 00:00:00 2001 From: Alberto Cattaneo Date: Fri, 27 Oct 2023 10:52:39 +0000 Subject: [PATCH 11/11] fix typo --- besskge/pipeline.py | 6 +++--- tests/test_pipeline.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/besskge/pipeline.py b/besskge/pipeline.py index 040ea14..49e8507 100644 --- a/besskge/pipeline.py +++ b/besskge/pipeline.py @@ -41,7 +41,7 @@ def __init__( return_scores: bool = False, return_topk: bool = False, k: int = 10, - windows_size: int = 1000, + window_size: int = 1000, use_ipu_model: bool = False, ) -> None: """ @@ -75,7 +75,7 @@ def __init__( :param k: If `return_topk` is set to True, for each query return the top-k most likely predictions (after filtering). Default: 10. - :param windows_size: + :param window_size: Size of the sliding window, namely the number of negative entities scored against each query at each step on IPU and returned to host. Should be decreased with large batch sizes, to avoid an OOM error. @@ -113,7 +113,7 @@ def __init__( self.return_scores = return_scores self.return_topk = return_topk self.k = k - self.window_size = windows_size + self.window_size = window_size self.corruption_scheme = corruption_scheme self.bess_module = AllScoresBESS( self.candidate_sampler, self.score_fn, self.window_size diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index cd336df..4f42f23 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -106,7 +106,7 @@ def test_all_scores_pipeline( return_scores=True, return_topk=True, k=10, - windows_size=1000, + window_size=1000, use_ipu_model=True, ) out = pipeline()