Skip to content

Commit

Permalink
Use previous logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Feb 16, 2025
1 parent ba29feb commit 3dcddf5
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
logger = logging.getLogger(__name__)
_RANDOM_SEED = 42

LabelType = TypeVar("LabelType", list[str | int], list[list[str | int]])
LabelType = TypeVar("LabelType", list[str], list[list[str]])


class StaticModelForClassification(FinetunableStaticModel):
Expand Down Expand Up @@ -230,18 +230,18 @@ def _initialize(self, y: LabelType) -> None:
:param y: The labels.
:raises ValueError: If the labels are inconsistent.
"""
if isinstance(y[0], (str, int)):
if isinstance(y[0], str):
# Check if all labels are strings.
if not all(isinstance(label, (str, int)) for label in y):
if not all(isinstance(label, str) for label in y):
raise ValueError("Inconsistent label types in y. All labels must be strings.")

Check warning on line 236 in model2vec/train/classifier.py

View check run for this annotation

Codecov / codecov/patch

model2vec/train/classifier.py#L236

Added line #L236 was not covered by tests
self.multilabel = False
classes = sorted({str(label) for label in y})
classes = sorted(set(y))
else:
# Check if all labels are lists or tuples.
if not all(isinstance(label, (list, tuple)) for label in y):
raise ValueError("Inconsistent label types in y. All labels must be lists or tuples.")

Check warning on line 242 in model2vec/train/classifier.py

View check run for this annotation

Codecov / codecov/patch

model2vec/train/classifier.py#L242

Added line #L242 was not covered by tests
self.multilabel = True
classes = sorted({str(label) for label in chain.from_iterable(y)})
classes = sorted(set(chain.from_iterable(y)))

self.classes_ = classes
self.out_dim = len(self.classes_) # Update output dimension
Expand All @@ -258,7 +258,6 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) ->
:param y: The labels.
:param max_length: The maximum length of the input.
:return: A TextDataset.
:raises ValueError: If the labels are inconsistent.
"""
# This is a speed optimization.
# assumes a mean token length of 10, which is really high, so safe.
Expand All @@ -273,18 +272,16 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) ->
labels_tensor = torch.zeros(len(y), num_classes, dtype=torch.float)
mapping = {label: idx for idx, label in enumerate(self.classes_)}
for i, sample_labels in enumerate(y):
if not isinstance(sample_labels, (list, tuple)):
raise ValueError("For multilabel classification, each label should be a list or tuple.")
indices = [mapping[str(label)] for label in sample_labels]
indices = [mapping[label] for label in sample_labels]
labels_tensor[i, indices] = 1.0
else:
labels_tensor = torch.tensor([self.classes_.index(str(label)) for label in y], dtype=torch.long)
labels_tensor = torch.tensor([self.classes_.index(label) for label in cast(list[str], y)], dtype=torch.long)
return TextDataset(tokenized, labels_tensor)

def _train_test_split(
self,
X: list[str],
y: LabelType,
y: list[str] | list[list[str]],
test_size: float,
) -> tuple[list[str], list[str], LabelType, LabelType]:
"""
Expand Down

0 comments on commit 3dcddf5

Please sign in to comment.