Skip to content

Commit

Permalink
[Bugfix] avoid yunchang undefined & format
Browse files Browse the repository at this point in the history
  • Loading branch information
DefTruth authored Jan 18, 2025
1 parent b2bcc38 commit d6332e9
Showing 1 changed file with 67 additions and 39 deletions.
106 changes: 67 additions & 39 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
JointAttnProcessor2_0,
FluxAttnProcessor2_0,
HunyuanAttnProcessor2_0,
CogVideoXAttnProcessor2_0
CogVideoXAttnProcessor2_0,
)

try:
from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoAttnProcessor2_0
from diffusers.models.transformers.transformer_hunyuan_video import (
HunyuanVideoAttnProcessor2_0,
)
except ImportError:
HunyuanVideoAttnProcessor2_0 = None

from diffusers.models.embeddings import apply_rotary_emb

from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_pipeline_parallel_world_size
get_pipeline_parallel_world_size,
)
from xfuser.core.fast_attention import (
xFuserFastAttention,
Expand All @@ -38,7 +40,7 @@
from xfuser.logger import init_logger
from xfuser.envs import PACKAGES_CHECKER

if torch.__version__ >= '2.5.0':
if torch.__version__ >= "2.5.0":
from xfuser.model_executor.layers.usp import USP
else:
from xfuser.model_executor.layers.usp_legacy import USP
Expand Down Expand Up @@ -195,7 +197,9 @@ def __init__(self):
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
import yunchang
from yunchang.kernels import AttnType

assert yunchang.__version__ >= "0.6.0"
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
Expand Down Expand Up @@ -404,6 +408,7 @@ def __init__(self):
)
else:
from yunchang.kernels import AttnType

self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=AttnType.TORCH,
Expand Down Expand Up @@ -467,19 +472,23 @@ def __call__(
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)

encoder_hidden_states_query_proj = None
encoder_hidden_states_key_proj = None
encoder_hidden_states_value_proj = None
else:
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
encoder_hidden_states_query_proj = (
encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
encoder_hidden_states_value_proj = (
encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)
)
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
Expand Down Expand Up @@ -597,6 +606,7 @@ def __init__(self):
)
else:
from yunchang.kernels import AttnType

self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=AttnType.TORCH,
Expand Down Expand Up @@ -679,7 +689,10 @@ def __call__(
key = apply_rotary_emb(key, image_rotary_emb)

#! ---------------------------------------- KV CACHE ----------------------------------------
if get_runtime_state().num_pipeline_patch > 1 and not self.use_long_ctx_attn_kvcache:
if (
get_runtime_state().num_pipeline_patch > 1
and not self.use_long_ctx_attn_kvcache
):
encoder_hidden_states_key_proj, key = key.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=2
)
Expand All @@ -697,10 +710,11 @@ def __call__(
#! ---------------------------------------- KV CACHE ----------------------------------------

#! ---------------------------------------- ATTENTION ----------------------------------------
if get_pipeline_parallel_world_size() == 1 and get_runtime_state().split_text_embed_in_sp:
hidden_states = USP(
query, key, value, dropout_p=0.0, is_causal=False
)
if (
get_pipeline_parallel_world_size() == 1
and get_runtime_state().split_text_embed_in_sp
):
hidden_states = USP(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
Expand Down Expand Up @@ -798,6 +812,7 @@ def __init__(self):
)
else:
from yunchang.kernels import AttnType

self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=AttnType.TORCH,
Expand Down Expand Up @@ -1000,6 +1015,7 @@ def __init__(self):
)
else:
from yunchang.kernels import AttnType

self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=AttnType.TORCH,
Expand Down Expand Up @@ -1063,10 +1079,11 @@ def __call__(
)

