Skip to content

Commit

Permalink
AllScoresBESS and high-level pipeline (#33)
Browse files Browse the repository at this point in the history
* implement AllScoresBESS and high-level pipeline

* notation consistency

* IpuModel for CI

* increase CI timeout

* reduce CI memory requirements

* make return all scores optional, to avoid host OOM

* fix typo

* fix typo

* add triple filtering to AllScoresPipeline

* add return_topk option to AllScoresPipeline

* fix typo
  • Loading branch information
AlCatt91 authored Oct 27, 2023
1 parent 1086c6c commit 0fdb741
Show file tree
Hide file tree
Showing 8 changed files with 663 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
ci:
runs-on: ubuntu-latest
container: graphcore/pytorch:3.3.0-ubuntu-20.04-20230703
timeout-minutes: 15
timeout-minutes: 25
steps:
- uses: actions/checkout@v3
- name: Install dev-requirements
Expand Down
1 change: 1 addition & 0 deletions besskge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def load_custom_ops_so() -> None:
loss,
metric,
negative_sampler,
pipeline,
scoring,
sharding,
utils,
Expand Down
153 changes: 147 additions & 6 deletions besskge/bess.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,22 +707,22 @@ def forward(
and respective scores are kept to be used in the next iteration, while
the rest are discarded.
:param relation: shape: (shard_bs,)
:param relation: shape: (1, shard_bs,)
Relation indices.
:param head: shape: (shard_bs,)
:param head: shape: (1, shard_bs,)
Head indices, if known. Default: None.
:param tail: shape: (shard_bs,)
:param tail: shape: (1, shard_bs,)
Tail indices, if known. Default: None.
:param negative: shape: (n_shard, B, padded_negative)
:param negative: shape: (1, n_shard, B, padded_negative)
Candidates to score against the queries.
It can be the same set for all queries (B=1),
or specific for each query in the batch (B=shard_bs).
If None, score each query against all entities in the knowledge
graph. Default: None.
:param triple_mask: shape: (shard_bs,)
:param triple_mask: shape: (1, shard_bs,)
Mask to filter the triples in the micro-batch
before computing metrics. Default: None.
:param negative_mask: shape: (n_shard, B, padded_negative)
:param negative_mask: shape: (1, n_shard, B, padded_negative)
If candidates are provided, mask to discard padding
negatives when computing best completions.
Requires the use of :code:`mask_on_gather=True` in the candidate
Expand Down Expand Up @@ -919,3 +919,144 @@ def loop_body(
)

return out_dict


class AllScoresBESS(torch.nn.Module):
"""
Distributed scoring of (h, r, ?) or (?, r, t) queries against
the entities in the knowledge graph, returning all scores to
host in blocks, based on the BESS :cite:p:`BESS`
inference scheme.
To be used in combination with a batch sampler based on a
"h_shard"/"t_shard"-partitioned triple set.
Since each iteration on IPU computes only part of the scores
(based on the size of the sliding window), metrics should be
computed on host after aggregating data (see
:class:`besskge.pipeline.AllScoresPipeline`).
Only to be used for inference.
"""

def __init__(
self,
candidate_sampler: PlaceholderNegativeSampler,
score_fn: BaseScoreFunction,
window_size: int = 1000,
) -> None:
"""
Initialize AllScores BESS-KGE module.
:param candidate_sampler:
:class:`besskge.negative_sampler.PlaceholderNegativeSampler` class,
specifying corruption scheme.
:param score_fn:
Scoring function.
:param window_size:
Size of the sliding window, namely the number of negative entities
scored against each query at each step on IPU and returned to host.
Should be decreased with large batch sizes, to avoid an OOM error.
Default: 1000.
"""
super().__init__()
self.sharding = score_fn.sharding
self.score_fn = score_fn
self.negative_sampler = candidate_sampler
self.window_size = window_size

if not score_fn.negative_sample_sharing:
raise ValueError("AllScoresBESS requires using negative sample sharing")

if self.negative_sampler.corruption_scheme not in ["h", "t"]:
raise ValueError("AllScoresBESS only support 'h', 't' corruption scheme")

if not isinstance(self.negative_sampler, PlaceholderNegativeSampler):
raise ValueError(
"AllScoresBESS requires a `PlaceholderNegativeSampler`"
" candidate_sampler"
)

self.entity_embedding = self.score_fn.entity_embedding
self.entity_embedding_size: int = self.entity_embedding.shape[-1]

self.candidate = torch.arange(self.window_size, dtype=torch.int32)
self.n_step = int(
np.ceil(self.sharding.max_entity_per_shard / self.window_size)
)

def forward(
self,
step: torch.Tensor,
relation: torch.Tensor,
head: Optional[torch.Tensor] = None,
tail: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward step.
Similarly to :class:`ScoreMovingBessKGE`, candidates are scored on the
device where they are gathered, then scores for the same query against
candidates in different shards are collected together via an AllToAll.
:param step:
The index of the block (of size self.window_size) of entities
on each IPU to score against queries.
:param relation: shape: (1, shard_bs,)
Relation indices.
:param head: shape: (1, shard_bs,)
Head indices, if known. Default: None.
:param tail: shape: (1, shard_bs,)
Tail indices, if known. Default: None.
:return:
The scores for the completions.
"""

relation = relation.squeeze(0)
if head is not None:
head = head.squeeze(0)
if tail is not None:
tail = tail.squeeze(0)

n_shard = self.sharding.n_shard
shard_bs = relation.shape[0]

relation_all = all_gather(relation, n_shard)
if self.negative_sampler.corruption_scheme == "h":
tail_embedding = self.entity_embedding[tail]
tail_embedding_all = all_gather(tail_embedding, n_shard)
elif self.negative_sampler.corruption_scheme == "t":
head_embedding = self.entity_embedding[head]
head_embedding_all = all_gather(head_embedding, n_shard)

# Local indices of the entities to score against queries
ent_slice = torch.minimum(
step * self.window_size
+ torch.arange(self.window_size, device=relation.device),
torch.tensor(self.sharding.max_entity_per_shard - 1),
)
negative_embedding = self.entity_embedding[ent_slice]

if self.negative_sampler.corruption_scheme == "h":
scores = self.score_fn.score_heads(
negative_embedding,
relation_all.flatten(end_dim=1),
tail_embedding_all.flatten(end_dim=1),
)
elif self.negative_sampler.corruption_scheme == "t":
scores = self.score_fn.score_tails(
head_embedding_all.flatten(end_dim=1),
relation_all.flatten(end_dim=1),
negative_embedding,
)

# Send back queries to original shard
scores = (
all_to_all(
scores.reshape(n_shard, shard_bs, self.window_size),
n_shard,
)
.transpose(0, 1)
.flatten(start_dim=1)
) # shape (bs, n_shard * ws)

return scores
Loading

0 comments on commit 0fdb741

Please sign in to comment.