Skip to content

Commit

Permalink
Changed classes to np array
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Feb 15, 2025
1 parent 6a4f89b commit 3609e62
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def __init__(
super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer)

@property
def classes(self) -> list[str]:
def classes(self) -> np.ndarray:
"""Return all clasess in the correct order."""
return self.classes_
return np.array(self.classes_)

def construct_head(self) -> nn.Sequential:
"""Constructs a simple classifier head."""
Expand Down Expand Up @@ -93,7 +93,7 @@ def predict(
if self.multilabel:
probs = torch.sigmoid(logits)
mask = (probs > threshold).cpu().numpy()
pred.extend([np.array(self.classes)[np.flatnonzero(row)] for row in mask])
pred.extend([self.classes[np.flatnonzero(row)] for row in mask])
else:
pred.extend([self.classes[idx] for idx in logits.argmax(dim=1).tolist()])
return np.array(pred, dtype=object)
Expand Down Expand Up @@ -275,7 +275,7 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) ->
indices = [mapping[label] for label in sample_labels]
labels_tensor[i, indices] = 1.0
else:
labels_tensor = torch.tensor([self.classes.index(label) for label in cast(list[str], y)], dtype=torch.long)
labels_tensor = torch.tensor([self.classes_.index(label) for label in cast(list[str], y)], dtype=torch.long)
return TextDataset(tokenized, labels_tensor)

def _train_test_split(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def test_init_predict(n_layers: int, mock_vectors: np.ndarray, mock_tokenizer: T
s = StaticModelForClassification(vectors=vectors_torched, tokenizer=mock_tokenizer, n_layers=n_layers)
assert s.vectors.shape == mock_vectors.shape
assert s.w.shape[0] == mock_vectors.shape[0]
assert s.classes == s.classes_
assert s.classes == ["0", "1"]
assert list(s.classes) == s.classes_
assert list(s.classes) == ["0", "1"]

head = s.construct_head()
assert head[0].in_features == mock_vectors.shape[1]
Expand Down

0 comments on commit 3609e62

Please sign in to comment.