From 14f39f8866bd3b507e7cd66ac1397990a7a3855b Mon Sep 17 00:00:00 2001 From: Daniel Justus Date: Fri, 27 Oct 2023 17:11:05 +0200 Subject: [PATCH] fp32 topk on cpu (#34) * fp32 topk on cpu * black --- besskge/pipeline.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/besskge/pipeline.py b/besskge/pipeline.py index 49e8507..535a1d6 100644 --- a/besskge/pipeline.py +++ b/besskge/pipeline.py @@ -271,7 +271,11 @@ def forward(self) -> Dict[str, Any]: if self.return_scores: scores.append(batch_scores_filt) if self.return_topk: - topk_ids.append(torch.topk(batch_scores_filt, k=self.k, dim=-1).indices) + topk_ids.append( + torch.topk( + batch_scores_filt.to(torch.float32), k=self.k, dim=-1 + ).indices + ) out = dict() if scores: