From acd24d763552e49cee0d4936b5265131ded8a34f Mon Sep 17 00:00:00 2001 From: stephantul Date: Tue, 11 Feb 2025 11:29:06 +0100 Subject: [PATCH 1/2] feat: replace 8m by 32m for training --- model2vec/train/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model2vec/train/base.py b/model2vec/train/base.py index 65f4b45..52407af 100644 --- a/model2vec/train/base.py +++ b/model2vec/train/base.py @@ -45,7 +45,7 @@ def construct_head(self) -> nn.Sequential: @classmethod def from_pretrained( - cls: type[ModelType], out_dim: int = 2, model_name: str = "minishlab/potion-base-8m", **kwargs: Any + cls: type[ModelType], out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any ) -> ModelType: """Load the model from a pretrained model2vec model.""" model = StaticModel.from_pretrained(model_name) From b8bc33b75fe9a74e0de37df41429e4c0c20b28c4 Mon Sep 17 00:00:00 2001 From: stephantul Date: Wed, 12 Feb 2025 09:42:43 +0100 Subject: [PATCH 2/2] small docs --- model2vec/train/base.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/model2vec/train/base.py b/model2vec/train/base.py index 52407af..db1d539 100644 --- a/model2vec/train/base.py +++ b/model2vec/train/base.py @@ -81,10 +81,9 @@ def _encode(self, input_ids: torch.Tensor) -> torch.Tensor: # Add a small epsilon to avoid division by zero length = zeros.sum(1) + 1e-16 embedded = self.embeddings(input_ids) - # Simulate actual mean - # Zero out the padding + # Weigh each token embedded = torch.bmm(w[:, None, :], embedded).squeeze(1) - # embedded = embedded.sum(1) + # Mean pooling by dividing by the length embedded = embedded / length[:, None] return nn.functional.normalize(embedded) @@ -106,7 +105,7 @@ def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tens """ encoded: list[Encoding] = self.tokenizer.encode_batch_fast(texts, add_special_tokens=False) encoded_ids: list[torch.Tensor] = [torch.Tensor(encoding.ids[:max_length]).long() for encoding in encoded] - return pad_sequence(encoded_ids, batch_first=True) + return pad_sequence(encoded_ids, batch_first=True, padding_value=self.pad_id) @property def device(self) -> str: