From d76b9174d1ba8e9686492c49453c314fa583db1b Mon Sep 17 00:00:00 2001 From: heiway Date: Sat, 4 Jan 2025 16:28:07 +0800 Subject: [PATCH] fix a dtype mismatch when use mps (#790) * fix a dtype mismatch when use mps * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- fish_speech/models/text2semantic/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py index be24c598..32307f87 100644 --- a/fish_speech/models/text2semantic/llama.py +++ b/fish_speech/models/text2semantic/llama.py @@ -249,7 +249,7 @@ def setup_caches( def embed(self, inp: Tensor, share_codebook_embeddings=True) -> Tensor: embeds = [] semantic_token_ids_tensor = torch.tensor( - self.semantic_token_ids, device=inp.device + self.semantic_token_ids, device=inp.device, dtype=inp.dtype ) for i in range(self.config.num_codebooks):