-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
optimize: separate speaker from tokenizer
- move `spk_stat.pt` into config - move `sample_audio_speaker` to DVAE - update rvcmd to `v0.2.7`
- Loading branch information
Showing
12 changed files
with
171 additions
and
170 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .dvae import DVAE | ||
from .gpt import GPT | ||
from .processors import gen_logits | ||
from .speaker import Speaker | ||
from .tokenizer import Tokenizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import lzma | ||
from typing import List, Optional, Union | ||
|
||
import pybase16384 as b14 | ||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
class Speaker: | ||
def __init__(self, dim: int, spk_cfg: str, device=torch.device("cpu")) -> None: | ||
spk_stat = torch.from_numpy(np.frombuffer(b14.decode_from_string(spk_cfg), dtype=np.float16).copy()).to(device=device) | ||
self.std, self.mean = spk_stat.requires_grad_(False).chunk(2) | ||
self.dim = dim | ||
|
||
def sample_random(self) -> str: | ||
return self._encode(self._sample_random()) | ||
|
||
@torch.no_grad() | ||
def apply( | ||
self, | ||
emb: torch.Tensor, | ||
spk_emb: str, | ||
input_ids: torch.Tensor, | ||
spk_emb_ids: int, | ||
device: torch.device, | ||
): | ||
n = ( | ||
F.normalize( | ||
torch.from_numpy( | ||
self._decode(spk_emb), | ||
), | ||
p=2.0, | ||
dim=0, | ||
eps=1e-12, | ||
) | ||
.to(device) | ||
.unsqueeze_(0) | ||
.expand(emb.size(0), -1) | ||
.unsqueeze_(1) | ||
.expand(emb.shape) | ||
) | ||
cond = input_ids.narrow(-1, 0, 1).eq(spk_emb_ids).expand(emb.shape) | ||
torch.where(cond, n, emb, out=emb) | ||
del cond, n | ||
|
||
@staticmethod | ||
@torch.no_grad() | ||
def decorate_code_prompts( | ||
text: List[str], | ||
prompt: str, | ||
txt_smp: Optional[str], | ||
spk_emb: Optional[str], | ||
) -> List[str]: | ||
for i, t in enumerate(text): | ||
text[i] = ( | ||
t.replace("[Stts]", "") | ||
.replace("[spk_emb]", "") | ||
.replace("[empty_spk]", "") | ||
.strip() | ||
) | ||
""" | ||
see https://github.com/2noise/ChatTTS/issues/459 | ||
""" | ||
|
||
if prompt: | ||
text = [prompt + i for i in text] | ||
|
||
txt_smp = "" if txt_smp is None else txt_smp | ||
if spk_emb is not None: | ||
text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text] | ||
else: | ||
text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text] | ||
|
||
return text | ||
|
||
@staticmethod | ||
@torch.no_grad() | ||
def decorate_text_prompts(text: List[str], prompt: str) -> List[str]: | ||
return [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text] | ||
|
||
@staticmethod | ||
@torch.no_grad() | ||
def encode_prompt(prompt: torch.Tensor) -> str: | ||
arr: np.ndarray = prompt.cpu().numpy().astype(np.uint16) | ||
shp = arr.shape | ||
assert len(shp) == 2, "prompt must be a 2D tensor" | ||
s = b14.encode_to_string( | ||
np.array(shp, dtype="<u2").tobytes() | ||
+ lzma.compress( | ||
arr.astype("<u2").tobytes(), | ||
format=lzma.FORMAT_RAW, | ||
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], | ||
), | ||
) | ||
del arr | ||
return s | ||
|
||
@staticmethod | ||
@torch.no_grad() | ||
def decode_prompt(prompt: str) -> torch.Tensor: | ||
dec = b14.decode_from_string(prompt) | ||
shp = np.frombuffer(dec[:4], dtype="<u2") | ||
p = np.frombuffer( | ||
lzma.decompress( | ||
dec[4:], | ||
format=lzma.FORMAT_RAW, | ||
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], | ||
), | ||
dtype="<u2", | ||
).copy() | ||
del dec | ||
return torch.from_numpy(p.astype(np.int32)).view(*shp) | ||
|
||
@torch.no_grad() | ||
def _sample_random(self) -> torch.Tensor: | ||
spk = ( | ||
torch.randn(self.dim, device=self.std.device, dtype=self.std.dtype) | ||
.mul_(self.std) | ||
.add_(self.mean) | ||
) | ||
return spk | ||
|
||
@staticmethod | ||
@torch.no_grad() | ||
def _encode(spk_emb: torch.Tensor) -> str: | ||
arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy() | ||
s = b14.encode_to_string( | ||
lzma.compress( | ||
arr.tobytes(), | ||
format=lzma.FORMAT_RAW, | ||
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], | ||
), | ||
) | ||
del arr | ||
return s | ||
|
||
@staticmethod | ||
def _decode(spk_emb: str) -> np.ndarray: | ||
return np.frombuffer( | ||
lzma.decompress( | ||
b14.decode_from_string(spk_emb), | ||
format=lzma.FORMAT_RAW, | ||
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], | ||
), | ||
dtype=np.float16, | ||
).copy() |
Oops, something went wrong.