-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable the export of only one decoder #1257
Changes from 69 commits
41b8f98
f91a018
4ce5fbe
552eebc
aa40ba4
9fa05e4
3a0d76a
b0aa234
dfabefd
35df7bd
469edc8
77cc527
4f72a7e
599c31c
dac2376
c13b645
0e83cd1
1f81f0b
a0d0802
7f65ce1
2840b81
5fa7b20
8011982
b643308
ca9ce30
144753a
4ee6167
f2d0f84
3ff719a
d794141
a34a16e
dfe7e5e
8d102f7
62b8974
b8e18c3
819691e
e259670
2f26201
bed73d4
6d8acb4
52c1745
1e9ba7e
4b68caa
e5fd9f8
addad92
2c063c0
1b47093
35caaf2
725857b
46b26b5
33957af
c2ec382
d86bce6
be836b5
b05f599
26d97e8
4b6c3ed
615a219
b3525f8
e0e2bae
cbc935f
624d91d
c558450
e72526d
7926999
44ef0f1
ed8e74f
c13a170
524b668
8951ddf
2491ef3
0ab6e61
1a7d491
cd8d4be
e6de5e7
4a32f7a
a51686e
e2f8a3b
52ce2d7
b76f43a
5b3d445
5406f95
52e0c69
8883323
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
logging, | ||
) | ||
from ...utils.import_utils import _diffusers_version | ||
from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask # noqa: F401 | ||
from ..tasks import TasksManager | ||
from .constants import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME | ||
|
||
|
@@ -158,15 +159,16 @@ def _get_submodels_for_export_stable_diffusion( | |
|
||
|
||
def _get_submodels_for_export_decoder( | ||
model: Union["PreTrainedModel", "TFPreTrainedModel"], use_past: bool | ||
model: Union["PreTrainedModel", "TFPreTrainedModel"], | ||
use_past: bool, | ||
legacy: bool = False, | ||
) -> Dict[str, Union["PreTrainedModel", "TFPreTrainedModel"]]: | ||
""" | ||
Returns the decoder part of the model. | ||
""" | ||
models_for_export = {} | ||
models_for_export = {ONNX_DECODER_NAME if legacy else "model": model} | ||
|
||
models_for_export[ONNX_DECODER_NAME] = model | ||
if use_past: | ||
if legacy and use_past: | ||
Comment on lines
+170
to
+172
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about encoder-decoders when legacy=False? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will impact only decoder models as this modification is done in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes makes sense thank you! |
||
models_for_export[ONNX_DECODER_WITH_PAST_NAME] = model | ||
|
||
return models_for_export | ||
|
@@ -226,6 +228,7 @@ def get_encoder_decoder_models_for_export( | |
def get_decoder_models_for_export( | ||
model: Union["PreTrainedModel", "TFPreTrainedModel"], | ||
config: "OnnxConfig", | ||
legacy: bool = False, | ||
) -> Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], "OnnxConfig"]]: | ||
""" | ||
Returns two versions of the decoder that can be used together to perform fast generation: | ||
|
@@ -245,31 +248,42 @@ def get_decoder_models_for_export( | |
`Dict[str, Tuple[Union[PreTrainedModel, TFPreTrainedModel], OnnxConfig]]: A Dict containing the model and | ||
onnx configs for the encoder and decoder parts of the model. | ||
""" | ||
models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past) | ||
|
||
models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past, legacy=legacy) | ||
|
||
onnx_kwargs = {"task": config.task, "float_dtype": config.float_dtype, "int_dtype": config.int_dtype} | ||
if model.config.model_type.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: | ||
onnx_kwargs["no_position_ids"] = config.no_position_ids | ||
|
||
onnx_config = config.__class__( | ||
model.config, | ||
use_past=config.use_past, | ||
use_past_in_inputs=False, | ||
**onnx_kwargs, | ||
) | ||
models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], onnx_config) | ||
|
||
if config.use_past: | ||
onnx_config_with_past = config.__class__( | ||
if legacy: | ||
onnx_config = config.__class__( | ||
model.config, | ||
use_past=True, | ||
use_past_in_inputs=True, | ||
use_past=config.use_past, | ||
use_past_in_inputs=False, | ||
**onnx_kwargs, | ||
) | ||
models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( | ||
models_for_export[ONNX_DECODER_WITH_PAST_NAME], | ||
onnx_config_with_past, | ||
models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], onnx_config) | ||
|
||
if config.use_past: | ||
onnx_config_with_past = config.__class__( | ||
model.config, | ||
use_past=True, | ||
use_past_in_inputs=True, | ||
**onnx_kwargs, | ||
) | ||
models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( | ||
models_for_export[ONNX_DECODER_WITH_PAST_NAME], | ||
onnx_config_with_past, | ||
) | ||
|
||
else: | ||
onnx_config = config.__class__( | ||
model.config, | ||
use_past=config.use_past, | ||
use_past_in_inputs=config.use_past, | ||
**onnx_kwargs, | ||
) | ||
models_for_export["model"] = (models_for_export["model"], onnx_config) | ||
|
||
return models_for_export | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could merge the
no_position_ids
andlegacy
as they correspond to the previous export behavior andno_position_ids
is not in a release yet. WDYT?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes good idea, will merge both