From 913bdcc1fc7349f79a775c8913710bd75b64cb75 Mon Sep 17 00:00:00 2001 From: BBC-Esq Date: Mon, 30 Dec 2024 21:59:29 -0500 Subject: [PATCH] Update core.py --- ChatTTS/core.py | 106 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 73 insertions(+), 33 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index c178a9ad2..7c09c1252 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -61,54 +61,87 @@ def has_loaded(self, use_decoder=False): return not not_finish + # Modified def download_models( self, source: Literal["huggingface", "local", "custom"] = "local", force_redownload=False, custom_path: Optional[torch.serialization.FILE_LIKE] = None, + cache_dir: Optional[str] = None, + local_dir: Optional[str] = None, ) -> Optional[str]: if source == "local": - download_path = os.getcwd() - if ( - not check_all_assets(Path(download_path), self.sha256_map, update=True) - or force_redownload - ): + download_path = local_dir if local_dir else (cache_dir if cache_dir else os.getcwd()) + if local_dir: + # Skip hash checking when using local_dir with tempfile.TemporaryDirectory() as tmp: download_all_assets(tmpdir=tmp) - if not check_all_assets( - Path(download_path), self.sha256_map, update=False + else: + # Do hash checking for cache_dir or default path + if ( + not check_all_assets(Path(download_path), self.sha256_map, update=True) + or force_redownload ): - self.logger.error( - "download to local path %s failed.", download_path - ) - return None + with tempfile.TemporaryDirectory() as tmp: + download_all_assets(tmpdir=tmp) + if not check_all_assets( + Path(download_path), self.sha256_map, update=False + ): + self.logger.error( + "download to local path %s failed.", download_path + ) + return None + elif source == "huggingface": - hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface")) try: - download_path = get_latest_modified_file( - os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots") - ) - except: - download_path = None - if download_path is None or force_redownload: - self.logger.log( - logging.INFO, - f"download from HF: https://huggingface.co/2Noise/ChatTTS", - ) - try: + if local_dir: + # Skip hash checking when using local_dir download_path = snapshot_download( repo_id="2Noise/ChatTTS", allow_patterns=["*.yaml", "*.json", "*.safetensors"], + local_dir=local_dir, + force_download=force_redownload ) - except: - download_path = None - else: - self.logger.log( - logging.INFO, f"load latest snapshot from cache: {download_path}" - ) - if download_path is None: - self.logger.error("download from huggingface failed.") - return None + elif cache_dir: + # Download to cache_dir and verify hashes + download_path = snapshot_download( + repo_id="2Noise/ChatTTS", + allow_patterns=["*.yaml", "*.json", "*.safetensors"], + cache_dir=cache_dir, + force_download=force_redownload + ) + if not check_all_assets(Path(download_path), self.sha256_map, update=False): + self.logger.error("Model verification failed") + return None + else: + # Original behavior for default HF cache + hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface")) + try: + download_path = get_latest_modified_file( + os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots") + ) + except: + download_path = None + if download_path is None or force_redownload: + self.logger.log( + logging.INFO, + f"download from HF: https://huggingface.co/2Noise/ChatTTS", + ) + download_path = snapshot_download( + repo_id="2Noise/ChatTTS", + allow_patterns=["*.yaml", "*.json", "*.safetensors"], + ) + if not check_all_assets(Path(download_path), self.sha256_map, update=False): + self.logger.error("Model verification failed") + return None + else: + self.logger.log( + logging.INFO, f"load latest snapshot from cache: {download_path}" + ) + except Exception as e: + self.logger.error(f"Failed to download models: {str(e)}") + download_path = None + elif source == "custom": self.logger.log(logging.INFO, f"try to load from local: {custom_path}") if not check_all_assets(Path(custom_path), self.sha256_map, update=False): @@ -116,8 +149,13 @@ def download_models( return None download_path = custom_path + if download_path is None: + self.logger.error("Model download failed") + return None + return download_path + # Modified def load( self, source: Literal["huggingface", "local", "custom"] = "local", @@ -129,8 +167,10 @@ def load( use_flash_attn=False, use_vllm=False, experimental: bool = False, + cache_dir: Optional[str] = None, + local_dir: Optional[str] = None, ) -> bool: - download_path = self.download_models(source, force_redownload, custom_path) + download_path = self.download_models(source, force_redownload, custom_path, cache_dir, local_dir) if download_path is None: return False return self._load(