Skip to content

Commit

Permalink
add triple filtering to AllScoresPipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
AlCatt91 committed Oct 26, 2023
1 parent 9dee8b1 commit 5726482
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 43 deletions.
80 changes: 67 additions & 13 deletions besskge/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,28 @@
High-level APIs for training/inference with BESS.
"""

from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Union

import numpy as np
import poptorch
import torch
from numpy.typing import NDArray
from tqdm import tqdm

from besskge.batch_sampler import ShardedBatchSampler
from besskge.bess import AllScoresBESS
from besskge.metric import Evaluation
from besskge.negative_sampler import PlaceholderNegativeSampler
from besskge.scoring import BaseScoreFunction
from besskge.utils import get_entity_filter


class AllScoresPipeline(torch.nn.Module):
"""
Pipeline to compute scores of (h, r, ?) / (?, r, t)
queries against all entities in the KG, and related
prediction metrics.
Pipeline to compute scores of (h, r, ?) / (?, r, t) queries against all entities
in the KG, and related prediction metrics.
It supports filtering out the scores of specific completions that appear in a given
set of triples.
To be used in combination with a batch sampler based on a
"h_shard"/"t_shard"-partitioned triple set.
Expand All @@ -34,6 +37,7 @@ def __init__(
corruption_scheme: str,
score_fn: BaseScoreFunction,
evaluation: Optional[Evaluation] = None,
filter_triples: Optional[List[Union[torch.Tensor, NDArray[np.int32]]]] = None,
return_scores: bool = False,
windows_size: int = 1000,
use_ipu_model: bool = False,
Expand All @@ -52,8 +56,13 @@ def __init__(
:param evaluation:
Evaluation module, for computing metrics.
Default: None.
:param filter_triples:
The set of all triples whose scores need to be filtered.
The triples passed here must have GLOBAL IDs for head/tail
entities. Default: None.
:param return_scores:
If True, store and return scores of all queries' completions.
If True, store and return scores of all queries' completions
(with filters applied, if specified).
For large number of queries/entities, this can cause the host
to go OOM.
Default: False.
Expand Down Expand Up @@ -94,6 +103,7 @@ def __init__(
self.evaluation = evaluation
self.return_scores = return_scores
self.window_size = windows_size
self.corruption_scheme = corruption_scheme
self.bess_module = AllScoresBESS(
self.candidate_sampler, self.score_fn, self.window_size
)
Expand All @@ -114,6 +124,36 @@ def __init__(
0,
poptorch.VariableRetrievalMode.OnePerGroup,
)
self.filter_triples: Optional[torch.Tensor] = None
if filter_triples:
# Reconstruct global IDs for all entities in triples
local_id_col = (
0 if self.batch_sampler.triple_partition_mode == "h_shard" else 2
)
triple_shard_offset = np.concatenate(
[np.array([0]), np.cumsum(batch_sampler.triple_counts)]
)
global_id_triples = []
for i in range(len(triple_shard_offset) - 1):
shard_triples = np.copy(
batch_sampler.triples[
triple_shard_offset[i] : triple_shard_offset[i + 1]
]
)
shard_triples[
:, local_id_col
] = self.bess_module.sharding.shard_and_idx_to_entity[i][
shard_triples[:, local_id_col]
]
global_id_triples.append(shard_triples)
self.triples = torch.from_numpy(np.concatenate(global_id_triples, axis=0))
self.filter_triples = torch.concat(
[
tr if isinstance(tr, torch.Tensor) else torch.from_numpy(tr)
for tr in filter_triples
],
dim=0,
)

def forward(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -179,19 +219,26 @@ def forward(self) -> Dict[str, Any]:
batch_scores_filt = batch_scores[triple_mask.flatten()][
:, np.unique(np.concatenate(batch_idx), return_index=True)[1]
][:, : self.bess_module.sharding.n_entity]
if self.return_scores:
scores.append(torch.clone(batch_scores_filt))
# Scores of positive triples
true_scores = batch_scores_filt[
torch.arange(batch_scores_filt.shape[0]),
ground_truth[triple_mask],
]
if self.filter_triples is not None:
# Filter for triples in batch
batch_filter = get_entity_filter(
self.triples[triple_id[triple_mask]],
self.filter_triples,
filter_mode=self.corruption_scheme,
)
batch_scores_filt[batch_filter[:, 0], batch_filter[:, 1]] = -torch.inf

if self.evaluation:
assert (
ground_truth is not None
), "Evaluation requires providing ground truth entities"
# Scores of true triples
true_scores = batch_scores_filt[
torch.arange(batch_scores_filt.shape[0]),
ground_truth[triple_mask],
]
# Mask scores of true triples to compute ranks
# If not already masked, mask scores of true triples
# to compute metrics
batch_scores_filt[
torch.arange(batch_scores_filt.shape[0]),
ground_truth[triple_mask],
Expand All @@ -202,6 +249,13 @@ def forward(self) -> Dict[str, Any]:
metrics.append(self.evaluation.dict_metrics_from_ranks(batch_ranks))
if self.evaluation.return_ranks:
ranks.append(batch_ranks)
if self.return_scores:
# Restore positive scores in the returned scores
batch_scores_filt[
torch.arange(batch_scores_filt.shape[0]),
ground_truth[triple_mask],
] = true_scores
scores.append(batch_scores_filt)

out = dict()
if scores:
Expand Down
36 changes: 36 additions & 0 deletions besskge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,42 @@ def gather_indices(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
return x.view(-1, mask_size)


def get_entity_filter(
triples: torch.Tensor, filter_triples: torch.Tensor, filter_mode: str
) -> torch.Tensor:
"""
Compare two sets of triples: for each triple (h,r,t) in the first set, find
the entities `e` such that (e,r,t) (or (h,r,e), depending on `filter_mode`)
appears in the second set of triples.
:param triples: shape (x, 3)
The set of triples to construct filters for.
:param filter_triples: shape (y, 3)
The set of triples determining the head/tail entities to filter.
:param filter_mode:
Set to "h" to look for entities appearing as heads of the same (r,t) pair,
or to "t" to look for entities appearing as tails of the same (h,r) pair.
:return: shape (z, 2)
The sparse filters. Each row is given by a tuple (i, j), with i the index
of the triple in `triples` to which the filter applies to and j the global
ID of the entity to filter.
"""
if filter_mode == "t":
ent_col = 0
elif filter_mode == "h":
ent_col = 2
else:
raise ValueError("`filter_mode` needs to be either 'h' or 't'")
relation_filter = (filter_triples[:, 1]) == triples[:, 1].view(-1, 1)
entity_filter = (filter_triples[:, ent_col]) == triples[:, ent_col].view(-1, 1)

filter = (entity_filter & relation_filter).nonzero(as_tuple=False)
filter[:, 1] = filter_triples[:, 2 - ent_col].view(1, -1)[:, filter[:, 1]]

return filter


def complex_multiplication(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor:
"""
Batched complex multiplication.
Expand Down
101 changes: 71 additions & 30 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from besskge.pipeline import AllScoresPipeline
from besskge.scoring import ComplEx
from besskge.sharding import PartitionedTripleSet, Sharding
from besskge.utils import get_entity_filter

seed = 1234
n_entity = 5000
Expand All @@ -30,15 +31,19 @@
unsharded_entity_table = torch.randn(size=(n_entity, 2 * embedding_size))
relation_table = torch.randn(size=(n_relation_type, 2 * embedding_size))

test_triples_h = np.random.randint(n_entity, size=(n_test_triple,))
test_triples_t = np.random.randint(n_entity, size=(n_test_triple,))
test_triples_r = np.random.randint(n_relation_type, size=(n_test_triple,))
test_triples_h = np.random.choice(n_entity - 1, size=n_test_triple, replace=False)
test_triples_t = np.random.choice(n_entity - 1, size=n_test_triple, replace=False)
test_triples_r = np.random.randint(n_relation_type, size=n_test_triple)
triples = {"test": np.stack([test_triples_h, test_triples_r, test_triples_t], axis=1)}


@pytest.mark.parametrize("corruption_scheme", ["h", "t"])
@pytest.mark.parametrize("compute_metrics", [True, False])
def test_all_scores_pipeline(corruption_scheme: str, compute_metrics: bool) -> None:
@pytest.mark.parametrize(
"filter_scores, extra_only", [(True, True), (True, False), (False, False)]
)
def test_all_scores_pipeline(
corruption_scheme: str, filter_scores: bool, extra_only: bool
) -> None:
ds = KGDataset(
n_entity=n_entity,
n_relation_type=n_relation_type,
Expand Down Expand Up @@ -76,18 +81,28 @@ def test_all_scores_pipeline(corruption_scheme: str, compute_metrics: bool) -> N
return_triple_idx=True,
)

if compute_metrics:
evaluation = Evaluation(
["mrr", "hits@10"], mode="average", reduction="sum", return_ranks=True
evaluation = Evaluation(
["mrr", "hits@10"], mode="average", reduction="sum", return_ranks=True
)

ground_truth_col = 0 if corruption_scheme == "h" else 2
if filter_scores:
extra_filter_triples = np.copy(triples["test"])
extra_filter_triples[:, ground_truth_col] += 1
triples_to_filter = (
[extra_filter_triples]
if extra_only
else [triples["test"], extra_filter_triples]
)
else:
evaluation = None
triples_to_filter = None

pipeline = AllScoresPipeline(
test_bs,
corruption_scheme,
score_fn,
evaluation,
filter_triples=triples_to_filter, # type: ignore
return_scores=True,
windows_size=1000,
use_ipu_model=True,
Expand All @@ -99,35 +114,61 @@ def test_all_scores_pipeline(corruption_scheme: str, compute_metrics: bool) -> N
ds.triples["test"][partitioned_triple_set.triple_sort_idx[out["triple_idx"]]]
)

# Real scores
# All scores, computed on CPU
if corruption_scheme == "t":
real_scores = score_fn.score_tails(
cpu_scores = score_fn.score_tails(
unsharded_entity_table[triple_reordered[:, 0]],
triple_reordered[:, 1],
unsharded_entity_table.unsqueeze(0),
)
else:
real_scores = score_fn.score_heads(
cpu_scores = score_fn.score_heads(
unsharded_entity_table.unsqueeze(0),
triple_reordered[:, 1],
unsharded_entity_table[triple_reordered[:, 2]],
)

assert_close(real_scores, out["scores"], atol=1e-3, rtol=1e-4)

if evaluation:
ground_truth_col = 0 if corruption_scheme == "h" else 2
real_scores_masked = torch.clone(real_scores)
pos_scores = real_scores_masked[
torch.arange(real_scores.shape[0]), triple_reordered[:, ground_truth_col]
]
real_scores_masked[
torch.arange(real_scores.shape[0]), triple_reordered[:, ground_truth_col]
] = -torch.inf

real_ranks = evaluation.ranks_from_scores(pos_scores, real_scores_masked)

# we allow for a off-by-one rank difference on at most 1% of triples,
# due to rounding differences in CPU vs IPU score computations
assert torch.all(torch.abs(real_ranks - out["ranks"]) <= 1)
assert (real_ranks != out["ranks"]).sum() < n_test_triple / 100
pos_scores = score_fn.score_triple(
unsharded_entity_table[triple_reordered[:, 0]],
triple_reordered[:, 1],
unsharded_entity_table[triple_reordered[:, 2]],
).flatten()
# mask positive scores to compute metrics
cpu_scores[
torch.arange(cpu_scores.shape[0]), triple_reordered[:, ground_truth_col]
] = -torch.inf
if filter_scores:
filter_triples = torch.from_numpy(
np.concatenate(triples_to_filter, axis=0) # type: ignore
)
tr_filter = get_entity_filter(
triple_reordered, filter_triples, corruption_scheme
)
cpu_scores[tr_filter[:, 0], tr_filter[:, 1]] = -torch.inf
# check filters (for convenience, here h/t entities are non-repeating)
if extra_only:
assert torch.all(tr_filter[:, 0] == torch.arange(n_test_triple))
assert torch.all(
tr_filter[:, 1] == triple_reordered[:, ground_truth_col] + 1
)
else:
assert torch.all(tr_filter[::2, 0] == tr_filter[1::2, 0])
assert torch.all(tr_filter[::2, 0] == torch.arange(n_test_triple))
assert torch.all(tr_filter[::2, 1] == triple_reordered[:, ground_truth_col])
assert torch.all(
tr_filter[1::2, 1] == triple_reordered[:, ground_truth_col] + 1
)

cpu_ranks = evaluation.ranks_from_scores(pos_scores, cpu_scores)

# we allow for a off-by-one rank difference on at most 1% of triples,
# due to rounding differences in CPU vs IPU score computations
assert torch.all(torch.abs(cpu_ranks - out["ranks"]) <= 1)
assert (cpu_ranks != out["ranks"]).sum() < n_test_triple / 100

# restore positive scores
cpu_scores[
torch.arange(cpu_scores.shape[0]), triple_reordered[:, ground_truth_col]
] = pos_scores

assert_close(cpu_scores, out["scores"], atol=1e-3, rtol=1e-4)

0 comments on commit 5726482

Please sign in to comment.