Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove attn mask patching #1473

Closed

Conversation

baskrahmer
Copy link
Contributor

What does this PR do?

Removes attention mask patching for specific models when doing an ONNX export.

Picked this up but I am not sure about:

  1. Whether to log a warning or raise an error upon exporting with sequence_length=1 and also not what the exact scope is for such an action. Right now it raises a warning for any models that have tasks prefixed with text-generation. Maybe this should be more specific.
  2. Whether to add a warning/error for previously exported models with such a configuration.

Fixes #1461

@fxmarty
Copy link
Contributor

fxmarty commented Oct 24, 2023

Not sure if this context is given anywhere in the code base, but anyway:

That's great @baskrahmer, thank you for the simplification! For context, @echarlaix introduced a simplification for the ONNX export of decoder-only models in #1257, where a single ONNX without subgraphs can be used, handling both prefill and decode steps (contrary to the previous decoder_model_merged.onnx) that handled both with subgraphs.

However, to do that, the traced model during the ONNX export needs to encompass the causal mask generation. Unfortunately, some architectures as llama https://github.com/huggingface/transformers/blob/fc142bd775ae4639f80a8b0085a5df33bd2853ce/src/transformers/models/llama/modeling_llama.py#L139-L147. So to export models with the new structure, we either need to patch the models to remove this controlflow (what was done), or simply use sequence_length>1 to go into the controlflow during tracing. That is what I was suggesting in #1461 for simplification purpose.

@baskrahmer baskrahmer force-pushed the remove_attn_mask_patching branch from eab6299 to 30a922c Compare October 26, 2023 16:40
@baskrahmer baskrahmer marked this pull request as ready for review October 26, 2023 16:40
@baskrahmer
Copy link
Contributor Author

@fxmarty thanks for the context. Sounds sensible :)

@fxmarty fxmarty requested review from echarlaix and fxmarty and removed request for echarlaix October 27, 2023 07:29
Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll test a bit more later!

optimum/utils/modeling_utils.py Show resolved Hide resolved
@fxmarty
Copy link
Contributor

fxmarty commented Oct 31, 2023

Hi @baskrahmer, following huggingface/transformers#27086 quite a few _make_causal, _prepare_attention_mask functions were removed and moved elsewhere. So this PR should fix the issue. I believe you can also remove the patching of _prepare_attn_mask and _make_causal_mask for Falcon (that don't exist anymore).

@fxmarty
Copy link
Contributor

fxmarty commented Oct 31, 2023

see #1495

@baskrahmer
Copy link
Contributor Author

I believe you can also remove the patching of _prepare_attn_mask and _make_causal_mask for Falcon (that don't exist anymore).

Not sure if I follow - you mean also removing this?

@baskrahmer baskrahmer force-pushed the remove_attn_mask_patching branch from 1c1a4be to 2df564d Compare October 31, 2023 19:55
@fxmarty
Copy link
Contributor

fxmarty commented Nov 2, 2023

Hi @baskrahmer, sorry for the late reply. I meant this:

# In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length.
if isinstance(self._model, FalconModel):
self._model._prepare_attn_mask = _falcon_prepare_attn_mask
else:
self._model.transformer._prepare_attn_mask = _falcon_prepare_attn_mask

EDIT: Nevermind, you already removed it!

I'm preparing a release for today in sync with Transformers release and we'll need this PR in, for the interest of time I'll be pushing to your branch to get this PR merged, apology in advance about that!

@baskrahmer
Copy link
Contributor Author

@fxmarty thanks for the reply. All good, you are definitely more in the details here so feel free to change anything :)

@fxmarty
Copy link
Contributor

fxmarty commented Nov 2, 2023

@baskrahmer #1509 is merged based off your branch (I could not push to your branch), sorry for the hurry and thank you for your contribution!

@fxmarty fxmarty closed this Nov 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove unnecessary _prepare_decoder_attention_mask patching
2 participants