diff --git a/audio_separator/separator/common_separator.py b/audio_separator/separator/common_separator.py index eaf8ef8..0af07d3 100644 --- a/audio_separator/separator/common_separator.py +++ b/audio_separator/separator/common_separator.py @@ -47,7 +47,19 @@ class CommonSeparator: LEAD_VOCAL_STEM_LABEL = "Lead Vocals" BV_VOCAL_STEM_LABEL = "Backing Vocals" - NON_ACCOM_STEMS = (VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM, SYNTH_STEM, STRINGS_STEM, WOODWINDS_STEM, BRASS_STEM, WIND_INST_STEM) + NON_ACCOM_STEMS = ( + VOCAL_STEM, + OTHER_STEM, + BASS_STEM, + DRUM_STEM, + GUITAR_STEM, + PIANO_STEM, + SYNTH_STEM, + STRINGS_STEM, + WOODWINDS_STEM, + BRASS_STEM, + WIND_INST_STEM, + ) def __init__(self, config): diff --git a/audio_separator/separator/separator.py b/audio_separator/separator/separator.py index 479b707..dd601f9 100644 --- a/audio_separator/separator/separator.py +++ b/audio_separator/separator/separator.py @@ -74,10 +74,33 @@ def __init__( output_single_stem=None, invert_using_spec=False, sample_rate=44100, - mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False}, - vr_params={"batch_size": 16, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False}, - demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}, - mdxc_params={"segment_size": 256, "batch_size": 1, "overlap": 8}, + mdx_params={ + "hop_length": 1024, + "segment_size": 256, + "overlap": 0.25, + "batch_size": 1, + "enable_denoise": False, + }, + vr_params={ + "batch_size": 16, + "window_size": 512, + "aggression": 5, + "enable_tta": False, + "enable_post_process": False, + "post_process_threshold": 0.2, + "high_end_process": False, + }, + demucs_params={ + "segment_size": "Default", + "shifts": 2, + "overlap": 0.25, + "segments_enabled": True, + }, + mdxc_params={ + "segment_size": 256, + "batch_size": 1, + "overlap": 8, + }, ): self.logger = logging.getLogger(__name__) self.logger.setLevel(log_level) @@ -143,7 +166,12 @@ def __init__( # These are parameters which users may want to configure so we expose them to the top-level Separator class, # even though they are specific to a single model architecture - self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params} + self.arch_specific_params = { + "MDX": mdx_params, + "VR": vr_params, + "Demucs": demucs_params, + "MDXC": mdxc_params, + } self.torch_device = None self.torch_device_cpu = None @@ -385,9 +413,16 @@ def list_supported_model_files(self): # Return object with list of model names, which are the keys in vr_download_list, mdx_download_list, demucs_download_list, mdx23_download_list, mdx23c_download_list, grouped by type: VR, MDX, Demucs, MDX23, MDX23C model_files_grouped_by_type = { "VR": model_downloads_list["vr_download_list"], - "MDX": {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"]}, + "MDX": { + **model_downloads_list["mdx_download_list"], + **model_downloads_list["mdx_download_vip_list"], + }, "Demucs": filtered_demucs_v4, - "MDXC": {**model_downloads_list["mdx23c_download_list"], **model_downloads_list["mdx23c_download_vip_list"], **model_downloads_list["roformer_download_list"]}, + "MDXC": { + **model_downloads_list["mdx23c_download_list"], + **model_downloads_list["mdx23c_download_vip_list"], + **model_downloads_list["roformer_download_list"], + }, } return model_files_grouped_by_type @@ -664,7 +699,12 @@ def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360 } # Instantiate the appropriate separator class depending on the model type - separator_classes = {"MDX": "mdx_separator.MDXSeparator", "VR": "vr_separator.VRSeparator", "Demucs": "demucs_separator.DemucsSeparator", "MDXC": "mdxc_separator.MDXCSeparator"} + separator_classes = { + "MDX": "mdx_separator.MDXSeparator", + "VR": "vr_separator.VRSeparator", + "Demucs": "demucs_separator.DemucsSeparator", + "MDXC": "mdxc_separator.MDXCSeparator", + } if model_type not in self.arch_specific_params or model_type not in separator_classes: raise ValueError(f"Model type not supported (yet): {model_type}") diff --git a/audio_separator/utils/cli.py b/audio_separator/utils/cli.py index d441de3..02547cf 100755 --- a/audio_separator/utils/cli.py +++ b/audio_separator/utils/cli.py @@ -174,7 +174,12 @@ def main(): "post_process_threshold": args.vr_post_process_threshold, "high_end_process": args.vr_high_end_process, }, - demucs_params={"segment_size": args.demucs_segment_size, "shifts": args.demucs_shifts, "overlap": args.demucs_overlap, "segments_enabled": args.demucs_segments_enabled}, + demucs_params={ + "segment_size": args.demucs_segment_size, + "shifts": args.demucs_shifts, + "overlap": args.demucs_overlap, + "segments_enabled": args.demucs_segments_enabled, + }, mdxc_params={ "segment_size": args.mdxc_segment_size, "batch_size": args.mdxc_batch_size, diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 348ef8b..34bc3dc 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -20,10 +20,35 @@ def common_expected_args(): "output_single_stem": None, "invert_using_spec": False, "sample_rate": 44100, - "mdx_params": {"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False}, - "vr_params": {"batch_size": 4, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False}, - "demucs_params": {"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}, - "mdxc_params": {"segment_size": 256, "batch_size": 1, "overlap": 8, "override_model_segment_size": False, "pitch_shift": 0}, + "mdx_params": { + "hop_length": 1024, + "segment_size": 256, + "overlap": 0.25, + "batch_size": 1, + "enable_denoise": False, + }, + "vr_params": { + "batch_size": 4, + "window_size": 512, + "aggression": 5, + "enable_tta": False, + "enable_post_process": False, + "post_process_threshold": 0.2, + "high_end_process": False, + }, + "demucs_params": { + "segment_size": "Default", + "shifts": 2, + "overlap": 0.25, + "segments_enabled": True, + }, + "mdxc_params": { + "segment_size": 256, + "batch_size": 1, + "overlap": 8, + "override_model_segment_size": False, + "pitch_shift": 0, + }, } diff --git a/tools/calculate-model-hashes.py b/tools/calculate-model-hashes.py index 6e3d882..0887c36 100755 --- a/tools/calculate-model-hashes.py +++ b/tools/calculate-model-hashes.py @@ -80,12 +80,19 @@ def iterate_and_hash(directory): vr_model_data = load_json_data(VR_MODEL_DATA_LOCAL_PATH) mdx_model_data = load_json_data(MDX_MODEL_DATA_LOCAL_PATH) - combined_model_params = {**vr_model_data, **mdx_model_data} + combined_model_params = { + **vr_model_data, + **mdx_model_data, + } model_info_list = [] for file, file_path in sorted(model_files): file_hash = get_model_hash(file_path) - model_info = {"file": file, "hash": file_hash, "params": combined_model_params.get(file_hash, "Parameters not found")} + model_info = { + "file": file, + "hash": file_hash, + "params": combined_model_params.get(file_hash, "Parameters not found"), + } model_info_list.append(model_info) print(f"Writing model info list to {OUTPUT_PATH}")