Skip to content

❓ [Question] How do you save a unet model compiled Torch-TensorRT (Stable Diffusion XL) #3018

Open
@dru10

Description

@dru10

❓ Question

How do you save a unet model compiled Torch-TensorRT from Stable Diffusion XL?

What you have already tried

I've tried following the compilation instructions from the tutorial (link). It wasn't very useful for my use case because I would like to save the compilation on disk and load it down the line when inference is needed.

So I've tried following the instructions which let you save your compilation using the dynamo backend (link). This script represents a summary of what I'm doing:

import torch
import torch_tensorrt
from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    use_safetensors=True,
).to("cuda")

inputs = [torch.randn((2, 4, 128, 128)).cuda()]  # After some digging, these are the input sizes needed to generate 1024x1024 images

trt_gm = torch_tensorrt.compile(pipe.unet, ir="dynamo", inputs=inputs)

But this yields the following error: TypeError: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'

So, I've tried to provide these arguments as well, found after some playing around with the code from diffusers:

kwargs = {
    "timestep": torch.tensor(951.0).cuda(),
    "encoder_hidden_states": torch.randn(
        (2, 77, 2048), dtype=torch.float16
    ).cuda(),
}

trt_gm = torch_tensorrt.compile(pipe.unet, ir="dynamo", inputs=inputs, **kwargs)

And I get the same error. Probably, the kwargs don't get passed down into the calling functions. After altering the code from torch export (which probably wasn't necessary), I got an error of the type: torch._dynamo.exc.InternalTorchDynamoError: argument of type 'NoneType' is not iterable

Any ideas how to properly compile a unet model from stable diffusion XL? Many thanks in advance.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • PyTorch Version (e.g., 1.0): 2.3.1+cu121
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu 22.04.3 LTS
  • How you installed PyTorch (conda, pip, libtorch, source): pip install torch --index-url https://download.pytorch.org/whl/cu121
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: Python 3.10.12
  • CUDA version: 12.4
  • GPU models and configuration: NVIDIA GeForce RTX 4090
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions