Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable the export of only one decoder #1257

Merged
merged 84 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
41b8f98
ONNX export decoder model refactorization
echarlaix Aug 3, 2023
f91a018
fix style
echarlaix Aug 4, 2023
4ce5fbe
fix index
echarlaix Aug 4, 2023
552eebc
merge main in branch
echarlaix Sep 8, 2023
aa40ba4
Merge branch 'main' into refactorization-decoder-ort
echarlaix Sep 12, 2023
9fa05e4
fix IO bindings
echarlaix Sep 12, 2023
3a0d76a
format
echarlaix Sep 12, 2023
b0aa234
enable mpt support
echarlaix Sep 12, 2023
dfabefd
format
echarlaix Sep 12, 2023
35df7bd
add trust remote code
echarlaix Sep 13, 2023
469edc8
fix test
echarlaix Sep 13, 2023
77cc527
format
echarlaix Sep 13, 2023
4f72a7e
rm redundant
echarlaix Sep 13, 2023
599c31c
format
echarlaix Sep 13, 2023
dac2376
merge main in branch
echarlaix Sep 13, 2023
c13b645
fix
echarlaix Sep 13, 2023
0e83cd1
Merge branch 'main' into refactorization-decoder-ort
echarlaix Sep 14, 2023
1f81f0b
Merge branch 'main' into refactorization-decoder-ort
echarlaix Sep 14, 2023
a0d0802
fix quantization
echarlaix Sep 14, 2023
7f65ce1
add test
echarlaix Sep 14, 2023
2840b81
format
echarlaix Sep 14, 2023
5fa7b20
format
echarlaix Sep 14, 2023
8011982
fix optimization
echarlaix Sep 14, 2023
b643308
fix opitmization
echarlaix Sep 15, 2023
ca9ce30
fix compatibility with legacy models
echarlaix Sep 15, 2023
144753a
format
echarlaix Sep 15, 2023
4ee6167
fix legacy models
echarlaix Sep 15, 2023
f2d0f84
format
echarlaix Sep 15, 2023
3ff719a
fix style
echarlaix Sep 15, 2023
d794141
format
echarlaix Sep 15, 2023
a34a16e
add export to main_export
echarlaix Sep 15, 2023
dfe7e5e
add legacy to ONNX export
echarlaix Sep 18, 2023
8d102f7
fix test
echarlaix Sep 18, 2023
62b8974
fix
echarlaix Sep 18, 2023
b8e18c3
rm unused import
echarlaix Sep 18, 2023
819691e
patch model to fix causal lm generation
echarlaix Sep 18, 2023
e259670
rm commen
echarlaix Sep 18, 2023
2f26201
add no psot process
echarlaix Sep 18, 2023
bed73d4
merge main in branch
echarlaix Sep 18, 2023
6d8acb4
fix
echarlaix Sep 18, 2023
52c1745
remove bloom caching
echarlaix Sep 18, 2023
1e9ba7e
fix
echarlaix Sep 19, 2023
4b68caa
format
echarlaix Sep 19, 2023
e5fd9f8
fix dynamic axis for position ids
echarlaix Sep 19, 2023
addad92
fix external data
echarlaix Sep 19, 2023
2c063c0
format
echarlaix Sep 19, 2023
1b47093
test
echarlaix Sep 19, 2023
35caaf2
test
echarlaix Sep 19, 2023
725857b
add model patcher
echarlaix Sep 19, 2023
46b26b5
format
echarlaix Sep 19, 2023
33957af
fix
echarlaix Sep 19, 2023
c2ec382
fix bart model patcher
echarlaix Sep 19, 2023
d86bce6
format
echarlaix Sep 19, 2023
be836b5
format
echarlaix Sep 20, 2023
b05f599
fix model patcher for opt models
echarlaix Sep 20, 2023
26d97e8
fix format
echarlaix Sep 20, 2023
4b6c3ed
add tmp onnxruntime max version
echarlaix Sep 20, 2023
615a219
add test
echarlaix Sep 20, 2023
b3525f8
format
echarlaix Sep 20, 2023
e0e2bae
tmp fix onnxruntime max version
echarlaix Sep 20, 2023
cbc935f
format
echarlaix Sep 20, 2023
624d91d
add test
echarlaix Sep 20, 2023
c558450
fix ort docker
echarlaix Sep 20, 2023
e72526d
fix format
echarlaix Sep 20, 2023
7926999
merge main in branch
echarlaix Sep 22, 2023
44ef0f1
add test
echarlaix Sep 22, 2023
ed8e74f
fix bart model patcher
echarlaix Sep 25, 2023
c13a170
raise when unsupported model
echarlaix Sep 25, 2023
524b668
add cached file
echarlaix Sep 25, 2023
8951ddf
minor
echarlaix Oct 3, 2023
2491ef3
add position warning
echarlaix Oct 4, 2023
0ab6e61
fixes
echarlaix Oct 5, 2023
1a7d491
enable post process after export to remove tied weights
echarlaix Oct 5, 2023
cd8d4be
comment
echarlaix Oct 5, 2023
e6de5e7
remove test
echarlaix Oct 5, 2023
4a32f7a
fix test
echarlaix Oct 5, 2023
a51686e
modify model
echarlaix Oct 6, 2023
e2f8a3b
remove deprecated use_merged in test
echarlaix Oct 6, 2023
52ce2d7
Merge branch 'main' into refactorization-decoder-ort
echarlaix Oct 9, 2023
b76f43a
Add mistral model patcher
echarlaix Oct 9, 2023
5b3d445
fix test
echarlaix Oct 9, 2023
5406f95
add slow test
echarlaix Oct 9, 2023
52e0c69
add workflow
echarlaix Oct 9, 2023
8883323
fix
echarlaix Oct 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ def parse_args_onnx(parser):
default=DEFAULT_DUMMY_SHAPES["nb_points_per_image"],
help="For Segment Anything. It corresponds to the number of points per segmentation masks.",
)
optional_group.add_argument(
"--legacy",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could merge the no_position_ids and legacy as they correspond to the previous export behavior and no_position_ids is not in a release yet. WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes good idea, will merge both

action="store_true",
help=("Export decoder only models in two (without + with past) model as a single ONNX file."),
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
)

# deprecated argument
parser.add_argument("--for-ort", action="store_true", help=argparse.SUPPRESS)
Expand Down Expand Up @@ -256,5 +261,6 @@ def run(self):
_variant=self.args.variant,
library_name=self.args.library_name,
no_position_ids=self.args.no_position_ids,
legacy=self.args.legacy,
**input_shapes,
)
6 changes: 5 additions & 1 deletion optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _get_submodels_and_onnx_configs(
fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
no_position_ids: bool = False,
legacy: bool = False,
):
is_stable_diffusion = "stable-diffusion" in task
if not custom_architecture:
Expand Down Expand Up @@ -106,7 +107,7 @@ def _get_submodels_and_onnx_configs(
):
models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config)
elif task.startswith("text-generation") and not monolith:
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config)
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config, legacy=legacy)
elif model.config.model_type == "sam":
models_and_onnx_configs = get_sam_models_for_export(model, onnx_config)
else:
Expand Down Expand Up @@ -185,6 +186,7 @@ def main_export(
_variant: str = "default",
library_name: Optional[str] = None,
no_position_ids: bool = False,
legacy: bool = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -425,6 +427,7 @@ def main_export(
preprocessors=preprocessors,
_variant=_variant,
no_position_ids=no_position_ids,
legacy=legacy,
)

if not is_stable_diffusion:
Expand Down Expand Up @@ -610,6 +613,7 @@ def main():
pad_token_id=args.pad_token_id,
for_ort=args.for_ort,
library_name=args.library_name,
legacy=args.legacy,
**input_shapes,
)

Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
elif self.task == "feature-extraction":
common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}})
else:
common_outputs = OrderedDict({"logits": {0: "batch_size"}})
common_outputs = OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}})
if self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")
Expand Down
7 changes: 2 additions & 5 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past_in_inputs:
common_inputs = {"input_ids": {0: "batch_size"}}
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}}
self.add_past_key_values(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
else:
Expand Down Expand Up @@ -164,10 +164,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
# generating wrong position_ids in the model itself:
# https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802
if not self.no_position_ids and self.task == "text-generation":
if self.use_past_in_inputs:
common_inputs["position_ids"] = {0: "batch_size"}
else:
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
echarlaix marked this conversation as resolved.
Show resolved Hide resolved

return common_inputs

Expand Down
41 changes: 36 additions & 5 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@
TextSeq2SeqOnnxConfig,
VisionOnnxConfig,
)
from .model_patcher import SAMModelPatcher, WavLMModelPatcher
from .model_patcher import (
BartModelPatcher,
BloomModelPatcher,
LlamaModelPatcher,
MPTModelPatcher,
OPTModelPatcher,
SAMModelPatcher,
WavLMModelPatcher,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -215,11 +223,21 @@ class OPTOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return OPTModelPatcher(self, model, model_kwargs=model_kwargs)


class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return LlamaModelPatcher(self, model, model_kwargs=model_kwargs)


class MPTOnnxConfig(TextDecoderOnnxConfig):
# MPT does not require position_ids input.
Expand All @@ -228,6 +246,11 @@ class MPTOnnxConfig(TextDecoderOnnxConfig):
num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers"
)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return MPTModelPatcher(self, model, model_kwargs=model_kwargs)


class BloomOnnxConfig(TextDecoderOnnxConfig):
# Bloom does not require position_ids input.
Expand Down Expand Up @@ -261,6 +284,11 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
1: decoder_sequence_name,
}

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return BloomModelPatcher(self, model, model_kwargs=model_kwargs)
echarlaix marked this conversation as resolved.
Show resolved Hide resolved


class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
Expand Down Expand Up @@ -400,7 +428,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
return int_tensor


class BartOnnxConfig(TextSeq2SeqOnnxConfig):
class M2M100OnnxConfig(TextSeq2SeqOnnxConfig):
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
encoder_num_layers="encoder_layers",
decoder_num_layers="decoder_layers",
Expand Down Expand Up @@ -524,11 +552,14 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
)


class MBartOnnxConfig(BartOnnxConfig):
pass
class BartOnnxConfig(M2M100OnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return BartModelPatcher(self, model, model_kwargs=model_kwargs)


class M2M100OnnxConfig(BartOnnxConfig):
class MBartOnnxConfig(BartOnnxConfig):
pass


Expand Down
113 changes: 113 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from transformers.utils import is_torch_available

from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask


if is_torch_available():
import torch
Expand Down Expand Up @@ -342,3 +344,114 @@ def patched_forward(
return {"iou_scores": iou_predictions, "pred_masks": low_res_masks}

self.patched_forward = patched_forward


class BloomModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

self.patch = self.real_config.task == "text-generation" and self.real_config.use_past
if self.patch:
self.orig_prepare_attn_mask = getattr(self._model.transformer, "_prepare_attn_mask")

def __enter__(self):
super().__enter__()
if self.patch:
setattr(self._model.transformer, "_prepare_attn_mask", _prepare_attn_mask)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if self.patch:
setattr(self._model.transformer, "_prepare_attn_mask", self.orig_prepare_attn_mask)


class LlamaModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

self.patch = self.real_config.task == "text-generation" and self.real_config.use_past
if self.patch:
self.orig_prepare_attn_mask = getattr(self._model.model, "_prepare_decoder_attention_mask")

def __enter__(self):
super().__enter__()
if self.patch:
setattr(self._model.model, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask)
echarlaix marked this conversation as resolved.
Show resolved Hide resolved

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if self.patch:
setattr(self._model.model, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask)


class BartModelPatcher(Seq2SeqModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

self.patch = self.real_config.task == "text-generation" and self.real_config.use_past
if self.patch:
self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask")

def __enter__(self):
super().__enter__()
if self.patch:
setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if self.patch:
setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask)


class OPTModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
self.patch = self.real_config.task == "text-generation" and self.real_config.use_past
if self.patch:
self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask")

def __enter__(self):
super().__enter__()
if self.patch:
setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if self.patch:
setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask)


class MPTModelPatcher(BloomModelPatcher):
pass


class BlenderbotSmallModelPatcher(BartModelPatcher):
pass


class BlenderbotModelPatcher(BartModelPatcher):
pass


class PegasusModelPatcher(BartModelPatcher):
pass
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
54 changes: 34 additions & 20 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
logging,
)
from ...utils.import_utils import _diffusers_version
from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask # noqa: F401
from ..tasks import TasksManager
from .constants import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME

Expand Down Expand Up @@ -158,15 +159,16 @@ def _get_submodels_for_export_stable_diffusion(


def _get_submodels_for_export_decoder(
model: Union["PreTrainedModel", "TFPreTrainedModel"], use_past: bool
model: Union["PreTrainedModel", "TFPreTrainedModel"],
use_past: bool,
legacy: bool = False,
) -> Dict[str, Union["PreTrainedModel", "TFPreTrainedModel"]]:
"""
Returns the decoder part of the model.
"""
models_for_export = {}
models_for_export = {ONNX_DECODER_NAME if legacy else "model": model}

models_for_export[ONNX_DECODER_NAME] = model
if use_past:
if legacy and use_past:
Comment on lines +170 to +172
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about encoder-decoders when legacy=False?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will impact only decoder models as this modification is done in _get_submodels_for_export_decoder, encoder-decoder models will not be impacted (not sure if that answers your question)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes makes sense thank you!

models_for_export[ONNX_DECODER_WITH_PAST_NAME] = model

return models_for_export
Expand Down Expand Up @@ -226,6 +228,7 @@ def get_encoder_decoder_models_for_export(
def get_decoder_models_for_export(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
config: "OnnxConfig",
legacy: bool = False,
) -> Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], "OnnxConfig"]]:
"""
Returns two versions of the decoder that can be used together to perform fast generation:
Expand All @@ -245,31 +248,42 @@ def get_decoder_models_for_export(
`Dict[str, Tuple[Union[PreTrainedModel, TFPreTrainedModel], OnnxConfig]]: A Dict containing the model and
onnx configs for the encoder and decoder parts of the model.
"""
models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past)

models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past, legacy=legacy)

onnx_kwargs = {"task": config.task, "float_dtype": config.float_dtype, "int_dtype": config.int_dtype}
if model.config.model_type.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:
onnx_kwargs["no_position_ids"] = config.no_position_ids

onnx_config = config.__class__(
model.config,
use_past=config.use_past,
use_past_in_inputs=False,
**onnx_kwargs,
)
models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], onnx_config)

if config.use_past:
onnx_config_with_past = config.__class__(
if legacy:
onnx_config = config.__class__(
model.config,
use_past=True,
use_past_in_inputs=True,
use_past=config.use_past,
use_past_in_inputs=False,
**onnx_kwargs,
)
models_for_export[ONNX_DECODER_WITH_PAST_NAME] = (
models_for_export[ONNX_DECODER_WITH_PAST_NAME],
onnx_config_with_past,
models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], onnx_config)

if config.use_past:
onnx_config_with_past = config.__class__(
model.config,
use_past=True,
use_past_in_inputs=True,
**onnx_kwargs,
)
models_for_export[ONNX_DECODER_WITH_PAST_NAME] = (
models_for_export[ONNX_DECODER_WITH_PAST_NAME],
onnx_config_with_past,
)

else:
onnx_config = config.__class__(
model.config,
use_past=config.use_past,
use_past_in_inputs=config.use_past,
**onnx_kwargs,
)
models_for_export["model"] = (models_for_export["model"], onnx_config)

return models_for_export

Expand Down
Loading