Skip to content

Commit

Permalink
Added int as possible label type
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Feb 16, 2025
1 parent b4df861 commit ba29feb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) ->
indices = [mapping[str(label)] for label in sample_labels]
labels_tensor[i, indices] = 1.0
else:
labels_tensor = torch.tensor([self.classes_.index(label) for label in cast(list[str], y)], dtype=torch.long)
labels_tensor = torch.tensor([self.classes_.index(str(label)) for label in y], dtype=torch.long)
return TextDataset(tokenized, labels_tensor)

def _train_test_split(
Expand Down

0 comments on commit ba29feb

Please sign in to comment.