diff --git a/.github/workflows/checksum.yml b/.github/workflows/checksum.yml index 28c9b60da..55e6843d5 100644 --- a/.github/workflows/checksum.yml +++ b/.github/workflows/checksum.yml @@ -13,7 +13,7 @@ jobs: - name: Run RVC-Models-Downloader run: | - wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.6/rvcmd_linux_amd64.deb + wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.7/rvcmd_linux_amd64.deb sudo apt -y install ./rvcmd_linux_amd64.deb rm -f ./rvcmd_linux_amd64.deb rvcmd -notrs -w 1 -notui assets/chtts diff --git a/ChatTTS/config/config.py b/ChatTTS/config/config.py index 661407c3e..8689e8672 100644 --- a/ChatTTS/config/config.py +++ b/ChatTTS/config/config.py @@ -119,3 +119,4 @@ class Config: dvae: DVAE = DVAE() gpt: GPT = GPT() vocos: Vocos = Vocos() + spk_stat: str = "愐穤巩噅廷戇笉屈癐媄垹垧帶爲漈塀殐慄亅倴庲舴猂瑈圐狴夥圓帍戛挠腉耐劤坽喳幾战謇聀崒栄呥倸庭燡欈杁襐褄乭埗幺爃弔摁斐捔兕佖廐舏竾豃磐姓趡佄幒爚欄豄讐皳訵仩帆投謌荃蝐叄圝伆幦抂茁呄掑斃讹傮庞爣蜀橁偐祄亥兡常爂欍扉丐浔佱僈強払伅扂蛐徴憍傞巀戺欀艂琐嗴啥値彷刂權穈扒卤俔贲庛初笂卄贐枴仭亁庛剎猢扃缐趤刁偵幪舏伌煁婐潤晍位弾舙茥穁葏蠣訑企庤刊笍橁溑僔云偁庯戚伍潉膐脴僵噔廃艅匊祂唐憴壝嗙席爥欁虁谐牴帽势弿牳蜁兀蛐傄喩丿帔刔圆衁廐罤庁促帙劢伈汄樐檄勵伴弝舑欍罅虐昴劭勅帜刼朊蕁虐蓴樑伫幨扑謪剀堐稴丵伱弐舮諸赁習俔容厱幫牶謃孄糐答嗝僊帜燲笄終瀒判久僤帘爴茇千孑冄凕佳引扐蜁歁缏裄剽儺恘爋朏眿廐呄塍嘇幻爱茠詁訐剴唭俐幾戊欀硁菐贄楕偒巡爀弎屄莐睳賙凶彎刅漄區唐溴剑劋庽舽猄煃跐夔惥伾庮舎伈罁垑坄怅业怯刁朇獁嶏覔坩俳巶爜朐潁崐萄俹凛常爺笌穀聐此夡倛帡刀匉終窏舣販侽怿扉伥贿憐忓謩姆幌犊漂慆癒却甝兎帼戏欅詂浐朔仹壭帰臷弎恇菐獤帡偖帘爞伅腂皐纤囅充幓戠伥灂丐訤戱倱弋爮嬌癁恐孄侥劬忶刓國詀桒古偩嘄庬戚茝赂监燤嘑勌幦舽持呂諐棤姑再底舡笍艃瀐孴倉傔弋爔猠乁濑塄偽嘧恂舛缇襃厐窴仡刱忕別漇穁岏缴廽价庌爊謈硄讑惤倁儂庭爋伇蝂嶐莔摝傠库刞茄歃戏薤伍伯廮创笠塄熐兴勽俄帅剉最腀砐敤卝侍弆戺朒虃旐蚄梕亖幔牻朣扅贐玔堝噅帡剌圅摀崐彤流僳庙爖嬇啁渐悤堁丛幆刧挜彃悐幤刹嚟恕芁看聀摐焔向乁帖爭欁癃糒圄弙佱廜戤謍婀咐昴焍亩廦艏拼謿芐癤怹兽幸舳朇畁喐稔毝丼弈懲挀譂勑哴啁伎常舭笯晁堑俄叩剔廟爍欦絁夒伤休傑廳戌蜅潆癐彴摑勯床刽欅艁砐忄搉从廡舊猥潂唐委仱僜廼爤朄呃弐礔滵垓幩爄挂筁乐籤刕凟幵爠弉癅乑吴勥伖帪舩茆婁碐幤叭乢巜艳猁桀桐啄唩俊幍舮猀艅焐螔琽亀帋爜缅噃咐斤喩予幩爛笆摀浐猴依侹幃刕園慄蛐栤澹仑座爼謉桃慐浔斕偻幛懰嬓衁愐氄悅仿应芔漄衃敐謤傁匩幹抃圉癄廐裄屵噉幍利謍聂搐蛔嚙坍怗舁圐畃膐栄刵东巆戤諾呃偑媤嗨跞忶爝眄祂朒嶔僭劉忾刐匋癄袐翴珅僷廲芄茈恈皐擄崑伄廉牍匃剃犏澤唑丄庺戃伃煀某杄偙亽帴切缌罄挐尴噙倰带舞漄橄塐糴俩僯帀般漀坂栐更両俇廱舌猁慂拐偤嶱卶应刪眉獁茐伔嘅偺帟舊漂恀栐暄喡乞庙舆匂敀潑恔劑侖延戦盽怶唯慳蝘蟃孫娎益袰玍屃痶翮笪儚裀倹椌玻翀詵筽舘惯堿某侰晈藏缮詗廦夸妎瑻瀒裔媀憞唃冶璭狻渠荑奬熹茅愺氰菣滠翦岓褌泣崲嚭欓湒聙宺爄蛅愸庍匃帆誔穮懌蓪玷澌氋抌訙屌臞廛玸听屺希疭孝凂紋新煎彃膲跱尪懁眆窴珏卓揨菸紭概囥显壌榄垫嘮嬭覤媸侵佮烒耸觌婀秋狃帹葯訤桜糨笾腢伀肶悍炂艤禖岅臺惘梷瞍友盁佨岧憳瓧嘴汬藊愌蘤嶠硴绤蜲襏括勾谂縨妥蓪澭竭萢藜纞糲煮愆瀯孯琓罂諺塿燗狟弙衯揻縷丱糅臄梱瀮杰巳猙亊符胠匃泀廏圃膂蒃籏礩岈簹缌劺燲褡孓膜拔蠿觮呋煣厌尷熜論弲牭紫寊誃紀橴賬傸箍弚窃侫簲慯烣渽祌壓媥噜夽夛諛玹疮禄冪謇媽衤盰缺繑薫兾萧嵱打滽箺嚯凣狢蠜崼覽烸簶盯籓摀苶峸懗泲涻凮愳緗剋笔懆廡瞿椏礤惐藥崍腈烄伹亯昣翬褍絋桫僨吨莌丛矄蜞娈憊苆塁蓏嚢嫼绻崱婋囱蠸篯晣芀繼索兓僖誹岯圪褰蠇唓妷胅巁渮砛傈蝷嵚冃購赁峍裋荂舾符熻岳墩寮粃凲袑彚太绲头摯繳狁俥籌冝諝註坎幫擤詒宒凕賐唶梎噔弼課屿覍囨焬櫱撪蝮蝬簸懰櫫涺嵍睻屪翔峞慘滟熲昱军烊舿尦舄糖奁溏凂彆蝲糴禍困皻灏牋睒诙嶱臀开蓈眎腼丢纻廏憤嫖暭袭崲肸螛妒榗紉谨窮袃瑠聍绊腆亿冲葐喋縔詖岑兾给堸赏旻桀蛨媆訂峦紷敯囬偐筨岸焸拭笵殒哜墒萍屓娓諙械臮望摰芑寭准僞谹氍旋憢菮屃划欣瘫谎蘻哐繁籥禦僿誵皯墓燀縿笞熦绗稹榎矻綞蓓帡戓沺区才畃洊詪糐裶盰窶耎偌劂誐庩惝滜沺哮呃煐譠崄槀猄肼蔐擋湌蠺篃恥諌瞦宍堫挪裕崑慩狲悠煋仛愞砈粵八棁害楐妋萔貨尵奂苰怫誎傫岆蕯屇脉夈仆茎刓繸芺壸碗曛汁戭炻獻凉媁兎狜爴怰賃纎袏娷禃蓥膹薪渻罸窿粫凾褄舺窮墫干苊繁冏僮訸夯绛蓪虛羽慲烏憷趎睊蠰莍塞成廎盁欏喓蜮譤崆楁囘矇薭伣艘虝帴奮苢渶虎暣翐蝃尾稈糶瀴罐嵚氮葯笫慐棌悶炯竻爅们媡姢嫺窷刮歫劈裩屬椕賑蜹薊刲義哯尗褦瓀稾礋揣窼舫尋姁椄侸嗫珺修纘媃腽蛛稹梭呛瀈蘟縀礉論夵售主梮蠉娅娭裀誼嶭観枳倊簈褃擞綿催瞃溶苊笛襹櫲盅六囫獩佃粨慯瓢眸旱荃婨蔞岋祗墼焻网牻琖詆峋秉胳媴袭澓賢経稟壩胫碯偏囫嶎纆窈槊賐撹璬莃缘誾宭愊眗喷监劋萘訯總槿棭戾墮犄恌縈簍樥蛔杁袭嫛憫倆篏墵賈羯茎觳蒜致娢慄勒覸蘍曲栂葭宆妋皽缽免盳猼蔂糥觧烳檸佯憓煶蔐筼种繷琲膌塄剰讎対腕棥渽忲俛浪譬秛惛壒嘸淫冻曄睻砃奫貯庴爅粓脮脡娎妖峵蘲討惋泊蠀㴆" diff --git a/ChatTTS/core.py b/ChatTTS/core.py index cb90bdeb9..3c749ab43 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -13,7 +13,7 @@ from huggingface_hub import snapshot_download from .config import Config -from .model import DVAE, GPT, gen_logits, Tokenizer +from .model import DVAE, GPT, gen_logits, Tokenizer, Speaker from .utils import ( check_all_assets, download_all_assets, @@ -152,25 +152,12 @@ def unload(self): if hasattr(self, module): delattr(self, module) self.__init__(logger) - + def sample_random_speaker(self) -> str: - return self.tokenizer._encode_spk_emb(self._sample_random_speaker()) + return self.speaker.sample_random() - @torch.inference_mode() def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str: - if isinstance(wav, np.ndarray): - wav = torch.from_numpy(wav).to(self.device) - return self.tokenizer._encode_prompt(self.dvae(wav, "encode").squeeze_(0)) - - @torch.no_grad() - def _sample_random_speaker(self) -> torch.Tensor: - dim: int = self.config.gpt.hidden_size - spk = ( - torch.randn(dim, device=self.std.device, dtype=self.std.dtype) - .mul_(self.std) - .add_(self.mean) - ) - return spk + return self.speaker.encode_prompt(self.dvae.sample_audio(wav)) @dataclass(repr=False, eq=False) class RefineTextParams: @@ -303,15 +290,7 @@ def _load( gpt.prepare(compile=compile and "cuda" in str(device)) self.gpt = gpt - spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt") - assert os.path.exists(spk_stat_path), f"Missing spk_stat.pt: {spk_stat_path}" - spk_stat: torch.Tensor = torch.load( - spk_stat_path, - weights_only=True, - mmap=True, - map_location=device, - ) - self.std, self.mean = spk_stat.requires_grad_(False).chunk(2) + self.speaker = Speaker(self.config.gpt.hidden_size, self.config.spk_stat, device) self.logger.log(logging.INFO, "gpt loaded.") decoder = ( @@ -479,14 +458,14 @@ def _infer_code( temperature = params.temperature input_ids, attention_mask, text_mask = self.tokenizer.encode( - self.tokenizer.decorate_code_prompts( + self.speaker.decorate_code_prompts( text, params.prompt, params.txt_smp, params.spk_emb, ), self.config.gpt.num_vq, - prompt_str=params.spk_smp, + prompt=self.speaker.decode_prompt(params.spk_smp) if params.spk_smp is not None else None, device=self.device_gpt, ) start_idx = input_ids.shape[-2] @@ -544,8 +523,8 @@ def _infer_code( del text_mask if params.spk_emb is not None: - self.tokenizer.apply_spk_emb( - emb, params.spk_emb, input_ids, self.gpt.device_gpt + self.speaker.apply( + emb, params.spk_emb, input_ids, self.tokenizer.spk_emb_ids, self.gpt.device_gpt, ) result = gpt.generate( @@ -585,7 +564,7 @@ def _refine_text( text = [text] input_ids, attention_mask, text_mask = self.tokenizer.encode( - self.tokenizer.decorate_text_prompts(text, params.prompt), + self.speaker.decorate_text_prompts(text, params.prompt), self.config.gpt.num_vq, device=self.device_gpt, ) diff --git a/ChatTTS/model/__init__.py b/ChatTTS/model/__init__.py index c0bba7271..4a8bcde8f 100644 --- a/ChatTTS/model/__init__.py +++ b/ChatTTS/model/__init__.py @@ -1,4 +1,5 @@ from .dvae import DVAE from .gpt import GPT from .processors import gen_logits +from .speaker import Speaker from .tokenizer import Tokenizer diff --git a/ChatTTS/model/dvae.py b/ChatTTS/model/dvae.py index ff00980d8..5602071eb 100644 --- a/ChatTTS/model/dvae.py +++ b/ChatTTS/model/dvae.py @@ -1,5 +1,5 @@ import math -from typing import List, Optional, Literal, Tuple +from typing import List, Optional, Literal, Union import numpy as np import pybase16384 as b14 @@ -216,7 +216,7 @@ def __init__( coef = torch.rand(100) else: coef = torch.from_numpy( - np.copy(np.frombuffer(b14.decode_from_string(coef), dtype=np.float32)) + np.frombuffer(b14.decode_from_string(coef), dtype=np.float32).copy() ) self.register_buffer("coef", coef.unsqueeze(0).unsqueeze_(2)) @@ -284,3 +284,9 @@ def forward( del vq_feats return torch.mul(dec_out, self.coef, out=dec_out) + + @torch.inference_mode() + def sample_audio(self, wav: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + if isinstance(wav, np.ndarray): + wav = torch.from_numpy(wav) + return self(wav, "encode").squeeze_(0) diff --git a/ChatTTS/model/speaker.py b/ChatTTS/model/speaker.py new file mode 100644 index 000000000..a0d3930fb --- /dev/null +++ b/ChatTTS/model/speaker.py @@ -0,0 +1,146 @@ +import lzma +from typing import List, Optional, Union + +import pybase16384 as b14 +import numpy as np +import torch +import torch.nn.functional as F + +class Speaker: + def __init__(self, dim: int, spk_cfg: str, device=torch.device("cpu")) -> None: + spk_stat = torch.from_numpy(np.frombuffer(b14.decode_from_string(spk_cfg), dtype=np.float16).copy()).to(device=device) + self.std, self.mean = spk_stat.requires_grad_(False).chunk(2) + self.dim = dim + + def sample_random(self) -> str: + return self._encode(self._sample_random()) + + @torch.no_grad() + def apply( + self, + emb: torch.Tensor, + spk_emb: str, + input_ids: torch.Tensor, + spk_emb_ids: int, + device: torch.device, + ): + n = ( + F.normalize( + torch.from_numpy( + self._decode(spk_emb), + ), + p=2.0, + dim=0, + eps=1e-12, + ) + .to(device) + .unsqueeze_(0) + .expand(emb.size(0), -1) + .unsqueeze_(1) + .expand(emb.shape) + ) + cond = input_ids.narrow(-1, 0, 1).eq(spk_emb_ids).expand(emb.shape) + torch.where(cond, n, emb, out=emb) + del cond, n + + @staticmethod + @torch.no_grad() + def decorate_code_prompts( + text: List[str], + prompt: str, + txt_smp: Optional[str], + spk_emb: Optional[str], + ) -> List[str]: + for i, t in enumerate(text): + text[i] = ( + t.replace("[Stts]", "") + .replace("[spk_emb]", "") + .replace("[empty_spk]", "") + .strip() + ) + """ + see https://github.com/2noise/ChatTTS/issues/459 + """ + + if prompt: + text = [prompt + i for i in text] + + txt_smp = "" if txt_smp is None else txt_smp + if spk_emb is not None: + text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text] + else: + text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text] + + return text + + @staticmethod + @torch.no_grad() + def decorate_text_prompts(text: List[str], prompt: str) -> List[str]: + return [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text] + + @staticmethod + @torch.no_grad() + def encode_prompt(prompt: torch.Tensor) -> str: + arr: np.ndarray = prompt.cpu().numpy().astype(np.uint16) + shp = arr.shape + assert len(shp) == 2, "prompt must be a 2D tensor" + s = b14.encode_to_string( + np.array(shp, dtype=" torch.Tensor: + dec = b14.decode_from_string(prompt) + shp = np.frombuffer(dec[:4], dtype=" torch.Tensor: + spk = ( + torch.randn(self.dim, device=self.std.device, dtype=self.std.dtype) + .mul_(self.std) + .add_(self.mean) + ) + return spk + + @staticmethod + @torch.no_grad() + def _encode(spk_emb: torch.Tensor) -> str: + arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy() + s = b14.encode_to_string( + lzma.compress( + arr.tobytes(), + format=lzma.FORMAT_RAW, + filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], + ), + ) + del arr + return s + + @staticmethod + def _decode(spk_emb: str) -> np.ndarray: + return np.frombuffer( + lzma.decompress( + b14.decode_from_string(spk_emb), + format=lzma.FORMAT_RAW, + filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], + ), + dtype=np.float16, + ).copy() diff --git a/ChatTTS/model/tokenizer.py b/ChatTTS/model/tokenizer.py index e55d1158f..20f7c7ed4 100644 --- a/ChatTTS/model/tokenizer.py +++ b/ChatTTS/model/tokenizer.py @@ -6,12 +6,8 @@ """ from typing import List, Tuple, Optional, Union -import lzma -import numpy as np -import pybase16384 as b14 import torch -import torch.nn.functional as F from transformers import BertTokenizerFast from ..utils import del_all @@ -41,7 +37,7 @@ def encode( self, text: List[str], num_vq: int, - prompt_str: Optional[str] = None, + prompt: Optional[torch.Tensor] = None, device="cpu", ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -51,8 +47,6 @@ def encode( max_attention_mask_len = -1 prompt_size = 0 - prompt = self._decode_prompt(prompt_str) if prompt_str is not None else None - if prompt is not None: assert prompt.size(0) == num_vq, "prompt dim 0 must equal to num_vq" prompt_size = prompt.size(1) @@ -142,123 +136,3 @@ def decode( return self._tokenizer.batch_decode( sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs ) - - @staticmethod - def _decode_spk_emb(spk_emb: str) -> np.ndarray: - return np.frombuffer( - lzma.decompress( - b14.decode_from_string(spk_emb), - format=lzma.FORMAT_RAW, - filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], - ), - dtype=np.float16, - ).copy() - - @torch.no_grad() - def apply_spk_emb( - self, - emb: torch.Tensor, - spk_emb: str, - input_ids: torch.Tensor, - device: torch.device, - ): - n = ( - F.normalize( - torch.from_numpy( - self._decode_spk_emb(spk_emb), - ), - p=2.0, - dim=0, - eps=1e-12, - ) - .to(device) - .unsqueeze_(0) - .expand(emb.size(0), -1) - .unsqueeze_(1) - .expand(emb.shape) - ) - cond = input_ids.narrow(-1, 0, 1).eq(self.spk_emb_ids).expand(emb.shape) - torch.where(cond, n, emb, out=emb) - del cond, n - - @staticmethod - @torch.no_grad() - def _decode_prompt(prompt: str) -> torch.Tensor: - dec = b14.decode_from_string(prompt) - shp = np.frombuffer(dec[:4], dtype=" str: - arr: np.ndarray = prompt.cpu().numpy().astype(np.uint16) - shp = arr.shape - assert len(shp) == 2, "prompt must be a 2D tensor" - s = b14.encode_to_string( - np.array(shp, dtype=" str: - arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy() - s = b14.encode_to_string( - lzma.compress( - arr.tobytes(), - format=lzma.FORMAT_RAW, - filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], - ), - ) - del arr - return s - - @staticmethod - @torch.no_grad() - def decorate_code_prompts( - text: List[str], - prompt: str, - txt_smp: Optional[str], - spk_emb: Optional[str], - ) -> List[str]: - for i, t in enumerate(text): - text[i] = ( - t.replace("[Stts]", "") - .replace("[spk_emb]", "") - .replace("[empty_spk]", "") - .strip() - ) - """ - see https://github.com/2noise/ChatTTS/issues/459 - """ - - if prompt: - text = [prompt + i for i in text] - - txt_smp = "" if txt_smp is None else txt_smp - if spk_emb is not None: - text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text] - else: - text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text] - - return text - - @staticmethod - @torch.no_grad() - def decorate_text_prompts(text: List[str], prompt: str) -> List[str]: - return [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text] diff --git a/ChatTTS/res/sha256_map.json b/ChatTTS/res/sha256_map.json index b64b89e81..ef47afb3d 100644 --- a/ChatTTS/res/sha256_map.json +++ b/ChatTTS/res/sha256_map.json @@ -2,7 +2,6 @@ "sha256_asset_Decoder_pt" : "9964e36e840f0e3a748c5f716fe6de6490d2135a5f5155f4a642d51860e2ec38", "sha256_asset_DVAE_full_pt" : "553eb75763511e23f3e5f86303e2163c5ca775489d637fb635d979c8ae58bbe5", "sha256_asset_GPT_pt" : "d7d4ee6461ea097a2be23eb40d73fb94ad3b3d39cb64fbb50cb3357fd466cadb", - "sha256_asset_spk_stat_pt" : "3228d8a4cbbf349d107a1b76d2f47820865bd3c9928c4bdfe1cefd5c7071105f", "sha256_asset_Vocos_pt" : "09a670eda1c08b740013679c7a90ebb7f1a97646ea7673069a6838e6b51d6c58", "sha256_asset_tokenizer_special_tokens_map_json": "bd0ac9d9bb1657996b5c5fbcaa7d80f8de530d01a283da97f89deae5b1b8d011", diff --git a/ChatTTS/utils/dl.py b/ChatTTS/utils/dl.py index 338721cf9..da21daa0b 100644 --- a/ChatTTS/utils/dl.py +++ b/ChatTTS/utils/dl.py @@ -51,7 +51,6 @@ def check_all_assets(base_dir: Path, sha256_map: Dict[str, str], update=False) - "Decoder.pt", "DVAE_full.pt", "GPT.pt", - "spk_stat.pt", "Vocos.pt", ] for model in names: @@ -114,7 +113,7 @@ def download_dns_yaml(url: str, folder: str): logger.get_logger().info(f"downloaded into {folder}") -def download_all_assets(tmpdir: str, version="0.2.6"): +def download_all_assets(tmpdir: str, version="0.2.7"): import subprocess import platform diff --git a/examples/cmd/stream.py b/examples/cmd/stream.py index eac7e9b33..336229818 100644 --- a/examples/cmd/stream.py +++ b/examples/cmd/stream.py @@ -1,4 +1,3 @@ -import time import random import numpy as np @@ -148,7 +147,6 @@ def generate(self, streamchat, output_format=None): # 流式播放接口 def play(self, streamchat, wait=5): import pyaudio # please install it manually - import time p = pyaudio.PyAudio() print(p.get_device_count()) diff --git a/tests/#655.py b/tests/#655.py index 798e36a86..6488b7763 100644 --- a/tests/#655.py +++ b/tests/#655.py @@ -53,14 +53,14 @@ top_K=20, # top K decode ) input_ids, attention_mask, text_mask = chat.tokenizer.encode( - chat.tokenizer.decorate_code_prompts( + chat.speaker.decorate_code_prompts( text, params.prompt, params.txt_smp, params.spk_emb, ), chat.config.gpt.num_vq, - prompt_str=params.spk_smp, + prompt=chat.speaker.decode_prompt(params.spk_smp) if params.spk_smp is not None else None, device=chat.device_gpt, ) with torch.inference_mode(): diff --git a/tools/checksum/tmpl.go b/tools/checksum/tmpl.go index fe5617e64..cfee203ed 100644 --- a/tools/checksum/tmpl.go +++ b/tools/checksum/tmpl.go @@ -4,7 +4,6 @@ var files = [...]string{ "asset/Decoder.pt", "asset/DVAE_full.pt", "asset/GPT.pt", - "asset/spk_stat.pt", "asset/Vocos.pt", "asset/tokenizer/special_tokens_map.json", @@ -16,7 +15,6 @@ const jsontmpl = `{ "sha256_asset_Decoder_pt" : "%s", "sha256_asset_DVAE_full_pt" : "%s", "sha256_asset_GPT_pt" : "%s", - "sha256_asset_spk_stat_pt" : "%s", "sha256_asset_Vocos_pt" : "%s", "sha256_asset_tokenizer_special_tokens_map_json": "%s",