Skip to content

🐛 [Bug] Can't load UNet on H100 after compiling ExportedProgram with torch_tensorrt.dynamo.compile and saving #3163

Closed
@readleyj

Description

@readleyj

Bug Description

I am trying to use torch_tensorrt.dynamo.compile() to AOT compile the UNet portion of a StableDiffusionPipeline from the diffusers library (version 0.30.2). I am able to export the UNet with torch.export.export(), compile it with torch_tensorrt.dynamo.compile() and save it with torch_tensorrt.save(). However, I am encountering a runtime error when attempting to load the saved compiled UNet with torch.export.load().

To Reproduce

Run the code below

import functools

import torch
import torch_tensorrt

from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler

def generate_sd_unet_inputs():
    sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
    timestep = torch.rand([], device="cuda", dtype=torch.float32) * 999
    encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
    
    return sample, timestep, encoder_hidden_states

with torch.inference_mode():
    pipe = StableDiffusionPipeline.from_pretrained(
        "CompVis/stable-diffusion-v1-4",
        torch_dtype=torch.float16,
        variant="fp16",
        use_safetensors=True,
    ).to("cuda")
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)

    unet_model = pipe.unet.eval()
    unet_model.forward = functools.partial(unet_model.forward, return_dict=False)
    
    arg_inputs_unet = generate_sd_unet_inputs()
    expected_outputs_unet = unet_model(*arg_inputs_unet)
    
    unet_exported_program = torch.export.export(unet_model, arg_inputs_unet)
        
    with torch_tensorrt.logging.errors():
        compiled_unet = torch_tensorrt.dynamo.compile(
            unet_exported_program,
            inputs=arg_inputs_unet,
            enabled_precisions={torch.float16, torch.float32},
            truncate_double=True,
        )
    
    torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
    loaded_unet = torch.export.load("sd_unet_compiled.ep").module()

Error message

...
WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py:370: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  engine_node = gm.graph.get_attr(engine_name)

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_0_engine target _run_on_acc_0_engine _run_on_acc_0_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_2_engine target _run_on_acc_2_engine _run_on_acc_2_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_4_engine target _run_on_acc_4_engine _run_on_acc_4_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_6_engine target _run_on_acc_6_engine _run_on_acc_6_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_8_engine target _run_on_acc_8_engine _run_on_acc_8_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1593: UserWarning: Additional 16 warnings suppressed about get_attr references
  warnings.warn(

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 48
     40     compiled_unet = torch_tensorrt.dynamo.compile(
     41         unet_exported_program,
     42         inputs=arg_inputs_unet,
     43         enabled_precisions={torch.float16, torch.float32},
     44         truncate_double=True,
     45     )
     47 torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
---> 48 loaded_unet = torch.export.load("sd_unet_compiled.ep")

File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/__init__.py:476 in load(f, extra_files, expected_opset_version)
    468 artifact: SerializedArtifact = SerializedArtifact(
    469     serialized_exported_program,
    470     serialized_state_dict,
    471     serialized_constants,
    472     serialized_example_inputs,
    473 )
    475 # Deserialize ExportedProgram
--> 476 ep = deserialize(artifact, expected_opset_version)
    478 return ep

File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py:2437, in deserialize(artifact, expected_opset_version)
   2433 exported_program_dict = json.loads(exported_program_str)
   2434 serialized_exported_program = _dict_to_dataclass(ExportedProgram, exported_program_dict)
   2435 return (
   2436     ExportedProgramDeserializer(expected_opset_version)
-> 2437     .deserialize(
   2438         serialized_exported_program,
   2439         artifact.state_dict,
   2440         artifact.constants,
   2441         artifact.example_inputs,
   2442     )
   2443 )

File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py:2329, in ExportedProgramDeserializer.deserialize(self, exported_program, state_dict, constants, example_inputs)
   2314 res = (
   2315     GraphModuleDeserializer()
   2316     .deserialize(
   (...)
   2322     )
   2323 )
   2324 range_constraints = self.deserialize_range_constraints(
   2325     symbol_name_to_range,
   2326     res.names_to_symbols,
   2327 )
-> 2329 return ep.ExportedProgram(
   2330     root=res.graph_module,
   2331     graph=res.graph_module.graph,
   2332     graph_signature=res.signature,
   2333     state_dict=res.state_dict,  # type: ignore[arg-type]
   2334     range_constraints=range_constraints,
   2335     module_call_graph=res.module_call_graph,
   2336     example_inputs=res.example_inputs,
   2337     constants=res.constants,
   2338     verifiers=[load_verifier(v) for v in exported_program.verifiers],
   2339 )

File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:700, in ExportedProgram.__init__(self, root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs, constants, verifiers)
    698 self._verifiers = verifiers
    699 # Validate should be always the last step of the constructor.
--> 700 self.validate()

File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:1117, in ExportedProgram.validate(self)
   1115 @compatibility(is_backward_compatible=False)
   1116 def validate(self):
-> 1117     self._validate()

File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:1126, in ExportedProgram._validate(self)
   1122 assert (
   1123     len(self.verifiers) > 0
   1124 ), "ExportedProgram must have at least one verifier."
   1125 for v in self.verifiers:
-> 1126     v().check(self)

File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py:155, in Verifier.check(self, ep)
    153 @final
    154 def check(self, ep: "ExportedProgram") -> None:
--> 155     self._check_graph_module(ep.graph_module)
    156     _verify_exported_program_module_call_graph(ep)
    157     _verify_exported_program_signature(ep)

File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py:214, in Verifier._check_graph_module(self, gm)
    211 if not isinstance(mod, torch.fx.GraphModule):
    212     continue
--> 214 mod.graph.lint()
    215 for node in mod.graph.nodes:
    216     # TODO(T140410192): should have fake tensor for all dialects
    217     if node.op in {"call_module", "call_method"}:

File ~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1549, in Graph.lint(self)
   1546     seen_values.add(node)
   1548     if node.name in seen_names:
-> 1549         raise RuntimeError(f'Node redefined name {node.name}!')
   1550     seen_names.add(node.name)
   1552 # Check targets are legit

RuntimeError: Node redefined name getitem_130!

Expected behavior

The code should load the saved compiled model without erroring out.

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 2.5.0.dev20240912+cu124
  • PyTorch Version (e.g. 1.0): 2.5.0.dev20240912+cu124
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu 22.04.4 LTS (x86_64)
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.11.10
  • CUDA version: 12.4
  • GPU models and configuration: 1/2 of an H100 (Configured with MIG)
  • Any other relevant information: Using diffusers version 0.30.2

Additional context

I have to use functools.partial() in the code above because the default output of the pipeline's forward method is the UNet2DConditionOutput dataclass. I tried to get rid of functools.partial() by instead using torch.export.register_dataclass() but was met with the same runtime error mentioned above.

Additionally, saving and loading the ExportedProgram (without Torch-TensorRT compilation) works fine.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions