Skip to content

Commit

Permalink
small docs
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul committed Feb 12, 2025
1 parent 13cf761 commit b8bc33b
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions model2vec/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,9 @@ def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
# Add a small epsilon to avoid division by zero
length = zeros.sum(1) + 1e-16
embedded = self.embeddings(input_ids)
# Simulate actual mean
# Zero out the padding
# Weigh each token
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
# embedded = embedded.sum(1)
# Mean pooling by dividing by the length
embedded = embedded / length[:, None]

return nn.functional.normalize(embedded)
Expand All @@ -106,7 +105,7 @@ def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tens
"""
encoded: list[Encoding] = self.tokenizer.encode_batch_fast(texts, add_special_tokens=False)
encoded_ids: list[torch.Tensor] = [torch.Tensor(encoding.ids[:max_length]).long() for encoding in encoded]
return pad_sequence(encoded_ids, batch_first=True)
return pad_sequence(encoded_ids, batch_first=True, padding_value=self.pad_id)

@property
def device(self) -> str:
Expand Down

0 comments on commit b8bc33b

Please sign in to comment.