#! ---------------------------------------- ATTENTION ----------------------------------------
if get_pipeline_parallel_world_size() == 1 and get_runtime_state().split_text_embed_in_sp:
hidden_states = USP(
query, key, value, dropout_p=0.0, is_causal=False
)
if (
get_pipeline_parallel_world_size() == 1
and get_runtime_state().split_text_embed_in_sp
):
hidden_states = USP(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
Expand Down Expand Up @@ -1104,9 +1121,7 @@ def __call__(
joint_strategy="front",
)

hidden_states = hidden_states.reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
else:
if HAS_FLASH_ATTN:
from flash_attn import flash_attn_func
Expand Down Expand Up @@ -1144,7 +1159,6 @@ def __call__(
# dropout
hidden_states = attn.to_out[1](hidden_states)


encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, latent_seq_length], dim=1
)
Expand Down Expand Up @@ -1177,6 +1191,7 @@ def __init__(self):
)
else:
from yunchang.kernels import AttnType

self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=AttnType.TORCH,
Expand Down Expand Up @@ -1240,10 +1255,11 @@ def __call__(
)

#! ---------------------------------------- ATTENTION ----------------------------------------
if get_pipeline_parallel_world_size() == 1 and get_runtime_state().split_text_embed_in_sp:
hidden_states = USP(
query, key, value, dropout_p=0.0, is_causal=False
)
if (
get_pipeline_parallel_world_size() == 1
and get_runtime_state().split_text_embed_in_sp
):
hidden_states = USP(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
Expand Down Expand Up @@ -1281,9 +1297,7 @@ def __call__(
joint_strategy="front",
)

hidden_states = hidden_states.reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
else:
if HAS_FLASH_ATTN:
from flash_attn import flash_attn_func
Expand Down Expand Up @@ -1321,13 +1335,14 @@ def __call__(
# dropout
hidden_states = attn.to_out[1](hidden_states)


encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, latent_seq_length], dim=1
)
return hidden_states, encoder_hidden_states


if HunyuanVideoAttnProcessor2_0 is not None:

@xFuserAttentionProcessorRegister.register(HunyuanVideoAttnProcessor2_0)
class xFuserHunyuanVideoAttnProcessor2_0(HunyuanVideoAttnProcessor2_0):
def __init__(self):
Expand All @@ -1349,6 +1364,7 @@ def __init__(self):
)
else:
from yunchang.kernels import AttnType

self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=AttnType.TORCH,
Expand Down Expand Up @@ -1395,14 +1411,20 @@ def __call__(
if attn.add_q_proj is None and encoder_hidden_states is not None:
query = torch.cat(
[
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
apply_rotary_emb(
query[:, :, : -encoder_hidden_states.shape[1]],
image_rotary_emb,
),
query[:, :, -encoder_hidden_states.shape[1] :],
],
dim=2,
)
key = torch.cat(
[
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
apply_rotary_emb(
key[:, :, : -encoder_hidden_states.shape[1]],
image_rotary_emb,
),
key[:, :, -encoder_hidden_states.shape[1] :],
],
dim=2,
Expand All @@ -1417,9 +1439,13 @@ def __call__(
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)

encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(
1, 2
)
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(
1, 2
)

if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
Expand All @@ -1440,10 +1466,11 @@ def __call__(
num_query_tokens = query.shape[2] - num_encoder_hidden_states_tokens

#! ---------------------------------------- ATTENTION ----------------------------------------
if get_pipeline_parallel_world_size() == 1 and get_runtime_state().split_text_embed_in_sp:
hidden_states = USP(
query, key, value, dropout_p=0.0, is_causal=False
)
if (
get_pipeline_parallel_world_size() == 1
and get_runtime_state().split_text_embed_in_sp
):
hidden_states = USP(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
elif HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
if get_runtime_state().split_text_embed_in_sp:
Expand Down Expand Up @@ -1520,5 +1547,6 @@ def __call__(
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

return hidden_states, encoder_hidden_states

else:
xFuserHunyuanVideoAttnProcessor2_0 = None
xFuserHunyuanVideoAttnProcessor2_0 = None

0 comments on commit d6332e9

Please sign in to comment.