From a3f476261a191a30004f477564225c7b82812c27 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 13 Dec 2023 16:10:03 +0100 Subject: [PATCH] Proper sentence-transformers ONNX export support (#1589) * proper sentence-transformers onnx export support * update doc * style * fix test * fix tests --- docs/source/exporters/onnx/overview.mdx | 11 +- optimum/exporters/onnx/__main__.py | 3 +- optimum/exporters/onnx/model_configs.py | 42 +++++++ optimum/exporters/onnx/model_patcher.py | 45 +++++++ optimum/exporters/tasks.py | 114 ++++++++++++------ optimum/utils/__init__.py | 1 + optimum/utils/import_utils.py | 7 +- optimum/utils/input_generators.py | 2 +- optimum/utils/testing_utils.py | 12 +- tests/exporters/exporters_utils.py | 7 +- .../exporters/onnx/test_exporters_onnx_cli.py | 26 +++- tests/exporters/onnx/test_onnx_export.py | 5 +- 12 files changed, 225 insertions(+), 50 deletions(-) diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 82b30f1e13b..63a326d4f05 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License. 🤗 Optimum handles the export of PyTorch or TensorFlow models to ONNX in the `exporters.onnx` module. It provides classes, functions, and a command line interface to perform the export easily. -Supported architectures: +Supported architectures from [🤗 Transformers](https://huggingface.co/docs/transformers/index): - AST - Audio Spectrogram Transformer @@ -89,7 +89,6 @@ Supported architectures: - SpeechT5 - Splinter - SqueezeBert -- Stable Diffusion - Swin - T5 - TROCR @@ -105,9 +104,15 @@ Supported architectures: - XLM-Roberta - Yolos -Supported architectures (Timm): +Supported architectures from [🤗 Diffusers](https://huggingface.co/docs/diffusers/index): +- Stable Diffusion + +Supported architectures from [🤗 Timm](https://huggingface.co/docs/timm/index): - Resnext50-32x4d - Resnext50d-32x4d - Resnext101-32x4d - Resnext101-32x8d - Resnext101-64x4d + +Supported architectures from [Sentence Transformers](https://github.com/UKPLab/sentence-transformers): +- All Transformer and CLIP-based models. diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index c466f54c423..d822da56d27 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -266,8 +266,7 @@ def main_export( _variant (`str`, defaults to `default`): Specify the variant of the ONNX export to use. library_name (`Optional[str]`, defaults to `None`): - The library of the model(`"tansformers"` or `"diffusers"` or `"timm"`). If not provided, will attempt to automatically detect - the library name for the checkpoint. + The library of the model (`"tansformers"` or `"diffusers"` or `"timm"` or `"sentence_transformers"`). If not provided, will attempt to automatically detect the library name for the checkpoint. legacy (`bool`, defaults to `False`): Disable the use of position_ids for text-generation models that require it for batched generation. Also enable to export decoder only models in three files (without + with past and the merged model). This argument is introduced for backward compatibility and will be removed in a future release of Optimum. **kwargs_shapes (`Dict`): diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index db210507d68..74672c6c6d0 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -63,6 +63,8 @@ from .model_patcher import ( FalconModelPatcher, SAMModelPatcher, + SentenceTransformersCLIPPatcher, + SentenceTransformersTransformerPatcher, SpeechT5ModelPatcher, VisionEncoderDecoderPatcher, WavLMModelPatcher, @@ -799,6 +801,32 @@ def inputs(self) -> Dict[str, Dict[int, str]]: return {"pixel_values": {0: "batch_size"}} +class SentenceTransformersTransformerOnnxConfig(TextEncoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + } + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return { + "token_embeddings": {0: "batch_size", 1: "sequence_length"}, + "sentence_embedding": {0: "batch_size"}, + } + + # we need to set output_attentions=True in the model input to avoid calling + # torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export + # due to the op torch.nn.functional.multi_head_attention_forward used for WavLM + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return SentenceTransformersTransformerPatcher(self, model, model_kwargs=model_kwargs) + + class CLIPNormalizedConfig(NormalizedTextAndVisionConfig): TEXT_CONFIG = "text_config" VISION_CONFIG = "vision_config" @@ -826,6 +854,20 @@ def outputs(self) -> Dict[str, Dict[int, str]]: } +class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig): + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return { + "text_embeds": {0: "text_batch_size"}, + "image_embeds": {0: "image_batch_size"}, + } + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return SentenceTransformersCLIPPatcher(self, model, model_kwargs=model_kwargs) + + class CLIPTextWithProjectionOnnxConfig(TextEncoderOnnxConfig): ATOL_FOR_VALIDATION = 1e-3 # The ONNX export of this architecture needs the Trilu operator support, available since opset 14 diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index ecafb73b973..16335e69a97 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -744,3 +744,48 @@ def patched_forward( return filterd_outputs self.patched_forward = patched_forward + + +class SentenceTransformersTransformerPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Dict[str, Any], + ): + super().__init__(config, model, model_kwargs) + + def patched_forward(input_ids, attention_mask): + result = self.orig_forward({"input_ids": input_ids, "attention_mask": attention_mask}) + + return result + + self.patched_forward = patched_forward + + +class SentenceTransformersCLIPPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Dict[str, Any], + ): + super().__init__(config, model, model_kwargs) + + def patched_forward(input_ids, attention_mask, pixel_values): + vision_outputs = model[0].model.vision_model(pixel_values=pixel_values) + image_embeds = model[0].model.visual_projection(vision_outputs[1]) + + text_outputs = model[0].model.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + ) + text_embeds = model[0].model.text_projection(text_outputs[1]) + + if len(model) > 1: + image_embeds = model[1:](image_embeds) + text_embeds = model[1:](text_embeds) + + return {"text_embeds": text_embeds, "image_embeds": image_embeds} + + self.patched_forward = patched_forward diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 3241b3a822a..759fa7c878a 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -195,10 +195,16 @@ class TasksManager: "image-classification": "create_model", } + _SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS = { + "feature-extraction": "SentenceTransformer", + "sentence-similarity": "SentenceTransformer", + } + _LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP = { - "transformers": _TRANSFORMERS_TASKS_TO_MODEL_LOADERS, "diffusers": _DIFFUSERS_TASKS_TO_MODEL_LOADERS, + "sentence_transformers": _SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS, "timm": _TIMM_TASKS_TO_MODEL_LOADERS, + "transformers": _TRANSFORMERS_TASKS_TO_MODEL_LOADERS, } if is_tf_available(): @@ -254,9 +260,10 @@ class TasksManager: # Reverse dictionaries str -> str, where several model loaders may map to the same task _LIBRARY_TO_MODEL_LOADERS_TO_TASKS_MAP = { - "transformers": get_model_loaders_to_tasks(_TRANSFORMERS_TASKS_TO_MODEL_LOADERS), "diffusers": get_model_loaders_to_tasks(_DIFFUSERS_TASKS_TO_MODEL_LOADERS), + "sentence_transformers": get_model_loaders_to_tasks(_SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS), "timm": get_model_loaders_to_tasks(_TIMM_TASKS_TO_MODEL_LOADERS), + "transformers": get_model_loaders_to_tasks(_TRANSFORMERS_TASKS_TO_MODEL_LOADERS), } _LIBRARY_TO_TF_MODEL_LOADERS_TO_TASKS_MAP = { "transformers": get_model_loaders_to_tasks(_TRANSFORMERS_TASKS_TO_TF_MODEL_LOADERS), @@ -871,6 +878,16 @@ class TasksManager: "semantic-segmentation", onnx="SegformerOnnxConfig", ), + "sentence-transformers-clip": supported_tasks_mapping( + "feature-extraction", + "sentence-similarity", + onnx="SentenceTransformersCLIPOnnxConfig", + ), + "sentence-transformers-transformer": supported_tasks_mapping( + "feature-extraction", + "sentence-similarity", + onnx="SentenceTransformersTransformerOnnxConfig", + ), "sew": supported_tasks_mapping( "feature-extraction", "automatic-speech-recognition", @@ -1354,6 +1371,9 @@ def determine_framework( ): # stable diffusion case framework = "pt" + elif "config_sentence_transformers.json" in all_files: + # Sentence Transformers libary relies on PyTorch. + framework = "pt" else: if request_exception is not None: raise RequestsConnectionError( @@ -1559,6 +1579,10 @@ def infer_library_from_model( model_info = huggingface_hub.model_info(model_name_or_path, revision=revision) library_name = getattr(model_info, "library_name", None) + # sentence-transformers package name is sentence_transformers + if library_name is not None: + library_name = library_name.replace("-", "_") + if library_name is None: all_files, _ = TasksManager.get_model_files(model_name_or_path, subfolder, cache_dir) @@ -1578,17 +1602,16 @@ def infer_library_from_model( library_name = "timm" elif hasattr(model_config, "_diffusers_version"): library_name = "diffusers" + elif any(file_path.startswith("sentence_") for file_path in all_files): + library_name = "sentence_transformers" else: library_name = "transformers" if library_name is None: raise ValueError( - "The library_name could not be automatically inferred. If using the command-line, please provide the argument --library (transformers,diffusers,timm)!" + "The library name could not be automatically inferred. If using the command-line, please provide the argument --library {transformers,diffusers,timm,sentence_transformers}. Example: `--library diffusers`." ) - if library_name == "sentence-transformers": - return "transformers" - return library_name @classmethod @@ -1647,6 +1670,17 @@ def standardize_model_attributes( model_type = json.load(fp)["architecture"] setattr(model.config, "model_type", model_type) + elif library_name == "sentence_transformers": + if "Transformer" in model[0].__class__.__name__: + model.config = model[0].auto_model.config + model.config.model_type = "sentence-transformers-transformer" + elif "CLIP" in model[0].__class__.__name__: + model.config = model[0].model.config + model.config.model_type = "sentence-transformers-clip" + else: + raise ValueError( + f"The export of a sentence-transformers model with the first module being {model[0].__class__.__name__} is currently not supported in Optimum. Please open an issue or submit a PR to add the support." + ) @staticmethod def get_all_tasks(): @@ -1747,39 +1781,45 @@ def get_model_from_task( if library_name == "timm": model = model_class(f"hf_hub:{model_name_or_path}", pretrained=True, exportable=True) - TasksManager.standardize_model_attributes( - model_name_or_path, model, subfolder, revision, cache_dir, library_name + elif library_name == "sentence_transformers": + cache_folder = model_kwargs.pop("cache_folder", None) + use_auth_token = model_kwargs.pop("use_auth_token", None) + model = model_class( + model_name_or_path, device=device, cache_folder=cache_folder, use_auth_token=use_auth_token ) - return model - - try: - if framework == "pt": - kwargs["torch_dtype"] = torch_dtype - - if isinstance(device, str): - device = torch.device(device) - elif device is None: - device = torch.device("cpu") - - # TODO : fix EulerDiscreteScheduler loading to enable for SD models - if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers": - with device: - # Initialize directly in the requested device, to save allocation time. Especially useful for large - # models to initialize on cuda device. - model = model_class.from_pretrained(model_name_or_path, **kwargs) + else: + try: + if framework == "pt": + kwargs["torch_dtype"] = torch_dtype + + if isinstance(device, str): + device = torch.device(device) + elif device is None: + device = torch.device("cpu") + + # TODO : fix EulerDiscreteScheduler loading to enable for SD models + if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers": + with device: + # Initialize directly in the requested device, to save allocation time. Especially useful for large + # models to initialize on cuda device. + model = model_class.from_pretrained(model_name_or_path, **kwargs) + else: + model = model_class.from_pretrained(model_name_or_path, **kwargs).to(device) else: - model = model_class.from_pretrained(model_name_or_path, **kwargs).to(device) - else: - model = model_class.from_pretrained(model_name_or_path, **kwargs) - except OSError: - if framework == "pt": - logger.info("Loading TensorFlow model in PyTorch before exporting.") - kwargs["from_tf"] = True - model = model_class.from_pretrained(model_name_or_path, **kwargs) - else: - logger.info("Loading PyTorch model in TensorFlow before exporting.") - kwargs["from_pt"] = True - model = model_class.from_pretrained(model_name_or_path, **kwargs) + model = model_class.from_pretrained(model_name_or_path, **kwargs) + except OSError: + if framework == "pt": + logger.info("Loading TensorFlow model in PyTorch before exporting.") + kwargs["from_tf"] = True + model = model_class.from_pretrained(model_name_or_path, **kwargs) + else: + logger.info("Loading PyTorch model in TensorFlow before exporting.") + kwargs["from_pt"] = True + model = model_class.from_pretrained(model_name_or_path, **kwargs) + + TasksManager.standardize_model_attributes( + model_name_or_path, model, subfolder, revision, cache_dir, library_name + ) return model diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 73eb86bdae5..889edbbbd48 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -36,6 +36,7 @@ is_onnx_available, is_onnxruntime_available, is_pydantic_available, + is_sentence_transformers_available, is_timm_available, is_torch_onnx_support_available, require_numpy_strictly_lower, diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 11d8ece3fc2..2bbfdfaf0ac 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -48,7 +48,8 @@ _accelerate_available = importlib.util.find_spec("accelerate") is not None _diffusers_available = importlib.util.find_spec("diffusers") is not None _auto_gptq_available = importlib.util.find_spec("auto_gptq") is not None -_timm_available = importlib.util.find_spec("diffusers") is not None +_timm_available = importlib.util.find_spec("timm") is not None +_sentence_transformers_available = importlib.util.find_spec("sentence_transformers") is not None torch_version = None if is_torch_available(): @@ -107,6 +108,10 @@ def is_timm_available(): return _timm_available +def is_sentence_transformers_available(): + return _sentence_transformers_available + + def is_auto_gptq_available(): if _auto_gptq_available: version_autogptq = packaging.version.parse(importlib_metadata.version("auto_gptq")) diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index aa1f785309c..b86086f8fbe 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -812,7 +812,7 @@ def __init__( output_channels if output_channels is not None else normalized_config.vision_config.output_channels ) - def generate(self, input_name: str, framework: str = "pt"): + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): shape = [self.batch_size, self.output_channels, self.image_embedding_size, self.image_embedding_size] return self.random_float_tensor(shape, framework=framework) diff --git a/optimum/utils/testing_utils.py b/optimum/utils/testing_utils.py index 9559d289bc6..f1c2f668e3c 100644 --- a/optimum/utils/testing_utils.py +++ b/optimum/utils/testing_utils.py @@ -24,7 +24,13 @@ import torch -from . import is_accelerate_available, is_auto_gptq_available, is_diffusers_available, is_timm_available +from . import ( + is_accelerate_available, + is_auto_gptq_available, + is_diffusers_available, + is_sentence_transformers_available, + is_timm_available, +) # Used to test the hub @@ -137,6 +143,10 @@ def require_timm(test_case): return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case) +def require_sentence_transformers(test_case): + return unittest.skipUnless(is_sentence_transformers_available(), "test requires sentence-transformers")(test_case) + + def grid_parameters( parameters: Dict[str, Iterable[Any]], yield_dict: bool = False, diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index c738ca5389b..a87fb9d042c 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -130,7 +130,7 @@ "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-RobertaModel", "roformer": "hf-internal-testing/tiny-random-RoFormerModel", - # "sam": "fxmarty/sam-vit-tiny-random", # TODO: re-enable once PyTorch 2.1 is released, see https://github.com/huggingface/optimum/pull/1301 + "sam": "fxmarty/sam-vit-tiny-random", "segformer": "hf-internal-testing/tiny-random-SegformerModel", "splinter": "hf-internal-testing/tiny-random-SplinterModel", "squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel", @@ -295,3 +295,8 @@ "resnext101-32x8d": "timm/resnext101_32x8d.tv_in1k", "resnext101-64x4d": "timm/resnext101_64x4d.c1_in1k", } + +PYTORCH_SENTENCE_TRANSFORMERS_MODEL = { + "sentence-transformers-clip": "sentence-transformers/all-MiniLM-L6-v2", + "sentence-transformers-transformer": "sentence-transformers/clip-ViT-B-32-multilingual-v1", +} diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index 07c6fa292e9..6ea10f28c19 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -33,13 +33,18 @@ ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME, ) -from optimum.utils.testing_utils import require_diffusers, require_timm +from optimum.utils.testing_utils import require_diffusers, require_sentence_transformers, require_timm if is_torch_available(): from optimum.exporters.tasks import TasksManager -from ..exporters_utils import PYTORCH_EXPORT_MODELS_TINY, PYTORCH_STABLE_DIFFUSION_MODEL, PYTORCH_TIMM_MODEL +from ..exporters_utils import ( + PYTORCH_EXPORT_MODELS_TINY, + PYTORCH_SENTENCE_TRANSFORMERS_MODEL, + PYTORCH_STABLE_DIFFUSION_MODEL, + PYTORCH_TIMM_MODEL, +) def _get_models_to_test(export_models_dict: Dict): @@ -204,6 +209,23 @@ def test_exporters_cli_pytorch_gpu_stable_diffusion(self, model_type: str, model def test_exporters_cli_fp16_stable_diffusion(self, model_type: str, model_name: str): self._onnx_export(model_name, model_type, device="cuda", fp16=True) + @parameterized.expand(_get_models_to_test(PYTORCH_SENTENCE_TRANSFORMERS_MODEL)) + @require_torch + @require_vision + @require_sentence_transformers + @pytest.mark.timm_test + def test_exporters_cli_pytorch_cpu_sentence_transformers( + self, + test_name: str, + model_type: str, + model_name: str, + task: str, + variant: str, + monolith: bool, + no_post_process: bool, + ): + self._onnx_export(model_name, task, monolith, no_post_process, variant=variant) + @parameterized.expand(_get_models_to_test(PYTORCH_TIMM_MODEL)) @require_torch @require_vision diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index eba2f01f61a..03e97ee69af 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -46,6 +46,7 @@ from ..exporters_utils import ( PYTORCH_EXPORT_MODELS_TINY, + PYTORCH_SENTENCE_TRANSFORMERS_MODEL, PYTORCH_STABLE_DIFFUSION_MODEL, PYTORCH_TIMM_MODEL, TENSORFLOW_EXPORT_MODELS, @@ -312,9 +313,9 @@ def test_all_models_tested(self): TasksManager._SUPPORTED_CLI_MODEL_TYPE - set(PYTORCH_EXPORT_MODELS_TINY.keys()) - set(PYTORCH_TIMM_MODEL.keys()) + - set(PYTORCH_SENTENCE_TRANSFORMERS_MODEL.keys()) ) - assert "sam" in missing_models_set # See exporters_utils.py - if len(missing_models_set) > 1: + if len(missing_models_set) > 0: self.fail(f"Not testing all models. Missing models: {missing_models_set}") @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY))