Skip to content

Commit

Permalink
Add output_bitrate argument. (#104)
Browse files Browse the repository at this point in the history
* Add output_bitrate argument.

* Add install extra dependencies to readme.

* Add output_bitrate argument to cli.
  • Loading branch information
empz authored Aug 27, 2024
1 parent 7e79526 commit 47a3171
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 3 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,15 @@ Once you're inside the conda env, run the following command to install the proje
poetry install
```
Install extra dependencies depending if you're running with GPU or CPU.
```sh
poetry install --extras "cpu"
```
or
```sh
poetry install --extras "gpu"
```
### Running the Command-Line Interface Locally
You can run the CLI command directly within the virtual environment. For example:
Expand Down
6 changes: 5 additions & 1 deletion audio_separator/separator/common_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(self, config):
# Output directory and format
self.output_dir = config.get("output_dir")
self.output_format = config.get("output_format")
self.output_bitrate = config.get("output_bitrate")

# Functional options which are applicable to all architectures and the user may tweak to affect the output
self.normalization_threshold = config.get("normalization_threshold")
Expand Down Expand Up @@ -250,9 +251,12 @@ def write_audio(self, stem_path: str, stem_source):
elif file_format == "mka":
file_format = "matroska"

# Set the bitrate to 320k for mp3 files if output_bitrate is not specified
bitrate = "320k" if file_format == "mp3" and self.output_bitrate is None else self.output_bitrate

# Export using the determined format
try:
audio_segment.export(stem_path, format=file_format)
audio_segment.export(stem_path, format=file_format, bitrate=bitrate)
self.logger.debug(f"Exported audio file successfully to {stem_path}")
except (IOError, ValueError) as e:
self.logger.error(f"Error exporting audio file: {e}")
Expand Down
4 changes: 4 additions & 0 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Separator:
model_file_dir (str): The directory where model files are stored.
output_dir (str): The directory where output files will be saved.
output_format (str): The format of the output audio file.
output_bitrate (str): The bitrate of the output audio file.
normalization_threshold (float): The threshold for audio normalization.
output_single_stem (str): Option to output a single stem.
invert_using_spec (bool): Flag to invert using spectrogram.
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
model_file_dir="/tmp/audio-separator-models/",
output_dir=None,
output_format="WAV",
output_bitrate=None,
normalization_threshold=0.9,
output_single_stem=None,
invert_using_spec=False,
Expand Down Expand Up @@ -113,6 +115,7 @@ def __init__(
os.makedirs(self.output_dir, exist_ok=True)

self.output_format = output_format
self.output_bitrate = output_bitrate

if self.output_format is None:
self.output_format = "WAV"
Expand Down Expand Up @@ -652,6 +655,7 @@ def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360
"model_path": model_path,
"model_data": model_data,
"output_format": self.output_format,
"output_bitrate": self.output_bitrate,
"output_dir": self.output_dir,
"normalization_threshold": self.normalization_threshold,
"output_single_stem": self.output_single_stem,
Expand Down
7 changes: 5 additions & 2 deletions audio_separator/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ def main():

model_filename_help = "model to use for separation (default: %(default)s). Example: -m 2_HP-UVR.pth"
output_format_help = "output format for separated files, any common format (default: %(default)s). Example: --output_format=MP3"
output_bitrate_help = "output bitrate for separated files, any ffmpeg-compatible bitrate (default: %(default)s). Example: --output_bitrate=320k"
output_dir_help = "directory to write output files (default: <current dir>). Example: --output_dir=/app/separated"
model_file_dir_help = "model files directory (default: %(default)s). Example: --model_file_dir=/app/models"
download_model_only_help = "Download a single model file only, without performing separation."

io_params = parser.add_argument_group("Separation I/O Params")
io_params.add_argument("-m", "--model_filename", default="model_bs_roformer_ep_317_sdr_12.9755.yaml", help=model_filename_help)
io_params.add_argument("--output_format", default="FLAC", help=output_format_help)
io_params.add_argument("--output_bitrate", default=None, help=output_bitrate_help)
io_params.add_argument("--output_dir", default=None, help=output_dir_help)
io_params.add_argument("--model_file_dir", default="/tmp/audio-separator-models/", help=model_file_dir_help)
io_params.add_argument("--download_model_only", action="store_true", help=download_model_only_help)
Expand Down Expand Up @@ -142,15 +144,16 @@ def main():
if not hasattr(args, "audio_file"):
parser.print_help()
sys.exit(1)

logger.info(f"Separator version {package_version} beginning with input file: {args.audio_file}")

separator = Separator(
log_formatter=log_formatter,
log_level=log_level,
model_file_dir=args.model_file_dir,
output_dir=args.output_dir,
output_format=args.output_format,
output_bitrate=args.output_bitrate,
normalization_threshold=args.normalization,
output_single_stem=args.single_stem,
invert_using_spec=args.invert_spect,
Expand Down

0 comments on commit 47a3171

Please sign in to comment.