Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DiT with decorator, triton fused_AdaLN and fineGrained #552

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import paddle
from paddlenlp.trainer import set_seed

from ppdiffusers import DDIMScheduler, DiTPipeline

dtype = paddle.float32
dtype = paddle.float16

# To speed up this code, call zkk and let him run for you,
# then you will get a speed increase of almost 100%.
os.environ['callZKK']= "True"

with paddle.LazyGuard():
pipe = DiTPipeline.from_pretrained("Alpha-VLLM/Large-DiT-7B-256", paddle_dtype=dtype)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
Expand Down
207 changes: 172 additions & 35 deletions ppdiffusers/ppdiffusers/models/dit_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from .modeling_utils import ModelMixin
from .transformer_2d import Transformer2DModelOutput

from paddle.framework import LayerHelper, in_dynamic_mode
from paddle.incubate.tt import adaptive_layer_norm, fused_adaLN_scale_residual
import os

def TypePromote(x, y):
TYPE_PROMOTE_DICT = {
Expand Down Expand Up @@ -90,7 +93,7 @@ def forward(self, t):


class Attention(nn.Layer):
def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True):
def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True, callZKK=False):
"""
Initialize the Attention module.

Expand All @@ -108,6 +111,7 @@ def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True):
wq (nn.Linear): Linear transformation for queries.
wk (nn.Linear): Linear transformation for keys.
wv (nn.Linear): Linear transformation for values.
qkv (nn.Linear): Linear transformation for queries, keys and values.
wo (nn.Linear): Linear transformation for output.
cache_k (paddle.Tensor): Cached keys for attention.
cache_v (paddle.Tensor): Cached values for attention.
Expand All @@ -120,9 +124,14 @@ def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True):
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = dim // n_heads

self.wq = nn.Linear(dim, n_heads * self.head_dim, bias_attr=False)
self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False)
self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False)
self.callZKK = callZKK
if not callZKK:
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias_attr=False)
self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False)
self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False)
else:
self.qkv = nn.Linear(dim, (n_heads + 2 * self.n_kv_heads) * self.head_dim, bias_attr=False)

self.wo = nn.Linear(n_heads * self.head_dim, dim, bias_attr=False)

if qk_norm:
Expand Down Expand Up @@ -184,7 +193,15 @@ def apply_rotary_emb(xq, xk, freqs_cis):
Tuple[paddle.Tensor, paddle.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
with paddle.amp.auto_cast(enable=False):
if not os.getenv('callZKK'):
with paddle.amp.auto_cast(enable=False):
xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2]))
xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2]))
freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_)
xq_out = paddle.as_real(xq_ * freqs_cis).flatten(start_axis=3)
xk_out = paddle.as_real(xk_ * freqs_cis).flatten(start_axis=3)
return xq_out.cast(xq.dtype), xk_out.cast(xk.dtype)
else:
xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2]))
xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2]))
freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_)
Expand All @@ -205,7 +222,13 @@ def forward(self, x, freqs_cis):

"""
bsz, seqlen, _ = tuple(x.shape)
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

if not self.callZKK:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
else:
qkv_out = self.qkv(x)
xq, xk, xv = paddle.split(qkv_out, 3, axis=-1)

dtype = xq.dtype

xq = self.q_norm(xq)
Expand Down Expand Up @@ -253,7 +276,7 @@ def forward(self, x, freqs_cis):


class FeedForward(nn.Layer):
def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None):
def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None, callZKK=False):
"""
Initialize the FeedForward module.

Expand All @@ -266,28 +289,88 @@ def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None):
dimension. Defaults to None.

Attributes:
w1 (nn.Linear): Linear transformation for the first
layer.
w1 (nn.Linear): Linear transformation for the first layer.
w2 (nn.Linear): Linear transformation for the second layer.
w3 (nn.Linear): Linear transformation for the third
layer.

w3 (nn.Linear): Linear transformation for the third layer.
w13 (nn.Linear): Linear transformation for the first and the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of))

self.w1 = nn.Linear(dim, hidden_dim, bias_attr=False)
self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False)
self.w3 = nn.Linear(dim, hidden_dim, bias_attr=False)
self.callZKK = callZKK
if not callZKK:
self.w1 = nn.Linear(dim, hidden_dim, bias_attr=False)
self.w3 = nn.Linear(dim, hidden_dim, bias_attr=False)
else:
self.w13 = nn.Linear(dim, hidden_dim * 2, bias_attr=False)

def compute_activation(self,
ffn1_out,
bias=None,
dequant_scales=None,
shift=None,
smooth=None,
act_method="swiglu",
compute_dtype="default",
quant_scale=-1,
quant_round_type=0,
quant_max_bound=0,
quant_min_bound=0):
origin_batch_size = ffn1_out.shape[0]
origin_seq_len = ffn1_out.shape[1]
ffn1_out = ffn1_out.reshape([origin_batch_size*origin_seq_len, ffn1_out.shape[-1]])
if in_dynamic_mode():
out = paddle._C_ops.fused_bias_act(
ffn1_out,
bias,
dequant_scales,
shift,
smooth,
act_method,
compute_dtype,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound
)
return out.reshape([origin_batch_size, origin_seq_len, out.shape[-1]])

helper = LayerHelper("fused_bias_act")
out = helper.create_variable_for_type_inference(dtype=ffn1_out.dtype)
inputs = {}
inputs["x"] = ffn1_out
attrs = {
"act_method": act_method,
"compute_dtype": compute_dtype,
"quant_scale": quant_scale,
"quant_round_type": quant_round_type,
"quant_max_bound": quant_max_bound,
"quant_min_bound": quant_min_bound,
}
helper.append_op(
type="fused_bias_act",
inputs=inputs,
outputs={"out": out},
attrs=attrs,
)
return out.reshape([origin_batch_size, origin_seq_len, out.shape[-1]])

def forward(self, x):
xw1 = F.silu(self.w1(x))
xw3 = self.w3(x)
output = self.w2(xw1 * xw3)
return output
if not self.callZKK:
xw1 = F.silu(self.w1(x))
xw3 = self.w3(x)
output = self.w2(xw1 * xw3)
return output
else:
ffn1_out = self.w13(x)
ffn1_out = self.compute_activation(ffn1_out)
ffn2_out = self.w2(ffn1_out)
return ffn2_out


class TransformerBlock(nn.Layer):
Expand All @@ -303,6 +386,7 @@ def __init__(
norm_eps: float,
qk_norm: bool,
fused_attn: bool,
callZKK=False,
) -> None:
"""
Initialize a TransformerBlock.
Expand Down Expand Up @@ -337,10 +421,11 @@ def __init__(
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn)
self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn, callZKK=callZKK)
mlp_hidden_dim = int(dim * mlp_ratio)
self.feed_forward = FeedForward(
dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier
dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier, callZKK=callZKK
)
self.layer_id = layer_id
self.attention_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False)
Expand All @@ -350,6 +435,8 @@ def __init__(
nn.Silu(),
nn.Linear(min(dim, 1024), 6 * dim),
)
self.norm_eps = norm_eps
self.callZKK = callZKK

def forward(self, x, freqs_cis, adaln_input=None):
"""
Expand All @@ -370,10 +457,17 @@ def forward(self, x, freqs_cis, adaln_input=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(
6, axis=1
)
h = x + gate_msa.unsqueeze(1) * self.attention(
modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis
)
out = h + gate_mlp.unsqueeze(1) * self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp))
if not self.callZKK:
h = x + gate_msa.unsqueeze(1) * self.attention(
modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis
)
out = h + gate_mlp.unsqueeze(1) * self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp))
else:
attention_out = self.attention(adaptive_layer_norm(x, scale_msa, shift_msa,
weight=self.attention_norm.weight, epsilon=self.norm_eps), freqs_cis)
residual_out, adaLN_out = fused_adaLN_scale_residual(x, attention_out, gate_msa, scale_mlp, shift_mlp,
weight=self.ffn_norm.weight, epsilon=self.norm_eps)
out = residual_out + gate_mlp.unsqueeze(1) * self.feed_forward(adaLN_out)
else:
h = x + self.attention(self.attention_norm(x), freqs_cis)
out = h + self.feed_forward(self.ffn_norm(h))
Expand Down Expand Up @@ -435,14 +529,14 @@ def __init__(
self.num_classes = num_classes
self.learn_sigma = learn_sigma
self.qk_norm = qk_norm

self.gradient_checkpointing = True
self.fused_attn = True

self.x_embedder = nn.Linear(in_channels * patch_size**2, dim)
self.t_embedder = TimestepEmbedder(min(dim, 1024))
self.y_embedder = LabelEmbedding(num_classes, min(dim, 1024), class_dropout_prob)

self.callZKK = True if os.getenv('callZKK') else False
# 2. Define transformers blocks
self.layers = nn.LayerList(
[
Expand All @@ -457,10 +551,13 @@ def __init__(
norm_eps=norm_eps,
qk_norm=qk_norm,
fused_attn=self.fused_attn,
callZKK=self.callZKK,
)
for idx in range(num_layers)
]
)

del self.layers

# 3. Define output layers
self.final_layer = FinalLayer(dim, patch_size, self.out_channels)
Expand Down Expand Up @@ -531,6 +628,21 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
)
return freqs_cis

@paddle.incubate.layers.inference(with_trt=False,
cache_static_model=True,
collect_shape=False)
def transformer_blocks(self, x, adaln_input):
for i, layer in enumerate(self.layers):
if self.gradient_checkpointing and False:
x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input)
else:
x = layer(
x,
self.freqs_cis[: x.shape[1]],
adaln_input,
)
return x

def forward(
self,
hidden_states: paddle.Tensor,
Expand All @@ -556,16 +668,19 @@ def forward(
adaln_input = t + y

# 2. Blocks
for i, layer in enumerate(self.layers):
if self.gradient_checkpointing:
x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input)
else:
x = layer(
x,
self.freqs_cis[: x.shape[1]],
adaln_input,
)

if not self.callZKK:
for i, layer in enumerate(self.layers):
if self.gradient_checkpointing:
x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input)
else:
x = layer(
x,
self.freqs_cis[: x.shape[1]],
adaln_input,
)
else:
x = self.transformer_blocks(x, adaln_input)

# 3. Output
hidden_states = self.final_layer(x, adaln_input)
output = self.unpatchify(hidden_states)
Expand All @@ -574,3 +689,25 @@ def forward(
return (output,)

return Transformer2DModelOutput(sample=output)

@classmethod
def custom_modify_weight(cls, state_dict):
# If you're not invited to zkk, you won't get any performance optimizations.
if os.getenv('callZKK'):
for key in list(state_dict.keys()):
if 'feed_forward.w1.weight' in key:
w1 = state_dict.pop(key)
w3_key = key.replace('w1', 'w3')
w3 = state_dict.pop(w3_key)
w13 = paddle.concat([w1, w3], axis=1)
state_dict[key.replace('w1', 'w13')] = w13
if 'attention.wq.weight' in key or 'attention.wk.weight' in key or 'attention.wv.weight' in key:
part = key.split('.')[-2]
layer_id = key.split('.')[1]
qkv_key = f'layers.{layer_id}.attention.qkv.weight'
if part == 'wq' and qkv_key not in state_dict:
state_dict[qkv_key] = state_dict.pop(key)
elif part in ('wk', 'wv'):
qkv = state_dict.get(qkv_key)
if qkv is not None:
state_dict[qkv_key] = paddle.concat([qkv, state_dict.pop(key)], axis=1)
5 changes: 5 additions & 0 deletions ppdiffusers/ppdiffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

return model

@classmethod
def custom_modify_weight(cls, state_dict):
pass

@classmethod
def _load_pretrained_model(
cls,
Expand Down Expand Up @@ -1130,6 +1134,7 @@ def _find_mismatched_keys(
error_msgs.append(
f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}."
)
cls.custom_modify_weight(state_dict)
faster_set_state_dict(model_to_load, state_dict)

missing_keys = sorted(list(set(expected_keys) - set(loaded_keys)))
Expand Down