Skip to content

Commit

Permalink
Implement MMDIT block that is necessary for flux (#592)
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon authored and eagarvey-amd committed Dec 13, 2024
1 parent 539be41 commit 680cd6e
Show file tree
Hide file tree
Showing 7 changed files with 343 additions and 5 deletions.
1 change: 1 addition & 0 deletions sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
from .ffn_block import FFN
from .ffn_moe_block import FFNMOE
from .mixture_of_experts_block import MoeBlock
from .mmdit import MMDITDoubleBlock

from .configs import *
146 changes: 146 additions & 0 deletions sharktank/sharktank/layers/mmdit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2024 Black Forest Labs. Inc. and Flux Authors
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""MMDIT Layers adapted from black-forest-labs' flux implementation
https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
"""

import torch.nn.functional as F
import torch
from torch import Tensor

from .. import ops

from .base import Theta, ThetaLayer
from .linear import LinearLayer
from .modulation import ModulationLayer
from .norm import RMSNormLayer
from .paged_llama_attention_block import PagedLlamaAttentionBlock


def qk_norm(q, k, v, rms_q, rms_k):
return rms_q(q).to(v), rms_k(k).to(v)


# TODO: Work on unifying with the current RoPE layer
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)


def attention(q, k, v, pe):
q, k = apply_rope(q, k, pe) # todo

x = ops.scaled_dot_product_attention(
q=q, k=k, v=v, a=None, is_causal=True, scale=None
)
x = ops.permute(x, (0, 2, 1, 3))
x = x.view(x.shape[0], x.shape[1], -1)

return x


class MMDITDoubleBlock(ThetaLayer):
def __init__(self, theta, num_heads: int):
super().__init__(theta)

self.num_heads = num_heads
self.add_module("img_mod", ModulationLayer(theta("img_mod"), double=True))
self.add_module("img_attn_qkv", LinearLayer(theta("img_attn.qkv")))
self.add_module(
"img_attn_norm_q",
RMSNormLayer(theta("img_attn.norm.query_norm"), epsilon=1e-6),
)
self.add_module(
"img_attn_norm_k",
RMSNormLayer(theta("img_attn.norm.key_norm"), epsilon=1e-6),
)
self.add_module("img_attn_proj", LinearLayer(theta("img_attn.proj")))

self.add_module("img_mlp1", LinearLayer(theta("img_mlp.0")))
self.add_module("img_mlp2", LinearLayer(theta("img_mlp.2")))

self.add_module("txt_mod", ModulationLayer(theta("txt_mod"), double=True))
self.add_module("txt_attn_qkv", LinearLayer(theta("txt_attn.qkv")))
self.add_module(
"txt_attn_norm_q",
RMSNormLayer(theta("txt_attn.norm.query_norm"), epsilon=1e-6),
)
self.add_module(
"txt_attn_norm_k",
RMSNormLayer(theta("txt_attn.norm.key_norm"), epsilon=1e-6),
)
self.add_module("txt_attn_proj", LinearLayer(theta("txt_attn.proj")))

self.add_module("txt_mlp1", LinearLayer(theta("txt_mlp.0")))
self.add_module("txt_mlp2", LinearLayer(theta("txt_mlp.2")))

def forward(
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)

# prepare image for attention
img_modulated = ops.layer_norm(img, None, None, eps=1e-6)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn_qkv(img_modulated)
img_qkv_2 = img_qkv.view(
img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1
) #
img_qkv_3 = ops.permute(img_qkv_2, (2, 0, 3, 1, 4))
img_q, img_k, img_v = img_qkv_3
img_q, img_k = qk_norm(
img_q, img_k, img_v, self.img_attn_norm_q, self.img_attn_norm_k
)

# prepare text for attention
txt_modulated = ops.layer_norm(txt, None, None, eps=1e-6)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_qkv_2 = txt_qkv.view(
txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1
) #
txt_qkv_3 = ops.permute(txt_qkv_2, (2, 0, 3, 1, 4))
txt_q, txt_k, txt_v = txt_qkv_3
txt_q, txt_k = qk_norm(
txt_q, txt_k, txt_v, self.txt_attn_norm_q, self.txt_attn_norm_k
)

# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)

attn = attention(q, k, v, pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]

# calculate the image blocks
# TODO: Refactor this for code reuse with the txt blocks
img = img + img_mod1.gate * self.img_attn_proj(img_attn)
img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm(
img, None, None, eps=1e-6
) + img_mod2.shift
img_mlp_out1 = self.img_mlp1(img_mlp_in)
img_mlp_out2 = ops.elementwise(F.gelu, img_mlp_out1)
img_mlp_out3 = self.img_mlp2(img_mlp_out2)
img = img + img_mod2.gate * img_mlp_out3

# calculate the text blocks
txt = txt + txt_mod1.gate * self.txt_attn_proj(txt_attn)
txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm(
txt, None, None, eps=1e-6
) + txt_mod2.shift
txt_mlp_out1 = self.txt_mlp1(txt_mlp_in)
# TODO: Unify with modulation layer by taking act_fn as an arg
txt_mlp_out2 = ops.elementwise(F.gelu, txt_mlp_out1)
txt_mlp_out3 = self.txt_mlp2(txt_mlp_out2)
txt = txt + txt_mod2.gate * txt_mlp_out3

return img, txt
42 changes: 42 additions & 0 deletions sharktank/sharktank/layers/modulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2024 Black Forest Labs. Inc. and Flux Authors
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Modulation Layer adapted from black-forest-labs' flux implementation
https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
"""

import torch
import torch.nn.functional as F

from .. import ops

from .base import Theta, ThetaLayer
from .linear import LinearLayer


class ModulationOut:
def __init__(self, shift, scale, gate):
self.shift = shift
self.scale = scale
self.gate = gate


class ModulationLayer(ThetaLayer):
def __init__(self, theta: Theta, double: bool):
super().__init__(theta)

self.is_double = double
self.multiplier = 6 if double else 3
self.add_module("lin", LinearLayer(theta("lin")))

def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut | None]:
silu_result = ops.elementwise(F.silu, vec)
out = self.lin(silu_result)[:, None, :].chunk(self.multiplier, dim=-1)

