Skip to content

Commit

Permalink
Added support for additional models with data and config from this re…
Browse files Browse the repository at this point in the history
…po rather than relying on the old UVR repos. Added first new model using this method: mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt
  • Loading branch information
beveradb committed Sep 16, 2024
1 parent b5b72bb commit 45a7be2
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 55 deletions.
10 changes: 10 additions & 0 deletions audio_separator/models.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"vr_download_list": {},
"mdx_download_list": {},
"mdx23c_download_list": {},
"roformer_download_list": {
"Roformer Model: Mel-Roformer-Karaoke-Aufr33-Viperx": {
"mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt": "mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956_config.yaml"
}
}
}
94 changes: 40 additions & 54 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
""" This file contains the Separator class, to facilitate the separation of stems from audio. """

from importlib import metadata
from importlib import metadata, resources
import os
import sys
import platform
Expand Down Expand Up @@ -74,33 +74,10 @@ def __init__(
output_single_stem=None,
invert_using_spec=False,
sample_rate=44100,
mdx_params={
"hop_length": 1024,
"segment_size": 256,
"overlap": 0.25,
"batch_size": 1,
"enable_denoise": False,
},
vr_params={
"batch_size": 16,
"window_size": 512,
"aggression": 5,
"enable_tta": False,
"enable_post_process": False,
"post_process_threshold": 0.2,
"high_end_process": False,
},
demucs_params={
"segment_size": "Default",
"shifts": 2,
"overlap": 0.25,
"segments_enabled": True,
},
mdxc_params={
"segment_size": 256,
"batch_size": 1,
"overlap": 8,
},
mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
vr_params={"batch_size": 16, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True},
mdxc_params={"segment_size": 256, "batch_size": 1, "overlap": 8},
):
self.logger = logging.getLogger(__name__)
self.logger.setLevel(log_level)
Expand Down Expand Up @@ -166,12 +143,7 @@ def __init__(

# These are parameters which users may want to configure so we expose them to the top-level Separator class,
# even though they are specific to a single model architecture
self.arch_specific_params = {
"MDX": mdx_params,
"VR": vr_params,
"Demucs": demucs_params,
"MDXC": mdxc_params,
}
self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}

self.torch_device = None
self.torch_device_cpu = None
Expand Down Expand Up @@ -351,7 +323,7 @@ def list_supported_model_files(self):
self.download_file_if_not_exists("https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json", download_checks_path)

model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
self.logger.debug(f"Model download list loaded")
self.logger.debug(f"UVR model download list loaded")

# model_downloads_list JSON structure / example snippet:
# {
Expand Down Expand Up @@ -410,18 +382,21 @@ def list_supported_model_files(self):
# Only show Demucs v4 models as we've only implemented support for v4
filtered_demucs_v4 = {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}

# Load the JSON file using importlib.resources
with resources.open_text("audio_separator", "models.json") as f:
audio_separator_models_list = json.load(f)
self.logger.debug(f"Audio-Separator model list loaded")

# Return object with list of model names, which are the keys in vr_download_list, mdx_download_list, demucs_download_list, mdx23_download_list, mdx23c_download_list, grouped by type: VR, MDX, Demucs, MDX23, MDX23C
model_files_grouped_by_type = {
"VR": model_downloads_list["vr_download_list"],
"MDX": {
**model_downloads_list["mdx_download_list"],
**model_downloads_list["mdx_download_vip_list"],
},
"VR": {**model_downloads_list["vr_download_list"], **audio_separator_models_list["vr_download_list"]},
"MDX": {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"], **audio_separator_models_list["mdx_download_list"]},
"Demucs": filtered_demucs_v4,
"MDXC": {
**model_downloads_list["mdx23c_download_list"],
**model_downloads_list["mdx23c_download_vip_list"],
**model_downloads_list["roformer_download_list"],
**audio_separator_models_list["roformer_download_list"],
},
}
return model_files_grouped_by_type
Expand All @@ -444,6 +419,8 @@ def download_model_files(self, model_filename):
public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"

audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"

yaml_config_filename = None

self.logger.debug(f"Searching for model_filename {model_filename} in supported_model_files_grouped")
Expand All @@ -457,7 +434,12 @@ def download_model_files(self, model_filename):
self.logger.debug(f"Single file model identified: {model_friendly_name}")
self.model_friendly_name = model_friendly_name

self.download_file_if_not_exists(f"{model_repo_url_prefix}/{model_filename}", model_path)
try:
self.download_file_if_not_exists(f"{model_repo_url_prefix}/{model_filename}", model_path)
except RuntimeError:
self.logger.debug("Model not found in UVR repo, attempting to download from audio-separator models repo...")
self.download_file_if_not_exists(f"{audio_separator_models_repo_url_prefix}/{model_filename}", model_path)

self.print_uvr_vip_message()

self.logger.debug(f"Returning path for single model file: {model_path}")
Expand Down Expand Up @@ -488,8 +470,13 @@ def download_model_files(self, model_filename):
# Checkpoint models apparently use config_key as the model filename, but the value is a YAML config file name...
# Both need to be downloaded, but the model data YAML file actually comes from the application data repo...
elif config_key.endswith(".ckpt"):
download_url = f"{model_repo_url_prefix}/{config_key}"
self.download_file_if_not_exists(download_url, os.path.join(self.model_file_dir, config_key))
try:
download_url = f"{model_repo_url_prefix}/{config_key}"
self.download_file_if_not_exists(download_url, os.path.join(self.model_file_dir, config_key))
except RuntimeError:
self.logger.debug("Model not found in UVR repo, attempting to download from audio-separator models repo...")
download_url = f"{audio_separator_models_repo_url_prefix}/{config_key}"
self.download_file_if_not_exists(download_url, os.path.join(self.model_file_dir, config_key))

# In case the user specified the YAML filename as the model input instead of the model filename, correct that
if model_filename.endswith(".yaml"):
Expand All @@ -503,11 +490,15 @@ def download_model_files(self, model_filename):
yaml_config_filename = config_value
yaml_config_filepath = os.path.join(self.model_file_dir, yaml_config_filename)

# Repo for model data and configuration sources from UVR
model_data_url_prefix = "https://raw.githubusercontent.com/TRvlvr/application_data/main"
yaml_config_url = f"{model_data_url_prefix}/mdx_model_data/mdx_c_configs/{yaml_config_filename}"

self.download_file_if_not_exists(f"{yaml_config_url}", yaml_config_filepath)
try:
# Repo for model data and configuration sources from UVR
model_data_url_prefix = "https://raw.githubusercontent.com/TRvlvr/application_data/main"
yaml_config_url = f"{model_data_url_prefix}/mdx_model_data/mdx_c_configs/{yaml_config_filename}"
self.download_file_if_not_exists(f"{yaml_config_url}", yaml_config_filepath)
except RuntimeError:
self.logger.debug("Model YAML config file not found in UVR repo, attempting to download from audio-separator models repo...")
yaml_config_url = f"{audio_separator_models_repo_url_prefix}/{yaml_config_filename}"
self.download_file_if_not_exists(f"{yaml_config_url}", yaml_config_filepath)

# MDX and VR models have config_value set to the model filename
else:
Expand Down Expand Up @@ -699,12 +690,7 @@ def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360
}

# Instantiate the appropriate separator class depending on the model type
separator_classes = {
"MDX": "mdx_separator.MDXSeparator",
"VR": "vr_separator.VRSeparator",
"Demucs": "demucs_separator.DemucsSeparator",
"MDXC": "mdxc_separator.MDXCSeparator",
}
separator_classes = {"MDX": "mdx_separator.MDXSeparator", "VR": "vr_separator.VRSeparator", "Demucs": "demucs_separator.DemucsSeparator", "MDXC": "mdxc_separator.MDXCSeparator"}

if model_type not in self.arch_specific_params or model_type not in separator_classes:
raise ValueError(f"Model type not supported (yet): {model_type}")
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "audio-separator"
version = "0.19.4"
version = "0.20.0"
description = "Easy to use audio stem separation, using various models from UVR trained primarily by @Anjok07"
authors = ["Andrew Beveridge <[email protected]>"]
license = "MIT"
readme = "README.md"
packages = [{include = "audio_separator"}]
include = ["audio_separator/separator/models.json"]
homepage = "https://github.com/karaokenerds/python-audio-separator"
repository = "https://github.com/karaokenerds/python-audio-separator"
documentation = "https://github.com/karaokenerds/python-audio-separator/blob/main/README.md"
Expand Down

0 comments on commit 45a7be2

Please sign in to comment.