Skip to content

Commit

Permalink
Optimized inference performance
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Feb 15, 2025
1 parent 5c9d397 commit 6a4f89b
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ def predict(
logits = self._predict_single_batch(X[batch : batch + batch_size])
if self.multilabel:
probs = torch.sigmoid(logits)
for sample in probs:
sample_labels = [self.classes[i] for i, p in enumerate(sample) if p > threshold]
pred.append(sample_labels)
mask = (probs > threshold).cpu().numpy()
pred.extend([np.array(self.classes)[np.flatnonzero(row)] for row in mask])
else:
pred.extend([self.classes[idx] for idx in logits.argmax(dim=1).tolist()])
return np.array(pred, dtype=object)
Expand Down

0 comments on commit 6a4f89b

Please sign in to comment.