Skip to content

Commit

Permalink
Implementing Custom Output File Naming for MDX, MDXC, and VR Models (#…
Browse files Browse the repository at this point in the history
…141)

* Update mdx_separator.py

* Update mdxc_separator.py

* Update vr_separator.py

* Update separator.py

* Update cli.py

* Update test_cli.py
  • Loading branch information
Bebra777228 authored Nov 3, 2024
1 parent 3dd25b6 commit 85909ad
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 13 deletions.
15 changes: 12 additions & 3 deletions audio_separator/separator/architectures/mdx_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,15 @@ def load_model(self):
self.model_run.to(self.torch_device).eval()
self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.")

def separate(self, audio_file_path):
def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None):
"""
Separates the audio file into primary and secondary sources based on the model's configuration.
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
Args:
audio_file_path (str): The path to the audio file to be processed.
primary_output_name (str, optional): Custom name for the primary output file. Defaults to None.
secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None.
Returns:
list: A list of paths to the output files generated by the separation process.
Expand Down Expand Up @@ -182,15 +184,22 @@ def separate(self, audio_file_path):

# Save and process the secondary stem if needed
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
if secondary_output_name:
self.secondary_stem_output_path = os.path.join(f"{secondary_output_name}.{self.output_format.lower()}")
else:
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")

self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
output_files.append(self.secondary_stem_output_path)

# Save and process the primary stem if needed
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
if primary_output_name:
self.primary_stem_output_path = os.path.join(f"{primary_output_name}.{self.output_format.lower()}")
else:
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")

if not isinstance(self.primary_source, np.ndarray):
self.primary_source = source.T

Expand Down
14 changes: 11 additions & 3 deletions audio_separator/separator/architectures/mdxc_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,15 @@ def load_model(self):
self.logger.error(f"Please try deleting the model file from {self.model_path} and run audio-separator again to re-download it.")
sys.exit(1)

def separate(self, audio_file_path):
def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None):
"""
Separates the audio file into primary and secondary sources based on the model's configuration.
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
Args:
audio_file_path (str): The path to the audio file to be processed.
primary_output_name (str, optional): Custom name for the primary output file. Defaults to None.
secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None.
Returns:
list: A list of paths to the output files generated by the separation process.
Expand Down Expand Up @@ -152,14 +154,20 @@ def separate(self, audio_file_path):
self.secondary_source = spec_utils.normalize(wave=source[self.secondary_stem_name], max_peak=self.normalization_threshold, min_peak=self.amplification_threshold).T

if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
if secondary_output_name:
self.secondary_stem_output_path = os.path.join(f"{secondary_output_name}.{self.output_format.lower()}")
else:
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")

self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
output_files.append(self.secondary_stem_output_path)

if not isinstance(source, dict) or not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
if primary_output_name:
self.primary_stem_output_path = os.path.join(f"{primary_output_name}.{self.output_format.lower()}")
else:
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")

if not isinstance(self.primary_source, np.ndarray):
self.primary_source = source.T
Expand Down
14 changes: 11 additions & 3 deletions audio_separator/separator/architectures/vr_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,15 @@ def __init__(self, common_config, arch_config: dict):

self.logger.info("VR Separator initialisation complete")

def separate(self, audio_file_path):
def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None):
"""
Separates the audio file into primary and secondary sources based on the model's configuration.
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
Args:
audio_file_path (str): The path to the audio file to be processed.
primary_output_name (str, optional): Custom name for the primary output file. Defaults to None.
secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None.
Returns:
list: A list of paths to the output files generated by the separation process.
Expand Down Expand Up @@ -195,7 +197,10 @@ def separate(self, audio_file_path):
self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
self.logger.debug("Resampling primary source to 44100Hz.")

self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
if primary_output_name:
self.primary_stem_output_path = os.path.join(f"{primary_output_name}.{self.output_format.lower()}")
else:
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")

self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
Expand All @@ -213,7 +218,10 @@ def separate(self, audio_file_path):
self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
self.logger.debug("Resampling secondary source to 44100Hz.")

self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
if secondary_output_name:
self.secondary_stem_output_path = os.path.join(f"{secondary_output_name}.{self.output_format.lower()}")
else:
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")

self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
Expand Down
8 changes: 5 additions & 3 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360
self.logger.debug("Loading model completed.")
self.logger.info(f'Load model duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - load_model_start_time)))}')

def separate(self, audio_file_path):
def separate(self, audio_file_path, primary_output_name=None, secondary_output_name=None):
"""
Separates the audio file into different stems (e.g., vocals, instruments) using the loaded model.
Expand All @@ -747,6 +747,8 @@ def separate(self, audio_file_path):
Parameters:
- audio_file_path (str): The path to the audio file to be separated.
- primary_output_name (str, optional): Custom name for the primary output file. Defaults to None.
- secondary_output_name (str, optional): Custom name for the secondary output file. Defaults to None.
Returns:
- output_files (list of str): A list containing the paths to the separated audio stem files.
Expand All @@ -766,10 +768,10 @@ def separate(self, audio_file_path):
if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type):
self.logger.debug("Autocast available.")
with autocast_mode.autocast(self.torch_device.type):
output_files = self.model_instance.separate(audio_file_path)
output_files = self.model_instance.separate(audio_file_path, primary_output_name, secondary_output_name)
else:
self.logger.debug("Autocast unavailable.")
output_files = self.model_instance.separate(audio_file_path)
output_files = self.model_instance.separate(audio_file_path, primary_output_name, secondary_output_name)

