Skip to content

Commit

Permalink
add inplace argument
Browse files Browse the repository at this point in the history
  • Loading branch information
ain-soph committed Aug 10, 2024
1 parent a7fb442 commit 677722a
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions ChatTTS/model/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def apply(
input_ids: torch.Tensor,
spk_emb_ids: int,
device: torch.device,
inplace: bool = True,
):
if isinstance(spk_emb, str):
spk_emb_tensor = torch.from_numpy(self._decode(spk_emb))
Expand All @@ -45,8 +46,10 @@ def apply(
.expand(emb.shape)
)
cond = input_ids.narrow(-1, 0, 1).eq(spk_emb_ids).expand(emb.shape)
torch.where(cond, n, emb, out=emb)
del cond, n
out = torch.where(cond, n, emb, out=emb if inplace else None)
if inplace:
del cond, n
return out

@staticmethod
@torch.no_grad()
Expand Down

0 comments on commit 677722a

Please sign in to comment.