Skip to content

Commit

Permalink
feat: replace 8m by 32m for training (#182)
Browse files Browse the repository at this point in the history
* feat: replace 8m by 32m for training

* small docs
  • Loading branch information
stephantul authored Feb 12, 2025
1 parent b83fb56 commit dd160fb
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions model2vec/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def construct_head(self) -> nn.Sequential:

@classmethod
def from_pretrained(
cls: type[ModelType], out_dim: int = 2, model_name: str = "minishlab/potion-base-8m", **kwargs: Any
cls: type[ModelType], out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any
) -> ModelType:
"""Load the model from a pretrained model2vec model."""
model = StaticModel.from_pretrained(model_name)
Expand Down Expand Up @@ -81,10 +81,9 @@ def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
# Add a small epsilon to avoid division by zero
length = zeros.sum(1) + 1e-16
embedded = self.embeddings(input_ids)
# Simulate actual mean
# Zero out the padding
# Weigh each token
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
# embedded = embedded.sum(1)
# Mean pooling by dividing by the length
embedded = embedded / length[:, None]

return nn.functional.normalize(embedded)
Expand All @@ -106,7 +105,7 @@ def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tens
"""
encoded: list[Encoding] = self.tokenizer.encode_batch_fast(texts, add_special_tokens=False)
encoded_ids: list[torch.Tensor] = [torch.Tensor(encoding.ids[:max_length]).long() for encoding in encoded]
return pad_sequence(encoded_ids, batch_first=True)
return pad_sequence(encoded_ids, batch_first=True, padding_value=self.pad_id)

@property
def device(self) -> str:
Expand Down

0 comments on commit dd160fb

Please sign in to comment.