Skip to content

Commit

Permalink
Update core.py
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Dec 31, 2024
1 parent 00c56ee commit 913bdcc
Showing 1 changed file with 73 additions and 33 deletions.
106 changes: 73 additions & 33 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,63 +61,101 @@ def has_loaded(self, use_decoder=False):

return not not_finish

# Modified
def download_models(
self,
source: Literal["huggingface", "local", "custom"] = "local",
force_redownload=False,
custom_path: Optional[torch.serialization.FILE_LIKE] = None,
cache_dir: Optional[str] = None,
local_dir: Optional[str] = None,
) -> Optional[str]:
if source == "local":
download_path = os.getcwd()
if (
not check_all_assets(Path(download_path), self.sha256_map, update=True)
or force_redownload
):
download_path = local_dir if local_dir else (cache_dir if cache_dir else os.getcwd())
if local_dir:
# Skip hash checking when using local_dir
with tempfile.TemporaryDirectory() as tmp:
download_all_assets(tmpdir=tmp)
if not check_all_assets(
Path(download_path), self.sha256_map, update=False
else:
# Do hash checking for cache_dir or default path
if (
not check_all_assets(Path(download_path), self.sha256_map, update=True)
or force_redownload
):
self.logger.error(
"download to local path %s failed.", download_path
)
return None
with tempfile.TemporaryDirectory() as tmp:
download_all_assets(tmpdir=tmp)
if not check_all_assets(
Path(download_path), self.sha256_map, update=False
):
self.logger.error(
"download to local path %s failed.", download_path
)
return None

elif source == "huggingface":
hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
try:
download_path = get_latest_modified_file(
os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots")
)
except:
download_path = None
if download_path is None or force_redownload:
self.logger.log(
logging.INFO,
f"download from HF: https://huggingface.co/2Noise/ChatTTS",
)
try:
if local_dir:
# Skip hash checking when using local_dir
download_path = snapshot_download(
repo_id="2Noise/ChatTTS",
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
local_dir=local_dir,
force_download=force_redownload
)
except:
download_path = None
else:
self.logger.log(
logging.INFO, f"load latest snapshot from cache: {download_path}"
)
if download_path is None:
self.logger.error("download from huggingface failed.")
return None
elif cache_dir:
# Download to cache_dir and verify hashes
download_path = snapshot_download(
repo_id="2Noise/ChatTTS",
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
cache_dir=cache_dir,
force_download=force_redownload
)
if not check_all_assets(Path(download_path), self.sha256_map, update=False):
self.logger.error("Model verification failed")
return None
else:
# Original behavior for default HF cache
hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
try:
download_path = get_latest_modified_file(
os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots")
)
except:
download_path = None
if download_path is None or force_redownload:
self.logger.log(
logging.INFO,
f"download from HF: https://huggingface.co/2Noise/ChatTTS",
)
download_path = snapshot_download(
repo_id="2Noise/ChatTTS",
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
)
if not check_all_assets(Path(download_path), self.sha256_map, update=False):
self.logger.error("Model verification failed")
return None
else:
self.logger.log(
logging.INFO, f"load latest snapshot from cache: {download_path}"
)
except Exception as e:
self.logger.error(f"Failed to download models: {str(e)}")
download_path = None

elif source == "custom":
self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
if not check_all_assets(Path(custom_path), self.sha256_map, update=False):
self.logger.error("check models in custom path %s failed.", custom_path)
return None
download_path = custom_path

if download_path is None:
self.logger.error("Model download failed")
return None

return download_path

# Modified
def load(
self,
source: Literal["huggingface", "local", "custom"] = "local",
Expand All @@ -129,8 +167,10 @@ def load(
use_flash_attn=False,
use_vllm=False,
experimental: bool = False,
cache_dir: Optional[str] = None,
local_dir: Optional[str] = None,
) -> bool:
download_path = self.download_models(source, force_redownload, custom_path)
download_path = self.download_models(source, force_redownload, custom_path, cache_dir, local_dir)
if download_path is None:
return False
return self._load(
Expand Down

0 comments on commit 913bdcc

Please sign in to comment.