Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix window_reverse to support batch inference when using onnx or tensorrt #166

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jhwei
Copy link

@jhwei jhwei commented Jan 18, 2025

This fix in window_reverse is useful when converting to onnx, tensorrt or other models. The change will not affect the pytorch inference or training.
The previous code may lead the converter to regard B as a constant value (which is 1 in most cases). This will lead to a wrong answer in batch inference. The new code will regard B as a dynamic value -1.

This PR used the same change in microsoft/Swin-Transformer#257

@ZhengPeng7
Copy link
Owner

Thanks! I've tested the inference and did some training, which shows no inconsistency.
But I'm not quite familiar with the ONNX problem that you mentioned. Do you mean that setting the batch size in network architecture to a fixed value is a bad idea for ONNX deployment? For example, after conversion with B=1, the converted ONNX model cannot do prediction with another batch size? But this modification can fix this problem?

@jhwei
Copy link
Author

jhwei commented Jan 21, 2025

I am not an expert in onnx either. I just did some experiments and found the issue.

To my understanding, onnx did something like generate a graph consists operator nodes to do the inference. During generating the graph, data will be split into tensors and constants, constants will be precomputed and saved in the graph.

In this case, B will be regarded as constant when generating the graph and it's related to the given input. It's OK when using fixed size input as the input shape used to generate onnx file and inference is the same.

Onnx also support dynamic shape input(which I am testing), and it support setting dimensions to be dynamic (usually it's batch size) and I think it will treat the shape as a 'tensor' when generating the graph. However, in the previous code, the batch size will be fixed to the input shape (1) and you will get a shape like (1, H, W, 2C) when you inference with b=2.

@jhwei
Copy link
Author

jhwei commented Jan 21, 2025

I have successfully done inferencing with dynamic batch size with onnx and tensorrt and will PR some example ipynb scripts later(it take some time).

I will give some brief ideas here for some interested.

  1. If the final goal is to do fast inference, I recommend regarding the torchvision::deform_conv as defined onnx operator DeformConv. This operator is defined in opt 19, but onnxruntime haven't implemented yet. However tensorrt already implemented the operator, so you can get a engine file to do the inference which is fast.

  2. If onnx inference is needed, this PR will help when doing dynamic batch inference. However, there is a bug in deform_conv2d_onnx_exporter and it does something like swin transformer by seting batch size to a fixed number. I tried to change the batch size to -1 and it worked, but I need more experiments to make sure it's the right way. The onnxruntime inference is way too slow compared with the pytorch version. Converting onnx to tensorrt will help just like others already did.

@ZhengPeng7
Copy link
Owner

Hi, @jhwei, thanks a lot for the explanation. However, as I tested, updating only those three lines cannot help the dynamic batch size in ONNX model inference. Someone else also made another PR on the dynamic input: #167. I'll also test the modification there.

But your PR makes sense on that part, I'll accept it after figuring out all these relevant things, thanks anyway :)

@jhwei
Copy link
Author

jhwei commented Jan 23, 2025

Thanks, @ZhengPeng7 This PR #167 partially does exactly what I did in idea 1 and it's the fastest way.

However, the PR does not really implement dynamic batch inference as the generated onnx file still only accept batch size equals 1. dynamic_axes should be used when using torch.onnx.export. In addition, when you inference with batch size larger than 1, the result is wrong without changes in window_reverse.

I just found out timm swim transformer also used the implementation today. Also, I think using C instead of B may not be the key. The reason may be int convert B into a constant and if we use the original B from size may also work.

Maybe I can work out a scratch demo tomorrow for both idea 1 &2 for better explaination.

@jhwei
Copy link
Author

jhwei commented Jan 23, 2025

@ZhengPeng7 Please see https://github.com/jhwei/BiRefNet/blob/45ec226ff1378fe0030d72c7bbdd7fbd7a7d3763/tutorials/BiRefNet_pth2onnx.ipynb for the demo supporting dynamic batch size. (The code may not be clean enough).

For idea 2, I tested that changing "batch":batch" to "batch":-1 worked for dynamic batch size. However, this is just an intuitive idea and I haven't tested it theotically.

@ZhengPeng7
Copy link
Owner

Thanks a lot, @jhwei. I have to say that I've been too exhausted recently. But I'll definitely test all the things you made and reply to you about whether it works well.

@jhwei
Copy link
Author

jhwei commented Jan 24, 2025

Thanks a lot, @jhwei. I have to say that I've been too exhausted recently. But I'll definitely test all the things you made and reply to you about whether it works well.

Thank you, take your time. It's a great chance for me to explore on onnx and trt conversion. Feel free to post any questions when you have time.

And, happy Chinese New Year in advance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants