diff --git a/audalign/__init__.py b/audalign/__init__.py index 462a6d0..70382b5 100644 --- a/audalign/__init__.py +++ b/audalign/__init__.py @@ -20,6 +20,7 @@ from pydub.utils import mediainfo import audalign.align as aligner +from audalign.config.fingerprint import FingerprintConfig import audalign.datalign as datalign import audalign.filehandler as filehandler from audalign.config import BaseConfig @@ -311,6 +312,7 @@ def write_processed_file( start_end: tuple = None, sample_rate: int = BaseConfig.sample_rate, normalize: bool = BaseConfig.normalize, + cant_read_extensions: list[str] = BaseConfig.cant_read_extensions, ) -> None: """ writes given file to the destination file after processing for fingerprinting @@ -329,6 +331,7 @@ def write_processed_file( start_end=start_end, sample_rate=sample_rate, normalize=normalize, + cant_read_extensions=cant_read_extensions, ) @@ -695,6 +698,7 @@ def write_shifts_from_results( write_multi_channel: bool = False, unprocessed: bool = False, normalize: bool = BaseConfig.normalize, + config: BaseConfig = None, ): """ For writing the results of an alignment with alternate source files or unprocessed files @@ -715,10 +719,12 @@ def write_shifts_from_results( unprocessed (bool): If true, writes files without processing. For total files, only doesn't normalize normalize (bool): if true, normalizes file when read """ + if config is None: + config = FingerprintConfig() if isinstance(read_from_dir, str): print("Finding audio files") read_from_dir = filehandler.get_audio_files_directory( - read_from_dir, full_path=True + read_from_dir, full_path=True, can_read_extensions=config.can_read_extensions, cant_read_extensions=config.cant_read_extensions ) if read_from_dir is not None: results_files = {} @@ -767,6 +773,7 @@ def convert_audio_file( start_end: tuple = None, sample_rate: int = None, normalize: bool = BaseConfig.normalize, + cant_read_extensions: list[str] = BaseConfig.cant_read_extensions, ): """ Convert audio file to type specified in destination path @@ -785,6 +792,7 @@ def convert_audio_file( start_end=start_end, sample_rate=sample_rate, normalize=normalize, + cant_read_extensions=cant_read_extensions, ) diff --git a/audalign/config/__init__.py b/audalign/config/__init__.py index ec96cdb..a5735f7 100644 --- a/audalign/config/__init__.py +++ b/audalign/config/__init__.py @@ -49,6 +49,7 @@ class BaseConfig(ABC): LOCALITY_SECS = "locality_seconds" ###################################################################### + # rankings settings # Add to ranking if second match is close rankings_second_is_close_add: int = 1 @@ -71,3 +72,28 @@ class BaseConfig(ABC): # used if rankings_get_top_num_match is not None. (used in visual) # subtracts second value from ranking if num matches is above first value rankings_num_matches_tups: typing.Optional[tuple] = None + + ###################################################################### + # filehandling settings + + # file types that can't be read and not explicitly filtered out by + # below extention lists will cause a crash + fail_on_decode_error = True + + # + cant_write_extensions = [".mov", ".mp4", ".m4a"] + cant_read_extensions = [".txt", ".md", ".pkf", ".py", ".pyc"] + can_read_extensions = [ + ".mov", + ".mp4", + ".m4a", + ".wav", + ".WAV", + ".mp3", + ".MOV", + ".ogg", + ".aiff", + ".aac", + ".wma", + ".flac", + ] diff --git a/audalign/filehandler.py b/audalign/filehandler.py index b50e817..658f556 100644 --- a/audalign/filehandler.py +++ b/audalign/filehandler.py @@ -12,6 +12,7 @@ from pydub.exceptions import CouldntDecodeError from audalign.config import BaseConfig +from audalign.config.fingerprint import FingerprintConfig try: import noisereduce @@ -19,24 +20,6 @@ # Optional dependency ... - -cant_write_ext = [".mov", ".mp4", ".m4a"] -cant_read_ext = [".txt", ".md", ".pkf", ".py", ".pyc"] -can_read_ext = [ - ".mov", - ".mp4", - ".m4a", - ".wav", - ".WAV", - ".mp3", - ".MOV", - ".ogg", - ".aiff", - ".aac", - ".wma", - ".flac", -] - def _import_optional_dependencies(func): @wraps(func) def wrapper_decorator(*args, **kwargs): @@ -135,7 +118,11 @@ def create_audiosegment( return audiofile -def get_audio_files_directory(directory_path: str, full_path: bool = False) -> list: +def get_audio_files_directory(directory_path: str, full_path: bool = False, + can_read_extensions: list[str] = BaseConfig.can_read_extensions, + cant_read_extensions: list[str] = BaseConfig.cant_read_extensions, + + ) -> list: """returns a list of the file paths in directory that are audio Args: @@ -146,7 +133,7 @@ def get_audio_files_directory(directory_path: str, full_path: bool = False) -> l """ aud_list = [] for file_path, ext in find_files(directory_path): - if check_is_audio_file(file_path=file_path): + if check_is_audio_file(file_path=file_path, can_read_extensions=can_read_extensions, cant_read_extensions=cant_read_extensions): if full_path is False: aud_list += [os.path.basename(file_path)] else: @@ -154,12 +141,16 @@ def get_audio_files_directory(directory_path: str, full_path: bool = False) -> l return aud_list -def check_is_audio_file(file_path: str) -> bool: +def check_is_audio_file( + file_path: str, + can_read_extensions: list[str] = BaseConfig.can_read_extensions, + cant_read_extensions: list[str] = BaseConfig.cant_read_extensions, + ) -> bool: ext = os.path.splitext(file_path)[1] try: - if ext in [".txt", ".json"] or ext in cant_read_ext: + if ext in [".txt", ".json"] or ext in cant_read_extensions: return False - elif ext.lower() not in can_read_ext: + elif ext.lower() not in can_read_extensions: AudioSegment.from_file(file_path) except CouldntDecodeError: return False @@ -172,6 +163,7 @@ def read( start_end: tuple = None, sample_rate=BaseConfig.sample_rate, normalize: bool = BaseConfig.normalize, + cant_read_extensions: list[str] = BaseConfig.cant_read_extensions, ): """ Reads any file supported by pydub (ffmpeg) and returns a numpy array and the bit depth @@ -186,7 +178,7 @@ def read( frame_rate (int): returns the bit depth """ - if os.path.splitext(filename)[1] in cant_read_ext: + if os.path.splitext(filename)[1] in cant_read_extensions: raise CouldntDecodeError audiofile = create_audiosegment( filename, start_end=start_end, sample_rate=sample_rate, normalize=normalize @@ -323,6 +315,7 @@ def _remove_noise( write_extension: str = None, destination_directory="", prop_decrease=1, + base_config: BaseConfig = FingerprintConfig(), **kwargs, ): @@ -343,7 +336,7 @@ def _remove_noise( file_name = os.path.basename(file_path) destination_name = os.path.join(destination_directory, file_name) - if os.path.splitext(destination_name)[1].lower() in cant_write_ext: + if os.path.splitext(destination_name)[1].lower() in base_config.cant_write_extensions: destination_name = os.path.splitext(destination_name)[0] + ".wav" if write_extension is not None: @@ -441,6 +434,7 @@ def _uniform_level( width: float = 5, overlap_ratio=0.5, exclude_min_db=-70, + base_config: BaseConfig = FingerprintConfig(), ): assert overlap_ratio < 1 and overlap_ratio >= 0 try: @@ -481,7 +475,7 @@ def _uniform_level( file_name = os.path.basename(file_path) if len(os.path.splitext(destination_name)[1]) == 0: destination_name = os.path.join(destination_name, file_name) - if os.path.splitext(destination_name)[1].lower() in cant_write_ext: + if os.path.splitext(destination_name)[1].lower() in base_config.cant_write_extensions: destination_name = os.path.splitext(destination_name)[0] + ".wav" if write_extension is not None: @@ -656,6 +650,7 @@ def _shift_write_separate( return_files: bool = False, unprocessed: bool = False, normalize: bool = BaseConfig.normalize, + base_config: BaseConfig = FingerprintConfig(), ): audsegs = _shift_prepend_space_audsegs( files_shifts=files_shifts, @@ -673,6 +668,7 @@ def _shift_write_separate( file_path=file_path, destination_path=destination_path, write_extension=write_extension, + base_config=base_config, ) audsegs = list(audsegs.values()) @@ -754,12 +750,13 @@ def _write_single_shift( file_path: str, destination_path: str, write_extension: str, + base_config: BaseConfig = FingerprintConfig(), ): file_name = os.path.basename(file_path) destination_name = os.path.join(destination_path, file_name) # type: ignore - if os.path.splitext(destination_name)[1] in cant_write_ext: + if os.path.splitext(destination_name)[1] in base_config.cant_write_extensions: destination_name = os.path.splitext(destination_name)[0] + ".wav" if write_extension: diff --git a/audalign/recognizers/__init__.py b/audalign/recognizers/__init__.py index 066400d..4a63148 100644 --- a/audalign/recognizers/__init__.py +++ b/audalign/recognizers/__init__.py @@ -46,7 +46,8 @@ def align_get_file_names( if target_aligning: file_names = [os.path.basename(x) for x in file_list] elif file_dir: - file_names = filehandler.get_audio_files_directory(file_dir) + file_names = filehandler.get_audio_files_directory( + file_dir, False, self.config.can_read_extensions, self.config.cant_read_extensions) elif fine_aud_file_dict: file_names = [os.path.basename(x) for x in fine_aud_file_dict.keys()] else: diff --git a/audalign/recognizers/correcognize/correcognize.py b/audalign/recognizers/correcognize/correcognize.py index 4ed72a9..eccc7fb 100644 --- a/audalign/recognizers/correcognize/correcognize.py +++ b/audalign/recognizers/correcognize/correcognize.py @@ -343,6 +343,7 @@ def get_array( _file_audsegs, sos, normalize: bool, + cant_read_extensions: list[str] = CorrelationConfig.cant_read_extensions, ): if _file_audsegs is not None: target_array = get_shifted_file( @@ -353,7 +354,7 @@ def get_array( ) else: target_array = read( - file_path, start_end=start_end, sample_rate=sample_rate, normalize=normalize + file_path, start_end=start_end, sample_rate=sample_rate, normalize=normalize, cant_read_extensions=cant_read_extensions, )[0] if sos is not None: target_array = signal.sosfilt(sos, target_array) diff --git a/audalign/recognizers/correcognizeSpectrogram/correcognize_spectrogram.py b/audalign/recognizers/correcognizeSpectrogram/correcognize_spectrogram.py index f8b91f8..5cb5552 100644 --- a/audalign/recognizers/correcognizeSpectrogram/correcognize_spectrogram.py +++ b/audalign/recognizers/correcognizeSpectrogram/correcognize_spectrogram.py @@ -360,6 +360,7 @@ def get_array( start_end=start_end, sample_rate=config.sample_rate, normalize=config.normalize, + cant_read_extensions=config.cant_read_extensions, )[0] if sos is not None: target_array = signal.sosfilt(sos, target_array) diff --git a/audalign/recognizers/fingerprint/__init__.py b/audalign/recognizers/fingerprint/__init__.py index a7ad92a..8f5f4e1 100644 --- a/audalign/recognizers/fingerprint/__init__.py +++ b/audalign/recognizers/fingerprint/__init__.py @@ -100,7 +100,7 @@ def align_get_file_names( fine_aud_file_dict: typing.Optional[dict], ) -> list: if target_aligning or file_dir: - file_names = filehandler.get_audio_files_directory(file_dir, full_path=True) + file_names = filehandler.get_audio_files_directory(file_dir, full_path=True, can_read_extensions=self.config.can_read_extensions, cant_read_extensions=self.config.cant_read_extensions) elif fine_aud_file_dict: file_names = fine_aud_file_dict.keys() for name, fingerprints in zip(self.file_names, self.fingerprinted_files): @@ -171,12 +171,14 @@ def recognize( if against_path is not None: if os.path.isdir(against_path): for path in filehandler.get_audio_files_directory( - against_path, full_path=True + against_path, full_path=True, can_read_extensions=self.config.can_read_extensions, cant_read_extensions=self.config.cant_read_extensions ): if path not in self.file_names and path not in to_fingerprint: to_fingerprint += [path] elif os.path.isfile(against_path): - if filehandler.check_is_audio_file(against_path): + if filehandler.check_is_audio_file(against_path, + self.config.can_read_extensions, + self.config.cant_read_extensions): to_fingerprint += [against_path] if len(to_fingerprint) > 0: self.fingerprint_directory(to_fingerprint) diff --git a/audalign/recognizers/fingerprint/fingerprinter.py b/audalign/recognizers/fingerprint/fingerprinter.py index 9ec807d..4d73137 100644 --- a/audalign/recognizers/fingerprint/fingerprinter.py +++ b/audalign/recognizers/fingerprint/fingerprinter.py @@ -47,6 +47,7 @@ def _fingerprint_worker( start_end=config.start_end, sample_rate=config.sample_rate, normalize=config.normalize, + cant_read_extensions=config.cant_read_extensions, ) except FileNotFoundError: print(f'"{file_path}" not found') diff --git a/audalign/recognizers/visrecognize/visrecognize.py b/audalign/recognizers/visrecognize/visrecognize.py index 25b22c1..9d8a540 100644 --- a/audalign/recognizers/visrecognize/visrecognize.py +++ b/audalign/recognizers/visrecognize/visrecognize.py @@ -386,6 +386,7 @@ def get_arrays( start_end=start_end, sample_rate=config.sample_rate, normalize=config.normalize, + cant_read_extensions=config.cant_read_extensions, ) arr2d = fingerprint.fingerprint(samples, config, retspec=True)