Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

❓ [Question] How do you save a unet model compiled Torch-TensorRT (Stable Diffusion XL) #3018

Open
dru10 opened this issue Jul 18, 2024 · 2 comments
Assignees
Labels
question Further information is requested

Comments

@dru10
Copy link

dru10 commented Jul 18, 2024

❓ Question

How do you save a unet model compiled Torch-TensorRT from Stable Diffusion XL?

What you have already tried

I've tried following the compilation instructions from the tutorial (link). It wasn't very useful for my use case because I would like to save the compilation on disk and load it down the line when inference is needed.

So I've tried following the instructions which let you save your compilation using the dynamo backend (link). This script represents a summary of what I'm doing:

import torch
import torch_tensorrt
from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    use_safetensors=True,
).to("cuda")

inputs = [torch.randn((2, 4, 128, 128)).cuda()]  # After some digging, these are the input sizes needed to generate 1024x1024 images

trt_gm = torch_tensorrt.compile(pipe.unet, ir="dynamo", inputs=inputs)

But this yields the following error: TypeError: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'

So, I've tried to provide these arguments as well, found after some playing around with the code from diffusers:

kwargs = {
    "timestep": torch.tensor(951.0).cuda(),
    "encoder_hidden_states": torch.randn(
        (2, 77, 2048), dtype=torch.float16
    ).cuda(),
}

trt_gm = torch_tensorrt.compile(pipe.unet, ir="dynamo", inputs=inputs, **kwargs)

And I get the same error. Probably, the kwargs don't get passed down into the calling functions. After altering the code from torch export (which probably wasn't necessary), I got an error of the type: torch._dynamo.exc.InternalTorchDynamoError: argument of type 'NoneType' is not iterable

Any ideas how to properly compile a unet model from stable diffusion XL? Many thanks in advance.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • PyTorch Version (e.g., 1.0): 2.3.1+cu121
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu 22.04.3 LTS
  • How you installed PyTorch (conda, pip, libtorch, source): pip install torch --index-url https://download.pytorch.org/whl/cu121
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: Python 3.10.12
  • CUDA version: 12.4
  • GPU models and configuration: NVIDIA GeForce RTX 4090
  • Any other relevant information:

Additional context

@dru10 dru10 added the question Further information is requested label Jul 18, 2024
@dru10
Copy link
Author

dru10 commented Jul 19, 2024

After altering the code from torch export (which probably wasn't necessary)

For reference, this is the modification I did inside py/torch_tensorrt/dynamo/_tracer.py#L81

exp_program = export(mod, tuple(torch_inputs), kwargs=kwargs, dynamic_shapes=tuple(dynamic_shapes))

And this is the traceback

Traceback (most recent call last):
  File "/workspace/torch-tensorrt/src/dummy.py", line 21, in <module>
    trt_gm = torch_tensorrt.compile(pipe.unet, ir="dynamo", inputs=inputs, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 248, in compile
    exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch_tensorrt/dynamo/_tracer.py", line 81, in trace
    exp_program = export(mod, tuple(torch_inputs), kwargs=kwargs, dynamic_shapes=tuple(dynamic_shapes))
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/__init__.py", line 174, in export
    return _export(
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/_trace.py", line 635, in wrapper
    raise e
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/_trace.py", line 618, in wrapper
    ep = fn(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/exported_program.py", line 83, in wrapper
    return fn(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/_trace.py", line 860, in _export
    gm_torch_level = _export_to_torch_ir(
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/export/_trace.py", line 347, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1311, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 703, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
    tracer.run()
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
    super().run()
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1272, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 335, in call_function
    return super().call_function(tx, args, kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 289, in call_function
    return super().call_function(tx, args, kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
    tracer.run()
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1644, in CONTAINS_OP
    self.push(right.call_method(self, "__contains__", [left], {}))
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/torch/_dynamo/variables/constant.py", line 182, in call_method
    result = search in self.value
torch._dynamo.exc.InternalTorchDynamoError: argument of type 'NoneType' is not iterable

from user code:
   File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1162, in forward
    aug_emb = self.get_aug_embed(
  File "/workspace/torch-tensorrt/venv/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 973, in get_aug_embed
    if "text_embeds" not in added_cond_kwargs:

@pangyoki
Copy link

pangyoki commented Sep 3, 2024

I have the same question. have you solved the problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants