diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py index dfc0face..f917c0ad 100644 --- a/src/setfit/modeling.py +++ b/src/setfit/modeling.py @@ -518,6 +518,13 @@ def predict_proba( probs = torch.stack(probs, axis=1) else: probs = np.stack(probs, axis=1) + if list(self.labels) != list(self.model_head.classes_): + # If the user has specified labels when instantiating the model, we have to take into account + # the possibility of the model head having reordered the labels. + head_labels = list(self.model_head.classes_) + user_labels = list(self.labels) + reorder_map = np.array([head_labels.index(label) for label in user_labels]) + probs = probs[:, reorder_map] outputs = self._output_type_conversion(probs, as_numpy=as_numpy) return outputs[0] if is_singular else outputs