From 680cd6ec4228cd11bf198090a79c4c92cbf68de9 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Wed, 4 Dec 2024 11:24:17 -0800 Subject: [PATCH] Implement MMDIT block that is necessary for flux (#592) --- sharktank/sharktank/layers/__init__.py | 1 + sharktank/sharktank/layers/mmdit.py | 146 +++++++++++++++++++++++ sharktank/sharktank/layers/modulation.py | 42 +++++++ sharktank/sharktank/layers/testing.py | 79 ++++++++++++ sharktank/sharktank/ops/default_impls.py | 16 ++- sharktank/sharktank/ops/signatures.py | 6 +- sharktank/tests/layers/mmdit_test.py | 58 +++++++++ 7 files changed, 343 insertions(+), 5 deletions(-) create mode 100644 sharktank/sharktank/layers/mmdit.py create mode 100644 sharktank/sharktank/layers/modulation.py create mode 100644 sharktank/tests/layers/mmdit_test.py diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index fd56ec872..620c15672 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -17,5 +17,6 @@ from .ffn_block import FFN from .ffn_moe_block import FFNMOE from .mixture_of_experts_block import MoeBlock +from .mmdit import MMDITDoubleBlock from .configs import * diff --git a/sharktank/sharktank/layers/mmdit.py b/sharktank/sharktank/layers/mmdit.py new file mode 100644 index 000000000..0b0750549 --- /dev/null +++ b/sharktank/sharktank/layers/mmdit.py @@ -0,0 +1,146 @@ +# Copyright 2024 Black Forest Labs. Inc. and Flux Authors +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""MMDIT Layers adapted from black-forest-labs' flux implementation +https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +""" + +import torch.nn.functional as F +import torch +from torch import Tensor + +from .. import ops + +from .base import Theta, ThetaLayer +from .linear import LinearLayer +from .modulation import ModulationLayer +from .norm import RMSNormLayer +from .paged_llama_attention_block import PagedLlamaAttentionBlock + + +def qk_norm(q, k, v, rms_q, rms_k): + return rms_q(q).to(v), rms_k(k).to(v) + + +# TODO: Work on unifying with the current RoPE layer +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +def attention(q, k, v, pe): + q, k = apply_rope(q, k, pe) # todo + + x = ops.scaled_dot_product_attention( + q=q, k=k, v=v, a=None, is_causal=True, scale=None + ) + x = ops.permute(x, (0, 2, 1, 3)) + x = x.view(x.shape[0], x.shape[1], -1) + + return x + + +class MMDITDoubleBlock(ThetaLayer): + def __init__(self, theta, num_heads: int): + super().__init__(theta) + + self.num_heads = num_heads + self.add_module("img_mod", ModulationLayer(theta("img_mod"), double=True)) + self.add_module("img_attn_qkv", LinearLayer(theta("img_attn.qkv"))) + self.add_module( + "img_attn_norm_q", + RMSNormLayer(theta("img_attn.norm.query_norm"), epsilon=1e-6), + ) + self.add_module( + "img_attn_norm_k", + RMSNormLayer(theta("img_attn.norm.key_norm"), epsilon=1e-6), + ) + self.add_module("img_attn_proj", LinearLayer(theta("img_attn.proj"))) + + self.add_module("img_mlp1", LinearLayer(theta("img_mlp.0"))) + self.add_module("img_mlp2", LinearLayer(theta("img_mlp.2"))) + + self.add_module("txt_mod", ModulationLayer(theta("txt_mod"), double=True)) + self.add_module("txt_attn_qkv", LinearLayer(theta("txt_attn.qkv"))) + self.add_module( + "txt_attn_norm_q", + RMSNormLayer(theta("txt_attn.norm.query_norm"), epsilon=1e-6), + ) + self.add_module( + "txt_attn_norm_k", + RMSNormLayer(theta("txt_attn.norm.key_norm"), epsilon=1e-6), + ) + self.add_module("txt_attn_proj", LinearLayer(theta("txt_attn.proj"))) + + self.add_module("txt_mlp1", LinearLayer(theta("txt_mlp.0"))) + self.add_module("txt_mlp2", LinearLayer(theta("txt_mlp.2"))) + + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor + ) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = ops.layer_norm(img, None, None, eps=1e-6) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn_qkv(img_modulated) + img_qkv_2 = img_qkv.view( + img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1 + ) # + img_qkv_3 = ops.permute(img_qkv_2, (2, 0, 3, 1, 4)) + img_q, img_k, img_v = img_qkv_3 + img_q, img_k = qk_norm( + img_q, img_k, img_v, self.img_attn_norm_q, self.img_attn_norm_k + ) + + # prepare text for attention + txt_modulated = ops.layer_norm(txt, None, None, eps=1e-6) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn_qkv(txt_modulated) + txt_qkv_2 = txt_qkv.view( + txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1 + ) # + txt_qkv_3 = ops.permute(txt_qkv_2, (2, 0, 3, 1, 4)) + txt_q, txt_k, txt_v = txt_qkv_3 + txt_q, txt_k = qk_norm( + txt_q, txt_k, txt_v, self.txt_attn_norm_q, self.txt_attn_norm_k + ) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the image blocks + # TODO: Refactor this for code reuse with the txt blocks + img = img + img_mod1.gate * self.img_attn_proj(img_attn) + img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm( + img, None, None, eps=1e-6 + ) + img_mod2.shift + img_mlp_out1 = self.img_mlp1(img_mlp_in) + img_mlp_out2 = ops.elementwise(F.gelu, img_mlp_out1) + img_mlp_out3 = self.img_mlp2(img_mlp_out2) + img = img + img_mod2.gate * img_mlp_out3 + + # calculate the text blocks + txt = txt + txt_mod1.gate * self.txt_attn_proj(txt_attn) + txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm( + txt, None, None, eps=1e-6 + ) + txt_mod2.shift + txt_mlp_out1 = self.txt_mlp1(txt_mlp_in) + # TODO: Unify with modulation layer by taking act_fn as an arg + txt_mlp_out2 = ops.elementwise(F.gelu, txt_mlp_out1) + txt_mlp_out3 = self.txt_mlp2(txt_mlp_out2) + txt = txt + txt_mod2.gate * txt_mlp_out3 + + return img, txt diff --git a/sharktank/sharktank/layers/modulation.py b/sharktank/sharktank/layers/modulation.py new file mode 100644 index 000000000..7ef7adfa1 --- /dev/null +++ b/sharktank/sharktank/layers/modulation.py @@ -0,0 +1,42 @@ +# Copyright 2024 Black Forest Labs. Inc. and Flux Authors +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Modulation Layer adapted from black-forest-labs' flux implementation +https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +""" + +import torch +import torch.nn.functional as F + +from .. import ops + +from .base import Theta, ThetaLayer +from .linear import LinearLayer + + +class ModulationOut: + def __init__(self, shift, scale, gate): + self.shift = shift + self.scale = scale + self.gate = gate + + +class ModulationLayer(ThetaLayer): + def __init__(self, theta: Theta, double: bool): + super().__init__(theta) + + self.is_double = double + self.multiplier = 6 if double else 3 + self.add_module("lin", LinearLayer(theta("lin"))) + + def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut | None]: + silu_result = ops.elementwise(F.silu, vec) + out = self.lin(silu_result)[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) diff --git a/sharktank/sharktank/layers/testing.py b/sharktank/sharktank/layers/testing.py index e2fc79d78..a21d5bf85 100644 --- a/sharktank/sharktank/layers/testing.py +++ b/sharktank/sharktank/layers/testing.py @@ -49,3 +49,82 @@ def make_llama_attention_block_theta( ), } ) + + +def make_mmdit_double_block_theta(dtype: torch.dtype | None = None) -> Theta: + return Theta( + { + "img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((128,), dtype=dtype) + ), + "img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((128,), dtype=dtype) + ), + "img_attn.proj.bias": DefaultPrimitiveTensor( + data=make_rand_torch((3072,), dtype=dtype) + ), + "img_attn.proj.weight": DefaultPrimitiveTensor( + data=make_rand_torch((3072, 3072), dtype=dtype) + ), + "img_attn.qkv.bias": DefaultPrimitiveTensor( + data=make_rand_torch((9216,), dtype=dtype) + ), + "img_attn.qkv.weight": DefaultPrimitiveTensor( + data=make_rand_torch((9216, 3072), dtype=dtype) + ), + "img_mlp.0.bias": DefaultPrimitiveTensor( + data=make_rand_torch((12288), dtype=dtype) + ), + "img_mlp.0.weight": DefaultPrimitiveTensor( + data=make_rand_torch((12288, 3072), dtype=dtype) + ), + "img_mlp.2.bias": DefaultPrimitiveTensor( + data=make_rand_torch((3072), dtype=dtype) + ), + "img_mlp.2.weight": DefaultPrimitiveTensor( + data=make_rand_torch((3072, 12288), dtype=dtype) + ), + "img_mod.lin.bias": DefaultPrimitiveTensor( + data=make_rand_torch((18432,), dtype=dtype) + ), + "img_mod.lin.weight": DefaultPrimitiveTensor( + data=make_rand_torch((18432, 3072), dtype=dtype) + ), + "txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((128,), dtype=dtype) + ), + "txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((128,), dtype=dtype) + ), + "txt_attn.proj.bias": DefaultPrimitiveTensor( + data=make_rand_torch((3072,), dtype=dtype) + ), + "txt_attn.proj.weight": DefaultPrimitiveTensor( + data=make_rand_torch((3072, 3072), dtype=dtype) + ), + "txt_attn.qkv.bias": DefaultPrimitiveTensor( + data=make_rand_torch((9216,), dtype=dtype) + ), + "txt_attn.qkv.weight": DefaultPrimitiveTensor( + data=make_rand_torch((9216, 3072), dtype=dtype) + ), + "txt_mlp.0.bias": DefaultPrimitiveTensor( + data=make_rand_torch((12288), dtype=dtype) + ), + "txt_mlp.0.weight": DefaultPrimitiveTensor( + data=make_rand_torch((12288, 3072), dtype=dtype) + ), + "txt_mlp.2.bias": DefaultPrimitiveTensor( + data=make_rand_torch((3072), dtype=dtype) + ), + "txt_mlp.2.weight": DefaultPrimitiveTensor( + data=make_rand_torch((3072, 12288), dtype=dtype) + ), + "txt_mod.lin.bias": DefaultPrimitiveTensor( + data=make_rand_torch((18432,), dtype=dtype) + ), + "txt_mod.lin.weight": DefaultPrimitiveTensor( + data=make_rand_torch((18432, 3072), dtype=dtype) + ), + } + ) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index d117ada23..47e737fb1 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -304,16 +304,26 @@ def interpolate_default( ) -@layer_norm.override(Tensor, Tensor, Tensor) def layer_norm_default(input, weight, bias, *, eps): input = unbox_tensor(input) - weight = unbox_tensor(weight) - bias = unbox_tensor(bias) + if weight is not None: + weight = unbox_tensor(weight) + else: + weight = torch.ones(input.shape, dtype=input.dtype) + if bias is not None: + bias = unbox_tensor(bias) + else: + bias = torch.zeros(input.shape, dtype=input.dtype) return F.layer_norm( input, normalized_shape=weight.shape, weight=weight, bias=bias, eps=eps ) +layer_norm.override(Tensor)(layer_norm_default) +layer_norm.override(Tensor, Tensor)(layer_norm_default) +layer_norm.override(Tensor, Tensor, Tensor)(layer_norm_default) + + # Linear def linear_default(input, weight, bias, *, accum_dtype) -> Tensor: input = unbox_tensor(input) diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index 408f00ec7..dc7fb108a 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -582,12 +582,14 @@ def layer_norm( def _layer_norm_trampoline( d: SignatureDispatcher, input: AnyTensor, - weight: AnyTensor, + weight: Optional[AnyTensor], bias: Optional[AnyTensor], *, eps: float, ): - tensors = [input, weight] + tensors = [input] + if weight is not None: + tensors.append(bias) if bias is not None: tensors.append(bias) for override in d.find_overrides(tensors): diff --git a/sharktank/tests/layers/mmdit_test.py b/sharktank/tests/layers/mmdit_test.py new file mode 100644 index 000000000..5bd5ce39a --- /dev/null +++ b/sharktank/tests/layers/mmdit_test.py @@ -0,0 +1,58 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging + +logging.basicConfig(level=logging.DEBUG) + +import unittest + +import torch + +from iree.turbine import aot +from sharktank.layers import ( + MMDITDoubleBlock, +) +import sharktank.ops as ops +from sharktank.layers.testing import ( + make_mmdit_double_block_theta, +) +from sharktank.types.tensors import DefaultPrimitiveTensor + + +class MMDITTest(unittest.TestCase): + def setUp(self): + torch.manual_seed(12345) + self.hidden_size = 3072 + self.num_heads = 24 + self.batch_size = 3 + + def testDoubleExport(self): + + theta = make_mmdit_double_block_theta() + mmdit = MMDITDoubleBlock( + theta=theta, + num_heads=self.num_heads, + ) + + img = torch.rand([self.batch_size, 1024, self.hidden_size]) + txt = torch.rand([self.batch_size, 512, self.hidden_size]) + vec = torch.rand([self.batch_size, self.hidden_size]) + rot = torch.rand([self.batch_size, 1, 1536, 64, 2, 2]) + mmdit.forward(img, txt, vec, rot) + fxb = aot.FxProgramsBuilder(mmdit) + + @fxb.export_program(name="mmdit", args=(img, txt, vec, rot), strict=False) + def _(model, img, txt, vec, rot) -> torch.Tensor: + return model.forward(img, txt, vec, rot) + + output = aot.export(fxb) + output.verify() + asm = str(output.mlir_module) + + +if __name__ == "__main__": + unittest.main()