Skip to content

Commit

Permalink
Proper sentence-transformers ONNX export support (#1589)
Browse files Browse the repository at this point in the history
* proper sentence-transformers onnx export support

* update doc

* style

* fix test

* fix tests
  • Loading branch information
fxmarty authored Dec 13, 2023
1 parent dad6a8a commit a3f4762
Show file tree
Hide file tree
Showing 12 changed files with 225 additions and 50 deletions.
11 changes: 8 additions & 3 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.

🤗 Optimum handles the export of PyTorch or TensorFlow models to ONNX in the `exporters.onnx` module. It provides classes, functions, and a command line interface to perform the export easily.

Supported architectures:
Supported architectures from [🤗 Transformers](https://huggingface.co/docs/transformers/index):

- AST
- Audio Spectrogram Transformer
Expand Down Expand Up @@ -89,7 +89,6 @@ Supported architectures:
- SpeechT5
- Splinter
- SqueezeBert
- Stable Diffusion
- Swin
- T5
- TROCR
Expand All @@ -105,9 +104,15 @@ Supported architectures:
- XLM-Roberta
- Yolos

Supported architectures (Timm):
Supported architectures from [🤗 Diffusers](https://huggingface.co/docs/diffusers/index):
- Stable Diffusion

Supported architectures from [🤗 Timm](https://huggingface.co/docs/timm/index):
- Resnext50-32x4d
- Resnext50d-32x4d
- Resnext101-32x4d
- Resnext101-32x8d
- Resnext101-64x4d

Supported architectures from [Sentence Transformers](https://github.com/UKPLab/sentence-transformers):
- All Transformer and CLIP-based models.
3 changes: 1 addition & 2 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,7 @@ def main_export(
_variant (`str`, defaults to `default`):
Specify the variant of the ONNX export to use.
library_name (`Optional[str]`, defaults to `None`):
The library of the model(`"tansformers"` or `"diffusers"` or `"timm"`). If not provided, will attempt to automatically detect
the library name for the checkpoint.
The library of the model (`"tansformers"` or `"diffusers"` or `"timm"` or `"sentence_transformers"`). If not provided, will attempt to automatically detect the library name for the checkpoint.
legacy (`bool`, defaults to `False`):
Disable the use of position_ids for text-generation models that require it for batched generation. Also enable to export decoder only models in three files (without + with past and the merged model). This argument is introduced for backward compatibility and will be removed in a future release of Optimum.
**kwargs_shapes (`Dict`):
Expand Down
42 changes: 42 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
from .model_patcher import (
FalconModelPatcher,
SAMModelPatcher,
SentenceTransformersCLIPPatcher,
SentenceTransformersTransformerPatcher,
SpeechT5ModelPatcher,
VisionEncoderDecoderPatcher,
WavLMModelPatcher,
Expand Down Expand Up @@ -799,6 +801,32 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
return {"pixel_values": {0: "batch_size"}}


class SentenceTransformersTransformerOnnxConfig(TextEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"token_embeddings": {0: "batch_size", 1: "sequence_length"},
"sentence_embedding": {0: "batch_size"},
}

# we need to set output_attentions=True in the model input to avoid calling
# torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export
# due to the op torch.nn.functional.multi_head_attention_forward used for WavLM
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return SentenceTransformersTransformerPatcher(self, model, model_kwargs=model_kwargs)


class CLIPNormalizedConfig(NormalizedTextAndVisionConfig):
TEXT_CONFIG = "text_config"
VISION_CONFIG = "vision_config"
Expand Down Expand Up @@ -826,6 +854,20 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
}


class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"text_embeds": {0: "text_batch_size"},
"image_embeds": {0: "image_batch_size"},
}

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return SentenceTransformersCLIPPatcher(self, model, model_kwargs=model_kwargs)


class CLIPTextWithProjectionOnnxConfig(TextEncoderOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
Expand Down
45 changes: 45 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,3 +744,48 @@ def patched_forward(
return filterd_outputs

self.patched_forward = patched_forward


class SentenceTransformersTransformerPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)

def patched_forward(input_ids, attention_mask):
result = self.orig_forward({"input_ids": input_ids, "attention_mask": attention_mask})

return result

self.patched_forward = patched_forward


class SentenceTransformersCLIPPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)

def patched_forward(input_ids, attention_mask, pixel_values):
vision_outputs = model[0].model.vision_model(pixel_values=pixel_values)
image_embeds = model[0].model.visual_projection(vision_outputs[1])

text_outputs = model[0].model.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
)
text_embeds = model[0].model.text_projection(text_outputs[1])

if len(model) > 1:
image_embeds = model[1:](image_embeds)
text_embeds = model[1:](text_embeds)

return {"text_embeds": text_embeds, "image_embeds": image_embeds}

self.patched_forward = patched_forward
114 changes: 77 additions & 37 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,16 @@ class TasksManager:
"image-classification": "create_model",
}

_SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS = {
"feature-extraction": "SentenceTransformer",
"sentence-similarity": "SentenceTransformer",
}

_LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP = {
"transformers": _TRANSFORMERS_TASKS_TO_MODEL_LOADERS,
"diffusers": _DIFFUSERS_TASKS_TO_MODEL_LOADERS,
"sentence_transformers": _SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS,
"timm": _TIMM_TASKS_TO_MODEL_LOADERS,
"transformers": _TRANSFORMERS_TASKS_TO_MODEL_LOADERS,
}

if is_tf_available():
Expand Down Expand Up @@ -254,9 +260,10 @@ class TasksManager:

# Reverse dictionaries str -> str, where several model loaders may map to the same task
_LIBRARY_TO_MODEL_LOADERS_TO_TASKS_MAP = {
"transformers": get_model_loaders_to_tasks(_TRANSFORMERS_TASKS_TO_MODEL_LOADERS),
"diffusers": get_model_loaders_to_tasks(_DIFFUSERS_TASKS_TO_MODEL_LOADERS),
"sentence_transformers": get_model_loaders_to_tasks(_SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS),
"timm": get_model_loaders_to_tasks(_TIMM_TASKS_TO_MODEL_LOADERS),
"transformers": get_model_loaders_to_tasks(_TRANSFORMERS_TASKS_TO_MODEL_LOADERS),
}
_LIBRARY_TO_TF_MODEL_LOADERS_TO_TASKS_MAP = {
"transformers": get_model_loaders_to_tasks(_TRANSFORMERS_TASKS_TO_TF_MODEL_LOADERS),
Expand Down Expand Up @@ -871,6 +878,16 @@ class TasksManager:
"semantic-segmentation",
onnx="SegformerOnnxConfig",
),
"sentence-transformers-clip": supported_tasks_mapping(
"feature-extraction",
"sentence-similarity",
onnx="SentenceTransformersCLIPOnnxConfig",
),
"sentence-transformers-transformer": supported_tasks_mapping(
"feature-extraction",
"sentence-similarity",
onnx="SentenceTransformersTransformerOnnxConfig",
),
"sew": supported_tasks_mapping(
"feature-extraction",
"automatic-speech-recognition",
Expand Down Expand Up @@ -1354,6 +1371,9 @@ def determine_framework(
):
# stable diffusion case
framework = "pt"
elif "config_sentence_transformers.json" in all_files:
# Sentence Transformers libary relies on PyTorch.
framework = "pt"
else:
if request_exception is not None:
raise RequestsConnectionError(
Expand Down Expand Up @@ -1559,6 +1579,10 @@ def infer_library_from_model(
model_info = huggingface_hub.model_info(model_name_or_path, revision=revision)
library_name = getattr(model_info, "library_name", None)

# sentence-transformers package name is sentence_transformers
if library_name is not None:
library_name = library_name.replace("-", "_")

if library_name is None:
all_files, _ = TasksManager.get_model_files(model_name_or_path, subfolder, cache_dir)

Expand All @@ -1578,17 +1602,16 @@ def infer_library_from_model(
library_name = "timm"
elif hasattr(model_config, "_diffusers_version"):
library_name = "diffusers"
elif any(file_path.startswith("sentence_") for file_path in all_files):
library_name = "sentence_transformers"
else:
library_name = "transformers"

if library_name is None:
raise ValueError(
"The library_name could not be automatically inferred. If using the command-line, please provide the argument --library (transformers,diffusers,timm)!"
"The library name could not be automatically inferred. If using the command-line, please provide the argument --library {transformers,diffusers,timm,sentence_transformers}. Example: `--library diffusers`."
)

if library_name == "sentence-transformers":
return "transformers"

return library_name

@classmethod
Expand Down Expand Up @@ -1647,6 +1670,17 @@ def standardize_model_attributes(
model_type = json.load(fp)["architecture"]

setattr(model.config, "model_type", model_type)
elif library_name == "sentence_transformers":
if "Transformer" in model[0].__class__.__name__:
model.config = model[0].auto_model.config
model.config.model_type = "sentence-transformers-transformer"
elif "CLIP" in model[0].__class__.__name__:
model.config = model[0].model.config
model.config.model_type = "sentence-transformers-clip"
else:
raise ValueError(
f"The export of a sentence-transformers model with the first module being {model[0].__class__.__name__} is currently not supported in Optimum. Please open an issue or submit a PR to add the support."
)

@staticmethod
def get_all_tasks():
Expand Down Expand Up @@ -1747,39 +1781,45 @@ def get_model_from_task(

if library_name == "timm":
model = model_class(f"hf_hub:{model_name_or_path}", pretrained=True, exportable=True)
TasksManager.standardize_model_attributes(
model_name_or_path, model, subfolder, revision, cache_dir, library_name
elif library_name == "sentence_transformers":
cache_folder = model_kwargs.pop("cache_folder", None)
use_auth_token = model_kwargs.pop("use_auth_token", None)
model = model_class(
model_name_or_path, device=device, cache_folder=cache_folder, use_auth_token=use_auth_token
)
return model

try:
if framework == "pt":
kwargs["torch_dtype"] = torch_dtype

if isinstance(device, str):
device = torch.device(device)
elif device is None:
device = torch.device("cpu")

# TODO : fix EulerDiscreteScheduler loading to enable for SD models
if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers":
with device:
# Initialize directly in the requested device, to save allocation time. Especially useful for large
# models to initialize on cuda device.
model = model_class.from_pretrained(model_name_or_path, **kwargs)
else:
try:
if framework == "pt":
kwargs["torch_dtype"] = torch_dtype

if isinstance(device, str):
device = torch.device(device)
elif device is None:
device = torch.device("cpu")

# TODO : fix EulerDiscreteScheduler loading to enable for SD models
if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers":
with device:
# Initialize directly in the requested device, to save allocation time. Especially useful for large
# models to initialize on cuda device.
model = model_class.from_pretrained(model_name_or_path, **kwargs)
else:
model = model_class.from_pretrained(model_name_or_path, **kwargs).to(device)
else:
model = model_class.from_pretrained(model_name_or_path, **kwargs).to(device)
else:
model = model_class.from_pretrained(model_name_or_path, **kwargs)
except OSError:
if framework == "pt":
logger.info("Loading TensorFlow model in PyTorch before exporting.")
kwargs["from_tf"] = True
model = model_class.from_pretrained(model_name_or_path, **kwargs)
else:
logger.info("Loading PyTorch model in TensorFlow before exporting.")
kwargs["from_pt"] = True
model = model_class.from_pretrained(model_name_or_path, **kwargs)
model = model_class.from_pretrained(model_name_or_path, **kwargs)
except OSError:
if framework == "pt":
logger.info("Loading TensorFlow model in PyTorch before exporting.")
kwargs["from_tf"] = True
model = model_class.from_pretrained(model_name_or_path, **kwargs)
else:
logger.info("Loading PyTorch model in TensorFlow before exporting.")
kwargs["from_pt"] = True
model = model_class.from_pretrained(model_name_or_path, **kwargs)

TasksManager.standardize_model_attributes(
model_name_or_path, model, subfolder, revision, cache_dir, library_name
)

return model

Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
is_onnx_available,
is_onnxruntime_available,
is_pydantic_available,
is_sentence_transformers_available,
is_timm_available,
is_torch_onnx_support_available,
require_numpy_strictly_lower,
Expand Down
7 changes: 6 additions & 1 deletion optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
_accelerate_available = importlib.util.find_spec("accelerate") is not None
_diffusers_available = importlib.util.find_spec("diffusers") is not None
_auto_gptq_available = importlib.util.find_spec("auto_gptq") is not None
_timm_available = importlib.util.find_spec("diffusers") is not None
_timm_available = importlib.util.find_spec("timm") is not None
_sentence_transformers_available = importlib.util.find_spec("sentence_transformers") is not None

torch_version = None
if is_torch_available():
Expand Down Expand Up @@ -107,6 +108,10 @@ def is_timm_available():
return _timm_available


def is_sentence_transformers_available():
return _sentence_transformers_available


def is_auto_gptq_available():
if _auto_gptq_available:
version_autogptq = packaging.version.parse(importlib_metadata.version("auto_gptq"))
Expand Down
2 changes: 1 addition & 1 deletion optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def __init__(
output_channels if output_channels is not None else normalized_config.vision_config.output_channels
)

def generate(self, input_name: str, framework: str = "pt"):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape = [self.batch_size, self.output_channels, self.image_embedding_size, self.image_embedding_size]
return self.random_float_tensor(shape, framework=framework)

Expand Down
Loading

0 comments on commit a3f4762

Please sign in to comment.