Skip to content

Commit

Permalink
minor changes (#107)
Browse files Browse the repository at this point in the history
* Update cli.py

* Create code_formatter.yml

* Update calculate-model-hashes.py

* Update test_cli.py

* Update separator.py

* Update common_separator.py

* Delete .github/workflows/code_formatter.yml
  • Loading branch information
Bebra777228 authored Sep 3, 2024
1 parent d1fcf5b commit 6ae982c
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 16 deletions.
14 changes: 13 additions & 1 deletion audio_separator/separator/common_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,19 @@ class CommonSeparator:
LEAD_VOCAL_STEM_LABEL = "Lead Vocals"
BV_VOCAL_STEM_LABEL = "Backing Vocals"

NON_ACCOM_STEMS = (VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM, SYNTH_STEM, STRINGS_STEM, WOODWINDS_STEM, BRASS_STEM, WIND_INST_STEM)
NON_ACCOM_STEMS = (
VOCAL_STEM,
OTHER_STEM,
BASS_STEM,
DRUM_STEM,
GUITAR_STEM,
PIANO_STEM,
SYNTH_STEM,
STRINGS_STEM,
WOODWINDS_STEM,
BRASS_STEM,
WIND_INST_STEM,
)

def __init__(self, config):

Expand Down
56 changes: 48 additions & 8 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,33 @@ def __init__(
output_single_stem=None,
invert_using_spec=False,
sample_rate=44100,
mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
vr_params={"batch_size": 16, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True},
mdxc_params={"segment_size": 256, "batch_size": 1, "overlap": 8},
mdx_params={
"hop_length": 1024,
"segment_size": 256,
"overlap": 0.25,
"batch_size": 1,
"enable_denoise": False,
},
vr_params={
"batch_size": 16,
"window_size": 512,
"aggression": 5,
"enable_tta": False,
"enable_post_process": False,
"post_process_threshold": 0.2,
"high_end_process": False,
},
demucs_params={
"segment_size": "Default",
"shifts": 2,
"overlap": 0.25,
"segments_enabled": True,
},
mdxc_params={
"segment_size": 256,
"batch_size": 1,
"overlap": 8,
},
):
self.logger = logging.getLogger(__name__)
self.logger.setLevel(log_level)
Expand Down Expand Up @@ -143,7 +166,12 @@ def __init__(

# These are parameters which users may want to configure so we expose them to the top-level Separator class,
# even though they are specific to a single model architecture
self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}
self.arch_specific_params = {
"MDX": mdx_params,
"VR": vr_params,
"Demucs": demucs_params,
"MDXC": mdxc_params,
}

self.torch_device = None
self.torch_device_cpu = None
Expand Down Expand Up @@ -385,9 +413,16 @@ def list_supported_model_files(self):
# Return object with list of model names, which are the keys in vr_download_list, mdx_download_list, demucs_download_list, mdx23_download_list, mdx23c_download_list, grouped by type: VR, MDX, Demucs, MDX23, MDX23C
model_files_grouped_by_type = {
"VR": model_downloads_list["vr_download_list"],
"MDX": {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"]},
"MDX": {
**model_downloads_list["mdx_download_list"],
**model_downloads_list["mdx_download_vip_list"],
},
"Demucs": filtered_demucs_v4,
"MDXC": {**model_downloads_list["mdx23c_download_list"], **model_downloads_list["mdx23c_download_vip_list"], **model_downloads_list["roformer_download_list"]},
"MDXC": {
**model_downloads_list["mdx23c_download_list"],
**model_downloads_list["mdx23c_download_vip_list"],
**model_downloads_list["roformer_download_list"],
},
}
return model_files_grouped_by_type

Expand Down Expand Up @@ -664,7 +699,12 @@ def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360
}

# Instantiate the appropriate separator class depending on the model type
separator_classes = {"MDX": "mdx_separator.MDXSeparator", "VR": "vr_separator.VRSeparator", "Demucs": "demucs_separator.DemucsSeparator", "MDXC": "mdxc_separator.MDXCSeparator"}
separator_classes = {
"MDX": "mdx_separator.MDXSeparator",
"VR": "vr_separator.VRSeparator",
"Demucs": "demucs_separator.DemucsSeparator",
"MDXC": "mdxc_separator.MDXCSeparator",
}

if model_type not in self.arch_specific_params or model_type not in separator_classes:
raise ValueError(f"Model type not supported (yet): {model_type}")
Expand Down
7 changes: 6 additions & 1 deletion audio_separator/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,12 @@ def main():
"post_process_threshold": args.vr_post_process_threshold,
"high_end_process": args.vr_high_end_process,
},
demucs_params={"segment_size": args.demucs_segment_size, "shifts": args.demucs_shifts, "overlap": args.demucs_overlap, "segments_enabled": args.demucs_segments_enabled},
demucs_params={
"segment_size": args.demucs_segment_size,
"shifts": args.demucs_shifts,
"overlap": args.demucs_overlap,
"segments_enabled": args.demucs_segments_enabled,
},
mdxc_params={
"segment_size": args.mdxc_segment_size,
"batch_size": args.mdxc_batch_size,
Expand Down
33 changes: 29 additions & 4 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,35 @@ def common_expected_args():
"output_single_stem": None,
"invert_using_spec": False,
"sample_rate": 44100,
"mdx_params": {"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
"vr_params": {"batch_size": 4, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
"demucs_params": {"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True},
"mdxc_params": {"segment_size": 256, "batch_size": 1, "overlap": 8, "override_model_segment_size": False, "pitch_shift": 0},
"mdx_params": {
"hop_length": 1024,
"segment_size": 256,
"overlap": 0.25,
"batch_size": 1,
"enable_denoise": False,
},
"vr_params": {
"batch_size": 4,
"window_size": 512,
"aggression": 5,
"enable_tta": False,
"enable_post_process": False,
"post_process_threshold": 0.2,
"high_end_process": False,
},
"demucs_params": {
"segment_size": "Default",
"shifts": 2,
"overlap": 0.25,
"segments_enabled": True,
},
"mdxc_params": {
"segment_size": 256,
"batch_size": 1,
"overlap": 8,
"override_model_segment_size": False,
"pitch_shift": 0,
},
}


Expand Down
11 changes: 9 additions & 2 deletions tools/calculate-model-hashes.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,19 @@ def iterate_and_hash(directory):
vr_model_data = load_json_data(VR_MODEL_DATA_LOCAL_PATH)
mdx_model_data = load_json_data(MDX_MODEL_DATA_LOCAL_PATH)

combined_model_params = {**vr_model_data, **mdx_model_data}
combined_model_params = {
**vr_model_data,
**mdx_model_data,
}

model_info_list = []
for file, file_path in sorted(model_files):
file_hash = get_model_hash(file_path)
model_info = {"file": file, "hash": file_hash, "params": combined_model_params.get(file_hash, "Parameters not found")}
model_info = {
"file": file,
"hash": file_hash,
"params": combined_model_params.get(file_hash, "Parameters not found"),
}
model_info_list.append(model_info)

print(f"Writing model info list to {OUTPUT_PATH}")
Expand Down

0 comments on commit 6ae982c

Please sign in to comment.