From 8e29429fda305021632a5c024136e20561273ff9 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Mon, 15 Jul 2024 09:14:54 -0700 Subject: [PATCH] [AOT] Use torch dialect in backend form instead of raw torch dialect (#81) With https://github.com/llvm/torch-mlir/pull/3541 we can now intercept the Torch dialect during TorchDynamo export at two stages: 1. OutputType.RAW: This gets us the torch dialect as-imported from the FX graph 2. OutputType.TORCH: This gets us the torch dialect in backend compliant form, after the raw torch goes through DecomposeComplexOps and ReduceOpVariants. We've been using 1 for all the AOT (e2e) tests, however this PR changes it to use 2, which is closer to the real backend lowering pipelines we use internally as well. --- requirements_lock.txt | 24 ++++++++++++------------ tools/aot/torch_exporter_harness.py | 5 +++++ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/requirements_lock.txt b/requirements_lock.txt index c5089423..d0653563 100644 --- a/requirements_lock.txt +++ b/requirements_lock.txt @@ -6,15 +6,15 @@ # --find-links https://github.com/llvm/torch-mlir-release/releases/expanded_assets/dev-wheels -filelock==3.15.3 \ - --hash=sha256:0151273e5b5d6cf753a61ec83b3a9b7d8821c39ae9af9d7ecf2f9e2f17404103 \ - --hash=sha256:e1199bf5194a2277273dacd50269f0d87d0682088a3c561c15674ea9005d8635 +filelock==3.15.4 \ + --hash=sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb \ + --hash=sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7 # via # torch # triton -fsspec==2024.6.0 \ - --hash=sha256:58d7122eb8a1a46f7f13453187bfea4972d66bf01618d37366521b1998034cee \ - --hash=sha256:f579960a56e6d8038a9efc8f9c77279ec12e6299aa86b0769a7e9c46b94527c2 +fsspec==2024.6.1 \ + --hash=sha256:3cb443f8bcd2efb31295a5b9fdb02aee81d8452c80d28f97a6d0959e6cee101e \ + --hash=sha256:fad7d7e209dd4c1208e3bbfda706620e0da5142bebbd9c384afb95b07e798e49 # via torch jinja2==3.1.4 \ --hash=sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369 \ @@ -143,9 +143,9 @@ packaging==24.1 \ --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \ --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124 # via torch-mlir -sympy==1.12.1 \ - --hash=sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88 \ - --hash=sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515 +sympy==1.13.0 \ + --hash=sha256:3b6af8f4d008b9a1a6a4268b335b984b23835f26d1d60b0526ebc71d48a25f57 \ + --hash=sha256:6b0b32a4673fb91bd3cac3b55406c8e01d53ae22780be467301cc452f6680c92 # via torch torch==2.4.0.dev20240604+cpu \ --hash=sha256:95dd17654e0f7c82a9ef50dca328222d7b434b7cc18713f0b83f99bc3f14fb2c \ @@ -153,9 +153,9 @@ torch==2.4.0.dev20240604+cpu \ # via # -r requirements.txt # torch-mlir -torch-mlir==20240620.127 \ - --hash=sha256:89b04afe4d39b273cefa141d69bcaf72a5517924d7b87799a88b3fd4c7786e7e \ - --hash=sha256:8ef2305b6f6846fddeda6fa9878548392eb7146cbba539db6c7da717b0f68c81 +torch-mlir==20240714.152 \ + --hash=sha256:0f4d1c75bd7f61f152f623c1d8cecfabae870b50f74cb4a1c98da227d25e19fb \ + --hash=sha256:e7030c986db3f204534ee3e4834d039899a606460bf951b6455b506c6cbdbf42 # via -r requirements.txt typing-extensions==4.12.2 \ --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ diff --git a/tools/aot/torch_exporter_harness.py b/tools/aot/torch_exporter_harness.py index c30060e7..d59c3307 100644 --- a/tools/aot/torch_exporter_harness.py +++ b/tools/aot/torch_exporter_harness.py @@ -48,6 +48,11 @@ def main(): *loader_result.inputs, # unpack list of input tensors dynamic_shapes=loader_result.dynamic_shapes, import_symbolic_shape_expressions=True, + # This is the Torch dialect imported from Dynamo/FX export and run + # through `torchdynamo-export-to-torch-backend-pipeline` (which + # runs `ReduceOpVariantsPass` and `DecomposeComplexOpsPass`) to + # get it in a backend compliant form (aka torch backend contract). + output_type="torch", func_name=loader_result.func_name, )