diff --git a/docs/img/backends_benchmark_cpu.png b/docs/img/backends_benchmark_cpu.png index 72ee321da..bb2600cac 100644 Binary files a/docs/img/backends_benchmark_cpu.png and b/docs/img/backends_benchmark_cpu.png differ diff --git a/docs/img/backends_benchmark_gpu.png b/docs/img/backends_benchmark_gpu.png index 11e0a00ee..e2a7bfaf6 100644 Binary files a/docs/img/backends_benchmark_gpu.png and b/docs/img/backends_benchmark_gpu.png differ diff --git a/docs/package_reference/util.md b/docs/package_reference/util.md index 3e81f6de2..495a914fc 100644 --- a/docs/package_reference/util.md +++ b/docs/package_reference/util.md @@ -10,7 +10,7 @@ ## Model Optimization ```eval_rst .. automodule:: sentence_transformers.backend - :members: export_optimized_onnx_model, export_dynamic_quantized_onnx_model + :members: export_optimized_onnx_model, export_dynamic_quantized_onnx_model, export_static_quantized_openvino_model ``` ## Similarity Metrics diff --git a/docs/sentence_transformer/usage/efficiency.rst b/docs/sentence_transformer/usage/efficiency.rst index c30770078..e596104e9 100644 --- a/docs/sentence_transformer/usage/efficiency.rst +++ b/docs/sentence_transformer/usage/efficiency.rst @@ -138,7 +138,13 @@ See this example for exporting a model with :doc:`optimization level 3 `: + +.. tab:: Hugging Face Hub Model + + Only quantize once:: + + from sentence_transformers import SentenceTransformer, export_static_quantized_openvino_model + + model = SentenceTransformer("all-MiniLM-L6-v2", backend="openvino") + export_static_quantized_openvino_model( + model, + quantization_config=None, + model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", + push_to_hub=True, + create_pr=True, + ) + + Before the pull request gets merged:: + + from sentence_transformers import SentenceTransformer + + pull_request_nr = 2 # TODO: Update this to the number of your pull request + model = SentenceTransformer( + "all-MiniLM-L6-v2", + backend="openvino", + model_kwargs={"file_name": "openvino/openvino_model_qint8_quantized.xml"}, + revision=f"refs/pr/{pull_request_nr}" + ) + + Once the pull request gets merged:: + + from sentence_transformers import SentenceTransformer + + model = SentenceTransformer( + "all-MiniLM-L6-v2", + backend="openvino", + model_kwargs={"file_name": "openvino/openvino_model_qint8_quantized.xml"}, + ) + +.. tab:: Local Model + + Only quantize once:: + + from sentence_transformers import SentenceTransformer, export_static_quantized_openvino_model + from optimum.intel import OVQuantizationConfig + + model = SentenceTransformer("path/to/my/mpnet-legal-finetuned", backend="openvino") + quantization_config = OVQuantizationConfig() + export_static_quantized_openvino_model(model, quantization_config, "path/to/my/mpnet-legal-finetuned") + + After quantizing:: + + from sentence_transformers import SentenceTransformer + + model = SentenceTransformer( + "path/to/my/mpnet-legal-finetuned", + backend="openvino", + model_kwargs={"file_name": "openvino/openvino_model_qint8_quantized.xml"}, + ) + Benchmarks ---------- @@ -388,7 +481,7 @@ The following images show the benchmark results for the different backends on GP openvino: OpenVINO, via backend="openvino".
  • - openvino-igpu: OpenVINO, via backend="openvino" and model_kwargs={"device": "GPU"}) to use the iGPU from my CPU. + openvino-qint8: OpenVINO quantized to int8 via export_static_quantized_openvino_model(..., OVQuantizationConfig(), ...) and backend="openvino".
  • @@ -428,13 +521,13 @@ Based on the benchmarks, this flowchart should help you decide which backend to A -->|CPU| C(Is a 0.4% accuracy loss acceptable?) B -->|yes| D[onnx-O4] B -->|no| F[float16] - C -->|yes| G[onnx-int8] + C -->|yes| G[openvino-qint8] C -->|no| H(Do you have an Intel CPU?) H -->|yes| I[openvino] H -->|no| J[onnx] click D "#optimizing-onnx-models" click F "#pytorch" - click G "#quantizing-onnx-models" + click G "#quantizing-openvino-models" click I "#openvino" click J "#onnx" diff --git a/sentence_transformers/__init__.py b/sentence_transformers/__init__.py index 1ba4558e8..488ecdfe9 100644 --- a/sentence_transformers/__init__.py +++ b/sentence_transformers/__init__.py @@ -6,7 +6,11 @@ import importlib import os -from sentence_transformers.backend import export_dynamic_quantized_onnx_model, export_optimized_onnx_model +from sentence_transformers.backend import ( + export_dynamic_quantized_onnx_model, + export_optimized_onnx_model, + export_static_quantized_openvino_model, +) from sentence_transformers.cross_encoder.CrossEncoder import CrossEncoder from sentence_transformers.datasets import ParallelSentencesDataset, SentencesDataset from sentence_transformers.LoggingHandler import LoggingHandler @@ -37,4 +41,5 @@ "quantize_embeddings", "export_optimized_onnx_model", "export_dynamic_quantized_onnx_model", + "export_static_quantized_openvino_model", ] diff --git a/sentence_transformers/backend.py b/sentence_transformers/backend.py index 355f40d83..9c60e4613 100644 --- a/sentence_transformers/backend.py +++ b/sentence_transformers/backend.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import os import shutil import tempfile from pathlib import Path @@ -9,11 +8,17 @@ import huggingface_hub +from sentence_transformers.util import disable_datasets_caching, is_datasets_available + logger = logging.getLogger(__name__) if TYPE_CHECKING: from sentence_transformers.SentenceTransformer import SentenceTransformer + try: + from optimum.intel import OVQuantizationConfig + except ImportError: + pass try: from optimum.onnxruntime.configuration import OptimizationConfig, QuantizationConfig except ImportError: @@ -97,7 +102,7 @@ def export_optimized_onnx_model( if file_suffix is None: file_suffix = "optimized" - save_or_push_to_hub_onnx_model( + save_or_push_to_hub_model( export_function=lambda save_dir: optimizer.optimize(optimization_config, save_dir, file_suffix=file_suffix), export_function_name="export_optimized_onnx_model", config=optimization_config, @@ -105,6 +110,7 @@ def export_optimized_onnx_model( push_to_hub=push_to_hub, create_pr=create_pr, file_suffix=file_suffix, + backend="onnx", ) @@ -180,7 +186,7 @@ def export_dynamic_quantized_onnx_model( if file_suffix is None: file_suffix = f"{quantization_config.weights_dtype.name.lower()}_quantized" - save_or_push_to_hub_onnx_model( + save_or_push_to_hub_model( export_function=lambda save_dir: quantizer.quantize(quantization_config, save_dir, file_suffix=file_suffix), export_function_name="export_dynamic_quantized_onnx_model", config=quantization_config, @@ -188,10 +194,122 @@ def export_dynamic_quantized_onnx_model( push_to_hub=push_to_hub, create_pr=create_pr, file_suffix=file_suffix, + backend="onnx", ) -def save_or_push_to_hub_onnx_model( +def export_static_quantized_openvino_model( + model: SentenceTransformer, + quantization_config: OVQuantizationConfig | dict | None, + model_name_or_path: str, + dataset_name: str | None = None, + dataset_config_name: str | None = None, + dataset_split: str | None = None, + column_name: str | None = None, + push_to_hub: bool = False, + create_pr: bool = False, + file_suffix: str = "qint8_quantized", +) -> None: + """ + Export a quantized OpenVINO model from a SentenceTransformer model. + + This function applies Post-Training Static Quantization (PTQ) using a calibration dataset, which calibrates + quantization constants without requiring model retraining. Each default quantization configuration converts + the model to int8 precision, enabling faster inference while maintaining accuracy. + + See https://sbert.net/docs/sentence_transformer/usage/efficiency.html for more information & benchmarks. + + Args: + model (SentenceTransformer): The SentenceTransformer model to be quantized. Must be loaded with `backend="openvino"`. + quantization_config (OVQuantizationConfig | dict | None): The quantization configuration. If None, default values are used. + model_name_or_path (str): The path or Hugging Face Hub repository name where the quantized model will be saved. + dataset_name(str, optional): The name of the dataset to load for calibration. + If not specified, the `sst2` subset of the `glue` dataset will be used by default. + dataset_config_name (str, optional): The specific configuration of the dataset to load. + dataset_split (str, optional): The split of the dataset to load (e.g., 'train', 'test'). Defaults to None. + column_name (str, optional): The column name in the dataset to use for calibration. Defaults to None. + push_to_hub (bool, optional): Whether to push the quantized model to the Hugging Face Hub. Defaults to False. + create_pr (bool, optional): Whether to create a pull request when pushing to the Hugging Face Hub. Defaults to False. + file_suffix (str, optional): The suffix to add to the quantized model file name. Defaults to `qint8_quantized`. + + Raises: + ImportError: If the required packages `optimum` and `openvino` are not installed. + ValueError: If the provided model is not a valid SentenceTransformer model loaded with `backend="openvino"`. + ValueError: If the provided quantization_config is not valid. + + Returns: + None + """ + from sentence_transformers import SentenceTransformer + from sentence_transformers.models.Transformer import Transformer + + try: + from optimum.intel import OVConfig, OVModelForFeatureExtraction, OVQuantizationConfig, OVQuantizer + except ImportError: + raise ImportError( + "Please install datasets, optimum-intel and openvino to use this function. " + "You can install them with pip: `pip install datasets optimum[openvino]`" + ) + if not is_datasets_available(): + raise ImportError( + "Please install datasets to use this function. You can install it with pip: `pip install datasets`" + ) + + if ( + not isinstance(model, SentenceTransformer) + or not len(model) + or not isinstance(model[0], Transformer) + or not isinstance(model[0].auto_model, OVModelForFeatureExtraction) + ): + raise ValueError( + 'The model must be a Transformer-based SentenceTransformer model loaded with `backend="openvino"`.' + ) + + if quantization_config is None: + quantization_config = OVQuantizationConfig() + + ov_model: OVModelForFeatureExtraction = model[0].auto_model + ov_config = OVConfig(quantization_config=quantization_config) + quantizer = OVQuantizer.from_pretrained(ov_model) + + if any(param is not None for param in [dataset_name, dataset_config_name, dataset_split, column_name]) and not all( + param is not None for param in [dataset_name, dataset_config_name, dataset_split, column_name] + ): + raise ValueError( + "Either specify all of `dataset_name`, `dataset_config_name`, `dataset_split`, and `column_name`, or leave them all unspecified." + ) + + def preprocess_function(examples): + return model.tokenizer(examples, padding="max_length", max_length=384, truncation=True) + + dataset_name = dataset_name if dataset_name is not None else "glue" + dataset_config_name = dataset_config_name if dataset_config_name is not None else "sst2" + dataset_split = dataset_split if dataset_split is not None else "train" + column_name = column_name if column_name is not None else "sentence" + with disable_datasets_caching(): + calibration_dataset = quantizer.get_calibration_dataset( + dataset_name=dataset_name, + dataset_config_name=dataset_config_name, + preprocess_function=lambda examples: preprocess_function(examples[column_name]), + num_samples=quantization_config.num_samples if quantization_config is not None else 300, + dataset_split=dataset_split, + ) + + save_or_push_to_hub_model( + export_function=lambda save_dir: quantizer.quantize( + calibration_dataset, save_directory=save_dir, ov_config=ov_config + ), + export_function_name="export_static_quantized_openvino_model", + config=quantization_config, + model_name_or_path=model_name_or_path, + push_to_hub=push_to_hub, + create_pr=create_pr, + file_suffix=file_suffix, + backend="openvino", + ) + + +def save_or_push_to_hub_model( export_function: Callable, export_function_name: str, config, @@ -199,14 +317,35 @@ def save_or_push_to_hub_onnx_model( push_to_hub: bool = False, create_pr: bool = False, file_suffix: str | None = None, + backend: str = "onnx", ): - if push_to_hub: - with tempfile.TemporaryDirectory() as save_dir: - export_function(save_dir) - file_name = f"model_{file_suffix}.onnx" - source = (Path(save_dir) / file_name).as_posix() - destination = (Path("onnx") / file_name).as_posix() - + if backend == "onnx": + file_name = f"model_{file_suffix}.onnx" + elif backend == "openvino": + file_name = f"openvino_model_{file_suffix}.xml" + + with tempfile.TemporaryDirectory() as save_dir: + export_function(save_dir) + + # OpenVINO models are saved in a nested directory + if backend == "openvino": + save_dir = Path(save_dir) / backend + # and we need to attach the file_suffix for both the .xml and .bin files + shutil.move(save_dir / "openvino_model.xml", save_dir / file_name) + shutil.move(save_dir / "openvino_model.bin", (save_dir / file_name).with_suffix(".bin")) + save_dir = save_dir.as_posix() + + # Because we upload folders and save_dir now has unnecessary files (tokenizer.json, config.json, etc.), + # we move the main file to a nested directory + if backend == "onnx": + dst_dir = Path(save_dir) / backend + dst_dir.mkdir(parents=True, exist_ok=True) + source = Path(save_dir) / file_name + destination = dst_dir / file_name + shutil.move(source, destination) + save_dir = dst_dir.as_posix() + + if push_to_hub: commit_description = "" if create_pr: opt_config_string = repr(config).replace("(", "(\n\t").replace(", ", ",\n\t").replace(")", "\n)") @@ -230,8 +369,8 @@ def save_or_push_to_hub_onnx_model( model = SentenceTransformer( "{model_name_or_path}", revision=f"refs/pr/{{pr_number}}", - backend="onnx", - model_kwargs={{"file_name": "{destination}"}}, + backend="{backend}", + model_kwargs={{"file_name": "{file_name}"}}, ) # Verify that everything works as expected @@ -243,23 +382,27 @@ def save_or_push_to_hub_onnx_model( ``` """ - huggingface_hub.upload_file( - path_or_fileobj=source, - path_in_repo=destination, + huggingface_hub.upload_folder( + folder_path=save_dir, + path_in_repo=backend, repo_id=model_name_or_path, repo_type="model", - commit_message=f"Add exported ONNX model {file_name!r}", + commit_message=f"Add exported {backend} model {file_name!r}", commit_description=commit_description, create_pr=create_pr, ) - else: - with tempfile.TemporaryDirectory() as save_dir: - export_function(save_dir) - - file_name = f"model_{file_suffix}.onnx" - source = os.path.join(save_dir, file_name) - destination = os.path.join(model_name_or_path, "onnx", file_name) + else: + dst_dir = Path(model_name_or_path) / backend # Create destination if it does not exist - os.makedirs(os.path.dirname(destination), exist_ok=True) + dst_dir.mkdir(parents=True, exist_ok=True) + + source = Path(save_dir) / file_name + destination = dst_dir / file_name shutil.copy(source, destination) + + # OpenVINO has a second file to save: the .bin file + if backend == "openvino": + bin_source = (Path(save_dir) / file_name).with_suffix(".bin") + bin_destination = (Path(dst_dir) / file_name).with_suffix(".bin") + shutil.copy(bin_source, bin_destination) diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index b2b5053a3..f9e82986b 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -260,7 +260,7 @@ def _backend_should_export( file_name = model_args.get("file_name", target_file_name) subfolder = model_args.get("subfolder", None) - primary_full_path = Path(subfolder, file_name).as_posix() if subfolder else file_name + primary_full_path = Path(subfolder, file_name).as_posix() if subfolder else Path(file_name).as_posix() secondary_full_path = ( Path(subfolder, self.backend, file_name).as_posix() if subfolder diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index bb4238aae..53d8247fa 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -1475,3 +1475,21 @@ def is_training_available() -> bool: Transformers models, i.e. Huggingface datasets and Huggingface accelerate. """ return is_accelerate_available() and is_datasets_available() + + +@contextmanager +def disable_datasets_caching(): + """ + A context manager that will disable caching in the datasets library. + """ + from datasets import disable_caching, enable_caching, is_caching_enabled + + is_originally_enabled = is_caching_enabled() + + try: + if is_originally_enabled: + disable_caching() + yield + finally: + if is_originally_enabled: + enable_caching()