# Clear GPU cache to free up memory
self.model_instance.clear_gpu_cache()
Expand Down
6 changes: 5 additions & 1 deletion audio_separator/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def main():
sample_rate_help = "modify the sample rate of the output audio (default: %(default)s). Example: --sample_rate=44100"
use_soundfile_help = "Use soundfile to write audio output (default: %(default)s). Example: --use_soundfile"
use_autocast_help = "use PyTorch autocast for faster inference (default: %(default)s). Do not use for CPU inference. Example: --use_autocast"
primary_output_name_help = "Custom name for the primary output file (default: %(default)s). Example: --primary_output_name=custom_primary_output"
secondary_output_name_help = "Custom name for the secondary output file (default: %(default)s). Example: --secondary_output_name=custom_secondary_output"

common_params = parser.add_argument_group("Common Separation Parameters")
common_params.add_argument("--invert_spect", action="store_true", help=invert_spect_help)
Expand All @@ -65,6 +67,8 @@ def main():
common_params.add_argument("--sample_rate", type=int, default=44100, help=sample_rate_help)
common_params.add_argument("--use_soundfile", action="store_true", help=use_soundfile_help)
common_params.add_argument("--use_autocast", action="store_true", help=use_autocast_help)
common_params.add_argument("--primary_output_name", default=None, help=primary_output_name_help)
common_params.add_argument("--secondary_output_name", default=None, help=secondary_output_name_help)

mdx_segment_size_help = "larger consumes more resources, but may give better results (default: %(default)s). Example: --mdx_segment_size=256"
mdx_overlap_help = "amount of overlap between prediction windows, 0.001-0.999. higher is better but slower (default: %(default)s). Example: --mdx_overlap=0.25"
Expand Down Expand Up @@ -201,6 +205,6 @@ def main():
separator.load_model(model_filename=args.model_filename)

for audio_file in args.audio_files:
output_files = separator.separate(audio_file)
output_files = separator.separate(audio_file, primary_output_name=args.primary_output_name, secondary_output_name=args.secondary_output_name)

logger.info(f"Separation complete! Output file(s): {' '.join(output_files)}")
44 changes: 44 additions & 0 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def test_cli_invert_spectrogram_argument(common_expected_args):
# Assertions
mock_separator.assert_called_once_with(**expected_args)


# Test using use_autocast argument
def test_cli_use_autocast_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--use_autocast"]
Expand All @@ -240,6 +241,7 @@ def test_cli_use_autocast_argument(common_expected_args):
# Assertions
mock_separator.assert_called_once_with(**common_expected_args)


# Test using use_autocast argument
def test_cli_use_autocast_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--use_autocast"]
Expand All @@ -254,3 +256,45 @@ def test_cli_use_autocast_argument(common_expected_args):

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)


# Test using primary_output_name argument
def test_cli_primary_output_name_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--primary_output_name=custom_primary_output"]
with patch("sys.argv", test_args):
with patch("audio_separator.separator.Separator") as mock_separator:
mock_separator_instance = mock_separator.return_value
mock_separator_instance.separate.return_value = ["output_file.mp3"]
main()

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", primary_output_name="custom_primary_output", secondary_output_name=None)


# Test using secondary_output_name argument
def test_cli_secondary_output_name_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--secondary_output_name=custom_secondary_output"]
with patch("sys.argv", test_args):
with patch("audio_separator.separator.Separator") as mock_separator:
mock_separator_instance = mock_separator.return_value
mock_separator_instance.separate.return_value = ["output_file.mp3"]
main()

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", primary_output_name=None, secondary_output_name="custom_secondary_output")


# Test using both primary_output_name and secondary_output_name arguments
def test_cli_both_output_names_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--primary_output_name=custom_primary_output", "--secondary_output_name=custom_secondary_output"]
with patch("sys.argv", test_args):
with patch("audio_separator.separator.Separator") as mock_separator:
mock_separator_instance = mock_separator.return_value
mock_separator_instance.separate.return_value = ["output_file.mp3"]
main()

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", primary_output_name="custom_primary_output", secondary_output_name="custom_secondary_output")

0 comments on commit 85909ad

Please sign in to comment.