Skip to content

Commit

Permalink
Dev (#90)
Browse files Browse the repository at this point in the history
* closes #89; build for newer CUDA archs

* closes #81; add community/optimize_sd15_with_controlnet_and_ip_adapter.py and fix doc in diffusion_pipeline_compiler
  • Loading branch information
chengzeyi authored Dec 25, 2023
1 parent dc38759 commit 3827f31
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/wheels_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ env:
# you need at least cuda 5.0 for some of the stuff compiled here.
# TORCH_CUDA_ARCH_LIST: "5.0+PTX 6.0 6.1 7.0 7.5 8.0+PTX"
# Feature 'f16 arithemetic and compare instructions' requires .target sm_53 or higher
TORCH_CUDA_ARCH_LIST: "6.0 6.1 7.0 7.5 8.0+PTX"
TORCH_CUDA_ARCH_LIST: "6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX"
MAX_JOBS: 2
DISTUTILS_USE_SDK: 1 # otherwise distutils will complain on windows about multiple versions of msvc
SFAST_APPEND_VERSION: 1
Expand Down
115 changes: 115 additions & 0 deletions community/optimize_sd15_with_controlnet_and_ip_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
from diffusers import AutoPipelineForText2Image, EulerDiscreteScheduler, ControlNetModel
from diffusers.utils import load_image
from sfast.compilers.diffusion_pipeline_compiler import (compile,
CompilationConfig)
import numpy as np
import cv2
from PIL import Image

CUDA_DEVICE = "cuda:0"


def canny_process(image, width, height):
np_image = cv2.resize(image, (width, height))
np_image = cv2.Canny(np_image, 100, 200)
np_image = np_image[:, :, None]
np_image = np.concatenate([np_image, np_image, np_image], axis=2)
# canny_image = Image.fromarray(np_image)
return Image.fromarray(np_image)


def reference_process(image, width, height):
np_image = cv2.resize(image, (width, height))
return Image.fromarray(np_image)


def load_model():
extra_kwargs = {}
# extra_kwargs['variant'] = variant

controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny",
torch_dtype=torch.float16,
variant="fp16",
name="diffusion_pytorch_model.fp16.safetensors",
use_safetensors=True)
extra_kwargs['controlnet'] = controlnet
model = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
**extra_kwargs)
model.scheduler = EulerDiscreteScheduler.from_config(
model.scheduler.config)
model.safety_checker = None
model.load_ip_adapter("h94/IP-Adapter",
subfolder="models",
weight_name="ip-adapter_sd15.safetensors")
model.to(torch.device(CUDA_DEVICE))

return model


def compile_model(model):
config = CompilationConfig.Default()
try:
import xformers
config.enable_xformers = True
except ImportError:
print('xformers not installed, skip')
try:
import triton
config.enable_triton = True
except ImportError:
print('Triton not installed, skip')
config.enable_cuda_graph = True

model = compile(model, config)

return model


if __name__ == "__main__":
control_img = 'https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/main/images/bird.png'
reference_img = 'https://huggingface.co/datasets/diffusers/dog-example/resolve/main/alvan-nee-eoqnr8ikwFE-unsplash.jpeg'

width = 768
height = 512

control_img = load_image(control_img)
reference_img = load_image(reference_img)
control_img = np.array(control_img)
reference_img = np.array(reference_img)
control_img = canny_process(control_img, width, height)
reference_img = reference_process(reference_img, width, height)

model = load_model()
model = compile_model(model)
seed = -1
batch_size = 4
generator = torch.Generator(device=CUDA_DEVICE).manual_seed(seed)
prompt = "dog"
negative_prompt = ""
num_inference_steps = 20
guidance_scale = 7.5
controlnet_conditioning_scale = 1.0

for _ in range(3):
images = model(
prompt=[prompt] * batch_size,
negative_prompt=[negative_prompt] * batch_size,
width=width,
height=height,
num_inference_steps=num_inference_steps,
num_images_per_prompt=1,
guidance_scale=guidance_scale,
ip_adapter_image=[reference_img] * batch_size,
image=[control_img] * batch_size,
generator=generator,
controlnet_conditioning_scale=controlnet_conditioning_scale,
).images

from sfast.utils.term_image import print_image

for image in images:
print_image(image, max_width=80)
2 changes: 2 additions & 0 deletions src/sfast/compilers/diffusion_pipeline_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ class Default:
Whether to enable CNN optimization by fusion.
enable_fused_linear_geglu:
Whether to enable fused Linear-GEGLU kernel.
It uses fp16 for accumulation, so could cause **quality degradation**.
prefer_lowp_gemm:
Whether to prefer low-precision GEMM and a series of fusion optimizations.
This will make the model faster, but may cause numerical issues.
These use fp16 for accumulation, so could cause **quality degradation**.
enable_xformers:
Whether to enable xformers and hijack it to make it compatible with JIT tracing.
enable_cuda_graph:
Expand Down

0 comments on commit 3827f31

Please sign in to comment.