-
Notifications
You must be signed in to change notification settings - Fork 365
Added flux demo #3418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
cehongwang
wants to merge
30
commits into
main
Choose a base branch
from
flux-demo
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Added flux demo #3418
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
f85820c
Added CPU offloading
cehongwang 4242743
Chagned CPU offload to default
cehongwang e87d27e
Added support to module with graph break
cehongwang 76cab94
Added back the control flag and fixed the CI
cehongwang 6352110
Chagned CPU offload to default
cehongwang 214e2e6
Added flux demo
cehongwang c9d8456
changed the file place and deleted unnecessary code
cehongwang e77737d
Fixed memory overhead and enabled Flux with Mutable Module
cehongwang 42c384d
Supported LoRA
cehongwang 33db1cb
Refined Flux demo, solved a bug of device mismatch, and prototyped Cu…
cehongwang c69f41c
Enabled Cuda Graph
cehongwang 8f44d7f
Enabled weight streaming and CudaGraph. Supported MTTM saving with dy…
cehongwang 044f4e6
Changed the Refitting test to disable CPU offload
cehongwang d383be4
Fixed Cuda Error
cehongwang 580fc03
Fixed the bug of SDXL Cuda Error
cehongwang 3e8323f
Changed the way to enable CudaGraph for MTTM
cehongwang 92ae47d
Finalize the refit revision
cehongwang 98cbd76
Fixed the comments
cehongwang 39ac60e
Correct the flux export example
cehongwang 6caf833
Added a textbox to display time the generation process takes
cehongwang c018151
Added perf script
cehongwang 9e390da
added back control flag
cehongwang ba76f6d
trying to add quantization to Flux
cehongwang 27dee53
Enable int8 and fp8 quantization for FLUX
cehongwang 044acdf
Optimized FLUX compilation memory usage
cehongwang b06e632
Optimized lowering and decomposition to benchmark quantization again
cehongwang 3dcf128
Fixed the benchmark typo
cehongwang 1581f0c
Use MutableTorchTensorRTModule to do quantization
cehongwang 18b6455
Added quantization debug script
cehongwang a4ff6bb
Fixed fp16 quantization error
cehongwang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
import argparse | ||
import re | ||
import time | ||
|
||
import gradio as gr | ||
import modelopt.torch.quantization as mtq | ||
import torch | ||
import torch_tensorrt | ||
from diffusers import FluxPipeline | ||
|
||
parser = argparse.ArgumentParser( | ||
description="Run Flux quantization with different dtypes" | ||
) | ||
|
||
parser.add_argument( | ||
"--dtype", | ||
choices=["fp8", "int8", "fp16"], | ||
default="fp16", | ||
help="Select the data type to use (fp8 or int8 or fp16)", | ||
) | ||
args = parser.parse_args() | ||
# Update enabled precisions based on dtype argument | ||
|
||
if args.dtype == "fp8": | ||
enabled_precisions = {torch.float8_e4m3fn, torch.float16} | ||
ptq_config = mtq.FP8_DEFAULT_CFG | ||
elif args.dtype == "int8": | ||
enabled_precisions = {torch.int8, torch.float16} | ||
ptq_config = mtq.INT8_DEFAULT_CFG | ||
ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None | ||
elif args.dtype == "fp16": | ||
enabled_precisions = {torch.float16} | ||
print(f"\nUsing {args.dtype}") | ||
|
||
|
||
DEVICE = "cuda:0" | ||
pipe = FluxPipeline.from_pretrained( | ||
"black-forest-labs/FLUX.1-dev", | ||
torch_dtype=torch.float16, | ||
) | ||
|
||
|
||
pipe.to(DEVICE).to(torch.float16) | ||
backbone = pipe.transformer | ||
backbone.eval() | ||
|
||
|
||
def filter_func(name): | ||
pattern = re.compile( | ||
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" | ||
) | ||
return pattern.match(name) is not None | ||
|
||
|
||
def do_calibrate( | ||
pipe, | ||
prompt: str, | ||
) -> None: | ||
""" | ||
Run calibration steps on the pipeline using the given prompts. | ||
""" | ||
image = pipe( | ||
prompt, | ||
output_type="pil", | ||
num_inference_steps=20, | ||
generator=torch.Generator("cuda").manual_seed(0), | ||
).images[0] | ||
|
||
|
||
def forward_loop(mod): | ||
# Switch the pipeline's backbone, run calibration | ||
pipe.transformer = mod | ||
do_calibrate( | ||
pipe=pipe, | ||
prompt="test", | ||
) | ||
|
||
|
||
if args.dtype != "fp16": | ||
backbone = mtq.quantize(backbone, ptq_config, forward_loop) | ||
mtq.disable_quantizer(backbone, filter_func) | ||
|
||
batch_size = 2 | ||
|
||
BATCH = torch.export.Dim("batch", min=1, max=8) | ||
dynamic_shapes = { | ||
"hidden_states": {0: BATCH}, | ||
"encoder_hidden_states": {0: BATCH}, | ||
"pooled_projections": {0: BATCH}, | ||
"timestep": {0: BATCH}, | ||
"txt_ids": {}, | ||
"img_ids": {}, | ||
"guidance": {0: BATCH}, | ||
"joint_attention_kwargs": {}, | ||
"return_dict": None, | ||
} | ||
|
||
settings = { | ||
"strict": False, | ||
"allow_complex_guards_as_runtime_asserts": True, | ||
"enabled_precisions": enabled_precisions, | ||
"truncate_double": True, | ||
"min_block_size": 1, | ||
"debug": False, | ||
"use_python_runtime": True, | ||
"immutable_weights": False, | ||
"offload_module_to_cpu": True, | ||
} | ||
|
||
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) | ||
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) | ||
pipe.transformer = trt_gm | ||
|
||
|
||
def generate_image(prompt, inference_step, batch_size=2): | ||
start_time = time.time() | ||
image = pipe( | ||
prompt, | ||
output_type="pil", | ||
num_inference_steps=inference_step, | ||
num_images_per_prompt=batch_size, | ||
).images | ||
end_time = time.time() | ||
return image, end_time - start_time | ||
|
||
|
||
generate_image(["Test"], 2) | ||
torch.cuda.empty_cache() | ||
|
||
|
||
def model_change(model): | ||
if model == "Torch Model": | ||
pipe.transformer = backbone | ||
backbone.to(DEVICE) | ||
else: | ||
backbone.to("cpu") | ||
pipe.transformer = trt_gm | ||
torch.cuda.empty_cache() | ||
|
||
|
||
def load_lora(path): | ||
|
||
pipe.load_lora_weights( | ||
path, | ||
adapter_name="lora1", | ||
) | ||
pipe.set_adapters(["lora1"], adapter_weights=[1]) | ||
pipe.fuse_lora() | ||
pipe.unload_lora_weights() | ||
print("LoRA loaded! Begin refitting") | ||
generate_image(["Test"], 2) | ||
print("Refitting Finished!") | ||
|
||
|
||
load_lora("/home/TensorRT/examples/apps/NGRVNG.safetensors") | ||
|
||
|
||
# Create Gradio interface | ||
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo: | ||
gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT") | ||
|
||
with gr.Row(): | ||
with gr.Column(): | ||
# Input components | ||
prompt_input = gr.Textbox( | ||
label="Prompt", placeholder="Enter your prompt here...", lines=3 | ||
) | ||
model_dropdown = gr.Dropdown( | ||
choices=["Torch Model", "Torch-TensorRT Accelerated Model"], | ||
value="Torch-TensorRT Accelerated Model", | ||
label="Model Variant", | ||
) | ||
|
||
lora_upload_path = gr.Textbox( | ||
label="LoRA Path", | ||
placeholder="Enter the LoRA checkpoint path here", | ||
value="/home/TensorRT/examples/apps/NGRVNG.safetensors", | ||
lines=2, | ||
) | ||
num_steps = gr.Slider( | ||
minimum=20, maximum=100, value=20, step=1, label="Inference Steps" | ||
) | ||
batch_size = gr.Slider( | ||
minimum=1, maximum=8, value=1, step=1, label="Batch Size" | ||
) | ||
|
||
generate_btn = gr.Button("Generate Image") | ||
load_lora_btn = gr.Button("Load LoRA") | ||
|
||
with gr.Column(): | ||
# Output component | ||
output_image = gr.Gallery(label="Generated Image") | ||
time_taken = gr.Textbox( | ||
label="Generation Time (seconds)", interactive=False | ||
) | ||
|
||
# Connect the button to the generation function | ||
model_dropdown.change(model_change, inputs=[model_dropdown]) | ||
load_lora_btn.click( | ||
fn=load_lora, | ||
inputs=[ | ||
lora_upload_path, | ||
], | ||
) | ||
|
||
# Update generate button click to include time output | ||
generate_btn.click( | ||
fn=generate_image, | ||
inputs=[ | ||
prompt_input, | ||
num_steps, | ||
batch_size, | ||
], | ||
outputs=[output_image, time_taken], | ||
) | ||
|
||
# Launch the interface | ||
if __name__ == "__main__": | ||
demo.launch() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# %% | ||
# Import the following libraries | ||
# ----------------------------- | ||
import re | ||
|
||
import modelopt.torch.opt as mto | ||
import modelopt.torch.quantization as mtq | ||
import torch | ||
import torch_tensorrt | ||
from diffusers import FluxPipeline | ||
from diffusers.models.attention_processor import Attention | ||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel | ||
from modelopt.torch.quantization.utils import export_torch_mode | ||
from torch.export._trace import _export | ||
from transformers import AutoModelForCausalLM | ||
|
||
# %% | ||
DEVICE = "cuda:0" | ||
pipe = FluxPipeline.from_pretrained( | ||
"black-forest-labs/FLUX.1-dev", | ||
torch_dtype=torch.float32, | ||
) | ||
pipe.transformer = FluxTransformer2DModel( | ||
num_layers=1, num_single_layers=1, guidance_embeds=True | ||
) | ||
|
||
pipe.to(DEVICE).to(torch.float32) | ||
# Store the config and transformer backbone | ||
config = pipe.transformer.config | ||
# global backbone | ||
backbone = pipe.transformer | ||
backbone.eval() | ||
|
||
|
||
def filter_func(name): | ||
pattern = re.compile( | ||
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" | ||
) | ||
return pattern.match(name) is not None | ||
|
||
|
||
def generate_image(pipe, prompt, image_name): | ||
seed = 42 | ||
image = pipe( | ||
prompt, | ||
output_type="pil", | ||
num_inference_steps=20, | ||
generator=torch.Generator("cuda").manual_seed(seed), | ||
).images[0] | ||
image.save(f"{image_name}.png") | ||
print(f"Image generated using {image_name} model saved as {image_name}.png") | ||
|
||
|
||
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code") | ||
|
||
# %% | ||
# Quantization | ||
|
||
|
||
def do_calibrate( | ||
pipe, | ||
prompt: str, | ||
) -> None: | ||
""" | ||
Run calibration steps on the pipeline using the given prompts. | ||
""" | ||
image = pipe( | ||
prompt, | ||
output_type="pil", | ||
num_inference_steps=20, | ||
generator=torch.Generator("cuda").manual_seed(0), | ||
).images[0] | ||
|
||
|
||
def forward_loop(mod): | ||
# Switch the pipeline's backbone, run calibration | ||
pipe.transformer = mod | ||
do_calibrate( | ||
pipe=pipe, | ||
prompt="test", | ||
) | ||
|
||
|
||
ptq_config = mtq.FP8_DEFAULT_CFG | ||
backbone = mtq.quantize(backbone, ptq_config, forward_loop) | ||
mtq.disable_quantizer(backbone, filter_func) | ||
|
||
|
||
# %% | ||
# Export the backbone using torch.export | ||
# -------------------------------------------------- | ||
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2`` | ||
# due to `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_ | ||
|
||
batch_size = 2 | ||
BATCH = torch.export.Dim("batch", min=1, max=2) | ||
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512) | ||
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. | ||
# To see this recommendation, you can try exporting using min=1, max=4096 | ||
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096) | ||
dynamic_shapes = { | ||
"hidden_states": {0: BATCH}, | ||
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN}, | ||
"pooled_projections": {0: BATCH}, | ||
"timestep": {0: BATCH}, | ||
"txt_ids": {0: SEQ_LEN}, | ||
"img_ids": {0: IMG_ID}, | ||
"guidance": {0: BATCH}, | ||
"joint_attention_kwargs": {}, | ||
"return_dict": None, | ||
} | ||
# The guidance factor is of type torch.float32 | ||
dummy_inputs = { | ||
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float32).to( | ||
DEVICE | ||
), | ||
"encoder_hidden_states": torch.randn( | ||
(batch_size, 512, 4096), dtype=torch.float32 | ||
).to(DEVICE), | ||
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float32).to( | ||
DEVICE | ||
), | ||
"timestep": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE), | ||
"txt_ids": torch.randn((512, 3), dtype=torch.float32).to(DEVICE), | ||
"img_ids": torch.randn((4096, 3), dtype=torch.float32).to(DEVICE), | ||
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE), | ||
"joint_attention_kwargs": {}, | ||
"return_dict": False, | ||
} | ||
|
||
# This will create an exported program which is going to be compiled with Torch-TensorRT | ||
with export_torch_mode(): | ||
ep = _export( | ||
backbone, | ||
args=(), | ||
kwargs=dummy_inputs, | ||
dynamic_shapes=dynamic_shapes, | ||
strict=False, | ||
allow_complex_guards_as_runtime_asserts=True, | ||
) | ||
|
||
with torch_tensorrt.logging.debug(): | ||
trt_gm = torch_tensorrt.dynamo.compile( | ||
ep, | ||
inputs=dummy_inputs, | ||
enabled_precisions={torch.float8_e4m3fn}, | ||
truncate_double=True, | ||
min_block_size=1, | ||
debug=False, | ||
use_python_runtime=True, | ||
immutable_weights=True, | ||
offload_module_to_cpu=True, | ||
) | ||
|
||
|
||
del ep | ||
pipe.transformer = trt_gm | ||
pipe.transformer.config = config | ||
|
||
|
||
# %% | ||
trt_gm.device = torch.device(DEVICE) | ||
# Function which generates images from the flux pipeline | ||
|
||
for _ in range(2): | ||
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code") | ||
|
||
# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.