Skip to content

Commit

Permalink
Update mmdit runner inputs, small attn reproducer, pad attention flag
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 14, 2024
1 parent 3495f63 commit 6c9d96d
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,13 @@ def is_valid_file(arg):
default="SD3_output.png",
help="Path to output file for generated images.",
)
p.add_argument(
"--attn_repro",
default=False,
action="store_true",
help="Just compile attention reproducer for mmdit.",
)


##############################################################################
# IREE Compiler Options
Expand Down
96 changes: 96 additions & 0 deletions models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,87 @@ def forward(
return_dict=False,
)[0]
return noise_pred

class MMDiTAttention(torch.nn.Module):
def __init__(
self,
):
super().__init__()

def forward(self, q, k, v):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=0.0, is_causal=False
)


@torch.no_grad()
def export_attn(
precision="fp16",
device="cpu",
target_triple="x86_64-unknown-linux-gnu",
ireec_flags="",
compile_to="torch",
decomp_attn=False,
attn_spec=None,
):
dtype = torch.float16 if precision == "fp16" else torch.float32
qkv_shape = (2, 24, 4250, 64)
attn_module = MMDiTAttention()
safe_name = "attn_repro_" + precision + "_" + target_triple
if decomp_attn == True:
safe_name += "_decomp"

if dtype == torch.float16:
attn_module = attn_module.half()

example_qkv = [
torch.empty(qkv_shape, dtype=dtype),
torch.empty(qkv_shape, dtype=dtype),
torch.empty(qkv_shape, dtype=dtype),
]

decomp_list = []
if decomp_attn == True:
decomp_list = [
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten.scaled_dot_product_attention,
]
with decompositions.extend_aot_decompositions(
from_current=True,
add_ops=decomp_list,
):
fxb = FxProgramsBuilder(attn_module)

@fxb.export_program(
args=(example_qkv,),
)
def _forward(
module,
inputs,
):
return module.forward(*inputs)

class CompiledAttn(CompiledModule):
run_forward = _forward

inst = CompiledAttn(context=Context(), import_to="IMPORT")

module_str = str(CompiledModule.get_mlir_module(inst))

if compile_to != "vmfb":
return module_str
else:
vmfb_path = utils.compile_to_vmfb(
module_str,
device,
target_triple,
ireec_flags,
safe_name,
return_path=True,
attn_spec=attn_spec,
)
return vmfb_path

@torch.no_grad()
def export_mmdit_model(
Expand Down Expand Up @@ -183,6 +263,22 @@ class CompiledMmdit(CompiledModule):
logging.basicConfig(level=logging.DEBUG)
from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args

if args.attn_repro:
mod_str = export_attn(
args.precision,
args.device,
args.iree_target_triple,
args.ireec_flags,
args.compile_to,
args.decomp_attn,
attn_spec=args.attn_spec,
)
if args.compile_to != "vmfb":
safe_name = "attn_repro_" + args.precision
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
exit()
if args.input_mlir:
mmdit_model = None
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,15 @@ def run_diffusers_mmdit(
dtype = torch.float16
else:
dtype = torch.float32


batch_size = args.batch_size * 2 #do classifier free guidance
hidden_states = torch.randn(
(args.batch_size, 16, args.height // 8, args.width // 8), dtype=dtype
(batch_size, 16, args.height // 8, args.width // 8), dtype=dtype
)
encoder_hidden_states = torch.randn(
(args.batch_size, args.max_length * 2, 4096), dtype=dtype
(batch_size, args.max_length * 2, 4096), dtype=dtype
)
pooled_projections = torch.randn((args.batch_size, 2048), dtype=dtype)
pooled_projections = torch.randn((batch_size, 2048), dtype=dtype)
timestep = torch.tensor([0], dtype=dtype)

turbine_output = run_mmdit_turbine(
Expand Down
2 changes: 1 addition & 1 deletion models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
"--iree-codegen-gpu-native-math-precision=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))",
],
"unet": [""],
"clip": [""],
Expand Down

0 comments on commit 6c9d96d

Please sign in to comment.