Skip to content

Commit

Permalink
Docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul committed Oct 3, 2024
1 parent b6785a5 commit 9caf747
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
logger = logging.getLogger(__name__)


PCAType = int | None | Literal["auto"]
PCADimType = int | None | Literal["auto"]


def distill_from_model(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerFast,
vocabulary: list[str] | None = None,
device: str = "cpu",
pca_dims: PCAType = 256,
pca_dims: PCADimType = 256,
apply_zipf: bool = True,
use_subword: bool = True,
) -> StaticModel:
Expand All @@ -42,7 +42,9 @@ def distill_from_model(
:param tokenizer: The tokenizer to use.
:param vocabulary: The vocabulary to use. If this is None, we use the model's vocabulary.
:param device: The device to use.
:param pca_dims: The number of components to use for PCA. If this is None, we don't apply PCA.
:param pca_dims: The number of components to use for PCA.
If this is None, we don't apply PCA.
If this is 'auto', we don't reduce dimenionality, but still apply PCA.
:param apply_zipf: Whether to apply Zipf weighting to the embeddings.
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
:raises: ValueError if the PCA dimension is larger than the number of dimensions in the embeddings.
Expand Down Expand Up @@ -136,7 +138,7 @@ def distill(
model_name: str,
vocabulary: list[str] | None = None,
device: str = "cpu",
pca_dims: PCAType = 256,
pca_dims: PCADimType = 256,
apply_zipf: bool = True,
use_subword: bool = True,
) -> StaticModel:
Expand All @@ -152,7 +154,9 @@ def distill(
:param model_name: The model name to use. Any sentencetransformer compatible model works.
:param vocabulary: The vocabulary to use. If this is None, we use the model's vocabulary.
:param device: The device to use.
:param pca_dims: The number of components to use for PCA. If this is None, we don't apply PCA.
:param pca_dims: The number of components to use for PCA.
If this is None, we don't apply PCA.
If this is 'auto', we don't reduce dimenionality, but still apply PCA.
:param apply_zipf: Whether to apply Zipf weighting to the embeddings.
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
:return: A StaticModel
Expand All @@ -172,7 +176,7 @@ def distill(
)


def _post_process_embeddings(embeddings: np.ndarray, pca_dims: PCAType, apply_zipf: bool) -> np.ndarray:
def _post_process_embeddings(embeddings: np.ndarray, pca_dims: PCADimType, apply_zipf: bool) -> np.ndarray:
"""Post process embeddings by applying PCA and Zipf weighting."""
if pca_dims is not None:
if pca_dims == "auto":
Expand Down

0 comments on commit 9caf747

Please sign in to comment.