Skip to content

Commit

Permalink
make return all scores optional, to avoid host OOM
Browse files Browse the repository at this point in the history
  • Loading branch information
AlCatt91 committed Oct 25, 2023
1 parent fbe91e9 commit c6325c4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
18 changes: 16 additions & 2 deletions besskge/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
corruption_scheme: str,
score_fn: BaseScoreFunction,
evaluation: Optional[Evaluation] = None,
return_scores: bool = False,
windows_size: int = 1000,
use_ipu_model: bool = False,
) -> None:
Expand All @@ -51,6 +52,11 @@ def __init__(
:param evaluation:
Evaluation module, for computing metrics.
Default: None.
:param return_scores:
If True, store and return scores of all queries' completions.
For large number of queries/entities, this can cause the host
to go OOM.
Default: False.
:param windows_size:
Size of the sliding window, namely the number of negative entities
scored against each query at each step on IPU and returned to host.
Expand All @@ -61,6 +67,10 @@ def __init__(
"""
super().__init__()
self.batch_sampler = batch_sampler
if not (evaluation or return_scores):
raise ValueError(
"Nothing to return. Provide `evaluation` or set" " `return_scores=True`"
)
if corruption_scheme not in ["h", "t"]:
raise ValueError("corruption_scheme needs to be either 'h' or 't'")
if (
Expand All @@ -82,6 +92,7 @@ def __init__(
)
self.score_fn = score_fn
self.evaluation = evaluation
self.return_scores = return_scores
self.window_size = windows_size
self.bess_module = AllScoresBESS(
self.candidate_sampler, self.score_fn, self.window_size
Expand Down Expand Up @@ -167,7 +178,8 @@ def forward(self) -> Dict[str, Any]:
batch_scores_filt = batch_scores[triple_mask.flatten()][
:, np.unique(np.concatenate(batch_idx), return_index=True)[1]
][:, : self.bess_module.sharding.n_entity]
scores.append(torch.clone(batch_scores_filt))
if self.return_scores:
scores.append(torch.clone(batch_scores_filt))

if self.evaluation:
assert (
Expand All @@ -190,7 +202,9 @@ def forward(self) -> Dict[str, Any]:
if self.evaluation.return_ranks:
ranks.append(batch_ranks)

out = dict(scores=torch.concat(scores, dim=0))
out = dict()
if scores:
out["scores"] = torch.concat(scores, dim=0)
if ids:
out["triple_idx"] = torch.concat(ids, dim=0)
if self.evaluation:
Expand Down
1 change: 1 addition & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_all_scores_pipeline(corruption_scheme: str, compute_metrics: bool) -> N
corruption_scheme,
score_fn,
evaluation,
return_scores=True,
windows_size=1000,
use_ipu_model=True,
)
Expand Down

0 comments on commit c6325c4

Please sign in to comment.