Skip to content

Commit

Permalink
Update core.py
Browse files Browse the repository at this point in the history
  • Loading branch information
TheNeodev authored Feb 28, 2025
1 parent a0cba1b commit 750e065
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions rvc_inferpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from multiprocessing import cpu_count
from urllib.parse import urlparse
from io import BytesIO
from fairseq.models.hubert import HubertModel
import torch
import numpy as np
import soundfile as sf
import torchaudio
from torchaudio.pipelines import HUBERT_BASE

# Third-party package imports from rvc_inferpy
from rvc_inferpy.split_audio import (
Expand Down Expand Up @@ -163,25 +164,23 @@ def note_to_hz(note_name: str) -> float | None:




def load_hubert(config, hubert_path: str = None):

"""
Load and return the Hubert model using the HubertModel.from_pretrained API.
If hubert_path is not provided or does not exist, fallback to a default model.
Load and return the HuBERT model using torchaudio.
If the specified hubert_path does not exist, use the pre-trained torchaudio model.
"""
if hubert_path is None or not os.path.exists(hubert_path):
# You can implement download logic here if needed.
hubert_path = "hubert_base.pt" # Default model file name

hubert_model = HubertModel.from_pretrained(
model_name_or_path=".", # Directory containing the model
checkpoint_file=hubert_path, # The checkpoint file to load
data_name_or_path="." # Dummy or actual data directory as required
)

print("Using torchaudio's pre-trained HuBERT model.")
hubert_model = HUBERT_BASE.get_model()
else:
print(f"Loading HuBERT model from {hubert_path}")
hubert_model = torch.jit.load(hubert_path)

hubert_model = hubert_model.to(config.device)
hubert_model = hubert_model.half() if config.is_half else hubert_model.float()
hubert_model.eval()

return hubert_model


Expand Down

0 comments on commit 750e065

Please sign in to comment.