Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AllScoresBESS and high-level pipeline #33

Merged
merged 11 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
ci:
runs-on: ubuntu-latest
container: graphcore/pytorch:3.3.0-ubuntu-20.04-20230703
timeout-minutes: 15
timeout-minutes: 25
steps:
- uses: actions/checkout@v3
- name: Install dev-requirements
Expand Down
1 change: 1 addition & 0 deletions besskge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def load_custom_ops_so() -> None:
loss,
metric,
negative_sampler,
pipeline,
scoring,
sharding,
utils,
Expand Down
153 changes: 147 additions & 6 deletions besskge/bess.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,22 +707,22 @@ def forward(
and respective scores are kept to be used in the next iteration, while
the rest are discarded.

:param relation: shape: (shard_bs,)
:param relation: shape: (1, shard_bs,)
Relation indices.
:param head: shape: (shard_bs,)
:param head: shape: (1, shard_bs,)
Head indices, if known. Default: None.
:param tail: shape: (shard_bs,)
:param tail: shape: (1, shard_bs,)
Tail indices, if known. Default: None.
:param negative: shape: (n_shard, B, padded_negative)
:param negative: shape: (1, n_shard, B, padded_negative)
Candidates to score against the queries.
It can be the same set for all queries (B=1),
or specific for each query in the batch (B=shard_bs).
If None, score each query against all entities in the knowledge
graph. Default: None.
:param triple_mask: shape: (shard_bs,)
:param triple_mask: shape: (1, shard_bs,)
Mask to filter the triples in the micro-batch
before computing metrics. Default: None.
:param negative_mask: shape: (n_shard, B, padded_negative)
:param negative_mask: shape: (1, n_shard, B, padded_negative)
If candidates are provided, mask to discard padding
negatives when computing best completions.
Requires the use of :code:`mask_on_gather=True` in the candidate
Expand Down Expand Up @@ -919,3 +919,144 @@ def loop_body(
)

return out_dict


class AllScoresBESS(torch.nn.Module):
"""
Distributed scoring of (h, r, ?) or (?, r, t) queries against
the entities in the knowledge graph, returning all scores to
host in blocks, based on the BESS :cite:p:`BESS`
inference scheme.
To be used in combination with a batch sampler based on a
"h_shard"/"t_shard"-partitioned triple set.
Since each iteration on IPU computes only part of the scores
(based on the size of the sliding window), metrics should be
computed on host after aggregating data (see
:class:`besskge.pipeline.AllScoresPipeline`).

Only to be used for inference.
"""

def __init__(
self,
candidate_sampler: PlaceholderNegativeSampler,
score_fn: BaseScoreFunction,
window_size: int = 1000,
) -> None:
"""
Initialize AllScores BESS-KGE module.

:param candidate_sampler:
:class:`besskge.negative_sampler.PlaceholderNegativeSampler` class,
specifying corruption scheme.
:param score_fn:
Scoring function.
:param window_size:
Size of the sliding window, namely the number of negative entities
scored against each query at each step on IPU and returned to host.
Should be decreased with large batch sizes, to avoid an OOM error.
Default: 1000.
"""
super().__init__()
self.sharding = score_fn.sharding
self.score_fn = score_fn
self.negative_sampler = candidate_sampler
self.window_size = window_size

if not score_fn.negative_sample_sharing:
raise ValueError("AllScoresBESS requires using negative sample sharing")

if self.negative_sampler.corruption_scheme not in ["h", "t"]:
raise ValueError("AllScoresBESS only support 'h', 't' corruption scheme")

if not isinstance(self.negative_sampler, PlaceholderNegativeSampler):
raise ValueError(
"AllScoresBESS requires a `PlaceholderNegativeSampler`"
" candidate_sampler"
)

self.entity_embedding = self.score_fn.entity_embedding
self.entity_embedding_size: int = self.entity_embedding.shape[-1]

self.candidate = torch.arange(self.window_size, dtype=torch.int32)
self.n_step = int(
np.ceil(self.sharding.max_entity_per_shard / self.window_size)
)

def forward(
self,
step: torch.Tensor,
relation: torch.Tensor,
head: Optional[torch.Tensor] = None,
tail: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward step.

Similarly to :class:`ScoreMovingBessKGE`, candidates are scored on the
device where they are gathered, then scores for the same query against
candidates in different shards are collected together via an AllToAll.

:param step:
The index of the block (of size self.window_size) of entities
on each IPU to score against queries.
:param relation: shape: (1, shard_bs,)
Relation indices.
:param head: shape: (1, shard_bs,)
Head indices, if known. Default: None.
:param tail: shape: (1, shard_bs,)
Tail indices, if known. Default: None.

:return:
The scores for the completions.
"""

relation = relation.squeeze(0)
if head is not None:
head = head.squeeze(0)
if tail is not None:
tail = tail.squeeze(0)

n_shard = self.sharding.n_shard
shard_bs = relation.shape[0]

relation_all = all_gather(relation, n_shard)
if self.negative_sampler.corruption_scheme == "h":
tail_embedding = self.entity_embedding[tail]
tail_embedding_all = all_gather(tail_embedding, n_shard)
elif self.negative_sampler.corruption_scheme == "t":
head_embedding = self.entity_embedding[head]
head_embedding_all = all_gather(head_embedding, n_shard)

# Local indices of the entities to score against queries
ent_slice = torch.minimum(
step * self.window_size
+ torch.arange(self.window_size, device=relation.device),
torch.tensor(self.sharding.max_entity_per_shard - 1),
)
negative_embedding = self.entity_embedding[ent_slice]

if self.negative_sampler.corruption_scheme == "h":
scores = self.score_fn.score_heads(
negative_embedding,
relation_all.flatten(end_dim=1),
tail_embedding_all.flatten(end_dim=1),
)
elif self.negative_sampler.corruption_scheme == "t":
scores = self.score_fn.score_tails(
head_embedding_all.flatten(end_dim=1),
relation_all.flatten(end_dim=1),
negative_embedding,
)

# Send back queries to original shard
scores = (
all_to_all(
scores.reshape(n_shard, shard_bs, self.window_size),
n_shard,
)
.transpose(0, 1)
.flatten(start_dim=1)
) # shape (bs, n_shard * ws)

return scores
Loading