diff --git a/everyvoice/wizard/basic.py b/everyvoice/wizard/basic.py index 60ac793b..2b3090d3 100644 --- a/everyvoice/wizard/basic.py +++ b/everyvoice/wizard/basic.py @@ -1,4 +1,6 @@ import json +import sys +import wave from pathlib import Path import questionary @@ -199,6 +201,42 @@ class ConfigFormatStep(Step): # ConfigFormatStep writes the results to disk and exits, so it's not reversible. REVERSIBLE = False + def validate_wav_channels(self, wavs_dir: Path): + """Validate the number of channels in WAV files.""" + wav_files = list(wavs_dir.rglob("*.wav")) # Recursively find all WAV files. + + # if not wav_files: + # rich_print( + # f"[red]No WAV files found in the directory '{wavs_dir}'. Please check your files.[/red]" + # ) + # sys.exit(1) # Exit if no WAV files are found. + + rich_print( + f"Validating the number of channels in WAV files located at '{wavs_dir}'..." + ) + for wav_file in wav_files: + try: + with wave.open(str(wav_file), "rb") as wav: + num_channels = wav.getnchannels() + if num_channels > 2: + rich_print( + f"[red]Error: The file '{wav_file}' has {num_channels} channels. " + "Only single-channel or two-channel WAV files are supported. " + "Please correct this and try again.[/red]" + ) + sys.exit( + 1 + ) # Exit gracefully if a file has more than 2 channels. + except wave.Error as e: + rich_print( + f"[red]Error reading WAV file '{wav_file}': {e}. Please check the file.[/red]" + ) + sys.exit(1) # Gracefully exit if the file cannot be read. + + rich_print( + "[green]All WAV files have valid channel counts (1 or 2 channels).[/green]" + ) + def prompt(self): return get_response_from_menu_prompt( "Which format would you like to output the configuration to?", @@ -220,7 +258,7 @@ def effect(self): preprocessed_dir.absolute().mkdir(parents=True, exist_ok=True) # used in configs preprocessed_dir_relative_to_configs = Path("..") / "preprocessed" - # log dir + # log dirself.state log_dir = output_path / "logs_and_checkpoints" log_dir.absolute().mkdir(parents=True, exist_ok=True) log_dir_relative_to_configs = Path("..") / "logs_and_checkpoints" @@ -236,6 +274,13 @@ def effect(self): ) # TODO: this should be fixed by https://github.com/EveryVoiceTTS/EveryVoice/issues/359 for dataset in [key for key in self.state.keys() if key.startswith("dataset_")]: dataset_state = self.state[dataset] + + # Validate WAV Channels + wavs_dir = Path(dataset_state[StepNames.wavs_dir_step]).expanduser() + self.validate_wav_channels( + wavs_dir + ) # Call the new validation function here + # Add Cleaners # TODO: these should really be dataset-specific cleaners, not global cleaners # so this should be fixed by https://github.com/EveryVoiceTTS/EveryVoice/issues/359