Skip to content

Commit

Permalink
[tests] Fix broken cuda, nightly and lora tests on main for CogVideoX (
Browse files Browse the repository at this point in the history
…#10270)

fix joint pos embedding device
  • Loading branch information
a-r-r-o-w authored Dec 18, 2024
1 parent 862a7d5 commit c4c99c3
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def _get_positional_embeddings(
output_type="pt",
)
pos_embedding = pos_embedding.flatten(0, 1)
joint_pos_embedding = torch.zeros(
joint_pos_embedding = pos_embedding.new_zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
Expand Down

0 comments on commit c4c99c3

Please sign in to comment.