Skip to content

Commit ff26394

Browse files
charchit7Charchit Sharma
andauthored
Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers (#12594)
* Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers - Store t_dim, h_dim, w_dim as instance variables in WanRotaryPosEmbed and SkyReelsV2RotaryPosEmbed __init__ - Use stored dimensions in forward() instead of recalculating with different formula - Fixes inconsistency between init (using // 6) and forward (using // 3) - Ensures split_sizes matches the dimensions used to create rotary embeddings * quality fix --------- Co-authored-by: Charchit Sharma <[email protected]>
1 parent 66e6a02 commit ff26394

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

src/diffusers/models/transformers/transformer_skyreels_v2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,10 @@ def __init__(
389389
t_dim = attention_head_dim - h_dim - w_dim
390390
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
391391

392+
self.t_dim = t_dim
393+
self.h_dim = h_dim
394+
self.w_dim = w_dim
395+
392396
freqs_cos = []
393397
freqs_sin = []
394398

@@ -412,11 +416,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
412416
p_t, p_h, p_w = self.patch_size
413417
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
414418

415-
split_sizes = [
416-
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
417-
self.attention_head_dim // 3,
418-
self.attention_head_dim // 3,
419-
]
419+
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
420420

421421
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
422422
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,11 @@ def __init__(
362362

363363
h_dim = w_dim = 2 * (attention_head_dim // 6)
364364
t_dim = attention_head_dim - h_dim - w_dim
365+
366+
self.t_dim = t_dim
367+
self.h_dim = h_dim
368+
self.w_dim = w_dim
369+
365370
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
366371

367372
freqs_cos = []
@@ -387,11 +392,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
387392
p_t, p_h, p_w = self.patch_size
388393
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
389394

390-
split_sizes = [
391-
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
392-
self.attention_head_dim // 3,
393-
self.attention_head_dim // 3,
394-
]
395+
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
395396

396397
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
397398
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)

0 commit comments

Comments
 (0)