Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Can't load UNet on H100 after compiling ExportedProgram with torch_tensorrt.dynamo.compile and saving #3163

Open
readleyj opened this issue Sep 16, 2024 · 13 comments
Assignees
Labels
bug Something isn't working

Comments

@readleyj
Copy link

readleyj commented Sep 16, 2024

Bug Description

I am trying to use torch_tensorrt.dynamo.compile() to AOT compile the UNet portion of a StableDiffusionPipeline from the diffusers library (version 0.30.2). I am able to export the UNet with torch.export.export(), compile it with torch_tensorrt.dynamo.compile() and save it with torch_tensorrt.save(). However, I am encountering a runtime error when attempting to load the saved compiled UNet with torch.export.load().

To Reproduce

Run the code below

import functools

import torch
import torch_tensorrt

from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler

def generate_sd_unet_inputs():
    sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
    timestep = torch.rand([], device="cuda", dtype=torch.float32) * 999
    encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
    
    return sample, timestep, encoder_hidden_states

with torch.inference_mode():
    pipe = StableDiffusionPipeline.from_pretrained(
        "CompVis/stable-diffusion-v1-4",
        torch_dtype=torch.float16,
        variant="fp16",
        use_safetensors=True,
    ).to("cuda")
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)

    unet_model = pipe.unet.eval()
    unet_model.forward = functools.partial(unet_model.forward, return_dict=False)
    
    arg_inputs_unet = generate_sd_unet_inputs()
    expected_outputs_unet = unet_model(*arg_inputs_unet)
    
    unet_exported_program = torch.export.export(unet_model, arg_inputs_unet)
        
    with torch_tensorrt.logging.errors():
        compiled_unet = torch_tensorrt.dynamo.compile(
            unet_exported_program,
            inputs=arg_inputs_unet,
            enabled_precisions={torch.float16, torch.float32},
            truncate_double=True,
        )
    
    torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
    loaded_unet = torch.export.load("sd_unet_compiled.ep").module()

Error message

...
WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py:370: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  engine_node = gm.graph.get_attr(engine_name)

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_0_engine target _run_on_acc_0_engine _run_on_acc_0_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_2_engine target _run_on_acc_2_engine _run_on_acc_2_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_4_engine target _run_on_acc_4_engine _run_on_acc_4_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_6_engine target _run_on_acc_6_engine _run_on_acc_6_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_8_engine target _run_on_acc_8_engine _run_on_acc_8_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1593: UserWarning: Additional 16 warnings suppressed about get_attr references
  warnings.warn(

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 48
     40     compiled_unet = torch_tensorrt.dynamo.compile(
     41         unet_exported_program,
     42         inputs=arg_inputs_unet,
     43         enabled_precisions={torch.float16, torch.float32},
     44         truncate_double=True,
     45     )
     47 torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
---> 48 loaded_unet = torch.export.load("sd_unet_compiled.ep")

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/__init__.py:476](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/__init__.py#line=475), in load(f, extra_files, expected_opset_version)
    468 artifact: SerializedArtifact = SerializedArtifact(
    469     serialized_exported_program,
    470     serialized_state_dict,
    471     serialized_constants,
    472     serialized_example_inputs,
    473 )
    475 # Deserialize ExportedProgram
--> 476 ep = deserialize(artifact, expected_opset_version)
    478 return ep

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py:2437](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py#line=2436), in deserialize(artifact, expected_opset_version)
   2433 exported_program_dict = json.loads(exported_program_str)
   2434 serialized_exported_program = _dict_to_dataclass(ExportedProgram, exported_program_dict)
   2435 return (
   2436     ExportedProgramDeserializer(expected_opset_version)
-> 2437     .deserialize(
   2438         serialized_exported_program,
   2439         artifact.state_dict,
   2440         artifact.constants,
   2441         artifact.example_inputs,
   2442     )
   2443 )

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py:2329](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py#line=2328), in ExportedProgramDeserializer.deserialize(self, exported_program, state_dict, constants, example_inputs)
   2314 res = (
   2315     GraphModuleDeserializer()
   2316     .deserialize(
   (...)
   2322     )
   2323 )
   2324 range_constraints = self.deserialize_range_constraints(
   2325     symbol_name_to_range,
   2326     res.names_to_symbols,
   2327 )
-> 2329 return ep.ExportedProgram(
   2330     root=res.graph_module,
   2331     graph=res.graph_module.graph,
   2332     graph_signature=res.signature,
   2333     state_dict=res.state_dict,  # type: ignore[arg-type]
   2334     range_constraints=range_constraints,
   2335     module_call_graph=res.module_call_graph,
   2336     example_inputs=res.example_inputs,
   2337     constants=res.constants,
   2338     verifiers=[load_verifier(v) for v in exported_program.verifiers],
   2339 )

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:700](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py#line=699), in ExportedProgram.__init__(self, root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs, constants, verifiers)
    698 self._verifiers = verifiers
    699 # Validate should be always the last step of the constructor.
--> 700 self.validate()

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:1117](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py#line=1116), in ExportedProgram.validate(self)
   1115 @compatibility(is_backward_compatible=False)
   1116 def validate(self):
-> 1117     self._validate()

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:1126](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py#line=1125), in ExportedProgram._validate(self)
   1122 assert (
   1123     len(self.verifiers) > 0
   1124 ), "ExportedProgram must have at least one verifier."
   1125 for v in self.verifiers:
-> 1126     v().check(self)

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py:155](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py#line=154), in Verifier.check(self, ep)
    153 @final
    154 def check(self, ep: "ExportedProgram") -> None:
--> 155     self._check_graph_module(ep.graph_module)
    156     _verify_exported_program_module_call_graph(ep)
    157     _verify_exported_program_signature(ep)

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py:214](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py#line=213), in Verifier._check_graph_module(self, gm)
    211 if not isinstance(mod, torch.fx.GraphModule):
    212     continue
--> 214 mod.graph.lint()
    215 for node in mod.graph.nodes:
    216     # TODO(T140410192): should have fake tensor for all dialects
    217     if node.op in {"call_module", "call_method"}:

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1549](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py#line=1548), in Graph.lint(self)
   1546     seen_values.add(node)
   1548     if node.name in seen_names:
-> 1549         raise RuntimeError(f'Node redefined name {node.name}!')
   1550     seen_names.add(node.name)
   1552 # Check targets are legit

RuntimeError: Node redefined name getitem_130!

Expected behavior

The code should load the saved compiled model without erroring out.

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 2.5.0.dev20240912+cu124
  • PyTorch Version (e.g. 1.0): 2.5.0.dev20240912+cu124
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu 22.04.4 LTS (x86_64)
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.11.10
  • CUDA version: 12.4
  • GPU models and configuration: 1/2 of an H100 (Configured with MIG)
  • Any other relevant information: Using diffusers version 0.30.2

Additional context

I have to use functools.partial() in the code above because the default output of the pipeline's forward method is the UNet2DConditionOutput dataclass. I tried to get rid of functools.partial() by instead using torch.export.register_dataclass() but was met with the same runtime error mentioned above.

Additionally, saving and loading the ExportedProgram (without Torch-TensorRT compilation) works fine.

@readleyj readleyj added the bug Something isn't working label Sep 16, 2024
@lanluo-nvidia
Copy link
Collaborator

@readleyj I tried in my environment from today's latest main branch using RTX 4080, I don't get the error as you pasted.
I can successfully load the unet.

@readleyj
Copy link
Author

readleyj commented Sep 18, 2024

@lanluo-nvidia Thank you for the reply. That is very strange. I will try with today's nightly and report back. Also, I am running this on an H100, could that possibly be the source of the issue?

@readleyj
Copy link
Author

readleyj commented Sep 18, 2024

I tried again with today's nightly (torch_tensorrt==2.5.0.dev20240918+cu124, torch==dev20240912+cu124) and I am encountering the same runtime error. Additionally, the results for the compiled UNet match the original UNet. At this point, I am not sure if the issue is with Torch-TensorRT or torch.export.

@readleyj
Copy link
Author

readleyj commented Sep 20, 2024

I also tried with release 2.4. There, I can successfully save and load the model but the compiled model outputs are full of nans. In general, Stable Diffusion with Torch-TensorRT seems very problematic.

@lanluo-nvidia
Copy link
Collaborator

@readleyj yes, we have bugs in release 2.4 which got fixed in current main branch, if you could paste the code:
after loaded the unet how do you generate the image.
I will give a try also.

@readleyj
Copy link
Author

readleyj commented Sep 21, 2024

@lanluo-nvidia After loading the UNet, I first check if the results match (expected_outputs_unet is defined in the previous code block)

with torch.inference_mode():    
    tensorrt_outputs_unet = loaded_unet(*arg_inputs_unet)
    for expected_output, tensorrt_output in zip(expected_outputs_unet, tensorrt_outputs_unet):
        assert torch.allclose(
            expected_output, tensorrt_output, 1e-2, 1e-2
        ), "UNet results do not match"
    
    print("UNet results match for Torch-TensorRT and Diffusers")

To generate an image, I plug the loaded UNet into a StableDiffusion pipeline as follows (code block assumes loaded_unet is already defined):

import torch
import torch_tensorrt
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler

PROMPT = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

with torch.inference_mode():
    pipe = StableDiffusionPipeline.from_pretrained(
        "CompVis/stable-diffusion-v1-4",
        torch_dtype=torch.float16,
        variant="fp16",
        use_safetensors=True
    ).to("cuda")
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
    
    class LoadedUNet(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.in_channels = pipe.unet.config.in_channels
            setattr(self, "config", pipe.unet.config)
            self.device = pipe.unet.device
    
        def forward(self, latent_model_input, t, encoder_hidden_states, **kwargs):
            sample = loaded_unet(latent_model_input, t, encoder_hidden_states)
            return sample
    
    pipe.unet = LoadedUNet()
    
    image = pipe(PROMPT,
                 num_inference_steps=50,
                 height=512,
                 width=512,
            ).images[0]

Note that you may receive a Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and / or seed. warning from diffusers. This happens when the image is all black or gibberish.

@lanluo-nvidia
Copy link
Collaborator

lanluo-nvidia commented Sep 21, 2024

@readleyj I have tried with release/2.5 branch: (this is our upcoming release branch and it is more stable then main branch since main branch is getting all the latest changes from both pytorch and torch_tensorrt)
Test 1): tested locally in my RTX4080 with release/2.5 branch

python -m pip install --pre --editable . --extra-index-url https://download.pytorch.org/whl/test/cu124
python -m pip install --pre “torchvision>=0.20.0,<0.21.0” --index-url https://download.pytorch.org/whl/test/cu124
python /home/lanl/git/script/python/stable_diffusion/test_issue3163.py

I can see it does throw out the results does not match error (the rtol atol is actually very close to 1e -2, 1e-2),
also does not see full of nans and also,
if I change the rtol, atol to 1e-2, 5e-2, it is able to generate the image as expected:

lan added expected_outputs_unet=(tensor([[[[-3.0304e-02,  4.7119e-01,  1.3281e-01,  ...,  3.9185e-01,
           -1.1200e-01,  1.3062e-01],
          [-2.1191e-01,  2.7026e-01, -2.9688e-01,  ..., -3.6548e-01,
            9.2712e-02, -6.3867e-01],
          [-5.2441e-01,  1.8555e-02,  9.9414e-01,  ..., -1.0977e+00,
            3.9429e-02, -6.8164e-01],
          ...,
          [-6.3574e-01, -2.2583e-01,  3.3234e-02,  ...,  2.8491e-01,
           -4.9316e-01,  9.1455e-01],
          [ 5.8594e-01, -7.3779e-01,  6.9695e-03,  ...,  7.1094e-01,
           -3.3569e-01,  1.2830e-01],
          [-2.4805e-01, -1.2152e-01,  8.3643e-01,  ..., -2.8641e-02,
            1.9739e-01,  1.3367e-01]],

         [[ 2.7319e-01,  4.0063e-01, -4.8682e-01,  ..., -6.2939e-01,
            2.1790e-02,  4.9634e-01],
          [-9.3323e-02, -7.7393e-01,  2.1399e-01,  ...,  7.6953e-01,
            4.5410e-01, -3.1909e-01],
          [ 1.8079e-01,  4.2017e-01,  1.1699e+00,  ...,  1.3843e-01,
           -3.2898e-02, -1.3953e-01],
          ...,
          [-1.4839e-02,  1.2131e-02,  1.7859e-01,  ...,  5.9717e-01,
            8.0762e-01, -7.5684e-01],
          [ 6.3232e-01, -6.1035e-01,  1.9214e-01,  ..., -3.3496e-01,
            4.9048e-01,  7.0166e-01],
          [ 2.3340e-01, -6.1279e-01,  7.8271e-01,  ..., -1.9067e-01,
           -6.3965e-01, -1.7529e-01]],

         [[-8.2324e-01, -4.7180e-02, -8.0383e-02,  ..., -6.2109e-01,
            4.5319e-02,  1.5930e-01],
          [ 1.0908e+00, -7.1143e-01,  9.6484e-01,  ...,  4.6777e-01,
           -2.4548e-01, -5.6445e-01],
          [ 2.2278e-01,  1.2256e+00,  3.4302e-01,  ..., -3.1372e-01,
            3.3203e-01,  1.1426e-01],
          ...,
          [ 1.7578e-01, -2.4002e-02,  3.9581e-02,  ...,  1.4160e-01,
            2.4902e-01, -2.7515e-01],
          [ 6.4893e-01, -1.7891e+00,  3.4570e-01,  ...,  3.9868e-01,
            5.0977e-01,  5.0146e-01],
          [ 2.1948e-01, -2.0020e-01,  3.3862e-01,  ..., -2.5488e-01,
            7.9346e-02, -3.8794e-01]],

         [[-3.6206e-01, -5.7080e-01, -7.8369e-02,  ...,  4.7388e-01,
            4.5093e-01, -2.6636e-01],
          [ 4.5630e-01,  4.8340e-01,  5.4053e-01,  ..., -2.9175e-01,
            2.3331e-02, -5.2979e-01],
          [ 4.5728e-01, -3.1177e-01, -1.5879e+00,  ..., -1.6748e-01,
            1.8408e-01, -3.1592e-01],
          ...,
          [-2.6074e-01,  1.6028e-01, -5.9766e-01,  ...,  2.4963e-01,
            2.9688e-01, -1.1699e+00],
          [-2.1367e+00,  5.9619e-01,  6.1133e-01,  ..., -3.5962e-01,
           -4.8193e-01,  1.5167e-02],
          [ 9.3018e-01,  5.7471e-01, -4.0332e-01,  ...,  5.8691e-01,
           -1.6826e+00, -4.2450e-02]]],


        [[[-1.4170e+00,  3.9453e-01, -4.8438e-01,  ...,  2.2180e-01,
            4.1724e-01,  9.6252e-02],
          [ 1.5015e-01,  4.6851e-01,  3.3643e-01,  ...,  5.3467e-02,
           -1.9666e-01, -9.2773e-02],
          [ 1.0840e+00,  5.0244e-01,  8.7695e-01,  ...,  3.6957e-02,
           -1.0840e+00, -7.1436e-01],
          ...,
          [-7.7100e-01,  2.4207e-01, -3.6084e-01,  ..., -6.8298e-02,
           -2.1643e-01,  1.4391e-03],
          [-3.7964e-01, -2.0032e-01,  4.6173e-02,  ..., -2.1252e-01,
            2.0972e-01, -6.0608e-02],
          [ 3.5840e-01, -1.3125e+00,  4.1528e-01,  ..., -6.7871e-01,
            9.4434e-01, -3.8055e-02]],

         [[-1.2225e-01,  1.2488e-01, -3.2935e-01,  ..., -3.2690e-01,
           -4.9219e-01,  4.6460e-01],
          [ 1.8616e-01,  1.6821e-01, -7.6675e-03,  ..., -3.6224e-02,
            4.6509e-01, -4.9976e-01],
          [ 5.1758e-01,  5.4883e-01, -7.9004e-01,  ...,  3.2275e-01,
           -1.1780e-01, -9.6191e-01],
          ...,
          [-6.4331e-02, -4.7754e-01, -8.2031e-01,  ...,  1.2024e-01,
           -1.6125e-01, -1.5442e-01],
          [-3.6938e-01, -2.1045e-01, -5.3857e-01,  ...,  1.2512e-01,
           -1.1646e-01,  1.6172e+00],
          [ 7.7515e-02,  4.2578e-01,  3.3789e-01,  ...,  3.6377e-01,
           -9.4189e-01, -8.0176e-01]],

         [[-5.4736e-01,  1.9482e-01, -1.4111e+00,  ..., -1.4087e-01,
            7.7576e-02, -6.6833e-02],
          [ 3.4082e-01, -1.1267e-01,  3.2129e-01,  ...,  5.9473e-01,
           -9.3896e-01, -3.3350e-01],
          [ 8.4277e-01,  1.0020e+00, -8.1055e-01,  ..., -2.3669e-01,
           -5.0049e-01, -4.0503e-01],
          ...,
          [-1.1273e-01,  4.9194e-02, -2.6172e-01,  ...,  2.6880e-01,
            3.7744e-01, -2.0447e-02],
          [ 1.2832e+00, -9.6985e-02, -2.9150e-01,  ...,  1.1292e-01,
           -1.9116e-01, -2.1643e-01],
          [ 1.8347e-01, -4.4531e-01, -3.4180e-01,  ...,  1.2793e-01,
            3.6011e-01,  9.5215e-01]],

         [[ 1.1777e+00, -3.1174e-02,  2.1133e+00,  ..., -1.7981e-01,
            1.1401e-01,  2.7466e-01],
          [-4.1113e-01,  1.4771e-01, -4.5264e-01,  ..., -5.7080e-01,
           -6.2354e-01, -2.0126e-02],
          [-1.0283e+00, -7.2070e-01,  1.6321e-01,  ..., -1.0547e-01,
            1.8105e+00, -6.9824e-01],
          ...,
          [-6.0352e-01,  3.2440e-02, -2.5537e-01,  ...,  2.0691e-01,
           -5.2277e-02, -4.4482e-01],
          [-8.5840e-01, -4.8291e-01, -2.7051e-01,  ..., -1.1688e-01,
           -4.1113e-01,  5.1562e-01],
          [ 3.0469e-01,  5.5273e-01,  1.2769e-01,  ..., -3.8086e-02,
           -2.3511e-01, -1.5625e-01]]]], device='cuda:0', dtype=torch.float16),)
lan added successfully saved compiled model
lan added successfully loaded compiled model
lan added tensorrt_outputs_unet=(tensor([[[[-0.0281,  0.4717,  0.1340,  ...,  0.3911, -0.1110,  0.1306],
          [-0.2119,  0.2698, -0.2964,  ..., -0.3638,  0.0912, -0.6377],
          [-0.5249,  0.0165,  0.9927,  ..., -1.1055,  0.0362, -0.6802],
          ...,
          [-0.6377, -0.2261,  0.0383,  ...,  0.2847, -0.4937,  0.9146],
          [ 0.5815, -0.7383,  0.0041,  ...,  0.7085, -0.3379,  0.1245],
          [-0.2472, -0.1250,  0.8354,  ..., -0.0276,  0.1979,  0.1342]],

         [[ 0.2739,  0.3992, -0.4868,  ..., -0.6265,  0.0222,  0.4968],
          [-0.0953, -0.7720,  0.2144,  ...,  0.7710,  0.4548, -0.3210],
          [ 0.1858,  0.4216,  1.1719,  ...,  0.1354, -0.0334, -0.1392],
          ...,
          [-0.0163,  0.0100,  0.1794,  ...,  0.5981,  0.8042, -0.7524],
          [ 0.6294, -0.6099,  0.1896,  ..., -0.3352,  0.4866,  0.7021],
          [ 0.2328, -0.6152,  0.7822,  ..., -0.1882, -0.6387, -0.1779]],

         [[-0.8218, -0.0516, -0.0839,  ..., -0.6216,  0.0450,  0.1566],
          [ 1.0908, -0.7075,  0.9653,  ...,  0.4673, -0.2465, -0.5654],
          [ 0.2251,  1.2188,  0.3413,  ..., -0.3125,  0.3306,  0.1157],
          ...,
          [ 0.1783, -0.0231,  0.0443,  ...,  0.1445,  0.2466, -0.2778],
          [ 0.6450, -1.7891,  0.3435,  ...,  0.3984,  0.5098,  0.5015],
          [ 0.2174, -0.2021,  0.3389,  ..., -0.2559,  0.0767, -0.3879]],

         [[-0.3618, -0.5703, -0.0786,  ...,  0.4707,  0.4492, -0.2673],
          [ 0.4551,  0.4832,  0.5396,  ..., -0.2891,  0.0298, -0.5327],
          [ 0.4585, -0.3105, -1.5898,  ..., -0.1676,  0.1854, -0.3171],
          ...,
          [-0.2598,  0.1622, -0.6060,  ...,  0.2484,  0.2986, -1.1680],
          [-2.1426,  0.5972,  0.6147,  ..., -0.3577, -0.4790,  0.0156],
          [ 0.9312,  0.5732, -0.4019,  ...,  0.5850, -1.6826, -0.0422]]],


        [[[-1.4150,  0.3909, -0.4822,  ...,  0.2200,  0.4146,  0.0955],
          [ 0.1503,  0.4656,  0.3364,  ...,  0.0545, -0.1993, -0.0927],
          [ 1.0791,  0.5010,  0.8813,  ...,  0.0355, -1.0830, -0.7158],
          ...,
          [-0.7690,  0.2378, -0.3633,  ..., -0.0714, -0.2169,  0.0047],
          [-0.3752, -0.2000,  0.0457,  ..., -0.2123,  0.2108, -0.0576],
          [ 0.3579, -1.3125,  0.4180,  ..., -0.6763,  0.9458, -0.0367]],

         [[-0.1224,  0.1234, -0.3323,  ..., -0.3257, -0.4893,  0.4646],
          [ 0.1893,  0.1653, -0.0038,  ..., -0.0334,  0.4651, -0.5015],
          [ 0.5176,  0.5493, -0.7900,  ...,  0.3220, -0.1155, -0.9575],
          ...,
          [-0.0630, -0.4766, -0.8208,  ...,  0.1180, -0.1615, -0.1575],
          [-0.3704, -0.2101, -0.5396,  ...,  0.1234, -0.1164,  1.6182],
          [ 0.0759,  0.4243,  0.3369,  ...,  0.3630, -0.9458, -0.8022]],

         [[-0.5449,  0.1912, -1.4131,  ..., -0.1411,  0.0784, -0.0674],
          [ 0.3401, -0.1118,  0.3210,  ...,  0.5952, -0.9404, -0.3323],
          [ 0.8384,  1.0010, -0.8042,  ..., -0.2351, -0.5015, -0.4026],
          ...,
          [-0.1104,  0.0457, -0.2615,  ...,  0.2690,  0.3806, -0.0194],
          [ 1.2861, -0.0950, -0.2893,  ...,  0.1160, -0.1880, -0.2148],
          [ 0.1829, -0.4490, -0.3406,  ...,  0.1271,  0.3596,  0.9507]],

         [[ 1.1777, -0.0316,  2.1074,  ..., -0.1764,  0.1111,  0.2766],
          [-0.4094,  0.1512, -0.4502,  ..., -0.5728, -0.6245, -0.0194],
          [-1.0283, -0.7241,  0.1641,  ..., -0.1035,  1.8174, -0.6987],
          ...,
          [-0.6021,  0.0345, -0.2515,  ...,  0.2064, -0.0580, -0.4395],
          [-0.8604, -0.4844, -0.2668,  ..., -0.1143, -0.4133,  0.5127],
          [ 0.3044,  0.5537,  0.1259,  ..., -0.0396, -0.2357, -0.1576]]]],
       device='cuda:0', dtype=torch.float16),)
Traceback (most recent call last):
  File "/home/lanl/git/script/python/stable_diffusion/test_issue3163.py", line 47, in <module>
    assert torch.allclose(
AssertionError: UNet results do not match

@lanluo-nvidia
Copy link
Collaborator

Test2) tested locally in my RTX4080 with release/2.5 branch found that if I do not do save and load the model, directly use the torch_tensorrt compiled model to inference, the UNet results does match and it can also generate the images as expected as Test1)

@lanluo-nvidia
Copy link
Collaborator

lanluo-nvidia commented Sep 21, 2024

Test3) tested with H100 using release/2.5 docker image:
docker run --gpus all --ipc=host --rm -it ghcr.io/pytorch/tensorrt/torch_tensorrt:release_2.5 bash

in the docker container:
apt-get install -y vim
python -m pip install diffusers transformers accelerate
python test_issue3163.py

It does throw me the following error:

lan added successfully saved compiled model
Traceback (most recent call last):
  File "/opt/torch_tensorrt/test_issue3163.py", line 42, in <module>
    loaded_unet = torch.export.load("sd_unet_compiled.ep").module()
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/export/__init__.py", line 473, in load
    ep = deserialize(artifact, expected_opset_version)
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/_export/serde/serialize.py", line 2437, in deserialize
    .deserialize(
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/_export/serde/serialize.py", line 2329, in deserialize
    return ep.ExportedProgram(
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/export/exported_program.py", line 700, in __init__
    self.validate()
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/export/exported_program.py", line 1117, in validate
    self._validate()
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/export/exported_program.py", line 1126, in _validate
    v().check(self)
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/_export/verifier.py", line 155, in check
    self._check_graph_module(ep.graph_module)
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/_export/verifier.py", line 214, in _check_graph_module
    mod.graph.lint()
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/fx/graph.py", line 1549, in lint
    raise RuntimeError(f'Node redefined name {node.name}!')
RuntimeError: Node redefined name getitem_130!
root@s4124-0059:/opt/torch_tensorrt# nvidia-smi
Sat Sep 21 18:38:15 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:45:00.0 Off |                    0 |
| N/A   26C    P0             61W /  700W |       0MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |

@readleyj
Copy link
Author

readleyj commented Sep 21, 2024

Yes, the error in Test 3) is exactly what I'm getting on my H100. I thought the problem might be with torch.export so I already created an issue on the PyTorch repo (pytorch/pytorch#136317)

@lanluo-nvidia
Copy link
Collaborator

@readleyj seems like it only happens for H100, I did the exactly same test in RTX 4080 using the same image, same test code as you provided, it is working.

Test4) test with RTX 4080 using release/2.5 docker image:
docker run --gpus all --ipc=host --rm -it ghcr.io/pytorch/tensorrt/torch_tensorrt:release_2.5 bash

in the docker container:
apt-get install -y vim
python -m pip install diffusers transformers accelerate
python test_issue3163.py
it is Unet result is matching(rtol, atol: 1e-2, 1e-2) and able to generate the image as expected.

@readleyj
Copy link
Author

readleyj commented Sep 21, 2024

@lanluo-nvidia Also, on my H100 tests, the model successfully compiles, the UNet results match (using compiled_unet directly) and I can generate an image (if I use compiled_unet in place of loaded_unet). But it's saving and loading the compiled model that breaks. To me this seems like a torch.export issue but I'm not sure.

@readleyj
Copy link
Author

@lanluo-nvidia Any updates on this? Should I expect this issue to be resolved soon or will this be on the backlog for a while? Unfortunately, I only have H100s at my disposal and this is blocking progress for me.

@readleyj readleyj changed the title 🐛 [Bug] Can't load UNet after compiling ExportedProgram with torch_tensorrt.dynamo.compile and saving 🐛 [Bug] Can't load UNet on H100 after compiling ExportedProgram with torch_tensorrt.dynamo.compile and saving Oct 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants