Skip to content

Commit

Permalink
fix a dtype mismatch when use mps (#790)
Browse files Browse the repository at this point in the history
* fix a dtype mismatch when use mps

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
heiway and pre-commit-ci[bot] authored Jan 4, 2025
1 parent 74c3cb3 commit d76b917
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion fish_speech/models/text2semantic/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def setup_caches(
def embed(self, inp: Tensor, share_codebook_embeddings=True) -> Tensor:
embeds = []
semantic_token_ids_tensor = torch.tensor(
self.semantic_token_ids, device=inp.device
self.semantic_token_ids, device=inp.device, dtype=inp.dtype
)

for i in range(self.config.num_codebooks):
Expand Down

0 comments on commit d76b917

Please sign in to comment.