diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 1a7a89e..afe1019 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) _RANDOM_SEED = 42 -LabelType = TypeVar("LabelType", list[str | int], list[list[str | int]]) +LabelType = TypeVar("LabelType", list[str], list[list[str]]) class StaticModelForClassification(FinetunableStaticModel): @@ -230,18 +230,18 @@ def _initialize(self, y: LabelType) -> None: :param y: The labels. :raises ValueError: If the labels are inconsistent. """ - if isinstance(y[0], (str, int)): + if isinstance(y[0], str): # Check if all labels are strings. - if not all(isinstance(label, (str, int)) for label in y): + if not all(isinstance(label, str) for label in y): raise ValueError("Inconsistent label types in y. All labels must be strings.") self.multilabel = False - classes = sorted({str(label) for label in y}) + classes = sorted(set(y)) else: # Check if all labels are lists or tuples. if not all(isinstance(label, (list, tuple)) for label in y): raise ValueError("Inconsistent label types in y. All labels must be lists or tuples.") self.multilabel = True - classes = sorted({str(label) for label in chain.from_iterable(y)}) + classes = sorted(set(chain.from_iterable(y))) self.classes_ = classes self.out_dim = len(self.classes_) # Update output dimension @@ -258,7 +258,6 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> :param y: The labels. :param max_length: The maximum length of the input. :return: A TextDataset. - :raises ValueError: If the labels are inconsistent. """ # This is a speed optimization. # assumes a mean token length of 10, which is really high, so safe. @@ -273,18 +272,16 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> labels_tensor = torch.zeros(len(y), num_classes, dtype=torch.float) mapping = {label: idx for idx, label in enumerate(self.classes_)} for i, sample_labels in enumerate(y): - if not isinstance(sample_labels, (list, tuple)): - raise ValueError("For multilabel classification, each label should be a list or tuple.") - indices = [mapping[str(label)] for label in sample_labels] + indices = [mapping[label] for label in sample_labels] labels_tensor[i, indices] = 1.0 else: - labels_tensor = torch.tensor([self.classes_.index(str(label)) for label in y], dtype=torch.long) + labels_tensor = torch.tensor([self.classes_.index(label) for label in cast(list[str], y)], dtype=torch.long) return TextDataset(tokenized, labels_tensor) def _train_test_split( self, X: list[str], - y: LabelType, + y: list[str] | list[list[str]], test_size: float, ) -> tuple[list[str], list[str], LabelType, LabelType]: """