From a3b431979a5587fe90897648639ec750490094c8 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Tue, 26 Nov 2024 14:47:35 -0800 Subject: [PATCH] Respond to PR feedback --- .../layers/{mmdit_double.py => mmdit.py} | 19 ++++++++++++++++--- sharktank/sharktank/layers/modulation.py | 10 ++++++++++ sharktank/sharktank/ops/default_impls.py | 4 ++-- sharktank/tests/layers/mmdit_test.py | 6 +++--- 4 files changed, 31 insertions(+), 8 deletions(-) rename sharktank/sharktank/layers/{mmdit_double.py => mmdit.py} (87%) diff --git a/sharktank/sharktank/layers/mmdit_double.py b/sharktank/sharktank/layers/mmdit.py similarity index 87% rename from sharktank/sharktank/layers/mmdit_double.py rename to sharktank/sharktank/layers/mmdit.py index e279df6e9..0b0750549 100644 --- a/sharktank/sharktank/layers/mmdit_double.py +++ b/sharktank/sharktank/layers/mmdit.py @@ -1,3 +1,13 @@ +# Copyright 2024 Black Forest Labs. Inc. and Flux Authors +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""MMDIT Layers adapted from black-forest-labs' flux implementation +https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +""" + import torch.nn.functional as F import torch from torch import Tensor @@ -15,6 +25,7 @@ def qk_norm(q, k, v, rms_q, rms_k): return rms_q(q).to(v), rms_k(k).to(v) +# TODO: Work on unifying with the current RoPE layer def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) @@ -89,7 +100,7 @@ def forward( img_q, img_k, img_v, self.img_attn_norm_q, self.img_attn_norm_k ) - # prepare txt for attention + # prepare text for attention txt_modulated = ops.layer_norm(txt, None, None, eps=1e-6) txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_qkv = self.txt_attn_qkv(txt_modulated) @@ -110,7 +121,8 @@ def forward( attn = attention(q, k, v, pe) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - # calculate the img bloks + # calculate the image blocks + # TODO: Refactor this for code reuse with the txt blocks img = img + img_mod1.gate * self.img_attn_proj(img_attn) img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm( img, None, None, eps=1e-6 @@ -120,12 +132,13 @@ def forward( img_mlp_out3 = self.img_mlp2(img_mlp_out2) img = img + img_mod2.gate * img_mlp_out3 - # calculate the txt bloks + # calculate the text blocks txt = txt + txt_mod1.gate * self.txt_attn_proj(txt_attn) txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm( txt, None, None, eps=1e-6 ) + txt_mod2.shift txt_mlp_out1 = self.txt_mlp1(txt_mlp_in) + # TODO: Unify with modulation layer by taking act_fn as an arg txt_mlp_out2 = ops.elementwise(F.gelu, txt_mlp_out1) txt_mlp_out3 = self.txt_mlp2(txt_mlp_out2) txt = txt + txt_mod2.gate * txt_mlp_out3 diff --git a/sharktank/sharktank/layers/modulation.py b/sharktank/sharktank/layers/modulation.py index 7c50f279c..7ef7adfa1 100644 --- a/sharktank/sharktank/layers/modulation.py +++ b/sharktank/sharktank/layers/modulation.py @@ -1,3 +1,13 @@ +# Copyright 2024 Black Forest Labs. Inc. and Flux Authors +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Modulation Layer adapted from black-forest-labs' flux implementation +https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +""" + import torch import torch.nn.functional as F diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index a98582ff2..47e737fb1 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -309,11 +309,11 @@ def layer_norm_default(input, weight, bias, *, eps): if weight is not None: weight = unbox_tensor(weight) else: - weight = torch.ones(input.shape) + weight = torch.ones(input.shape, dtype=input.dtype) if bias is not None: bias = unbox_tensor(bias) else: - bias = torch.zeros(input.shape) + bias = torch.zeros(input.shape, dtype=input.dtype) return F.layer_norm( input, normalized_shape=weight.shape, weight=weight, bias=bias, eps=eps ) diff --git a/sharktank/tests/layers/mmdit_test.py b/sharktank/tests/layers/mmdit_test.py index 9eddb0e2f..5bd5ce39a 100644 --- a/sharktank/tests/layers/mmdit_test.py +++ b/sharktank/tests/layers/mmdit_test.py @@ -18,7 +18,7 @@ ) import sharktank.ops as ops from sharktank.layers.testing import ( - make_mmdit_block_theta, + make_mmdit_double_block_theta, ) from sharktank.types.tensors import DefaultPrimitiveTensor @@ -30,9 +30,9 @@ def setUp(self): self.num_heads = 24 self.batch_size = 3 - def testExport(self): + def testDoubleExport(self): - theta = make_mmdit_block_theta() + theta = make_mmdit_double_block_theta() mmdit = MMDITDoubleBlock( theta=theta, num_heads=self.num_heads,