Skip to content

Commit 020560d

Browse files
Fix num_frames in i2v (#339)
* Fix num_frames in i2v * Remove print in flash_attention
1 parent af7d305 commit 020560d

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

diffsynth/models/wan_video_dit.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def half(x):
112112
causal=causal,
113113
deterministic=deterministic)[0].unflatten(0, (b, lq))
114114
elif FLASH_ATTN_2_AVAILABLE:
115-
print(q_lens, lq, k_lens, lk, causal, window_size)
116115
x = flash_attn.flash_attn_varlen_func(
117116
q=q,
118117
k=k,
@@ -128,7 +127,6 @@ def half(x):
128127
causal=causal,
129128
window_size=window_size,
130129
deterministic=deterministic).unflatten(0, (b, lq))
131-
print(x.shape)
132130
else:
133131
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
134132
k = k.unsqueeze(0).transpose(1, 2).to(dtype)

diffsynth/pipelines/wan_video.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,16 +150,16 @@ def encode_prompt(self, prompt, positive=True):
150150
return {"context": prompt_emb}
151151

152152

153-
def encode_image(self, image, height, width):
153+
def encode_image(self, image, num_frames, height, width):
154154
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
155155
image = self.preprocess_image(image.resize((width, height))).to(self.device)
156156
clip_context = self.image_encoder.encode_image([image])
157-
msk = torch.ones(1, 81, height//8, width//8, device=self.device)
157+
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
158158
msk[:, 1:] = 0
159159
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
160160
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
161161
msk = msk.transpose(1, 2)[0]
162-
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, 80, height, width).to(image.device)], dim=1)], device=self.device)[0]
162+
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
163163
y = torch.concat([msk, y])
164164
return {"clip_fea": clip_context, "y": [y]}
165165

@@ -234,7 +234,7 @@ def __call__(
234234
# Encode image
235235
if input_image is not None and self.image_encoder is not None:
236236
self.load_models_to_device(["image_encoder", "vae"])
237-
image_emb = self.encode_image(input_image, height, width)
237+
image_emb = self.encode_image(input_image, num_frames, height, width)
238238
else:
239239
image_emb = {}
240240

0 commit comments

Comments
 (0)