diff --git a/model2vec/train/base.py b/model2vec/train/base.py index 65f4b45..db1d539 100644 --- a/model2vec/train/base.py +++ b/model2vec/train/base.py @@ -45,7 +45,7 @@ def construct_head(self) -> nn.Sequential: @classmethod def from_pretrained( - cls: type[ModelType], out_dim: int = 2, model_name: str = "minishlab/potion-base-8m", **kwargs: Any + cls: type[ModelType], out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any ) -> ModelType: """Load the model from a pretrained model2vec model.""" model = StaticModel.from_pretrained(model_name) @@ -81,10 +81,9 @@ def _encode(self, input_ids: torch.Tensor) -> torch.Tensor: # Add a small epsilon to avoid division by zero length = zeros.sum(1) + 1e-16 embedded = self.embeddings(input_ids) - # Simulate actual mean - # Zero out the padding + # Weigh each token embedded = torch.bmm(w[:, None, :], embedded).squeeze(1) - # embedded = embedded.sum(1) + # Mean pooling by dividing by the length embedded = embedded / length[:, None] return nn.functional.normalize(embedded) @@ -106,7 +105,7 @@ def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tens """ encoded: list[Encoding] = self.tokenizer.encode_batch_fast(texts, add_special_tokens=False) encoded_ids: list[torch.Tensor] = [torch.Tensor(encoding.ids[:max_length]).long() for encoding in encoded] - return pad_sequence(encoded_ids, batch_first=True) + return pad_sequence(encoded_ids, batch_first=True, padding_value=self.pad_id) @property def device(self) -> str: