From 3e4ed8891cda3a645c5ceaa7b3827638d860fea4 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 19 Jul 2024 22:13:41 +0900 Subject: [PATCH] chore(format): run black on dev (#602) Co-authored-by: github-actions[bot] --- ChatTTS/model/gpt.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 7b9740892..e8bf1fc8f 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -406,12 +406,15 @@ def generate( attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_( attention_mask ) - + progress = inputs_ids.size(1) # pre-allocate inputs_ids inputs_ids_buf = torch.zeros( - inputs_ids.size(0), progress+max_new_token, inputs_ids.size(2), - dtype=inputs_ids.dtype, device=inputs_ids.device, + inputs_ids.size(0), + progress + max_new_token, + inputs_ids.size(2), + dtype=inputs_ids.dtype, + device=inputs_ids.device, ) inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids) del inputs_ids @@ -502,7 +505,9 @@ def generate( logits = logits.reshape(-1, logits.size(2)) # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c") inputs_ids_sliced = inputs_ids.narrow( - 1, start_idx, inputs_ids.size(1)-start_idx, + 1, + start_idx, + inputs_ids.size(1) - start_idx, ).permute(0, 2, 1) logits_token = inputs_ids_sliced.reshape( inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1), @@ -510,9 +515,15 @@ def generate( ).to(self.device) del inputs_ids_sliced else: - logits_token = inputs_ids.narrow( - 1, start_idx, inputs_ids.size(1)-start_idx, - ).narrow(2, 0, 1).to(self.device) + logits_token = ( + inputs_ids.narrow( + 1, + start_idx, + inputs_ids.size(1) - start_idx, + ) + .narrow(2, 0, 1) + .to(self.device) + ) logits /= temperature