Description
Bug Description
I am trying to use torch_tensorrt.dynamo.compile()
to AOT compile the UNet portion of a StableDiffusionPipeline
from the diffusers library (version 0.30.2). I am able to export the UNet with torch.export.export()
, compile it with torch_tensorrt.dynamo.compile()
and save it with torch_tensorrt.save()
. However, I am encountering a runtime error when attempting to load the saved compiled UNet with torch.export.load()
.
To Reproduce
Run the code below
import functools
import torch
import torch_tensorrt
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
def generate_sd_unet_inputs():
sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
timestep = torch.rand([], device="cuda", dtype=torch.float32) * 999
encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
return sample, timestep, encoder_hidden_states
with torch.inference_mode():
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
unet_model = pipe.unet.eval()
unet_model.forward = functools.partial(unet_model.forward, return_dict=False)
arg_inputs_unet = generate_sd_unet_inputs()
expected_outputs_unet = unet_model(*arg_inputs_unet)
unet_exported_program = torch.export.export(unet_model, arg_inputs_unet)
with torch_tensorrt.logging.errors():
compiled_unet = torch_tensorrt.dynamo.compile(
unet_exported_program,
inputs=arg_inputs_unet,
enabled_precisions={torch.float16, torch.float32},
truncate_double=True,
)
torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
loaded_unet = torch.export.load("sd_unet_compiled.ep").module()
Error message
...
WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py:370: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
engine_node = gm.graph.get_attr(engine_name)
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_0_engine target _run_on_acc_0_engine _run_on_acc_0_engine of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_2_engine target _run_on_acc_2_engine _run_on_acc_2_engine of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_4_engine target _run_on_acc_4_engine _run_on_acc_4_engine of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_6_engine target _run_on_acc_6_engine _run_on_acc_6_engine of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_8_engine target _run_on_acc_8_engine _run_on_acc_8_engine of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1593: UserWarning: Additional 16 warnings suppressed about get_attr references
warnings.warn(
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[1], line 48
40 compiled_unet = torch_tensorrt.dynamo.compile(
41 unet_exported_program,
42 inputs=arg_inputs_unet,
43 enabled_precisions={torch.float16, torch.float32},
44 truncate_double=True,
45 )
47 torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
---> 48 loaded_unet = torch.export.load("sd_unet_compiled.ep")
File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/__init__.py:476 in load(f, extra_files, expected_opset_version)
468 artifact: SerializedArtifact = SerializedArtifact(
469 serialized_exported_program,
470 serialized_state_dict,
471 serialized_constants,
472 serialized_example_inputs,
473 )
475 # Deserialize ExportedProgram
--> 476 ep = deserialize(artifact, expected_opset_version)
478 return ep
File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py:2437, in deserialize(artifact, expected_opset_version)
2433 exported_program_dict = json.loads(exported_program_str)
2434 serialized_exported_program = _dict_to_dataclass(ExportedProgram, exported_program_dict)
2435 return (
2436 ExportedProgramDeserializer(expected_opset_version)
-> 2437 .deserialize(
2438 serialized_exported_program,
2439 artifact.state_dict,
2440 artifact.constants,
2441 artifact.example_inputs,
2442 )
2443 )
File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py:2329, in ExportedProgramDeserializer.deserialize(self, exported_program, state_dict, constants, example_inputs)
2314 res = (
2315 GraphModuleDeserializer()
2316 .deserialize(
(...)
2322 )
2323 )
2324 range_constraints = self.deserialize_range_constraints(
2325 symbol_name_to_range,
2326 res.names_to_symbols,
2327 )
-> 2329 return ep.ExportedProgram(
2330 root=res.graph_module,
2331 graph=res.graph_module.graph,
2332 graph_signature=res.signature,
2333 state_dict=res.state_dict, # type: ignore[arg-type]
2334 range_constraints=range_constraints,
2335 module_call_graph=res.module_call_graph,
2336 example_inputs=res.example_inputs,
2337 constants=res.constants,
2338 verifiers=[load_verifier(v) for v in exported_program.verifiers],
2339 )
File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:700, in ExportedProgram.__init__(self, root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs, constants, verifiers)
698 self._verifiers = verifiers
699 # Validate should be always the last step of the constructor.
--> 700 self.validate()
File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:1117, in ExportedProgram.validate(self)
1115 @compatibility(is_backward_compatible=False)
1116 def validate(self):
-> 1117 self._validate()
File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:1126, in ExportedProgram._validate(self)
1122 assert (
1123 len(self.verifiers) > 0
1124 ), "ExportedProgram must have at least one verifier."
1125 for v in self.verifiers:
-> 1126 v().check(self)
File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py:155, in Verifier.check(self, ep)
153 @final
154 def check(self, ep: "ExportedProgram") -> None:
--> 155 self._check_graph_module(ep.graph_module)
156 _verify_exported_program_module_call_graph(ep)
157 _verify_exported_program_signature(ep)
File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py:214, in Verifier._check_graph_module(self, gm)
211 if not isinstance(mod, torch.fx.GraphModule):
212 continue
--> 214 mod.graph.lint()
215 for node in mod.graph.nodes:
216 # TODO(T140410192): should have fake tensor for all dialects
217 if node.op in {"call_module", "call_method"}:
File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1549, in Graph.lint(self)
1546 seen_values.add(node)
1548 if node.name in seen_names:
-> 1549 raise RuntimeError(f'Node redefined name {node.name}!')
1550 seen_names.add(node.name)
1552 # Check targets are legit
RuntimeError: Node redefined name getitem_130!
Expected behavior
The code should load the saved compiled model without erroring out.
Environment
- Torch-TensorRT Version (e.g. 1.0.0): 2.5.0.dev20240912+cu124
- PyTorch Version (e.g. 1.0): 2.5.0.dev20240912+cu124
- CPU Architecture: x86_64
- OS (e.g., Linux): Ubuntu 22.04.4 LTS (x86_64)
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.11.10
- CUDA version: 12.4
- GPU models and configuration: 1/2 of an H100 (Configured with MIG)
- Any other relevant information: Using diffusers version 0.30.2
Additional context
I have to use functools.partial()
in the code above because the default output of the pipeline's forward method is the UNet2DConditionOutput
dataclass. I tried to get rid of functools.partial()
by instead using torch.export.register_dataclass()
but was met with the same runtime error mentioned above.
Additionally, saving and loading the ExportedProgram (without Torch-TensorRT compilation) works fine.