Description
❓ Question
How do you save a unet model compiled Torch-TensorRT from Stable Diffusion XL?
What you have already tried
I've tried following the compilation instructions from the tutorial (link). It wasn't very useful for my use case because I would like to save the compilation on disk and load it down the line when inference is needed.
So I've tried following the instructions which let you save your compilation using the dynamo backend (link). This script represents a summary of what I'm doing:
import torch
import torch_tensorrt
from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
use_safetensors=True,
).to("cuda")
inputs = [torch.randn((2, 4, 128, 128)).cuda()] # After some digging, these are the input sizes needed to generate 1024x1024 images
trt_gm = torch_tensorrt.compile(pipe.unet, ir="dynamo", inputs=inputs)
But this yields the following error: TypeError: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'
So, I've tried to provide these arguments as well, found after some playing around with the code from diffusers:
kwargs = {
"timestep": torch.tensor(951.0).cuda(),
"encoder_hidden_states": torch.randn(
(2, 77, 2048), dtype=torch.float16
).cuda(),
}
trt_gm = torch_tensorrt.compile(pipe.unet, ir="dynamo", inputs=inputs, **kwargs)
And I get the same error. Probably, the kwargs don't get passed down into the calling functions. After altering the code from torch export (which probably wasn't necessary), I got an error of the type: torch._dynamo.exc.InternalTorchDynamoError: argument of type 'NoneType' is not iterable
Any ideas how to properly compile a unet model from stable diffusion XL? Many thanks in advance.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- PyTorch Version (e.g., 1.0): 2.3.1+cu121
- CPU Architecture: x86_64
- OS (e.g., Linux): Ubuntu 22.04.3 LTS
- How you installed PyTorch (
conda
,pip
,libtorch
, source):pip install torch --index-url https://download.pytorch.org/whl/cu121
- Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: Python 3.10.12
- CUDA version: 12.4
- GPU models and configuration: NVIDIA GeForce RTX 4090
- Any other relevant information: