Skip to content

Commit

Permalink
Update core.py
Browse files Browse the repository at this point in the history
  • Loading branch information
TheNeodev authored Feb 28, 2025
1 parent edceb3f commit a0cba1b
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions rvc_inferpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from multiprocessing import cpu_count
from urllib.parse import urlparse
from io import BytesIO

from fairseq.models.hubert import HubertModel
import torch
import numpy as np
import soundfile as sf
Expand Down Expand Up @@ -161,21 +161,25 @@ def note_to_hz(note_name: str) -> float | None:
return None




def load_hubert(config, hubert_path: str = None):

"""
Load and return the Hubert model.
If the specified hubert_path does not exist, download base models.
Load and return the Hubert model using the HubertModel.from_pretrained API.
If hubert_path is not provided or does not exist, fallback to a default model.
"""
from fairseq import checkpoint_utils

if hubert_path is None or not os.path.exists(hubert_path):
for model_file in BASE_MODELS:
download_manager(os.path.join(BASE_DOWNLOAD_LINK, model_file), BASE_DIR)
hubert_path = "hubert_base.pt"
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
[hubert_path], suffix=""
# You can implement download logic here if needed.
hubert_path = "hubert_base.pt" # Default model file name

hubert_model = HubertModel.from_pretrained(
model_name_or_path=".", # Directory containing the model
checkpoint_file=hubert_path, # The checkpoint file to load
data_name_or_path="." # Dummy or actual data directory as required
)
hubert_model = models[0].to(config.device)

hubert_model = hubert_model.to(config.device)
hubert_model = hubert_model.half() if config.is_half else hubert_model.float()
hubert_model.eval()
return hubert_model
Expand Down

0 comments on commit a0cba1b

Please sign in to comment.