Skip to content

Commit

Permalink
refactor can cant read extensions to config
Browse files Browse the repository at this point in the history
  • Loading branch information
benfmiller committed Sep 22, 2024
1 parent 63a8aa0 commit 3f2277c
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 33 deletions.
10 changes: 9 additions & 1 deletion audalign/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pydub.utils import mediainfo

import audalign.align as aligner
from audalign.config.fingerprint import FingerprintConfig
import audalign.datalign as datalign
import audalign.filehandler as filehandler
from audalign.config import BaseConfig
Expand Down Expand Up @@ -311,6 +312,7 @@ def write_processed_file(
start_end: tuple = None,
sample_rate: int = BaseConfig.sample_rate,
normalize: bool = BaseConfig.normalize,
cant_read_extensions: list[str] = BaseConfig.cant_read_extensions,
) -> None:
"""
writes given file to the destination file after processing for fingerprinting
Expand All @@ -329,6 +331,7 @@ def write_processed_file(
start_end=start_end,
sample_rate=sample_rate,
normalize=normalize,
cant_read_extensions=cant_read_extensions,
)


Expand Down Expand Up @@ -695,6 +698,7 @@ def write_shifts_from_results(
write_multi_channel: bool = False,
unprocessed: bool = False,
normalize: bool = BaseConfig.normalize,
config: BaseConfig = None,
):
"""
For writing the results of an alignment with alternate source files or unprocessed files
Expand All @@ -715,10 +719,12 @@ def write_shifts_from_results(
unprocessed (bool): If true, writes files without processing. For total files, only doesn't normalize
normalize (bool): if true, normalizes file when read
"""
if config is None:
config = FingerprintConfig()
if isinstance(read_from_dir, str):
print("Finding audio files")
read_from_dir = filehandler.get_audio_files_directory(
read_from_dir, full_path=True
read_from_dir, full_path=True, can_read_extensions=config.can_read_extensions, cant_read_extensions=config.cant_read_extensions
)
if read_from_dir is not None:
results_files = {}
Expand Down Expand Up @@ -767,6 +773,7 @@ def convert_audio_file(
start_end: tuple = None,
sample_rate: int = None,
normalize: bool = BaseConfig.normalize,
cant_read_extensions: list[str] = BaseConfig.cant_read_extensions,
):
"""
Convert audio file to type specified in destination path
Expand All @@ -785,6 +792,7 @@ def convert_audio_file(
start_end=start_end,
sample_rate=sample_rate,
normalize=normalize,
cant_read_extensions=cant_read_extensions,
)


Expand Down
26 changes: 26 additions & 0 deletions audalign/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class BaseConfig(ABC):
LOCALITY_SECS = "locality_seconds"

######################################################################
# rankings settings

# Add to ranking if second match is close
rankings_second_is_close_add: int = 1
Expand All @@ -71,3 +72,28 @@ class BaseConfig(ABC):
# used if rankings_get_top_num_match is not None. (used in visual)
# subtracts second value from ranking if num matches is above first value
rankings_num_matches_tups: typing.Optional[tuple] = None

######################################################################
# filehandling settings

# file types that can't be read and not explicitly filtered out by
# below extention lists will cause a crash
fail_on_decode_error = True

#
cant_write_extensions = [".mov", ".mp4", ".m4a"]
cant_read_extensions = [".txt", ".md", ".pkf", ".py", ".pyc"]
can_read_extensions = [
".mov",
".mp4",
".m4a",
".wav",
".WAV",
".mp3",
".MOV",
".ogg",
".aiff",
".aac",
".wma",
".flac",
]
51 changes: 24 additions & 27 deletions audalign/filehandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,14 @@
from pydub.exceptions import CouldntDecodeError

from audalign.config import BaseConfig
from audalign.config.fingerprint import FingerprintConfig

try:
import noisereduce
except ImportError:
# Optional dependency
...


cant_write_ext = [".mov", ".mp4", ".m4a"]
cant_read_ext = [".txt", ".md", ".pkf", ".py", ".pyc"]
can_read_ext = [
".mov",
".mp4",
".m4a",
".wav",
".WAV",
".mp3",
".MOV",
".ogg",
".aiff",
".aac",
".wma",
".flac",
]

def _import_optional_dependencies(func):
@wraps(func)
def wrapper_decorator(*args, **kwargs):
Expand Down Expand Up @@ -135,7 +118,11 @@ def create_audiosegment(
return audiofile


def get_audio_files_directory(directory_path: str, full_path: bool = False) -> list:
def get_audio_files_directory(directory_path: str, full_path: bool = False,
can_read_extensions: list[str] = BaseConfig.can_read_extensions,
cant_read_extensions: list[str] = BaseConfig.cant_read_extensions,

) -> list:
"""returns a list of the file paths in directory that are audio
Args:
Expand All @@ -146,20 +133,24 @@ def get_audio_files_directory(directory_path: str, full_path: bool = False) -> l
"""
aud_list = []
for file_path, ext in find_files(directory_path):
if check_is_audio_file(file_path=file_path):
if check_is_audio_file(file_path=file_path, can_read_extensions=can_read_extensions, cant_read_extensions=cant_read_extensions):
if full_path is False:
aud_list += [os.path.basename(file_path)]
else:
aud_list += [file_path]
return aud_list


def check_is_audio_file(file_path: str) -> bool:
def check_is_audio_file(
file_path: str,
can_read_extensions: list[str] = BaseConfig.can_read_extensions,
cant_read_extensions: list[str] = BaseConfig.cant_read_extensions,
) -> bool:
ext = os.path.splitext(file_path)[1]
try:
if ext in [".txt", ".json"] or ext in cant_read_ext:
if ext in [".txt", ".json"] or ext in cant_read_extensions:
return False
elif ext.lower() not in can_read_ext:
elif ext.lower() not in can_read_extensions:
AudioSegment.from_file(file_path)
except CouldntDecodeError:
return False
Expand All @@ -172,6 +163,7 @@ def read(
start_end: tuple = None,
sample_rate=BaseConfig.sample_rate,
normalize: bool = BaseConfig.normalize,
cant_read_extensions: list[str] = BaseConfig.cant_read_extensions,
):
"""
Reads any file supported by pydub (ffmpeg) and returns a numpy array and the bit depth
Expand All @@ -186,7 +178,7 @@ def read(
frame_rate (int): returns the bit depth
"""

if os.path.splitext(filename)[1] in cant_read_ext:
if os.path.splitext(filename)[1] in cant_read_extensions:
raise CouldntDecodeError
audiofile = create_audiosegment(
filename, start_end=start_end, sample_rate=sample_rate, normalize=normalize
Expand Down Expand Up @@ -323,6 +315,7 @@ def _remove_noise(
write_extension: str = None,
destination_directory="",
prop_decrease=1,
base_config: BaseConfig = FingerprintConfig(),
**kwargs,
):

Expand All @@ -343,7 +336,7 @@ def _remove_noise(

file_name = os.path.basename(file_path)
destination_name = os.path.join(destination_directory, file_name)
if os.path.splitext(destination_name)[1].lower() in cant_write_ext:
if os.path.splitext(destination_name)[1].lower() in base_config.cant_write_extensions:
destination_name = os.path.splitext(destination_name)[0] + ".wav"

if write_extension is not None:
Expand Down Expand Up @@ -441,6 +434,7 @@ def _uniform_level(
width: float = 5,
overlap_ratio=0.5,
exclude_min_db=-70,
base_config: BaseConfig = FingerprintConfig(),
):
assert overlap_ratio < 1 and overlap_ratio >= 0
try:
Expand Down Expand Up @@ -481,7 +475,7 @@ def _uniform_level(
file_name = os.path.basename(file_path)
if len(os.path.splitext(destination_name)[1]) == 0:
destination_name = os.path.join(destination_name, file_name)
if os.path.splitext(destination_name)[1].lower() in cant_write_ext:
if os.path.splitext(destination_name)[1].lower() in base_config.cant_write_extensions:
destination_name = os.path.splitext(destination_name)[0] + ".wav"

if write_extension is not None:
Expand Down Expand Up @@ -656,6 +650,7 @@ def _shift_write_separate(
return_files: bool = False,
unprocessed: bool = False,
normalize: bool = BaseConfig.normalize,
base_config: BaseConfig = FingerprintConfig(),
):
audsegs = _shift_prepend_space_audsegs(
files_shifts=files_shifts,
Expand All @@ -673,6 +668,7 @@ def _shift_write_separate(
file_path=file_path,
destination_path=destination_path,
write_extension=write_extension,
base_config=base_config,
)

audsegs = list(audsegs.values())
Expand Down Expand Up @@ -754,12 +750,13 @@ def _write_single_shift(
file_path: str,
destination_path: str,
write_extension: str,
base_config: BaseConfig = FingerprintConfig(),
):

file_name = os.path.basename(file_path)
destination_name = os.path.join(destination_path, file_name) # type: ignore

if os.path.splitext(destination_name)[1] in cant_write_ext:
if os.path.splitext(destination_name)[1] in base_config.cant_write_extensions:
destination_name = os.path.splitext(destination_name)[0] + ".wav"

if write_extension:
Expand Down
3 changes: 2 additions & 1 deletion audalign/recognizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def align_get_file_names(
if target_aligning:
file_names = [os.path.basename(x) for x in file_list]
elif file_dir:
file_names = filehandler.get_audio_files_directory(file_dir)
file_names = filehandler.get_audio_files_directory(
file_dir, False, self.config.can_read_extensions, self.config.cant_read_extensions)
elif fine_aud_file_dict:
file_names = [os.path.basename(x) for x in fine_aud_file_dict.keys()]
else:
Expand Down
3 changes: 2 additions & 1 deletion audalign/recognizers/correcognize/correcognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def get_array(
_file_audsegs,
sos,
normalize: bool,
cant_read_extensions: list[str] = CorrelationConfig.cant_read_extensions,
):
if _file_audsegs is not None:
target_array = get_shifted_file(
Expand All @@ -353,7 +354,7 @@ def get_array(
)
else:
target_array = read(
file_path, start_end=start_end, sample_rate=sample_rate, normalize=normalize
file_path, start_end=start_end, sample_rate=sample_rate, normalize=normalize, cant_read_extensions=cant_read_extensions,
)[0]
if sos is not None:
target_array = signal.sosfilt(sos, target_array)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def get_array(
start_end=start_end,
sample_rate=config.sample_rate,
normalize=config.normalize,
cant_read_extensions=config.cant_read_extensions,
)[0]
if sos is not None:
target_array = signal.sosfilt(sos, target_array)
Expand Down
8 changes: 5 additions & 3 deletions audalign/recognizers/fingerprint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def align_get_file_names(
fine_aud_file_dict: typing.Optional[dict],
) -> list:
if target_aligning or file_dir:
file_names = filehandler.get_audio_files_directory(file_dir, full_path=True)
file_names = filehandler.get_audio_files_directory(file_dir, full_path=True, can_read_extensions=self.config.can_read_extensions, cant_read_extensions=self.config.cant_read_extensions)
elif fine_aud_file_dict:
file_names = fine_aud_file_dict.keys()
for name, fingerprints in zip(self.file_names, self.fingerprinted_files):
Expand Down Expand Up @@ -171,12 +171,14 @@ def recognize(
if against_path is not None:
if os.path.isdir(against_path):
for path in filehandler.get_audio_files_directory(
against_path, full_path=True
against_path, full_path=True, can_read_extensions=self.config.can_read_extensions, cant_read_extensions=self.config.cant_read_extensions
):
if path not in self.file_names and path not in to_fingerprint:
to_fingerprint += [path]
elif os.path.isfile(against_path):
if filehandler.check_is_audio_file(against_path):
if filehandler.check_is_audio_file(against_path,
self.config.can_read_extensions,
self.config.cant_read_extensions):
to_fingerprint += [against_path]
if len(to_fingerprint) > 0:
self.fingerprint_directory(to_fingerprint)
Expand Down
1 change: 1 addition & 0 deletions audalign/recognizers/fingerprint/fingerprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _fingerprint_worker(
start_end=config.start_end,
sample_rate=config.sample_rate,
normalize=config.normalize,
cant_read_extensions=config.cant_read_extensions,
)
except FileNotFoundError:
print(f'"{file_path}" not found')
Expand Down
1 change: 1 addition & 0 deletions audalign/recognizers/visrecognize/visrecognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def get_arrays(
start_end=start_end,
sample_rate=config.sample_rate,
normalize=config.normalize,
cant_read_extensions=config.cant_read_extensions,
)
arr2d = fingerprint.fingerprint(samples, config, retspec=True)

Expand Down

0 comments on commit 3f2277c

Please sign in to comment.