@@ -150,16 +150,16 @@ def encode_prompt(self, prompt, positive=True):
150150 return {"context" : prompt_emb }
151151
152152
153- def encode_image (self , image , height , width ):
153+ def encode_image (self , image , num_frames , height , width ):
154154 with torch .amp .autocast (dtype = torch .bfloat16 , device_type = torch .device (self .device ).type ):
155155 image = self .preprocess_image (image .resize ((width , height ))).to (self .device )
156156 clip_context = self .image_encoder .encode_image ([image ])
157- msk = torch .ones (1 , 81 , height // 8 , width // 8 , device = self .device )
157+ msk = torch .ones (1 , num_frames , height // 8 , width // 8 , device = self .device )
158158 msk [:, 1 :] = 0
159159 msk = torch .concat ([torch .repeat_interleave (msk [:, 0 :1 ], repeats = 4 , dim = 1 ), msk [:, 1 :]], dim = 1 )
160160 msk = msk .view (1 , msk .shape [1 ] // 4 , 4 , height // 8 , width // 8 )
161161 msk = msk .transpose (1 , 2 )[0 ]
162- y = self .vae .encode ([torch .concat ([image .transpose (0 , 1 ), torch .zeros (3 , 80 , height , width ).to (image .device )], dim = 1 )], device = self .device )[0 ]
162+ y = self .vae .encode ([torch .concat ([image .transpose (0 , 1 ), torch .zeros (3 , num_frames - 1 , height , width ).to (image .device )], dim = 1 )], device = self .device )[0 ]
163163 y = torch .concat ([msk , y ])
164164 return {"clip_fea" : clip_context , "y" : [y ]}
165165
@@ -234,7 +234,7 @@ def __call__(
234234 # Encode image
235235 if input_image is not None and self .image_encoder is not None :
236236 self .load_models_to_device (["image_encoder" , "vae" ])
237- image_emb = self .encode_image (input_image , height , width )
237+ image_emb = self .encode_image (input_image , num_frames , height , width )
238238 else :
239239 image_emb = {}
240240
0 commit comments