return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
79 changes: 79 additions & 0 deletions sharktank/sharktank/layers/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,82 @@ def make_llama_attention_block_theta(
),
}
)


def make_mmdit_double_block_theta(dtype: torch.dtype | None = None) -> Theta:
return Theta(
{
"img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
),
"img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
),
"img_attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch((3072,), dtype=dtype)
),
"img_attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch((3072, 3072), dtype=dtype)
),
"img_attn.qkv.bias": DefaultPrimitiveTensor(
data=make_rand_torch((9216,), dtype=dtype)
),
"img_attn.qkv.weight": DefaultPrimitiveTensor(
data=make_rand_torch((9216, 3072), dtype=dtype)
),
"img_mlp.0.bias": DefaultPrimitiveTensor(
data=make_rand_torch((12288), dtype=dtype)
),
"img_mlp.0.weight": DefaultPrimitiveTensor(
data=make_rand_torch((12288, 3072), dtype=dtype)
),
"img_mlp.2.bias": DefaultPrimitiveTensor(
data=make_rand_torch((3072), dtype=dtype)
),
"img_mlp.2.weight": DefaultPrimitiveTensor(
data=make_rand_torch((3072, 12288), dtype=dtype)
),
"img_mod.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch((18432,), dtype=dtype)
),
"img_mod.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((18432, 3072), dtype=dtype)
),
"txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
),
"txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
),
"txt_attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch((3072,), dtype=dtype)
),
"txt_attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch((3072, 3072), dtype=dtype)
),
"txt_attn.qkv.bias": DefaultPrimitiveTensor(
data=make_rand_torch((9216,), dtype=dtype)
),
"txt_attn.qkv.weight": DefaultPrimitiveTensor(
data=make_rand_torch((9216, 3072), dtype=dtype)
),
"txt_mlp.0.bias": DefaultPrimitiveTensor(
data=make_rand_torch((12288), dtype=dtype)
),
"txt_mlp.0.weight": DefaultPrimitiveTensor(
data=make_rand_torch((12288, 3072), dtype=dtype)
),
"txt_mlp.2.bias": DefaultPrimitiveTensor(
data=make_rand_torch((3072), dtype=dtype)
),
"txt_mlp.2.weight": DefaultPrimitiveTensor(
data=make_rand_torch((3072, 12288), dtype=dtype)
),
"txt_mod.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch((18432,), dtype=dtype)
),
"txt_mod.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((18432, 3072), dtype=dtype)
),
}
)
16 changes: 13 additions & 3 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,16 +304,26 @@ def interpolate_default(
)


@layer_norm.override(Tensor, Tensor, Tensor)
def layer_norm_default(input, weight, bias, *, eps):
input = unbox_tensor(input)
weight = unbox_tensor(weight)
bias = unbox_tensor(bias)
if weight is not None:
weight = unbox_tensor(weight)
else:
weight = torch.ones(input.shape, dtype=input.dtype)
if bias is not None:
bias = unbox_tensor(bias)
else:
bias = torch.zeros(input.shape, dtype=input.dtype)
return F.layer_norm(
input, normalized_shape=weight.shape, weight=weight, bias=bias, eps=eps
)


layer_norm.override(Tensor)(layer_norm_default)
layer_norm.override(Tensor, Tensor)(layer_norm_default)
layer_norm.override(Tensor, Tensor, Tensor)(layer_norm_default)


# Linear
def linear_default(input, weight, bias, *, accum_dtype) -> Tensor:
input = unbox_tensor(input)
Expand Down
6 changes: 4 additions & 2 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,12 +582,14 @@ def layer_norm(
def _layer_norm_trampoline(
d: SignatureDispatcher,
input: AnyTensor,
weight: AnyTensor,
weight: Optional[AnyTensor],
bias: Optional[AnyTensor],
*,
eps: float,
):
tensors = [input, weight]
tensors = [input]
if weight is not None:
tensors.append(bias)
if bias is not None:
tensors.append(bias)
for override in d.find_overrides(tensors):
Expand Down
58 changes: 58 additions & 0 deletions sharktank/tests/layers/mmdit_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging

logging.basicConfig(level=logging.DEBUG)

import unittest

import torch

from iree.turbine import aot
from sharktank.layers import (
MMDITDoubleBlock,
)
import sharktank.ops as ops
from sharktank.layers.testing import (
make_mmdit_double_block_theta,
)
from sharktank.types.tensors import DefaultPrimitiveTensor


class MMDITTest(unittest.TestCase):
def setUp(self):
torch.manual_seed(12345)
self.hidden_size = 3072
self.num_heads = 24
self.batch_size = 3

def testDoubleExport(self):

theta = make_mmdit_double_block_theta()
mmdit = MMDITDoubleBlock(
theta=theta,
num_heads=self.num_heads,
)

img = torch.rand([self.batch_size, 1024, self.hidden_size])
txt = torch.rand([self.batch_size, 512, self.hidden_size])
vec = torch.rand([self.batch_size, self.hidden_size])
rot = torch.rand([self.batch_size, 1, 1536, 64, 2, 2])
mmdit.forward(img, txt, vec, rot)
fxb = aot.FxProgramsBuilder(mmdit)

@fxb.export_program(name="mmdit", args=(img, txt, vec, rot), strict=False)
def _(model, img, txt, vec, rot) -> torch.Tensor:
return model.forward(img, txt, vec, rot)

output = aot.export(fxb)
output.verify()
asm = str(output.mlir_module)


if __name__ == "__main__":
unittest.main()

0 comments on commit 680cd6e

Please sign in to comment.