Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for exporting traced graph? #225

Open
KyleErnewein opened this issue Aug 15, 2024 · 1 comment
Open

Support for exporting traced graph? #225

KyleErnewein opened this issue Aug 15, 2024 · 1 comment

Comments

@KyleErnewein
Copy link

Hi, I have a usecase that ingests ONNX models, and needs to convert to PyTorch then export a traced graph (via torch.export.export()).

After converting ONNX to torch (via onnx2torch.convert()), I'm running into issues tracing the graph, due to dynamic flow control in the converted torch model.

Is there any plan for onnx2torch to support this type of usecase? Or are there any recommendations for how to workaround the dynamic flow control in the converted torch model?

An example of a problematic op is reshape - the converted torch model has logic that is conditional on the input shape parameter, to replicate ONNX's special handling of shape dimensions that have value of 0 (meaning use input shape for that dim).

Here's code to reproduce that issue:

import io
import onnx
import onnx2torch
import torch

# Create ONNX model containing a single reshape node:
class M(torch.nn.Module):
    def forward(self, x):
        x = x.reshape(20, 10)
        return x

torch_args = (torch.rand(10, 20),)
with io.BytesIO() as tmp_file:
    torch.onnx.export(model=M(), args=torch_args, f=tmp_file)
    onnx_model = onnx.load_from_string(tmp_file.getvalue())

# convert onnx --> torch
converted_torch = onnx2torch.convert(onnx_model)

# export traced graph (ExportedProgram):
ep = torch.export.export(converted_torch, args=torch_args)

This raises the following error (snippet - actual trace is very long):

...
UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "<eval_with_key>.1", line 6, in forward
    reshape = self.Reshape(input_1, constant);  input_1 = constant = None
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/onnx2torch/node_converters/reshape.py", line 36, in forward
    return _forward()
  File "/usr/local/lib/python3.8/site-packages/onnx2torch/node_converters/reshape.py", line 31, in _forward
    return self._do_reshape(input_tensor, shape)
  File "/usr/local/lib/python3.8/site-packages/onnx2torch/node_converters/reshape.py", line 20, in _do_reshape
    if torch.any(shape == 0):

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Any insights would be appreciated. Thanks!

@KyleErnewein
Copy link
Author

For reference, this is my python env (output of pip freeze):

asttokens==2.4.1
backcall==0.2.0
decorator==5.1.1
executing==2.0.1
filelock==3.15.4
fsspec==2024.6.1
ipython==8.12.3
jedi==0.19.1
Jinja2==3.1.4
MarkupSafe==2.1.5
matplotlib-inline==0.1.7
mpmath==1.3.0
networkx==3.1
numpy==1.24.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.20
nvidia-nvtx-cu12==12.1.105
onnx==1.16.2
onnx2torch==1.5.15
parso==0.8.4
pexpect==4.9.0
pickleshare==0.7.5
pillow==10.4.0
prompt_toolkit==3.0.47
protobuf==5.27.3
ptyprocess==0.7.0
pure_eval==0.2.3
Pygments==2.18.0
six==1.16.0
stack-data==0.6.3
sympy==1.13.2
torch==2.4.0
torchvision==0.19.0
traitlets==5.14.3
triton==3.0.0
typing_extensions==4.12.2
wcwidth==0.2.13

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant