Skip to content

Commit

Permalink
fix trt support (#410)
Browse files Browse the repository at this point in the history
  • Loading branch information
timudk authored Jan 31, 2025
1 parent 4343061 commit af9d2f0
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 104 deletions.
69 changes: 8 additions & 61 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,22 @@ This repo contains minimal inference code to run image generation & editing with
```bash
cd $HOME && git clone https://github.com/black-forest-labs/flux
cd $HOME/flux

# Using pyvenv
python3.10 -m venv .venv
source .venv/bin/activate
pip install -e ".[all]"
```

## Local installation with TRT support
### Local installation with TensorRT support

If you would like to install the repository with [TensorRT](https://github.com/NVIDIA/TensorRT) support, you currently need to install a PyTorch image from NVIDIA instead. First install [enroot](https://github.com/NVIDIA/enroot), next follow the steps below:

```bash
docker pull nvcr.io/nvidia/pytorch:24.10-py3
cd $HOME && git clone https://github.com/black-forest-labs/flux
cd $HOME/flux
docker run --rm -it --gpus all -v $PWD:/workspace/flux nvcr.io/nvidia/pytorch:24.10-py3 /bin/bash
# inside container
cd /workspace/flux
pip install -e ".[all]"
pip install -r trt_requirements.txt
enroot import 'docker://[email protected]#nvidia/pytorch:25.01-py3'
enroot create -n pti2501 nvidia+pytorch+25.01-py3.sqsh
enroot start --rw -m ${PWD}/flux:/workspace/flux -r pti2501
cd flux
pip install -e ".[tensorrt]" --extra-index-url https://pypi.nvidia.com
```

### Models
Expand Down Expand Up @@ -55,57 +53,6 @@ We are offering an extensive suite of models. For more information about the inv

The weights of the autoencoder are also released under [apache-2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md) and can be found in the HuggingFace repos above.

We also offer a Gradio-based demo for an interactive experience. To run the Gradio demo:

```bash
python demo_gr.py --name flux-schnell --device cuda
```

Options:

- `--name`: Choose the model to use (options: "flux-schnell", "flux-dev")
- `--device`: Specify the device to use (default: "cuda" if available, otherwise "cpu")
- `--offload`: Offload model to CPU when not in use
- `--share`: Create a public link to your demo

To run the demo with the dev model and create a public link:

```bash
python demo_gr.py --name flux-dev --share
```

## Diffusers integration

`FLUX.1 [schnell]` and `FLUX.1 [dev]` are integrated with the [🧨 diffusers](https://github.com/huggingface/diffusers) library. To use it with diffusers, install it:

```shell
pip install git+https://github.com/huggingface/diffusers.git
```

Then you can use `FluxPipeline` to run the model

```python
import torch
from diffusers import FluxPipeline

model_id = "black-forest-labs/FLUX.1-schnell" #you can also use `black-forest-labs/FLUX.1-dev`

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power

prompt = "A cat holding a sign that says hello world"
seed = 42
image = pipe(
prompt,
output_type="pil",
num_inference_steps=4, #use a larger number if you are using [dev]
generator=torch.Generator("cpu").manual_seed(seed)
).images[0]
image.save("flux-schnell.png")
```

To learn more check out the [diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) documentation

## API usage

Our API offers access to our models. It is documented here:
Expand Down
2 changes: 2 additions & 0 deletions demo_gr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

NSFW_THRESHOLD = 0.85


def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
t5 = load_t5(device, max_length=256 if is_schnell else 512)
clip = load_clip(device)
Expand All @@ -23,6 +24,7 @@ def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool)
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
return model, ae, t5, clip, nsfw_classifier


class FluxGenerator:
def __init__(self, model_name: str, device: str, offload: bool):
self.device = torch.device(device)
Expand Down
9 changes: 9 additions & 0 deletions docs/structural-conditioning.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ python -m src.flux.cli_control --loop --name <name>

where `name` is one of `flux-dev-canny`, `flux-dev-depth`, `flux-dev-canny-lora`, or `flux-dev-depth-lora`.

### TRT engine infernece

You may also download ONNX export of [FLUX.1 Depth \[dev\]](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-onnx) and [FLUX.1 Canny \[dev\]](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-onnx). We provide exports in BF16, FP8, and FP4 precision. Note that you need to install the repository with TensorRT support as outlined [here](../README.md).

```bash
TRT_ENGINE_DIR=<your_trt_engine_will_be_saved_here> ONNX_DIR=<path_of_downloaded_onnx_export> python src/flux/cli.py "<prompt>" --img_cond_path="assets/robot.webp" --trt --static_shape=False --name=<name> --trt_transformer_precision <precision>
```
where `<precision>` is either bf16, fp8, or fp4. For fp4, you need a NVIDIA GPU based on the [Blackwell Architecture](https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/).

## Diffusers usage

Flux Control (including the LoRAs) is also compatible with the `diffusers` Python library. Check out the [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.
11 changes: 11 additions & 0 deletions docs/text-to-image.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ python -m flux --name <name> \
--prompt "<prompt>"
```

### TRT engine infernece

You may also download ONNX export of [FLUX.1 \[dev\]](https://huggingface.co/black-forest-labs/FLUX.1-dev-onnx) and [FLUX.1 \[schnell\]](https://huggingface.co/black-forest-labs/FLUX.1-schnell-onnx). We provide exports in BF16, FP8, and FP4 precision. Note that you need to install the repository with TensorRT support as outlined [here](../README.md).

```bash
TRT_ENGINE_DIR=<your_trt_engine_will_be_saved_here> ONNX_DIR=<path_of_downloaded_onnx_export> python src/flux/cli.py "<prompt>" --trt --static_shape=False --name=<name> --trt_transformer_precision <precision>
```
where `<precision>` is either bf16, fp8, or fp4. For fp4, you need a NVIDIA GPU based on the [Blackwell Architecture](https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/).

### Streamlit and Gradio

We also provide a streamlit demo that does both text-to-image and image-to-image. The demo can be run via

```bash
Expand Down
21 changes: 18 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ requires-python = ">=3.10"
license = { file = "LICENSE.md" }
dynamic = ["version"]
dependencies = [
"torch == 2.5.1",
"torchvision",
"einops",
"fire >= 0.6.0",
"huggingface-hub",
Expand All @@ -25,6 +23,10 @@ dependencies = [
]

[project.optional-dependencies]
torch = [
"torch == 2.5.1",
"torchvision",
]
streamlit = [
"streamlit",
"streamlit-drawable-canvas",
Expand All @@ -33,9 +35,22 @@ streamlit = [
gradio = [
"gradio",
]
tensorrt = [
"tensorrt-cu12 == 10.8.0.43",
"colored",
"cuda-python",
"diffusers",
"nvidia-modelopt[torch,onnx] ~= 0.19.0",
"opencv-python ~= 4.8.0.74",
"onnx ~= 1.17.0",
"onnxruntime ~= 1.19.2",
"onnx-graphsurgeon",
"polygraphy ~= 0.49.9",
]
all = [
"flux[streamlit]",
"flux[gradio]",
"flux[streamlit]",
"flux[torch]",
]

[project.scripts]
Expand Down
27 changes: 21 additions & 6 deletions src/flux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ class SamplingOptions:


def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = (
"Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
)
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the prompt or write a command starting with a slash:\n"
Expand Down Expand Up @@ -113,6 +111,7 @@ def main(
output_dir: str = "output",
add_sampling_metadata: bool = True,
trt: bool = False,
trt_transformer_precision: str = "bf16",
**kwargs: dict | None,
):
"""
Expand All @@ -135,6 +134,19 @@ def main(
trt: use TensorRT backend for optimized inference
kwargs: additional arguments for TensorRT support
"""

prompt = prompt.split("|")
if len(prompt) == 1:
prompt = prompt[0]
additional_prompts = None
else:
additional_prompts = prompt[1:]
prompt = prompt[0]

assert not (
(additional_prompts is not None) and loop
), "Do not provide additional prompts and set loop to True"

nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

if name not in configs:
Expand Down Expand Up @@ -193,6 +205,7 @@ def main(
onnx_dir=os.environ.get("ONNX_DIR", "./onnx"),
opt_image_height=height,
opt_image_width=width,
transformer_precision=trt_transformer_precision,
)

torch.cuda.synchronize()
Expand Down Expand Up @@ -251,9 +264,7 @@ def main(
torch.cuda.empty_cache()
t5, clip = t5.to(torch_device), clip.to(torch_device)
inp = prepare(t5, clip, x, prompt=opts.prompt)
timesteps = get_schedule(
opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")
)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))

# offload TEs to CPU, load model to gpu
if offload:
Expand Down Expand Up @@ -287,12 +298,16 @@ def main(
if loop:
print("-" * 80)
opts = parse_prompt(opts)
elif additional_prompts:
next_prompt = additional_prompts.pop(0)
opts.prompt = next_prompt
else:
opts = None

if trt:
trt_ctx_manager.stop_runtime()


def app():
Fire(main)

Expand Down
2 changes: 2 additions & 0 deletions src/flux/cli_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def main(
img_cond_path: str = "assets/robot.webp",
lora_scale: float | None = 0.85,
trt: bool = False,
trt_transformer_precision: str = "bf16",
**kwargs: dict | None,
):
"""
Expand Down Expand Up @@ -272,6 +273,7 @@ def main(
onnx_dir=os.environ.get("ONNX_DIR", "./onnx"),
opt_image_height=height,
opt_image_width=width,
transformer_precision=trt_transformer_precision,
)
torch.cuda.synchronize()

Expand Down
2 changes: 1 addition & 1 deletion src/flux/trt/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flux.trt.engine.clip_engine import CLIPEngine
from flux.trt.engine.t5_engine import T5Engine
from flux.trt.engine.transformer_engine import TransformerEngine
from flux.trt.engine.vae_engine import VAEEngine, VAEDecoder, VAEEncoder
from flux.trt.engine.vae_engine import VAEDecoder, VAEEncoder, VAEEngine

__all__ = [
"BaseEngine",
Expand Down
3 changes: 1 addition & 2 deletions src/flux/trt/engine/vae_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
# limitations under the License.

import torch
from cuda import cudart

from flux.trt.engine.base_engine import BaseEngine, Engine
from flux.trt.mixin import VAEMixin
from cuda import cudart


class VAEDecoder(VAEMixin, Engine):
Expand Down Expand Up @@ -162,7 +162,6 @@ def load(self):
if self.encoder is not None:
self.encoder.load()


def activate(
self,
device: str,
Expand Down
4 changes: 3 additions & 1 deletion src/flux/trt/exporter/vae_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from math import ceil

import torch

from flux.modules.autoencoder import Decoder, Encoder
from flux.trt.exporter.base_exporter import BaseExporter
from flux.trt.mixin import VAEMixin
Expand Down
1 change: 1 addition & 0 deletions src/flux/trt/mixin/clip_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

from typing import Any

from flux.trt.mixin.base_mixin import BaseMixin


Expand Down
Loading

0 comments on commit af9d2f0

Please sign in to comment.