Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update PyTorch export tutorial #2617

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

pavithraes
Copy link
Contributor

This PR updates the "Exporting StableHLO from PyTorch" tutorial. Specifically:

  • Update dependencies, combine the tensorflow dependency in the first section to avoid warnings (about using tensorflow-cpu) on Kaggle
  • Minor updates to narration and links, along with an introduction statement
  • [TODO] Fix dynamic batch size code, currently commented

(Supersedes #2616, same updates but we get a rich diff here)

"outputs": [],
"source": [
"# TODO: FIX\n",
"# Kernel crashes with \"INVALID_ARGUMENT: Non-broadcast dimensions must not be dynamic\"\n",
Copy link
Member

@GleasonK GleasonK Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TorchXLA coverage of dynamic shapes isn't great, for dynamic shape / export workflows we actually recommend torch_xla2 which uses PT->JAX->StableHLO, and has much better coverage for dynamic shapes:
https://github.com/pytorch/xla/tree/master/experimental/torch_xla2

Perhaps we just make a note of "For dynamic shape export we recommend using torch_xla2 (link), which converts PyTorch models to JAX prior to exporting to StableHLO and has high opset coverage along with great dynamic shape support."

If we want to add a tutorial on using torch_xla2, it likely wouldn't be too hard, here's a recent example (discord invite needed to view):
https://discord.com/channels/999073994483433573/1281512734902714379/1298361762839527467

I believe the following should work:

# Install torch_xla2
pip install torch-xla2 # Note: may need to follow GH instructions linked above but pypi package looks up to date.

# Export the program from native pytorch using torch.export
from torch.export import Dim
batch = Dim("batch", min=4, max=6)
dynamic_shapes = ({0: batch},)
dynamic_export = export(resnet18, sample_input, dynamic_shapes=dynamic_shapes)

# Convert the exported program to StableHLO using torch_xla2
from torch_xla2 import export
stablehlo = export.exported_program_to_stablehlo(dynamic_export).mlir_module()
print(stablehlo)

Note the print will include lots of debug info, so maybe use some indexing, or print starting with the word @main for 1k characters or something to keep the tutorial readable. I.e. print(stablehlo[1000:2000]) or something.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants