diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index f3c57103f9b8..69b3ee8466f4 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -691,7 +691,7 @@ def _get_positional_embeddings( output_type="pt", ) pos_embedding = pos_embedding.flatten(0, 1) - joint_pos_embedding = torch.zeros( + joint_pos_embedding = pos_embedding.new_zeros( 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False ) joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)