Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to export ColPali Model to ONNX #2074

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

akshayballal95
Copy link

What does this PR do?

This PR adds support for exporting the ColPali merged model to ONNX format. The model is based on the "pali gemma" model type, and thus, I have added it under the "feature-extraction" task. Do suggest if there is a better way to integrate this. If this looks fine with a few modifications, I can add support for the Paligemma text-generation task as well.

Before submitting

Who can review?

@fxmarty, @echarlaix, @JingyaHuang, @michaelbenayoun

@akshayballal95
Copy link
Author

@fxmarty, @echarlaix, @JingyaHuang, @michaelbenayoun

Are you open to merging this?

@echarlaix
Copy link
Collaborator

echarlaix commented Dec 6, 2024

Comment on lines +512 to +527
class ColPaliModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

def patched_forward(input_ids=None, pixel_values=None, attention_mask=None, **kwargs):
outputs = self.orig_forward(
input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **kwargs
)
return outputs

self.patched_forward = patched_forward
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it needed ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its needed because the original ColPali Model only takes **kwargs and no named arguments. This resulted in an error. This fixes that error

optimum/exporters/onnx/model_patcher.py Outdated Show resolved Hide resolved
optimum/exporters/onnx/model_configs.py Outdated Show resolved Hide resolved
@akshayballal95
Copy link
Author

Apologies for the delay @akshayballal95, could you add a test with a tiny random model like https://huggingface.co/hf-internal-testing/tiny-random-PaliGemmaForConditionalGeneration, can be added here https://github.com/huggingface/optimum/blob/main/tests/exporters/exporters_utils.py#L37

I have added the conversion test. It works fine locally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants