Skip to content

Commit

Permalink
Fixed merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Feb 16, 2025
2 parents 2dc5b17 + 2d51516 commit 5003768
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ def predict(
pred.extend([self.classes[np.flatnonzero(row)] for row in mask])
else:
pred.extend([self.classes[idx] for idx in logits.argmax(dim=1).tolist()])
return np.array(pred, dtype=object)
if self.multilabel:
# Return as object array to allow for lists of varying lengths.
return np.array(pred, dtype=object)
else:
return np.array(pred)

@torch.no_grad()
def _predict_single_batch(self, X: list[str]) -> torch.Tensor:
Expand Down

0 comments on commit 5003768

Please sign in to comment.