Skip to content

Commit

Permalink
[BugFix,Doc] Revert dynamic shape in export tutorial
Browse files Browse the repository at this point in the history
ghstack-source-id: fc856218e840469a5bb0143241d100e9cc612538
Pull Request resolved: #2563
  • Loading branch information
vmoens committed Nov 14, 2024
1 parent 304e707 commit 9d292a0
Showing 1 changed file with 19 additions and 43 deletions.
62 changes: 19 additions & 43 deletions tutorials/sphinx-tutorials/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,51 +338,27 @@
# `AOTI documentation <https://pytorch.org/docs/main/torch.compiler_aot_inductor.html>`_:
#

from tempfile import TemporaryDirectory

from torch._inductor import aoti_compile_and_package, aoti_load_package

with TemporaryDirectory() as tmpdir:
path = str(Path(tmpdir) / "model.pt2")
with torch.no_grad():
pkg_path = aoti_compile_and_package(
exported_policy,
args=(),
kwargs={"pixels": pixels},
# Specify the generated shared library path
package_path=path,
)
print("pkg_path", pkg_path)

compiled_module = aoti_load_package(pkg_path)

print(compiled_module(pixels=pixels))

#####################################
# An extra feature of AOTInductor is its capacity of dealing with dynamic shapes. This can be useful if you don't know
# the shape of your input data ahead of time. For instance, we may want to run our policy for one, two or more
# observations at a time. For this, let us re-export our policy, marking a new unsqueezed batch dimension as dynamic:

batch_dim = torch.export.Dim("batch", min=1, max=32)
pixels_unsqueeze = pixels.unsqueeze(0)
exported_dynamic_policy = torch.export.export(
policy_transform,
args=(),
kwargs={"pixels": pixels_unsqueeze},
strict=False,
dynamic_shapes={"pixels": {0: batch_dim}},
)
# Then recompile and export
pkg_path = aoti_compile_and_package(
exported_dynamic_policy,
args=(),
kwargs={"pixels": pixels_unsqueeze},
package_path=path,
)
# from tempfile import TemporaryDirectory
#
# from torch._inductor import aoti_compile_and_package, aoti_load_package
#
# with TemporaryDirectory() as tmpdir:
# path = str(Path(tmpdir) / "model.pt2")
# with torch.no_grad():
# pkg_path = aoti_compile_and_package(
# exported_policy,
# args=(),
# kwargs={"pixels": pixels},
# # Specify the generated shared library path
# package_path=path,
# )
# print("pkg_path", pkg_path)
#
# compiled_module = aoti_load_package(pkg_path)
#
# print(compiled_module(pixels=pixels))

#####################################
# More information about this can be found in the
# `AOTInductor tutorial <https://pytorch.org/tutorials/recipes/torch_export_aoti_python.html>`_.
#
# Exporting TorchRL models with ONNX
# ----------------------------------
Expand Down

0 comments on commit 9d292a0

Please sign in to comment.