Skip to content

Commit

Permalink
Respond to PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon committed Nov 26, 2024
1 parent 319526e commit a3b4319
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# Copyright 2024 Black Forest Labs. Inc. and Flux Authors
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""MMDIT Layers adapted from black-forest-labs' flux implementation
https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
"""

import torch.nn.functional as F
import torch
from torch import Tensor
Expand All @@ -15,6 +25,7 @@ def qk_norm(q, k, v, rms_q, rms_k):
return rms_q(q).to(v), rms_k(k).to(v)


# TODO: Work on unifying with the current RoPE layer
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
Expand Down Expand Up @@ -89,7 +100,7 @@ def forward(
img_q, img_k, img_v, self.img_attn_norm_q, self.img_attn_norm_k
)

# prepare txt for attention
# prepare text for attention
txt_modulated = ops.layer_norm(txt, None, None, eps=1e-6)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn_qkv(txt_modulated)
Expand All @@ -110,7 +121,8 @@ def forward(
attn = attention(q, k, v, pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]

# calculate the img bloks
# calculate the image blocks
# TODO: Refactor this for code reuse with the txt blocks
img = img + img_mod1.gate * self.img_attn_proj(img_attn)
img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm(
img, None, None, eps=1e-6
Expand All @@ -120,12 +132,13 @@ def forward(
img_mlp_out3 = self.img_mlp2(img_mlp_out2)
img = img + img_mod2.gate * img_mlp_out3

# calculate the txt bloks
# calculate the text blocks
txt = txt + txt_mod1.gate * self.txt_attn_proj(txt_attn)
txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm(
txt, None, None, eps=1e-6
) + txt_mod2.shift
txt_mlp_out1 = self.txt_mlp1(txt_mlp_in)
# TODO: Unify with modulation layer by taking act_fn as an arg
txt_mlp_out2 = ops.elementwise(F.gelu, txt_mlp_out1)
txt_mlp_out3 = self.txt_mlp2(txt_mlp_out2)
txt = txt + txt_mod2.gate * txt_mlp_out3
Expand Down
10 changes: 10 additions & 0 deletions sharktank/sharktank/layers/modulation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# Copyright 2024 Black Forest Labs. Inc. and Flux Authors
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Modulation Layer adapted from black-forest-labs' flux implementation
https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
"""

import torch
import torch.nn.functional as F

Expand Down
4 changes: 2 additions & 2 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,11 @@ def layer_norm_default(input, weight, bias, *, eps):
if weight is not None:
weight = unbox_tensor(weight)
else:
weight = torch.ones(input.shape)
weight = torch.ones(input.shape, dtype=input.dtype)
if bias is not None:
bias = unbox_tensor(bias)
else:
bias = torch.zeros(input.shape)
bias = torch.zeros(input.shape, dtype=input.dtype)
return F.layer_norm(
input, normalized_shape=weight.shape, weight=weight, bias=bias, eps=eps
)
Expand Down
6 changes: 3 additions & 3 deletions sharktank/tests/layers/mmdit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
import sharktank.ops as ops
from sharktank.layers.testing import (
make_mmdit_block_theta,
make_mmdit_double_block_theta,
)
from sharktank.types.tensors import DefaultPrimitiveTensor

Expand All @@ -30,9 +30,9 @@ def setUp(self):
self.num_heads = 24
self.batch_size = 3

def testExport(self):
def testDoubleExport(self):

theta = make_mmdit_block_theta()
theta = make_mmdit_double_block_theta()
mmdit = MMDITDoubleBlock(
theta=theta,
num_heads=self.num_heads,
Expand Down

0 comments on commit a3b4319

Please sign in to comment.