From fb1fed4d2ec77d607de8bb114a498f81fd144beb Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 16 Apr 2024 13:36:28 +0100 Subject: [PATCH] Add multiple paths to AmtDataset (#26) * clean up * update test * clean audio * add multiple paths to dataset * add tdqm --- amt/audio.py | 170 +--- amt/data.py | 105 +-- amt/inference/transcribe.py | 19 +- amt/run.py | 6 + amt/train.py | 79 +- baselines/giantmidi/transcribe_new_files.py | 67 -- baselines/hft_transformer/src/amt.py | 407 -------- .../hft_transformer/transcribe_new_files.py | 196 ---- baselines/requirements-baselines.txt | 3 - config/config.json | 2 +- .../2024-03-06__test-alignment-methods.ipynb | 886 ------------------ .../2024-03-07__run-aria-amt-and-evals.ipynb | 309 ------ ..._experiment-with-sound-augmentations.ipynb | 213 ----- scripts/eval/adjust.py | 58 ++ scripts/eval/dtw.sh | 5 - amt/evaluate.py => scripts/eval/mir.py | 11 + scripts/eval/mir.sh | 5 - scripts/eval/prune.sh | 6 - scripts/eval/split.sh | 4 - tests/test_data.py | 82 +- 20 files changed, 256 insertions(+), 2377 deletions(-) delete mode 100644 baselines/giantmidi/transcribe_new_files.py delete mode 100644 baselines/hft_transformer/src/amt.py delete mode 100644 baselines/hft_transformer/transcribe_new_files.py delete mode 100644 baselines/requirements-baselines.txt delete mode 100644 notebooks/2024-03-06__test-alignment-methods.ipynb delete mode 100644 notebooks/2024-03-07__run-aria-amt-and-evals.ipynb delete mode 100644 notebooks/2024-03-11__experiment-with-sound-augmentations.ipynb create mode 100644 scripts/eval/adjust.py delete mode 100644 scripts/eval/dtw.sh rename amt/evaluate.py => scripts/eval/mir.py (92%) delete mode 100644 scripts/eval/mir.sh delete mode 100644 scripts/eval/prune.sh delete mode 100644 scripts/eval/split.sh diff --git a/amt/audio.py b/amt/audio.py index 447822e..7038d17 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -4,13 +4,7 @@ import random import torch import torchaudio -import torch.nn.functional as F import torchaudio.functional as AF -import numpy as np - -from functools import lru_cache -from subprocess import CalledProcessError, run -from typing import Optional, Union from amt.config import load_config from amt.tokenizer import AmtTokenizer @@ -28,160 +22,6 @@ TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token -def load_audio(file: str, sr: int = SAMPLE_RATE): - """ - Open an audio file and read as mono waveform, resampling as necessary - - Parameters - ---------- - file: str - The audio file to open - - sr: int - The sample rate to resample the audio if necessary - - Returns - ------- - A NumPy array containing the audio waveform, in float32 dtype. - """ - - # This launches a subprocess to decode audio while down-mixing - # and resampling as necessary. Requires the ffmpeg CLI in PATH. - # fmt: off - cmd = [ - "ffmpeg", - "-nostdin", - "-threads", "0", - "-i", file, - "-f", "s16le", - "-ac", "1", - "-acodec", "pcm_s16le", - "-ar", str(sr), - "-" - ] - - # chat-gpt says that this will work for reading mp3 ?? not tested - # cmd = [ - # "ffmpeg", - # "-nostdin", - # "-threads", "0", - # "-i", file, - # "-ac", "1", - # "-ar", str(sr), - # "-" - # ] - - # fmt: on - try: - out = run(cmd, capture_output=True, check=True).stdout - except CalledProcessError as e: - raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e - - return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 - - -def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): - """ - Pad or trim the audio array to N_SAMPLES, as expected by the encoder. - """ - if torch.is_tensor(array): - if array.shape[axis] > length: - array = array.index_select( - dim=axis, index=torch.arange(length, device=array.device) - ) - - if array.shape[axis] < length: - pad_widths = [(0, 0)] * array.ndim - pad_widths[axis] = (0, length - array.shape[axis]) - array = F.pad( - array, [pad for sizes in pad_widths[::-1] for pad in sizes] - ) - else: - if array.shape[axis] > length: - array = array.take(indices=range(length), axis=axis) - - if array.shape[axis] < length: - pad_widths = [(0, 0)] * array.ndim - pad_widths[axis] = (0, length - array.shape[axis]) - array = np.pad(array, pad_widths) - - return array - - -@lru_cache(maxsize=None) -def mel_filters(device, n_mels: int) -> torch.Tensor: - """ - load the mel filterbank matrix for projecting STFT into a Mel spectrogram. - Allows decoupling librosa dependency; saved using: - - np.savez_compressed( - "mel_filters.npz", - mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), - mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), - ) - """ - assert n_mels in {80, 128, 256}, f"Unsupported n_mels: {n_mels}" - - filters_path = os.path.join( - os.path.dirname(__file__), "assets", "mel_filters.npz" - ) - with np.load(filters_path, allow_pickle=False) as f: - return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) - - -def log_mel_spectrogram( - audio: Union[str, np.ndarray, torch.Tensor], - n_mels: int = 256, - padding: int = 0, - device: Optional[Union[str, torch.device]] = None, -): - """ - Compute the log-Mel spectrogram of - - Parameters - ---------- - audio: Union[str, np.ndarray, torch.Tensor], shape = (*) - The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz - - n_mels: int - The number of Mel-frequency filters, only 80 is supported - - padding: int - Number of zero samples to pad to the right - - device: Optional[Union[str, torch.device]] - If given, the audio tensor is moved to this device before STFT - - Returns - ------- - torch.Tensor, shape = (80, n_frames) - A Tensor that contains the Mel spectrogram - """ - if not torch.is_tensor(audio): - if isinstance(audio, str): - audio = load_audio(audio) - audio = torch.from_numpy(audio) - - if device is not None: - audio = audio.to(device) - if padding > 0: - audio = F.pad(audio, (0, padding)) - window = torch.hann_window(N_FFT).to(audio.device) - stft = torch.stft( - audio, N_FFT, HOP_LENGTH, window=window, return_complex=True - ) - magnitudes = stft[..., :-1].abs() ** 2 - - filters = mel_filters(audio.device, n_mels) - mel_spec = filters @ magnitudes - - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - - return log_spec - - # Refactor default params are stored in config.json class AudioTransform(torch.nn.Module): def __init__( @@ -191,15 +31,15 @@ def __init__( max_snr: int = 50, max_dist_gain: int = 25, min_dist_gain: int = 0, - noise_ratio: float = 0.75, - reverb_ratio: float = 0.75, + noise_ratio: float = 0.9, + reverb_ratio: float = 0.9, applause_ratio: float = 0.01, bandpass_ratio: float = 0.15, distort_ratio: float = 0.15, reduce_ratio: float = 0.01, - detune_ratio: float = 0.0, - detune_max_shift: float = 0.0, - spec_aug_ratio: float = 0.9, + detune_ratio: float = 0.1, + detune_max_shift: float = 0.15, + spec_aug_ratio: float = 0.95, ): super().__init__() self.tokenizer = AmtTokenizer() diff --git a/amt/data.py b/amt/data.py index ad8660f..f4103ae 100644 --- a/amt/data.py +++ b/amt/data.py @@ -85,7 +85,7 @@ def get_wav_mid_segments( midi_dict=midi_dict, start_ms=idx // samples_per_ms, end_ms=(idx + num_samples) / samples_per_ms, - max_pedal_len_ms=10000, + max_pedal_len_ms=15000, ) # Hardcoded to 2.5s @@ -148,8 +148,8 @@ def pianoteq_cmd_fn(mid_path: str, wav_path: str): safe_mid_path = shlex.quote(mid_path) safe_wav_path = shlex.quote(wav_path) - # Construct the command - command = f"/home/mchorse/pianoteq/x86-64bit/Pianoteq\\ 8\\ STAGE --preset {safe_preset} --midi {safe_mid_path} --wav {safe_wav_path}" + executable_path = "/home/loubb/pianoteq/x86-64bit/Pianoteq 8 STAGE" + command = f'"{executable_path}" --preset {safe_preset} --midi {safe_mid_path} --wav {safe_wav_path}' return command @@ -205,8 +205,6 @@ def write_synth_features(cli_cmd_fn: Callable, mid_path: str, save_path: str): if os.path.isfile(audio_path_temp): os.remove(audio_path_temp) - print(f"Found {len(features)}") - with open(save_path, mode="a") as file: for wav, seq in features: wav_buffer = io.BytesIO() @@ -256,34 +254,37 @@ def build_synth_worker_fn( class AmtDataset(torch.utils.data.Dataset): - def __init__(self, load_path: str): + def __init__(self, load_paths: str | list): self.tokenizer = AmtTokenizer(return_tensors=True) self.config = load_config()["data"] self.mixup_fn = self.tokenizer.export_msg_mixup() - self.file_buff = open(load_path, mode="r") - self.file_mmap = mmap.mmap( - self.file_buff.fileno(), 0, access=mmap.ACCESS_READ - ) - - index_path = AmtDataset._get_index_path(load_path=load_path) - if os.path.isfile(index_path) is True: - self.index = self._load_index(load_path=index_path) - else: - print("Calculating index...") - self.index = self._build_index() - print( - f"Index of length {len(self.index)} calculated, saving to {index_path}" - ) - self._save_index(index=self.index, save_path=index_path) - def close(self): - if self.file_buff: - self.file_buff.close() - if self.file_mmap: - self.file_mmap.close() + if isinstance(load_paths, str): + load_paths = [load_paths] + self.file_buffs = [] + self.file_mmaps = [] + self.index = [] + + for path in load_paths: + buff = open(path, mode="r") + self.file_buffs.append(buff) + mmap_obj = mmap.mmap(buff.fileno(), 0, access=mmap.ACCESS_READ) + self.file_mmaps.append(mmap_obj) + + index_path = AmtDataset._get_index_path(load_path=path) + if os.path.isfile(index_path): + _index = self._load_index(load_path=index_path) + else: + print("Calculating index...") + _index = self._build_index(mmap_obj) + print( + f"Index of length {len(_index)} calculated, saving to {index_path}" + ) + self._save_index(index=_index, save_path=index_path) - def __del__(self): - self.close() + self.index.extend( + [(len(self.file_mmaps) - 1, pos) for pos in _index] + ) def __len__(self): return len(self.index) @@ -295,13 +296,13 @@ def _format(tok): return tuple(tok) return tok - self.file_mmap.seek(self.index[idx]) + file_id, pos = self.index[idx] + mmap_obj = self.file_mmaps[file_id] + mmap_obj.seek(pos) # Load data from line - wav = torch.load( - io.BytesIO(base64.b64decode(self.file_mmap.readline())) - ) - _seq = orjson.loads(base64.b64decode(self.file_mmap.readline())) + wav = torch.load(io.BytesIO(base64.b64decode(mmap_obj.readline()))) + _seq = orjson.loads(base64.b64decode(mmap_obj.readline())) _seq = [_format(tok) for tok in _seq] # Format seq _seq = self.mixup_fn(_seq) # Data augmentation @@ -317,18 +318,14 @@ def _format(tok): return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt), idx - def _build_index(self): - self.file_mmap.seek(0) - index = [] - while True: - pos = self.file_mmap.tell() - self.file_mmap.readline() - if self.file_mmap.readline() == b"": - break - else: - index.append(pos) + def close(self): + for buff in self.file_buffs: + buff.close() + for mmap in self.file_mmaps: + mmap.close() - return index + def __del__(self): + self.close() def _save_index(self, index: list, save_path: str): with open(save_path, "w") as file: @@ -345,17 +342,17 @@ def _get_index_path(load_path: str): f"{load_path.rsplit('.', 1)[0]}_index.{load_path.rsplit('.', 1)[1]}" ) - def _build_index(self): - self.file_mmap.seek(0) + def _build_index(self, mmap_obj): + mmap_obj.seek(0) index = [] pos = 0 while True: pos_buff = pos - pos = self.file_mmap.find(b"\n", pos) + pos = mmap_obj.find(b"\n", pos) if pos == -1: break - pos = self.file_mmap.find(b"\n", pos + 1) + pos = mmap_obj.find(b"\n", pos + 1) if pos == -1: break @@ -433,16 +430,10 @@ def build( if shutil.which("cat") is None: print("The GNU cat command is not available") else: - print("Concatinating sharded dataset files") - shell_cmd = f"cat " - for _path in sharded_save_paths: - shell_cmd += f"{_path} " - print() - shell_cmd += f">> {save_path}" - - os.system(shell_cmd) for _path in sharded_save_paths: + shell_cmd = f"cat {_path} >> {save_path}" + os.system(shell_cmd) os.remove(_path) # Create index by loading object - AmtDataset(load_path=save_path) + AmtDataset(load_paths=save_path) diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index 4984dad..622b005 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -1,5 +1,4 @@ import os -import sys import signal import time import random @@ -185,14 +184,14 @@ def process_segments( [MAX_BLOCK_LEN for _ in prefixes], dtype=torch.int ).cuda() - # for idx in ( - # pbar := tqdm( - # range(min_prefix_len, MAX_BLOCK_LEN - 1), - # total=MAX_BLOCK_LEN - (min_prefix_len + 1), - # leave=False, - # ) - # ): - for idx in range(min_prefix_len, MAX_BLOCK_LEN - 1): + for idx in ( + pbar := tqdm( + range(min_prefix_len, MAX_BLOCK_LEN - 1), + total=MAX_BLOCK_LEN - (min_prefix_len + 1), + leave=False, + ) + ): + # for idx in range(min_prefix_len, MAX_BLOCK_LEN - 1): with torch.backends.cuda.sdp_kernel( enable_flash=False, enable_mem_efficient=False, enable_math=True ): @@ -277,7 +276,7 @@ def gpu_manager( ) decode_token = torch.compile( decode_token, - # mode="reduce-overhead", + mode="reduce-overhead", # mode="max-autotune", fullgraph=True, ) diff --git a/amt/run.py b/amt/run.py index 1af6ced..88fdfb3 100644 --- a/amt/run.py +++ b/amt/run.py @@ -115,6 +115,8 @@ def build_synth( test_paths, ) = get_synth_mid_paths(mid_dir, csv_path) + print(f"Found {len(train_paths)} train and {len(test_paths)} test paths") + print(f"Building {train_file}") AmtDataset.build( load_paths=train_paths, @@ -187,6 +189,10 @@ def build_maestro(maestro_dir, train_file, val_file, test_file, num_procs): matched_paths_test, ) = get_matched_maestro_paths(maestro_dir) + print( + f"Found {len(matched_paths_train)}, {len(matched_paths_val)}, {len(matched_paths_test)} train, val, and test paths" + ) + print(f"Building {train_file}") AmtDataset.build( load_paths=matched_paths_train, diff --git a/amt/train.py b/amt/train.py index 540d862..ebad059 100644 --- a/amt/train.py +++ b/amt/train.py @@ -1,7 +1,6 @@ import os import sys import csv -import math import random import functools import argparse @@ -25,7 +24,7 @@ from amt.config import load_model_config from aria.utils import _load_weight -GRADIENT_ACC_STEPS = 2 +GRADIENT_ACC_STEPS = 32 # ----- USAGE ----- # @@ -175,9 +174,9 @@ def get_pretrain_optim( num_epochs: int, steps_per_epoch: int, ): - LR = 3e-4 + LR = 5e-4 END_RATIO = 0.1 - WARMUP_STEPS = 500 + WARMUP_STEPS = 1000 return _get_optim( lr=LR, @@ -209,22 +208,22 @@ def get_finetune_optim( def get_dataloaders( - train_data_path: str, + train_data_paths: str, val_data_path: str, batch_size: int, num_workers: int, ): logger = get_logger(__name__) logger.info("Indexing datasets...") - train_dataset = AmtDataset(load_path=train_data_path) - val_dataset = AmtDataset(load_path=val_data_path) + train_dataset = AmtDataset(load_paths=train_data_paths) + val_dataset = AmtDataset(load_paths=val_data_path) logger.info( f"Loaded datasets with length: train={len(train_dataset)}; val={len(val_dataset)}" ) # Pitch aug (to the sequence tensors) must be applied in the train # dataloader as it needs to be done to every element in the batch equally. - # Having this code running on the main process was causing a bottlekneck. + # Having this code running on the main process was causing a bottleneck. # Furthermore, distortion runs very slowly on the gpu, so we do it in # the dataloader instead. tensor_pitch_aug = AmtTokenizer().export_tensor_pitch_aug() @@ -260,6 +259,34 @@ def _collate_fn(seqs, max_pitch_shift: int): return train_dataloader, val_dataloader +def plot_spec(mel: torch.Tensor, name: str | int): + import matplotlib.pyplot as plt + + plt.figure(figsize=(10, 4)) + plt.imshow(mel, aspect="auto", origin="lower", cmap="viridis") + plt.colorbar(format="%+2.0f dB") + plt.title("(mel)-Spectrogram") + plt.tight_layout() + plt.savefig(name) + plt.close() + + +def _debug(wav, mel, src, tgt, idx): + print("Running debug", idx) + for _idx in range(wav.shape[0]): + if os.path.isdir(f"debug/{idx}") is False: + os.makedirs(f"debug/{idx}") + torchaudio.save( + f"debug/{idx}/wav_{_idx}.wav", wav[_idx].unsqueeze(0).cpu(), 16000 + ) + plot_spec(mel[_idx].cpu(), f"debug/{idx}/mel_{_idx}.png") + tokenizer = AmtTokenizer() + src_dec = tokenizer.decode(src[_idx]) + mid_dict = tokenizer._detokenize_midi_dict(src_dec, 30000) + mid = mid_dict.to_midi() + mid.save(f"debug/{idx}/mid_{_idx}.mid") + + def _train( epochs: int, accelerator: accelerate.Accelerator, @@ -520,7 +547,7 @@ def val_loop(dataloader, _epoch: int, aug: bool): # how to register and restore this random state during checkpointing. def resume_train( model_name: str, - train_data_path: str, + train_data_paths: str, val_data_path: str, mode: str, num_workers: int, @@ -539,9 +566,8 @@ def resume_train( assert batch_size > 0, "Invalid batch size" assert torch.cuda.is_available() is True, "CUDA not available" assert os.path.isdir(checkpoint_dir), f"No dir at {checkpoint_dir}" - assert os.path.isfile( - train_data_path - ), f"No file found at {train_data_path}" + for _path in train_data_paths: + assert os.path.isfile(_path), f"No file found at {_path}" assert os.path.isfile(val_data_path), f"No file found at {val_data_path}" tokenizer = AmtTokenizer() @@ -567,7 +593,9 @@ def resume_train( f"model_name={model_name}, " f"mode={mode}, " f"epochs={epochs}, " + f"num_proc={accelerator.num_processes}, " f"batch_size={batch_size}, " + f"grad_acc_steps={GRADIENT_ACC_STEPS}, " f"num_workers={num_workers}, " f"checkpoint_dir={checkpoint_dir}, " f"resume_step={resume_step}, " @@ -586,7 +614,7 @@ def resume_train( logger.info(f"Loaded transform with config: {audio_transform.get_params()}") train_dataloader, val_dataloader = get_dataloaders( - train_data_path=train_data_path, + train_data_paths=train_data_paths, val_data_path=val_data_path, batch_size=batch_size, num_workers=num_workers, @@ -654,7 +682,7 @@ def resume_train( def train( model_name: str, - train_data_path: str, + train_data_paths: str, val_data_path: str, mode: str, num_workers: int, @@ -670,9 +698,8 @@ def train( assert epochs > 0, "Invalid number of epochs" assert batch_size > 0, "Invalid batch size" assert torch.cuda.is_available() is True, "CUDA not available" - assert os.path.isfile( - train_data_path - ), f"No file found at {train_data_path}" + for _path in train_data_paths: + assert os.path.isfile(_path), f"No file found at {_path}" assert os.path.isfile(val_data_path), f"No file found at {val_data_path}" if mode == "finetune": assert os.path.isfile(finetune_cp_path), "Invalid checkpoint path" @@ -692,7 +719,9 @@ def train( f"model_name={model_name}, " f"mode={mode}, " f"epochs={epochs}, " + f"num_proc={accelerator.num_processes}, " f"batch_size={batch_size}, " + f"grad_acc_steps={GRADIENT_ACC_STEPS}, " f"num_workers={num_workers}" ) @@ -714,7 +743,7 @@ def train( ) train_dataloader, val_dataloader = get_dataloaders( - train_data_path=train_data_path, + train_data_paths=train_data_paths, val_data_path=val_data_path, batch_size=batch_size, num_workers=num_workers, @@ -798,8 +827,8 @@ def parse_resume_args(): argp = argparse.ArgumentParser(prog="python amt/train.py resume") argp.add_argument("model", help="name of model config file") argp.add_argument("resume_mode", help="training mode", choices=["pt", "ft"]) - argp.add_argument("train_data", help="path to train data") - argp.add_argument("val_data", help="path to val data") + argp.add_argument("-train_data", nargs="+", help="paths to train data") + argp.add_argument("-val_data", help="path to val data") argp.add_argument("-cdir", help="checkpoint dir", type=str, required=True) argp.add_argument("-rstep", help="resume step", type=int, required=True) argp.add_argument("-repoch", help="resume epoch", type=int, required=True) @@ -817,8 +846,8 @@ def parse_resume_args(): def parse_train_args(): argp = argparse.ArgumentParser(prog="python amt/train.py pretrain") argp.add_argument("model", help="name of model config file") - argp.add_argument("train_data", help="path to train dir") - argp.add_argument("val_data", help="path to val dir") + argp.add_argument("-train_data", nargs="+", help="paths to train data") + argp.add_argument("-val_data", help="path to val dir") argp.add_argument( "-cpath", help="resuming checkpoint", type=str, required=False ) @@ -851,7 +880,7 @@ def parse_train_args(): train_args = parse_train_args() train( model_name=train_args.model, - train_data_path=train_args.train_data, + train_data_paths=train_args.train_data, val_data_path=train_args.val_data, mode="pretrain", num_workers=train_args.workers, @@ -864,7 +893,7 @@ def parse_train_args(): train_args = parse_train_args() train( model_name=train_args.model, - train_data_path=train_args.train_data, + train_data_paths=train_args.train_data, val_data_path=train_args.val_data, mode="finetune", num_workers=train_args.workers, @@ -878,7 +907,7 @@ def parse_train_args(): resume_args = parse_resume_args() resume_train( model_name=resume_args.model, - train_data_path=resume_args.train_data, + train_data_paths=resume_args.train_data, val_data_path=resume_args.val_data, mode="pretrain" if resume_args.resume_mode == "pt" else "finetune", num_workers=resume_args.workers, diff --git a/baselines/giantmidi/transcribe_new_files.py b/baselines/giantmidi/transcribe_new_files.py deleted file mode 100644 index 0650c73..0000000 --- a/baselines/giantmidi/transcribe_new_files.py +++ /dev/null @@ -1,67 +0,0 @@ -import os -import argparse -import time -import torch -import piano_transcription_inference -import glob - - -def transcribe_piano(mp3s_dir, midis_dir, begin_index=None, end_index=None): - """Transcribe piano solo mp3s to midi files.""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' - os.makedirs(midis_dir, exist_ok=True) - - # Transcriptor - transcriptor = piano_transcription_inference.PianoTranscription(device=device) - - transcribe_time = time.time() - for n, mp3_path in enumerate(glob.glob(os.path.join(mp3s_dir, '*.mp3'))[begin_index:end_index]): - print(n, mp3_path) - midi_file = os.path.basename(mp3_path).replace('.mp3', '.midi') - midi_path = os.path.join(midis_dir, midi_file) - if os.path.exists(midi_path): - continue - - (audio, _) = ( - piano_transcription_inference - .load_audio(mp3_path, sr=piano_transcription_inference.sample_rate, mono=True) - ) - - try: - # Transcribe - transcribed_dict = transcriptor.transcribe(audio, midi_path) - print(transcribed_dict) - except: - print('Failed for this audio!') - - print('Time: {:.3f} s'.format(time.time() - transcribe_time)) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Example of parser. ') - parser.add_argument('--mp3s_dir', type=str, required=True, help='') - parser.add_argument('--midis_dir', type=str, required=True, help='') - parser.add_argument( - '--begin_index', type=int, required=False, - help='File num., of an ordered list of files, to start transcribing from.', default=None - ) - parser.add_argument( - '--end_index', type=int, required=False, default=None, - help='File num., of an ordered list of files, to end transcription.' - ) - - # Parse arguments - args = parser.parse_args() - transcribe_piano( - mp3s_dir=args.mp3s_dir, - midis_dir=args.midis_dir, - begin_index=args.begin_index, - end_index=args.end_index - ) - -""" -python transcribe_new_files.py \ - transcribe_piano \ - --mp3s_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data \ - --midis_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data/kong-model -""" \ No newline at end of file diff --git a/baselines/hft_transformer/src/amt.py b/baselines/hft_transformer/src/amt.py deleted file mode 100644 index 45f97ff..0000000 --- a/baselines/hft_transformer/src/amt.py +++ /dev/null @@ -1,407 +0,0 @@ -#! python - -import pickle -import torch -import numpy as np -import torchaudio -import pretty_midi - -class AMT(): - def __init__(self, config, model_path, batch_size=1, verbose_flag=False): - if verbose_flag is True: - print('torch version: '+torch.__version__) - print('torch cuda : '+str(torch.cuda.is_available())) - if torch.cuda.is_available(): - self.device = 'cuda' - else: - self.device = 'cpu' - - self.config = config - - if model_path == None: - self.model = None - else: - with open(model_path, 'rb') as f: - self.model = pickle.load(f) - self.model = self.model.to(self.device) - self.model.eval() - if verbose_flag is True: - print(self.model) - - self.batch_size = batch_size - - - def wav2feature(self, f_wav): - ### torchaudio - # torchaudio.transforms.MelSpectrogram() - # default - # sapmle_rate(16000) - # win_length(n_fft) - # hop_length(win_length//2) - # n_fft(400) - # f_min(0) - # f_max(None) - # pad(0) - # n_mels(128) - # window_fn(hann_window) - # center(True) - # power(2.0) - # pad_mode(reflect) - # onesided(True) - # norm(None) - ## melfilter: htk - ## normalize: none -> slaney - - wave, sr = torchaudio.load(f_wav) - wave_mono = torch.mean(wave, dim=0) - tr_fsconv = torchaudio.transforms.Resample(sr, self.config['feature']['sr']) - wave_mono_16k = tr_fsconv(wave_mono) - tr_mel = torchaudio.transforms.MelSpectrogram( - sample_rate=self.config['feature']['sr'], - n_fft=self.config['feature']['fft_bins'], - win_length=self.config['feature']['window_length'], - hop_length=self.config['feature']['hop_sample'], - pad_mode=self.config['feature']['pad_mode'], - n_mels=self.config['feature']['mel_bins'], - norm='slaney' - ) - mel_spec = tr_mel(wave_mono_16k) - a_feature = (torch.log(mel_spec + self.config['feature']['log_offset'])).T - - return a_feature - - - def transcript(self, a_feature, mode='combination', ablation_flag=False): - # a_feature: [num_frame, n_mels] - a_feature = np.array(a_feature, dtype=np.float32) - - a_tmp_b = np.full([self.config['input']['margin_b'], self.config['feature']['n_bins']], self.config['input']['min_value'], dtype=np.float32) - len_s = int(np.ceil(a_feature.shape[0] / self.config['input']['num_frame']) * self.config['input']['num_frame']) - a_feature.shape[0] - a_tmp_f = np.full([len_s+self.config['input']['margin_f'], self.config['feature']['n_bins']], self.config['input']['min_value'], dtype=np.float32) - a_input = torch.from_numpy(np.concatenate([a_tmp_b, a_feature, a_tmp_f], axis=0)) - # a_input: [margin_b+a_feature.shape[0]+len_s+margin_f, n_bins] - - a_output_onset_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) - a_output_offset_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) - a_output_mpe_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) - a_output_velocity_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.int8) - - if mode == 'combination': - a_output_onset_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) - a_output_offset_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) - a_output_mpe_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) - a_output_velocity_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.int8) - - self.model.eval() - for i in range(0, a_feature.shape[0], self.config['input']['num_frame']): - input_spec = (a_input[i:i+self.config['input']['margin_b']+self.config['input']['num_frame']+self.config['input']['margin_f']]).T.unsqueeze(0).to(self.device) - - with torch.no_grad(): - if mode == 'combination': - if ablation_flag is True: - output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.model(input_spec) - else: - output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.model(input_spec) - # output_onset: [batch_size, n_frame, n_note] - # output_offset: [batch_size, n_frame, n_note] - # output_mpe: [batch_size, n_frame, n_note] - # output_velocity: [batch_size, n_frame, n_note, n_velocity] - else: - output_onset_A, output_offset_A, output_mpe_A, output_velocity_A = self.model(input_spec) - - a_output_onset_A[i:i + self.config['input']['num_frame']] = (output_onset_A.squeeze(0)).to('cpu').detach().numpy() - a_output_offset_A[i:i + self.config['input']['num_frame']] = (output_offset_A.squeeze(0)).to('cpu').detach().numpy() - a_output_mpe_A[i:i + self.config['input']['num_frame']] = (output_mpe_A.squeeze(0)).to('cpu').detach().numpy() - a_output_velocity_A[i:i + self.config['input']['num_frame']] = (output_velocity_A.squeeze(0).argmax(2)).to('cpu').detach().numpy() - - if mode == 'combination': - a_output_onset_B[i:i+self.config['input']['num_frame']] = (output_onset_B.squeeze(0)).to('cpu').detach().numpy() - a_output_offset_B[i:i+self.config['input']['num_frame']] = (output_offset_B.squeeze(0)).to('cpu').detach().numpy() - a_output_mpe_B[i:i+self.config['input']['num_frame']] = (output_mpe_B.squeeze(0)).to('cpu').detach().numpy() - a_output_velocity_B[i:i+self.config['input']['num_frame']] = (output_velocity_B.squeeze(0).argmax(2)).to('cpu').detach().numpy() - - if mode == 'combination': - return a_output_onset_A, a_output_offset_A, a_output_mpe_A, a_output_velocity_A, a_output_onset_B, a_output_offset_B, a_output_mpe_B, a_output_velocity_B - else: - return a_output_onset_A, a_output_offset_A, a_output_mpe_A, a_output_velocity_A - - - def transcript_stride(self, a_feature, n_offset, mode='combination', ablation_flag=False): - # a_feature: [num_frame, n_mels] - a_feature = np.array(a_feature, dtype=np.float32) - - half_frame = int(self.config['input']['num_frame']/2) - a_tmp_b = np.full([self.config['input']['margin_b'] + n_offset, self.config['feature']['n_bins']], self.config['input']['min_value'], dtype=np.float32) - tmp_len = a_feature.shape[0] + self.config['input']['margin_b'] + self.config['input']['margin_f'] + half_frame - len_s = int(np.ceil(tmp_len / half_frame) * half_frame) - tmp_len - a_tmp_f = np.full([len_s+self.config['input']['margin_f']+(half_frame-n_offset), self.config['feature']['n_bins']], self.config['input']['min_value'], dtype=np.float32) - - a_input = torch.from_numpy(np.concatenate([a_tmp_b, a_feature, a_tmp_f], axis=0)) - # a_input: [n_offset+margin_b+a_feature.shape[0]+len_s+(half_frame-n_offset)+margin_f, n_bins] - - a_output_onset_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) - a_output_offset_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) - a_output_mpe_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) - a_output_velocity_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.int8) - - if mode == 'combination': - a_output_onset_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) - a_output_offset_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) - a_output_mpe_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) - a_output_velocity_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.int8) - - self.model.eval() - for i in range(0, a_feature.shape[0], half_frame): - input_spec = (a_input[i:i+self.config['input']['margin_b']+self.config['input']['num_frame']+self.config['input']['margin_f']]).T.unsqueeze(0).to(self.device) - - with torch.no_grad(): - if mode == 'combination': - if ablation_flag is True: - output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.model(input_spec) - else: - output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.model(input_spec) - # output_onset: [batch_size, n_frame, n_note] - # output_offset: [batch_size, n_frame, n_note] - # output_mpe: [batch_size, n_frame, n_note] - # output_velocity: [batch_size, n_frame, n_note, n_velocity] - else: - output_onset_A, output_offset_A, output_mpe_A, output_velocity_A = self.model(input_spec) - - a_output_onset_A[i:i+half_frame] = ( - (output_onset_A - .squeeze(0)[n_offset : n_offset+half_frame]) - .to('cpu').detach().numpy() - ) - a_output_offset_A[i:i+half_frame] = ( - (output_offset_A - .squeeze(0) - [n_offset:n_offset+half_frame]) - .to('cpu').detach().numpy() - ) - a_output_mpe_A[i:i+half_frame] = ( - (output_mpe_A - .squeeze(0) - [n_offset:n_offset+half_frame]) - .to('cpu').detach().numpy() - ) - a_output_velocity_A[i:i+half_frame] = (output_velocity_A.squeeze(0)[n_offset:n_offset+half_frame].argmax(2)).to('cpu').detach().numpy() - - if mode == 'combination': - a_output_onset_B[i:i+half_frame] = (output_onset_B.squeeze(0)[n_offset:n_offset+half_frame]).to('cpu').detach().numpy() - a_output_offset_B[i:i+half_frame] = (output_offset_B.squeeze(0)[n_offset:n_offset+half_frame]).to('cpu').detach().numpy() - a_output_mpe_B[i:i+half_frame] = (output_mpe_B.squeeze(0)[n_offset:n_offset+half_frame]).to('cpu').detach().numpy() - a_output_velocity_B[i:i+half_frame] = (output_velocity_B.squeeze(0)[n_offset:n_offset+half_frame].argmax(2)).to('cpu').detach().numpy() - - if mode == 'combination': - return ( - a_output_onset_A, - a_output_offset_A, - a_output_mpe_A, - a_output_velocity_A, - a_output_onset_B, - a_output_offset_B, - a_output_mpe_B, - a_output_velocity_B - ) - else: - return a_output_onset_A, a_output_offset_A, a_output_mpe_A, a_output_velocity_A - - - def mpe2note( - self, - a_onset=None, - a_offset=None, - a_mpe=None, - a_velocity=None, - thred_onset=0.5, - thred_offset=0.5, - thred_mpe=0.5, - mode_velocity='ignore_zero', - mode_offset='shorter' - ): - ## mode_velocity - ## org: 0-127 - ## ignore_zero: 0-127 (output note does not include 0) (default) - - ## mode_offset - ## shorter: use shorter one of mpe and offset (default) - ## longer : use longer one of mpe and offset - ## offset : use offset (ignore mpe) - - a_note = [] - hop_sec = float(self.config['feature']['hop_sample'] / self.config['feature']['sr']) - - for j in range(self.config['midi']['num_note']): - # find local maximum - a_onset_detect = [] - for i in range(len(a_onset)): - if a_onset[i][j] >= thred_onset: - left_flag = True - for ii in range(i-1, -1, -1): - if a_onset[i][j] > a_onset[ii][j]: - left_flag = True - break - elif a_onset[i][j] < a_onset[ii][j]: - left_flag = False - break - right_flag = True - for ii in range(i+1, len(a_onset)): - if a_onset[i][j] > a_onset[ii][j]: - right_flag = True - break - elif a_onset[i][j] < a_onset[ii][j]: - right_flag = False - break - if (left_flag is True) and (right_flag is True): - if (i == 0) or (i == len(a_onset) - 1): - onset_time = i * hop_sec - else: - if a_onset[i-1][j] == a_onset[i+1][j]: - onset_time = i * hop_sec - elif a_onset[i-1][j] > a_onset[i+1][j]: - onset_time = (i * hop_sec - (hop_sec * 0.5 * (a_onset[i-1][j] - a_onset[i+1][j]) / (a_onset[i][j] - a_onset[i+1][j]))) - else: - onset_time = (i * hop_sec + (hop_sec * 0.5 * (a_onset[i+1][j] - a_onset[i-1][j]) / (a_onset[i][j] - a_onset[i-1][j]))) - a_onset_detect.append({'loc': i, 'onset_time': onset_time}) - - a_offset_detect = [] - for i in range(len(a_offset)): - if a_offset[i][j] >= thred_offset: - left_flag = True - for ii in range(i-1, -1, -1): - if a_offset[i][j] > a_offset[ii][j]: - left_flag = True - break - elif a_offset[i][j] < a_offset[ii][j]: - left_flag = False - break - right_flag = True - for ii in range(i+1, len(a_offset)): - if a_offset[i][j] > a_offset[ii][j]: - right_flag = True - break - elif a_offset[i][j] < a_offset[ii][j]: - right_flag = False - break - if (left_flag is True) and (right_flag is True): - if (i == 0) or (i == len(a_offset) - 1): - offset_time = i * hop_sec - else: - if a_offset[i-1][j] == a_offset[i+1][j]: - offset_time = i * hop_sec - elif a_offset[i-1][j] > a_offset[i+1][j]: - offset_time = (i * hop_sec - (hop_sec * 0.5 * (a_offset[i-1][j] - a_offset[i+1][j]) / (a_offset[i][j] - a_offset[i+1][j]))) - else: - offset_time = (i * hop_sec + (hop_sec * 0.5 * (a_offset[i+1][j] - a_offset[i-1][j]) / (a_offset[i][j] - a_offset[i-1][j]))) - a_offset_detect.append({'loc': i, 'offset_time': offset_time}) - - time_next = 0.0 - time_offset = 0.0 - time_mpe = 0.0 - for idx_on in range(len(a_onset_detect)): - # onset - loc_onset = a_onset_detect[idx_on]['loc'] - time_onset = a_onset_detect[idx_on]['onset_time'] - - if idx_on + 1 < len(a_onset_detect): - loc_next = a_onset_detect[idx_on+1]['loc'] - # time_next = loc_next * hop_sec - time_next = a_onset_detect[idx_on+1]['onset_time'] - else: - loc_next = len(a_mpe) - time_next = (loc_next-1) * hop_sec - - # offset - loc_offset = loc_onset+1 - flag_offset = False - #time_offset = 0### - for idx_off in range(len(a_offset_detect)): - if loc_onset < a_offset_detect[idx_off]['loc']: - loc_offset = a_offset_detect[idx_off]['loc'] - time_offset = a_offset_detect[idx_off]['offset_time'] - flag_offset = True - break - if loc_offset > loc_next: - loc_offset = loc_next - time_offset = time_next - - # offset by MPE - # (1frame longer) - loc_mpe = loc_onset+1 - flag_mpe = False - # time_mpe = 0 ### - for ii_mpe in range(loc_onset+1, loc_next): - if a_mpe[ii_mpe][j] < thred_mpe: - loc_mpe = ii_mpe - flag_mpe = True - time_mpe = loc_mpe * hop_sec - break - ''' - # (right algorighm) - loc_mpe = loc_onset - flag_mpe = False - for ii_mpe in range(loc_onset+1, loc_next+1): - if a_mpe[ii_mpe][j] < thred_mpe: - loc_mpe = ii_mpe-1 - flag_mpe = True - time_mpe = loc_mpe * hop_sec - break - ''' - pitch_value = int(j+self.config['midi']['note_min']) - velocity_value = int(a_velocity[loc_onset][j]) - - if (flag_offset is False) and (flag_mpe is False): - offset_value = float(time_next) - elif (flag_offset is True) and (flag_mpe is False): - offset_value = float(time_offset) - elif (flag_offset is False) and (flag_mpe is True): - offset_value = float(time_mpe) - else: - if mode_offset == 'offset': - ## (a) offset - offset_value = float(time_offset) - elif mode_offset == 'longer': - ## (b) longer - if loc_offset >= loc_mpe: - offset_value = float(time_offset) - else: - offset_value = float(time_mpe) - else: - ## (c) shorter - if loc_offset <= loc_mpe: - offset_value = float(time_offset) - else: - offset_value = float(time_mpe) - if mode_velocity != 'ignore_zero': - a_note.append({'pitch': pitch_value, 'onset': float(time_onset), 'offset': offset_value, 'velocity': velocity_value}) - else: - if velocity_value > 0: - a_note.append({'pitch': pitch_value, 'onset': float(time_onset), 'offset': offset_value, 'velocity': velocity_value}) - - if ( - (len(a_note) > 1) and - (a_note[len(a_note)-1]['pitch'] == a_note[len(a_note)-2]['pitch']) and - (a_note[len(a_note)-1]['onset'] < a_note[len(a_note)-2]['offset']) - ): - a_note[len(a_note)-2]['offset'] = a_note[len(a_note)-1]['onset'] - - a_note = sorted(sorted(a_note, key=lambda x: x['pitch']), key=lambda x: x['onset']) - return a_note - - - def note2midi(self, a_note, f_midi): - midi = pretty_midi.PrettyMIDI() - instrument = pretty_midi.Instrument(program=0) - for note in a_note: - instrument.notes.append( - pretty_midi.Note( - velocity=note['velocity'], - pitch=note['pitch'], - start=note['onset'], - end=note['offset'] - ) - ) - midi.instruments.append(instrument) - midi.write(f_midi) - - return diff --git a/baselines/hft_transformer/transcribe_new_files.py b/baselines/hft_transformer/transcribe_new_files.py deleted file mode 100644 index 594bb44..0000000 --- a/baselines/hft_transformer/transcribe_new_files.py +++ /dev/null @@ -1,196 +0,0 @@ -#! python -import os -import argparse -import json -import sys -import glob -from baselines.hft_transformer.src import amt -from pydub import AudioSegment -from pydub.exceptions import CouldntDecodeError -import random -import torch -here = os.path.dirname(os.path.abspath(__file__)) - - -_AMT = None -def get_AMT(config_file=None, model_file=None): - global _AMT - if _AMT is None: - if config_file is None: - config_file = os.path.join(here, 'model_files/config-aug.json') - if model_file is None: - if torch.cuda.is_available(): - model_file = os.path.join(here, 'model_files/model-with-aug-data_006_009.pkl') - else: - model_file = os.path.join(here, 'model_files/model-with-aug-data_006_009_cpu.bin') - with open(config_file, 'r', encoding='utf-8') as f: - config = json.load(f) - if torch.cuda.is_available(): - _AMT = amt.AMT(config, model_file, verbose_flag=False) - else: - model = torch.load(model_file, map_location=torch.device('cpu')) - _AMT = amt.AMT(config, model_path=None, verbose_flag=False) - _AMT.model = model - return _AMT - -def check_and_convert_mp3_to_wav(fname): - wav_file = fname.replace('.mp3', '.wav') - if not os.path.exists(wav_file): - print('converting ' + fname + ' to .wav...') - try: - sound = AudioSegment.from_mp3(fname) - sound.export(fname.replace('.mp3', '.wav'), format="wav") - except CouldntDecodeError: - print('failed to convert ' + fname) - return None - return wav_file - - -def transcribe_file( - fname, - output_fname, - mode='combination', - thred_mpe=0.5, - thred_onset=0.5, - thred_offset=0.5, - n_stride=0, - ablation=False, - AMT=None -): - if AMT is None: - AMT = get_AMT() - - a_feature = AMT.wav2feature(fname) - - # transcript - if n_stride > 0: - output = AMT.transcript_stride(a_feature, n_stride, mode=mode, ablation_flag=ablation) - else: - output = AMT.transcript(a_feature, mode=mode, ablation_flag=ablation) - (output_1st_onset, output_1st_offset, output_1st_mpe, output_1st_velocity, - output_2nd_onset, output_2nd_offset, output_2nd_mpe, output_2nd_velocity) = output - - # note (mpe2note) - a_note_1st_predict = AMT.mpe2note( - a_onset=output_1st_onset, - a_offset=output_1st_offset, - a_mpe=output_1st_mpe, - a_velocity=output_1st_velocity, - thred_onset=thred_onset, - thred_offset=thred_offset, - thred_mpe=thred_mpe, - mode_velocity='ignore_zero', - mode_offset='shorter' - ) - - a_note_2nd_predict = AMT.mpe2note( - a_onset=output_2nd_onset, - a_offset=output_2nd_offset, - a_mpe=output_2nd_mpe, - a_velocity=output_2nd_velocity, - thred_onset=thred_onset, - thred_offset=thred_offset, - thred_mpe=thred_mpe, - mode_velocity='ignore_zero', - mode_offset='shorter' - ) - - AMT.note2midi(a_note_2nd_predict, output_fname) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - # necessary arguments - parser.add_argument('-input_dir_to_transcribe', default=None, help='file list') - parser.add_argument('-input_file_to_transcribe', default=None, help='one file') - parser.add_argument('-output_dir', help='output directory') - parser.add_argument('-output_file', default=None, help='output file') - parser.add_argument('-f_config', help='config json file', default=None) - parser.add_argument('-model_file', help='input model file', default=None) - parser.add_argument('-start_index', help='start index', type=int, default=None) - parser.add_argument('-end_index', help='end index', type=int, default=None) - parser.add_argument('-skip_transcribe_mp3', action='store_true', default=False) - # parameters - parser.add_argument('-mode', help='mode to transcript (combination|single)', default='combination') - parser.add_argument('-thred_mpe', help='threshold value for mpe detection', type=float, default=0.5) - parser.add_argument('-thred_onset', help='threshold value for onset detection', type=float, default=0.5) - parser.add_argument('-thred_offset', help='threshold value for offset detection', type=float, default=0.5) - parser.add_argument('-n_stride', help='number of samples for offset', type=int, default=0) - parser.add_argument('-ablation', help='ablation mode', action='store_true') - args = parser.parse_args() - - assert (args.input_dir_to_transcribe is not None) or (args.input_file_to_transcribe is not None), "input file or directory is not specified" - - if args.input_dir_to_transcribe is not None: - if not args.skip_transcribe_mp3: - # list file - a_mp3s = ( - glob.glob(os.path.join(args.input_dir_to_transcribe, '*.mp3')) + - glob.glob(os.path.join(args.input_dir_to_transcribe, '*', '*.mp3')) - ) - print(f'transcribing {len(a_mp3s)} files: [{str(a_mp3s)}]...') - list(map(check_and_convert_mp3_to_wav, a_mp3s)) - - a_list = ( - glob.glob(os.path.join(args.input_dir_to_transcribe, '*.wav')) + - glob.glob(os.path.join(args.input_dir_to_transcribe, '*', '*.wav')) - ) - if (args.start_index is not None) or (args.end_index is not None): - if args.start_index is None: - args.start_index = 0 - if args.end_index is None: - args.end_index = len(a_list) - a_list = a_list[args.start_index:args.end_index] - # shuffle a_list - random.shuffle(a_list) - - elif args.input_file_to_transcribe is not None: - args.input_file_to_transcribe = check_and_convert_mp3_to_wav(args.input_file_to_transcribe) - if args.input_file_to_transcribe is None: - sys.exit() - a_list = [args.input_file_to_transcribe] - print(f'transcribing {str(a_list)} files...') - - # load model - AMT = get_AMT(args.f_config, args.model_file) - - long_filename_counter = 0 - for fname in a_list: - if args.output_file is not None: - output_fname = args.output_file - else: - output_fname = fname.replace('.wav', '') - if len(output_fname) > 200: - output_fname = output_fname[:200] + f'_fnabbrev-{long_filename_counter}' - output_fname += '_transcribed.mid' - output_fname = os.path.join(args.output_dir, os.path.basename(output_fname)) - if os.path.exists(output_fname): - continue - - print('[' + fname + ']') - try: - transcribe_file( - fname, - output_fname, - args.mode, - args.thred_mpe, - args.thred_onset, - args.thred_offset, - args.n_stride, - args.ablation, - AMT, - ) - except Exception as e: - print(e) - continue - - print('** done **') - - -""" -e.g. usage: - -python evaluation/transcribe_new_files.py \ - -input_dir_to_transcribe evaluation/glenn-gould-bach-data \ - -output_dir hft-evaluation-data/ \ -""" diff --git a/baselines/requirements-baselines.txt b/baselines/requirements-baselines.txt deleted file mode 100644 index b56d966..0000000 --- a/baselines/requirements-baselines.txt +++ /dev/null @@ -1,3 +0,0 @@ -pretty_midi -librosa -piano_transcription_inference diff --git a/config/config.json b/config/config.json index 9da2e4e..9442648 100644 --- a/config/config.json +++ b/config/config.json @@ -17,7 +17,7 @@ "n_mels": 256 }, "data": { - "stride_factor": 12, + "stride_factor": 15, "max_seq_len": 4096 } } \ No newline at end of file diff --git a/notebooks/2024-03-06__test-alignment-methods.ipynb b/notebooks/2024-03-06__test-alignment-methods.ipynb deleted file mode 100644 index 8e4d229..0000000 --- a/notebooks/2024-03-06__test-alignment-methods.ipynb +++ /dev/null @@ -1,886 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "de9ab3d6-09d8-42ba-aa3a-a892f03f376a", - "metadata": {}, - "source": [ - "# Test how to induce phone effect" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "5864bcb4-6da0-4f22-9515-1395dfa9d56d", - "metadata": {}, - "outputs": [], - "source": [ - "import IPython\n", - "import torchaudio\n", - "import torchaudio.functional as F" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "6c4f9fa0-45a7-4faa-b4b6-a02ca7239deb", - "metadata": {}, - "outputs": [], - "source": [ - "fpath = 'test-transcription/hft-transcribed__02_R1_2004_05_Track05.wav'\n", - "waveform, sample_rate = torchaudio.load(fpath)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "c91b8863-de1c-4c09-8106-9c4a8e0ab11c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "44100" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sample_rate" - ] - }, - { - "cell_type": "code", - "execution_count": 65, - "id": "183078c0-5db2-423e-831a-2da570bb5830", - "metadata": {}, - "outputs": [], - "source": [ - "# IPython.display.Audio(data=waveform, rate=sample_rate)" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "0623242c-e8aa-4a39-a318-bf61d67490f4", - "metadata": {}, - "outputs": [], - "source": [ - "phone_wav = F.highpass_biquad(waveform, sample_rate, cutoff_freq=1200)\n", - "phone_wav = F.lowpass_biquad(phone_wav, sample_rate, cutoff_freq=1400)\n", - "resample_rate = 6000\n", - "phone_wav = F.resample(phone_wav, orig_freq=sample_rate, new_freq=resample_rate, lowpass_filter_width=3)" - ] - }, - { - "cell_type": "markdown", - "id": "f17b5a6d-968b-4b53-ada1-496f4a13efdf", - "metadata": {}, - "source": [ - "# MIR_EVAL" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "id": "ac2937f2-9852-4478-9820-bbc504b8c24f", - "metadata": {}, - "outputs": [], - "source": [ - "import pretty_midi\n", - "import numpy as np \n", - "import mir_eval\n", - "\n", - "def midi_to_intervals_and_pitches(midi_file_path):\n", - " \"\"\"\n", - " This function reads a MIDI file and extracts note intervals and pitches\n", - " suitable for use with mir_eval's transcription evaluation functions.\n", - " \"\"\"\n", - " # Load the MIDI file\n", - " midi_data = pretty_midi.PrettyMIDI(midi_file_path)\n", - " \n", - " # Prepare lists to collect note intervals and pitches\n", - " notes = []\n", - " for instrument in midi_data.instruments:\n", - " # Skip drum instruments\n", - " if not instrument.is_drum:\n", - " for note in instrument.notes:\n", - " notes.append([note.start, note.end, note.pitch])\n", - " notes = sorted(notes, key=lambda x: x[0])\n", - " notes = np.array(notes)\n", - " intervals, pitches = notes[:, :2], notes[:, 2]\n", - " intervals -= intervals[0][0]\n", - " return intervals, pitches\n", - "\n", - "def midi_to_hz(note, shift=0):\n", - " \"\"\"\n", - " Convert MIDI to HZ.\n", - "\n", - " Shift, if != 0, is subtracted from the MIDI note. Use \"2\" for the hFT augmented model transcriptions, else pitches won't match.\n", - " \"\"\"\n", - " # the one used in hFT transformer\n", - " return 440.0 * (2.0 ** (note.astype(int) - shift - 69) / 12)\n", - " a = 440 # frequency of A (common value is 440Hz)\n", - " # return (a / 32) * (2 ** ((note - 9) / 12))" - ] - }, - { - "cell_type": "markdown", - "id": "f552f9d8-1fe1-4de3-8954-ae41b568a153", - "metadata": {}, - "source": [ - "# Kong's Alignment Method" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "a390d3bf-0816-4b5e-9e85-a9e409b4359b", - "metadata": {}, - "outputs": [], - "source": [ - "import csv\n", - "def get_stats(csv_path):\n", - " \"\"\"Parse aligned results csv file to get results.\n", - "\n", - " Args:\n", - " csv_path: str, aligned result path, e.g., xx_corresp.txt\n", - "\n", - " Returns:\n", - " stat_dict, dict, keys: \n", - " true positive (TP), \n", - " deletion (D), \n", - " insertion (I), \n", - " substitution (S), \n", - " error rate (ER), \n", - " ground truth number (N)\n", - " \"\"\"\n", - " with open(csv_path, 'r') as fr:\n", - " reader = csv.reader(fr, delimiter='\\t')\n", - " lines = list(reader)\n", - "\n", - " lines = lines[1 :]\n", - "\n", - " TP, D, I, S = 0, 0, 0, 0\n", - " align_counter = []\n", - " ref_counter = []\n", - "\n", - " for line in lines:\n", - " line = line[0 : -1]\n", - " [alignID, _, _, alignPitch, _, refID, _, _, refPitch, _] = line\n", - "\n", - " if alignID != '*' and refID != '*':\n", - " if alignPitch == refPitch:\n", - " TP += 1\n", - " else:\n", - " S += 1\n", - "\n", - " if alignID == '*':\n", - " D += 1\n", - "\n", - " if refID == '*':\n", - " I += 1\n", - "\n", - " N = TP + D + S\n", - " ER = (D + I + S) / N\n", - " stat_dict = {'TP': TP, 'D': D, 'I': I, 'S': S, 'ER': ER, 'N': N}\n", - " return stat_dict" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "a82c99f6-eac2-46fc-8c36-3b20c893c26a", - "metadata": {}, - "outputs": [], - "source": [ - "import os \n", - "def align_files(ref_fp, est_fp):\n", - " align_tools_dir = '../../2017_midi_alignment'\n", - " ref_fn = os.path.basename(ref_fp)\n", - " est_fn = os.path.basename(est_fp)\n", - " ref_fn_name, ext = os.path.splitext(ref_fn)\n", - " est_fn_name, ext = os.path.splitext(est_fn)\n", - " \n", - " # Copy MIDI files\n", - " cmd = f'cp \"{ref_fp}\" \"{align_tools_dir}/{ref_fn}\"; '\n", - " cmd += f'cp \"{est_fp}\" \"{align_tools_dir}/{est_fn}\"; '\n", - " print(cmd)\n", - " os.system(cmd)\n", - " \n", - " # Align\n", - " cmd = f'cd {align_tools_dir}; '\n", - " # cmd += f'./MIDIToMIDIAlign.sh {ref_fn_name} {est_fn_name}; '\n", - " cmd += f'./MIDIToMIDIAlign.sh {ref_fn} {est_fn}; '\n", - " print(cmd)\n", - " os.system(cmd)" - ] - }, - { - "cell_type": "markdown", - "id": "8d8f5331-92e3-49ae-8def-6af5834a5a9b", - "metadata": {}, - "source": [ - "# Test Sample MIDI" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "35ee30b1-bed9-42f9-8011-086431aa60ae", - "metadata": {}, - "outputs": [], - "source": [ - "from importlib import reload\n", - "import sys\n", - "sys.path.insert(0, '../../aria-dl/hFT-Transformer/evaluation/')\n", - "import transcribe_new_files as t\n", - "import glob\n", - "import aria.utils\n", - "from importlib import reload \n", - "reload(aria)\n", - "import IPython\n", - "\n", - "all_maestro_files = sorted(glob.glob('../../corpus/maestro-v3.0.0/2004/*'))" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "51c500b6-20b9-4a4a-b677-8c14d58fd2aa", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "input_wav_file = 'test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.wav'\n", - "output_midi_file = 'test-transcription/hft-transcribed__02_R1_2004_05_Track05.midi'\n", - "# t.transcribe_file(input_wav_file, output_midi_file)\n", - "gold_truth_midi_file = 'test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "44129b9c-4fb1-49c7-8d10-e186395ad8b0", - "metadata": {}, - "outputs": [], - "source": [ - "aria.utils.midi_to_audio(\"test-transcription/hft-transcribed__02_R1_2004_05_Track05.midi\")" - ] - }, - { - "cell_type": "code", - "execution_count": 67, - "id": "c94d739a-6583-4ee4-875b-3d0e8493220f", - "metadata": {}, - "outputs": [], - "source": [ - "import IPython\n", - "# IPython.display.Audio(data='test-transcription/hft-transcribed__02_R1_2004_05_Track05.wav', rate=44100)" - ] - }, - { - "cell_type": "code", - "execution_count": 68, - "id": "3385a328-0728-4b60-bd1b-e03957ac79b5", - "metadata": {}, - "outputs": [], - "source": [ - "import IPython\n", - "# IPython.display.Audio(data=input_wav_file, rate=44100)" - ] - }, - { - "cell_type": "markdown", - "id": "43746faf-a7bd-45a8-8a0c-4238e68b3d34", - "metadata": {}, - "source": [ - "#### evaluate using mir_eval" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "cf506615-9e3c-4f20-bf08-b12dd74a0670", - "metadata": {}, - "outputs": [], - "source": [ - "ref_intervals, ref_pitches = midi_to_intervals_and_pitches(gold_truth_midi_file)\n", - "est_intervals, est_pitches = midi_to_intervals_and_pitches(output_midi_file)\n", - "\n", - "ref_pitches_hz = midi_to_hz(ref_pitches)\n", - "est_pitches_hz = midi_to_hz(est_pitches, shift=2) ## shift=2 because hFT transcribes 2 notes above, for some reason" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "id": "58c4b8af-4729-4c06-8a63-85198222bb0c", - "metadata": {}, - "outputs": [], - "source": [ - "scores = mir_eval.transcription.evaluate(ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz)" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "34cbfa89-371a-4466-a4b2-ebe1900e5ae9", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "matched_onsets = mir_eval.transcription.match_note_onsets(ref_intervals, est_intervals)" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "cf1847d5-d253-4062-b8cf-9ed5142707ae", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([71., 55., 71., 59., 62., 72.])" - ] - }, - "execution_count": 50, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ref_pitches[[0, 1, 2, 3, 4, 5]]" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "id": "93c07152-104c-4c6d-8f6c-f7472c31a7ec", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([73., 57., 73., 61., 64., 74.])" - ] - }, - "execution_count": 51, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "est_pitches[[0, 1, 2, 3, 4, 5]]" - ] - }, - { - "cell_type": "code", - "execution_count": 64, - "id": "187602e4-269d-448d-b300-f7e9087fd7a7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'{\\n \"Precision\": 0.7708092856226754,\\n \"Recall\": 0.7613377248543197,\\n \"F-measure\": 0.7660442291759608,\\n \"Average_Overlap_Ratio\": 0.8455788515638166,\\n \"Precision_no_offset\": 0.9976914197768373,\\n \"Recall_no_offset\": 0.9854319736508741,\\n \"F-measure_no_offset\": 0.9915238034542095,\\n \"Average_Overlap_Ratio_no_offset\": 0.7557460905388416,\\n \"Onset_Precision\": 0.9980761831473643,\\n \"Onset_Recall\": 0.9858120091208513,\\n \"Onset_F-measure\": 0.9919061882607865,\\n \"Offset_Precision\": 0.8117224573553931,\\n \"Offset_Recall\": 0.8017481631618951,\\n \"Offset_F-measure\": 0.806704480275317\\n}'" - ] - }, - "execution_count": 64, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import json\n", - "json.dumps(scores, indent=4)" - ] - }, - { - "cell_type": "markdown", - "id": "94c5ab54-c4be-45a5-92f4-9b99a40ffc50", - "metadata": {}, - "source": [ - "### evaluate using kong's method" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "77bb09e5-e989-4042-8d6d-977c232ac817", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'test-transcription/hft-transcribed__02_R1_2004_05_Track05.midi'" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "output_midi_file" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "ba68de3b-b1a6-4f3d-815b-09403862a2a1", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "cp \"test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi\" \"../../2017_midi_alignment/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi\"; cp \"test-transcription/hft-transcribed__02_R1_2004_05_Track05.midi\" \"../../2017_midi_alignment/hft-transcribed__02_R1_2004_05_Track05.midi\"; \n", - "cd ../../2017_midi_alignment; ./MIDIToMIDIAlign.sh MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi hft-transcribed__02_R1_2004_05_Track05.midi; \n", - "File not found: ./MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi_fmt3x.txt\n", - "File not found: ./MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi_hmm.txt\n", - "File not found: ./MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi_fmt3x.txt\n", - "File not found: ./MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi_fmt3x.txt\n", - "File not found: ./hft-transcribed__02_R1_2004_05_Track05.midi_match.txt\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Assertion failed: (ifs.is_open()), function ReadFile, file Midi_v170101.hpp, line 177.\n", - "./MIDIToMIDIAlign.sh: line 14: 68574 Abort trap: 6 $ProgramFolder/midi2pianoroll 0 $RelCurrentFolder/${I1}\n", - "Assertion failed: (ifs.is_open()), function ReadFile, file Midi_v170101.hpp, line 177.\n", - "./MIDIToMIDIAlign.sh: line 15: 68575 Abort trap: 6 $ProgramFolder/midi2pianoroll 0 $RelCurrentFolder/${I2}\n", - "./MIDIToMIDIAlign.sh: line 17: 68576 Segmentation fault: 11 $ProgramFolder/SprToFmt3x $RelCurrentFolder/${I1}_spr.txt $RelCurrentFolder/${I1}_fmt3x.txt\n", - "Assertion failed: (false), function ReadFile, file Fmt3x_v170225.hpp, line 252.\n", - "./MIDIToMIDIAlign.sh: line 18: 68577 Abort trap: 6 $ProgramFolder/Fmt3xToHmm $RelCurrentFolder/${I1}_fmt3x.txt $RelCurrentFolder/${I1}_hmm.txt\n", - "Assertion failed: (false), function ReadFile, file Hmm_v170225.hpp, line 69.\n", - "./MIDIToMIDIAlign.sh: line 20: 68578 Abort trap: 6 $ProgramFolder/ScorePerfmMatcher $RelCurrentFolder/${I1}_hmm.txt $RelCurrentFolder/${I2}_spr.txt $RelCurrentFolder/${I2}_pre_match.txt 0.001\n", - "Assertion failed: (false), function ReadFile, file Fmt3x_v170225.hpp, line 252.\n", - "./MIDIToMIDIAlign.sh: line 21: 68579 Abort trap: 6 $ProgramFolder/ErrorDetection $RelCurrentFolder/${I1}_fmt3x.txt $RelCurrentFolder/${I1}_hmm.txt $RelCurrentFolder/${I2}_pre_match.txt $RelCurrentFolder/${I2}_err_match.txt 0\n", - "Assertion failed: (false), function ReadFile, file Fmt3x_v170225.hpp, line 252.\n", - "./MIDIToMIDIAlign.sh: line 22: 68580 Abort trap: 6 $ProgramFolder/RealignmentMOHMM $RelCurrentFolder/${I1}_fmt3x.txt $RelCurrentFolder/${I1}_hmm.txt $RelCurrentFolder/${I2}_err_match.txt $RelCurrentFolder/${I2}_realigned_match.txt 0.3\n", - "cp: ./hft-transcribed__02_R1_2004_05_Track05.midi_realigned_match.txt: No such file or directory\n", - "Assertion failed: (false), function ReadFile, file ScorePerfmMatch_v170503.hpp, line 86.\n", - "./MIDIToMIDIAlign.sh: line 25: 68582 Abort trap: 6 $ProgramFolder/MatchToCorresp $RelCurrentFolder/${I2}_match.txt $RelCurrentFolder/${I1}_spr.txt $RelCurrentFolder/${I2}_corresp.txt\n", - "rm: ./hft-transcribed__02_R1_2004_05_Track05.midi_realigned_match.txt: No such file or directory\n", - "rm: ./hft-transcribed__02_R1_2004_05_Track05.midi_err_match.txt: No such file or directory\n", - "rm: ./hft-transcribed__02_R1_2004_05_Track05.midi_pre_match.txt: No such file or directory\n" - ] - } - ], - "source": [ - "align_files(gold_truth_midi_file, output_midi_file)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "434270a9-91c5-4b75-b56f-d2f0cda631ab", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[34mCode\u001b[m\u001b[m/\n", - "LICENCE.txt\n", - "MANUAL.pdf\n", - "MIDI-Unprocessed_02_R1_2009_03-06_ORIG_MID--AUDIO_02_R1_2009_02_R1_2009_04_WAV.mid\n", - "MIDI-Unprocessed_02_R1_2009_03-06_ORIG_MID--AUDIO_02_R1_2009_02_R1_2009_04_WAV_corresp.txt\n", - "MIDI-Unprocessed_02_R1_2009_03-06_ORIG_MID--AUDIO_02_R1_2009_02_R1_2009_04_WAV_match.txt\n", - "MIDI-Unprocessed_02_R1_2009_03-06_ORIG_MID--AUDIO_02_R1_2009_02_R1_2009_04_WAV_spr.txt\n", - "MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi\n", - "\u001b[31mMIDIToMIDIAlign.sh\u001b[m\u001b[m*\n", - "\u001b[31mMusicXMLToMIDIAlign.sh\u001b[m\u001b[m*\n", - "\u001b[34mPrograms\u001b[m\u001b[m/\n", - "Scriabin_op_42_Hf4EIJB4DGc_cut_no_4.mid\n", - "Scriabin_op_42_Hf4EIJB4DGc_cut_no_4_corresp.txt\n", - "Scriabin_op_42_Hf4EIJB4DGc_cut_no_4_match.txt\n", - "Scriabin_op_42_Hf4EIJB4DGc_cut_no_4_spr.txt\n", - "\u001b[31mcompile.sh\u001b[m\u001b[m*\n", - "ex_align1.mid\n", - "\u001b[31mex_align2.mid\u001b[m\u001b[m*\n", - "ex_ref.musx\n", - "ex_ref.pdf\n", - "ex_ref.xml\n", - "ex_ref_fmt3x.txt\n", - "ex_ref_hmm.txt\n", - "hft-transcribed__02_R1_2004_05_Track05.midi\n", - "scriabin_etude_op_42_no_4_dery.mid\n", - "scriabin_etude_op_42_no_4_dery_fmt3x.txt\n", - "scriabin_etude_op_42_no_4_dery_hmm.txt\n", - "scriabin_etude_op_42_no_4_dery_spr.txt\n" - ] - } - ], - "source": [ - "ls ../../2017_midi_alignment/" - ] - }, - { - "cell_type": "markdown", - "id": "4e1f906a-6e8a-46c4-88fe-fadac564ba39", - "metadata": {}, - "source": [ - "# Test Kong's Samples" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "cbcd4410-403e-4565-96ec-c7067b9b4b76", - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "sys.path.insert(0, '..')\n", - "import amt.audio" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "07e5a978-fc0f-444d-8aeb-851dfda1085a", - "metadata": {}, - "outputs": [], - "source": [ - "audio_transform = amt.audio.AudioTransform()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "365ccc37-2ab3-4f58-89bd-bb972b39075c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Chopin, Frédéric, Études, Op.10, g0hoN6_HDVU.mid\n", - "Handel, George Frideric, Air in E major, HWV 425, bNzVz5byPqk.mid\n", - "Liszt, Franz, Hungarian Rhapsody No.2, S.244_2, LdH1hSWGFGU.mid\n", - "Ravel, Maurice, Jeux d'eau, v-QmwrhO3ec.mid\n" - ] - } - ], - "source": [ - "ls ../../GiantMIDI-Piano/midis_preview/" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "27d20498-f77f-4c34-b368-40a3ccefc871", - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "import os \n", - "df = pd.read_csv('../../GiantMIDI-Piano/midis_for_evaluation/groundtruth_maestro_giantmidi-piano.csv', sep='\\t')" - ] - }, - { - "cell_type": "code", - "execution_count": 62, - "id": "81c4dabd-be84-4d51-82a2-b2426f001ef6", - "metadata": {}, - "outputs": [], - "source": [ - "gt_folder = '../../GiantMIDI-Piano/midis_for_evaluation/ground_truth/'\n", - "giant_midi_folder = '../../GiantMIDI-Piano/midis_for_evaluation/giantmidi-piano/'\n", - "maestro_midi_folder = '../../GiantMIDI-Piano/midis_for_evaluation/maestro/'\n", - "gt_fn, giant_midi_fn, maestro_fn = df[['GroundTruth', 'GiantMIDI-Piano', 'Maestro']].iloc[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 65, - "id": "0c8b2c1a-6dbb-4191-9a6f-43982b1dd65f", - "metadata": {}, - "outputs": [], - "source": [ - "gt_fp = os.path.join(gt_folder, gt_fn)\n", - "giant_midi_fp = os.path.join(giant_midi_folder, giant_midi_fn)\n", - "maestro_midi_fp = os.path.join(maestro_midi_folder, maestro_fn)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4313bfc4-ffa4-4182-8f17-5519421aa126", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "import mirdata\n", - "import mido \n", - "mido.MidiFile(filename=gt_fp)" - ] - }, - { - "cell_type": "code", - "execution_count": 160, - "id": "f1b4c7b9-b3b5-48e0-a5cf-43e704c6fb51", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "ref_intervals, ref_pitches = midi_to_intervals_and_pitches(gt_fp)\n", - "est_intervals, est_pitches = midi_to_intervals_and_pitches(maestro_midi_fp)\n", - "\n", - "ref_pitches_hz = midi_to_hz(ref_pitches)\n", - "est_pitches_hz = midi_to_hz(est_pitches)" - ] - }, - { - "cell_type": "code", - "execution_count": 161, - "id": "b619d23b-080a-4ef1-b2f7-b3241a18e9cb", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[(1, 1), (2, 2), (23, 23), (67, 66), (71, 70), (336, 385), (677, 709)]" - ] - }, - "execution_count": 161, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mir_eval.transcription.match_notes(\n", - " ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 164, - "id": "bcb0f59e-e097-455f-90c3-fa28a1b78a91", - "metadata": {}, - "outputs": [], - "source": [ - "# mir_eval.transcription.precision_recall_f1_overlap()\n", - "# mir_eval.transcription.evaluate()" - ] - }, - { - "cell_type": "code", - "execution_count": 165, - "id": "7096c4ac-a857-40d3-b14b-5b91443134fe", - "metadata": {}, - "outputs": [], - "source": [ - "scores = mir_eval.transcription.evaluate(ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz)" - ] - }, - { - "cell_type": "code", - "execution_count": 166, - "id": "a70c27b3-f9f7-4bd6-8471-4493c9fd8a89", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "OrderedDict([('Precision', 0.008939974457215836),\n", - " ('Recall', 0.008816120906801008),\n", - " ('F-measure', 0.008877615726062145),\n", - " ('Average_Overlap_Ratio', 0.8771661491453748),\n", - " ('Precision_no_offset', 0.04469987228607918),\n", - " ('Recall_no_offset', 0.04408060453400504),\n", - " ('F-measure_no_offset', 0.04438807863031072),\n", - " ('Average_Overlap_Ratio_no_offset', 0.5049151558206666),\n", - " ('Onset_Precision', 0.2771392081736909),\n", - " ('Onset_Recall', 0.27329974811083124),\n", - " ('Onset_F-measure', 0.2752060875079264),\n", - " ('Offset_Precision', 0.4623243933588761),\n", - " ('Offset_Recall', 0.45591939546599497),\n", - " ('Offset_F-measure', 0.45909955611921366)])" - ] - }, - "execution_count": 166, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "scores" - ] - }, - { - "cell_type": "code", - "execution_count": 146, - "id": "273e921e-fdf6-4385-a17a-ab1e43b2f2f2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Precision: 0.0012091898428053204, Recall: 0.0012594458438287153, F-measure: 0.0012338062924120913\n" - ] - } - ], - "source": [ - "import pretty_midi\n", - "import numpy as np\n", - "import mir_eval\n", - "\n", - "def midi_to_intervals_and_pitches(midi_file_path):\n", - " # Load the MIDI file\n", - " midi_data = pretty_midi.PrettyMIDI(midi_file_path)\n", - " \n", - " intervals = []\n", - " pitches = []\n", - " \n", - " for instrument in midi_data.instruments:\n", - " if not instrument.is_drum:\n", - " for note in instrument.notes:\n", - " start_time = note.start\n", - " end_time = note.end\n", - " intervals.append([start_time, end_time])\n", - " pitches.append(note.pitch)\n", - " \n", - " intervals = np.array(intervals)\n", - " pitches = np.array(pitches)\n", - " \n", - " return intervals, pitches\n", - "\n", - "# Load your reference and estimated MIDI files\n", - "ref_intervals, ref_pitches = midi_to_intervals_and_pitches(gt_fp)\n", - "est_intervals, est_pitches = midi_to_intervals_and_pitches(giant_midi_fp)\n", - "ref_pitches_hz = midi_to_hz(ref_pitches)\n", - "est_pitches_hz = midi_to_hz(est_pitches)\n", - "\n", - "# Evaluate using mir_eval\n", - "precision, recall, f_measure, _ = mir_eval.transcription.precision_recall_f1_overlap(\n", - " ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz\n", - ")\n", - "\n", - "print(f\"Precision: {precision}, Recall: {recall}, F-measure: {f_measure}\")" - ] - }, - { - "cell_type": "markdown", - "id": "3392980e-8d35-4764-a9f4-4d8e322791c5", - "metadata": {}, - "source": [ - "# Try using the GiantMIDI Method" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ffc236a1-8837-47d5-a02b-56eded2be3a4", - "metadata": {}, - "outputs": [], - "source": [ - "csv_path = f'{align_tools_dir}/{maestro_fn[: -4]}_corresp.txt'\n", - "maestro_stats = get_stats(csv_path)\n", - "\n", - "csv_path = f'{align_tools_dir}/{giant_midi_fn[: -4]}_corresp.txt'\n", - "giantmidi_stats = get_stats(csv_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 155, - "id": "2128d51d-9845-47ec-871a-2fff47fd9640", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'TP': 780, 'D': 8, 'I': 41, 'S': 6, 'ER': 0.06926952141057935, 'N': 794}" - ] - }, - "execution_count": 155, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "giantmidi_stats" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8d045113-f996-4959-a615-9edc909b08a5", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d37949b3-2ab3-426f-827b-7e77fe728a3c", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6ceee36a-3022-4c84-9e09-84cbd709cbae", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 66, - "id": "dad2f979-56cd-4af6-acf1-64420ab61255", - "metadata": {}, - "outputs": [], - "source": [ - "# IPython.display.Audio(data=phone_wav, rate=resample_rate)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "feaced10-0bc3-44b3-8f9d-1647f203c89c", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/2024-03-07__run-aria-amt-and-evals.ipynb b/notebooks/2024-03-07__run-aria-amt-and-evals.ipynb deleted file mode 100644 index 9ff83d0..0000000 --- a/notebooks/2024-03-07__run-aria-amt-and-evals.ipynb +++ /dev/null @@ -1,309 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-08T02:01:24.262700Z", - "start_time": "2024-03-08T02:01:23.754126Z" - }, - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "mkdir: cannot create directory ‘../amt/assets’: File exists\r\n" - ] - } - ], - "source": [ - "! mkdir '../amt/assets'\n", - "! mkdir '../amt/assets/impulse'\n", - "! mkdir '../amt/assets/noise'" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-08T03:39:12.094229Z", - "start_time": "2024-03-08T03:39:07.414566Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "import sys\n", - "import subprocess\n", - "\n", - "MODEL_NAME = \"medium\"\n", - "CHECKPOINT_NAME = f\"med-81.safetensors\"" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-08T03:51:28.491243Z", - "start_time": "2024-03-08T03:51:28.477962Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "if not os.path.isfile(f\"{CHECKPOINT_NAME}\"):\n", - " ! wget https://storage.googleapis.com/aria-checkpoints/amt/{CHECKPOINT_NAME} {CHECKPOINT_NAME}" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-08T03:39:12.177865Z", - "start_time": "2024-03-08T03:39:12.133447Z" - }, - "collapsed": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import amt.run\n", - "from importlib import reload\n", - "reload(amt.run)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-08T03:50:35.603392Z", - "start_time": "2024-03-08T03:39:13.021352Z" - }, - "collapsed": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "838861: Getting wav segments\n", - "838861: Finished file 1 - test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.wav\n", - "838861: 0 file(s) remaining in queue\n", - "839109: GPU task timeout\n", - "839109: Finished GPU tasks\n" - ] - } - ], - "source": [ - "# model_name, checkpoint_path, save_dir, load_path=None, load_dir=None, batch_size=16, multi_gpu=False\n", - "amt.run.transcribe(\n", - " model_name=MODEL_NAME,\n", - " checkpoint_path=CHECKPOINT_NAME,\n", - " load_path='test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.wav',\n", - " save_dir=\"test-transcription/aria-amt-tests\",\n", - " batch_size=1,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "# Evaluate Output" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-08T04:42:50.244198Z", - "start_time": "2024-03-08T04:42:46.774805Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "import amt.evaluate" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-08T04:04:40.989918Z", - "start_time": "2024-03-08T04:04:40.944807Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "t = 'test-transcription/aria-amt-tests/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.mid'" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-08T04:47:00.700219Z", - "start_time": "2024-03-08T04:46:51.855649Z" - }, - "collapsed": false - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f0b22ca267334ff99664db71cf5bf507", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1 [00:00" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "reload(amt.evaluate)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-08T04:46:00.278645Z", - "start_time": "2024-03-08T04:46:00.086183Z" - }, - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi\r\n", - "test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.wav\r\n" - ] - } - ], - "source": [ - "ls test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_*" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/notebooks/2024-03-11__experiment-with-sound-augmentations.ipynb b/notebooks/2024-03-11__experiment-with-sound-augmentations.ipynb deleted file mode 100644 index f45c6dd..0000000 --- a/notebooks/2024-03-11__experiment-with-sound-augmentations.ipynb +++ /dev/null @@ -1,213 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "outputs": [ - { - "ename": "OSError", - "evalue": "dlopen(/Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so, 0x0006): Symbol not found: __ZN2at4_ops15sum_dim_IntList4callERKNS_6TensorEN3c108ArrayRefIxEEbNS5_8optionalINS5_10ScalarTypeEEE\n Referenced from: <34C7FCDA-98E6-3DB6-B57D-478635DE1F58> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so\n Expected in: <89972BE7-3028-34DA-B561-E66870D59767> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib", - "output_type": "error", - "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mOSError\u001B[0m Traceback (most recent call last)", - "Input \u001B[0;32mIn [1]\u001B[0m, in \u001B[0;36m\u001B[0;34m()\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mIPython\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdisplay\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Audio\n\u001B[0;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m download_asset\n\u001B[1;32m 5\u001B[0m SAMPLE_WAV \u001B[38;5;241m=\u001B[39m download_asset(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtutorial-assets/steam-train-whistle-daniel_simon.wav\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 6\u001B[0m SAMPLE_RIR \u001B[38;5;241m=\u001B[39m download_asset(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtutorial-assets/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo-8000hz.wav\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/__init__.py:1\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ( \u001B[38;5;66;03m# noqa: F401\u001B[39;00m\n\u001B[1;32m 2\u001B[0m _extension,\n\u001B[1;32m 3\u001B[0m compliance,\n\u001B[1;32m 4\u001B[0m datasets,\n\u001B[1;32m 5\u001B[0m functional,\n\u001B[1;32m 6\u001B[0m io,\n\u001B[1;32m 7\u001B[0m kaldi_io,\n\u001B[1;32m 8\u001B[0m models,\n\u001B[1;32m 9\u001B[0m pipelines,\n\u001B[1;32m 10\u001B[0m sox_effects,\n\u001B[1;32m 11\u001B[0m transforms,\n\u001B[1;32m 12\u001B[0m utils,\n\u001B[1;32m 13\u001B[0m )\n\u001B[1;32m 14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbackend\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m get_audio_backend, list_audio_backends, set_audio_backend\n\u001B[1;32m 16\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/_extension.py:103\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 99\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mException\u001B[39;00m:\n\u001B[1;32m 100\u001B[0m \u001B[38;5;28;01mpass\u001B[39;00m\n\u001B[0;32m--> 103\u001B[0m \u001B[43m_init_extension\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/_extension.py:88\u001B[0m, in \u001B[0;36m_init_extension\u001B[0;34m()\u001B[0m\n\u001B[1;32m 85\u001B[0m warnings\u001B[38;5;241m.\u001B[39mwarn(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtorchaudio C++ extension is not available.\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 86\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m\n\u001B[0;32m---> 88\u001B[0m \u001B[43m_load_lib\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mlibtorchaudio\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m 89\u001B[0m \u001B[38;5;66;03m# This import is for initializing the methods registered via PyBind11\u001B[39;00m\n\u001B[1;32m 90\u001B[0m \u001B[38;5;66;03m# This has to happen after the base library is loaded\u001B[39;00m\n\u001B[1;32m 91\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _torchaudio \u001B[38;5;66;03m# noqa\u001B[39;00m\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/_extension.py:51\u001B[0m, in \u001B[0;36m_load_lib\u001B[0;34m(lib)\u001B[0m\n\u001B[1;32m 49\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m path\u001B[38;5;241m.\u001B[39mexists():\n\u001B[1;32m 50\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[0;32m---> 51\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mops\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload_library\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpath\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 52\u001B[0m torch\u001B[38;5;241m.\u001B[39mclasses\u001B[38;5;241m.\u001B[39mload_library(path)\n\u001B[1;32m 53\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mTrue\u001B[39;00m\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torch/_ops.py:643\u001B[0m, in \u001B[0;36m_Ops.load_library\u001B[0;34m(self, path)\u001B[0m\n\u001B[1;32m 638\u001B[0m path \u001B[38;5;241m=\u001B[39m _utils_internal\u001B[38;5;241m.\u001B[39mresolve_library_path(path)\n\u001B[1;32m 639\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m dl_open_guard():\n\u001B[1;32m 640\u001B[0m \u001B[38;5;66;03m# Import the shared library into the process, thus running its\u001B[39;00m\n\u001B[1;32m 641\u001B[0m \u001B[38;5;66;03m# static (global) initialization code in order to register custom\u001B[39;00m\n\u001B[1;32m 642\u001B[0m \u001B[38;5;66;03m# operators with the JIT.\u001B[39;00m\n\u001B[0;32m--> 643\u001B[0m \u001B[43mctypes\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mCDLL\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpath\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 644\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mloaded_libraries\u001B[38;5;241m.\u001B[39madd(path)\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/ctypes/__init__.py:382\u001B[0m, in \u001B[0;36mCDLL.__init__\u001B[0;34m(self, name, mode, handle, use_errno, use_last_error, winmode)\u001B[0m\n\u001B[1;32m 379\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_FuncPtr \u001B[38;5;241m=\u001B[39m _FuncPtr\n\u001B[1;32m 381\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m handle \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m--> 382\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_handle \u001B[38;5;241m=\u001B[39m \u001B[43m_dlopen\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_name\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmode\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 383\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 384\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_handle \u001B[38;5;241m=\u001B[39m handle\n", - "\u001B[0;31mOSError\u001B[0m: dlopen(/Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so, 0x0006): Symbol not found: __ZN2at4_ops15sum_dim_IntList4callERKNS_6TensorEN3c108ArrayRefIxEEbNS5_8optionalINS5_10ScalarTypeEEE\n Referenced from: <34C7FCDA-98E6-3DB6-B57D-478635DE1F58> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so\n Expected in: <89972BE7-3028-34DA-B561-E66870D59767> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib" - ] - } - ], - "source": [ - "from IPython.display import Audio\n", - "\n", - "from torchaudio.utils import download_asset\n", - "\n", - "SAMPLE_WAV = download_asset(\"tutorial-assets/steam-train-whistle-daniel_simon.wav\")\n", - "SAMPLE_RIR = download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo-8000hz.wav\")\n", - "SAMPLE_SPEECH = download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042-8000hz.wav\")\n", - "SAMPLE_NOISE = download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo-8000hz.wav\")" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-03-11T20:15:46.720641Z", - "start_time": "2024-03-11T20:15:40.451397Z" - } - } - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": 15, - "outputs": [], - "source": [ - "import IPython" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-03-11T18:24:40.427485Z", - "start_time": "2024-03-11T18:24:40.423546Z" - } - } - }, - { - "cell_type": "code", - "execution_count": 18, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "bach_old_1.mp3\r\n" - ] - } - ], - "source": [ - "ls scratch/files-with-reverb" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-03-11T18:25:01.398215Z", - "start_time": "2024-03-11T18:25:01.250743Z" - } - } - }, - { - "cell_type": "code", - "execution_count": 21, - "outputs": [ - { - "data": { - "text/plain": "", - "text/html": "\n \n " - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "IPython.display.Audio('scratch/files-with-reverb/bach_old_1.mp3')" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-03-11T18:25:42.144419Z", - "start_time": "2024-03-11T18:25:42.058763Z" - } - } - }, - { - "cell_type": "code", - "execution_count": 22, - "outputs": [ - { - "ename": "OSError", - "evalue": "dlopen(/Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so, 0x0006): Symbol not found: __ZN2at4_ops15sum_dim_IntList4callERKNS_6TensorEN3c108ArrayRefIxEEbNS5_8optionalINS5_10ScalarTypeEEE\n Referenced from: <34C7FCDA-98E6-3DB6-B57D-478635DE1F58> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so\n Expected in: <89972BE7-3028-34DA-B561-E66870D59767> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib", - "output_type": "error", - "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mOSError\u001B[0m Traceback (most recent call last)", - "Input \u001B[0;32mIn [22]\u001B[0m, in \u001B[0;36m\u001B[0;34m()\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mspeechbrain\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mpretrained\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m SpectralMaskEnhancement\n\u001B[1;32m 2\u001B[0m model \u001B[38;5;241m=\u001B[39m SpectralMaskEnhancement\u001B[38;5;241m.\u001B[39mfrom_hparams(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mspeechbrain/mtl-mimic-voicebank\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/speechbrain/__init__.py:4\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;124;03m\"\"\" Comprehensive speech processing toolkit\u001B[39;00m\n\u001B[1;32m 2\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mos\u001B[39;00m\n\u001B[0;32m----> 4\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mcore\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Stage, Brain, create_experiment_directory, parse_arguments\n\u001B[1;32m 5\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m alignment \u001B[38;5;66;03m# noqa\u001B[39;00m\n\u001B[1;32m 6\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m dataio \u001B[38;5;66;03m# noqa\u001B[39;00m\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/speechbrain/core.py:36\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 34\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mnn\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mparallel\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m DistributedDataParallel \u001B[38;5;28;01mas\u001B[39;00m DDP\n\u001B[1;32m 35\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mhyperpyyaml\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m resolve_references\n\u001B[0;32m---> 36\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mspeechbrain\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdistributed\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m run_on_main\n\u001B[1;32m 37\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mspeechbrain\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdataio\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdataloader\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m LoopedLoader\n\u001B[1;32m 38\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mspeechbrain\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdataio\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdataloader\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m SaveableDataLoader\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/speechbrain/utils/__init__.py:11\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 8\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m filename\u001B[38;5;241m.\u001B[39mendswith(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m.py\u001B[39m\u001B[38;5;124m\"\u001B[39m) \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m filename\u001B[38;5;241m.\u001B[39mstartswith(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m__\u001B[39m\u001B[38;5;124m\"\u001B[39m):\n\u001B[1;32m 9\u001B[0m __all__\u001B[38;5;241m.\u001B[39mappend(filename[:\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m3\u001B[39m])\n\u001B[0;32m---> 11\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/speechbrain/utils/parameter_transfer.py:12\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 9\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mlogging\u001B[39;00m\n\u001B[1;32m 10\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mpathlib\u001B[39;00m\n\u001B[0;32m---> 12\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mspeechbrain\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mpretrained\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mfetching\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m fetch\n\u001B[1;32m 13\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mspeechbrain\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mcheckpoints\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m (\n\u001B[1;32m 14\u001B[0m DEFAULT_LOAD_HOOKS,\n\u001B[1;32m 15\u001B[0m DEFAULT_TRANSFER_HOOKS,\n\u001B[1;32m 16\u001B[0m PARAMFILE_EXT,\n\u001B[1;32m 17\u001B[0m get_default_hook,\n\u001B[1;32m 18\u001B[0m )\n\u001B[1;32m 20\u001B[0m logger \u001B[38;5;241m=\u001B[39m logging\u001B[38;5;241m.\u001B[39mgetLogger(\u001B[38;5;18m__name__\u001B[39m)\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/speechbrain/pretrained/__init__.py:3\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;124;03m\"\"\"Pretrained models\"\"\"\u001B[39;00m\n\u001B[0;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01minterfaces\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/speechbrain/pretrained/interfaces.py:18\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 16\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mspeechbrain\u001B[39;00m\n\u001B[1;32m 17\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\n\u001B[0;32m---> 18\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m\n\u001B[1;32m 19\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01msentencepiece\u001B[39;00m\n\u001B[1;32m 20\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtypes\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m SimpleNamespace\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/__init__.py:1\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ( \u001B[38;5;66;03m# noqa: F401\u001B[39;00m\n\u001B[1;32m 2\u001B[0m _extension,\n\u001B[1;32m 3\u001B[0m compliance,\n\u001B[1;32m 4\u001B[0m datasets,\n\u001B[1;32m 5\u001B[0m functional,\n\u001B[1;32m 6\u001B[0m io,\n\u001B[1;32m 7\u001B[0m kaldi_io,\n\u001B[1;32m 8\u001B[0m models,\n\u001B[1;32m 9\u001B[0m pipelines,\n\u001B[1;32m 10\u001B[0m sox_effects,\n\u001B[1;32m 11\u001B[0m transforms,\n\u001B[1;32m 12\u001B[0m utils,\n\u001B[1;32m 13\u001B[0m )\n\u001B[1;32m 14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbackend\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m get_audio_backend, list_audio_backends, set_audio_backend\n\u001B[1;32m 16\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/_extension.py:103\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 99\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mException\u001B[39;00m:\n\u001B[1;32m 100\u001B[0m \u001B[38;5;28;01mpass\u001B[39;00m\n\u001B[0;32m--> 103\u001B[0m \u001B[43m_init_extension\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/_extension.py:88\u001B[0m, in \u001B[0;36m_init_extension\u001B[0;34m()\u001B[0m\n\u001B[1;32m 85\u001B[0m warnings\u001B[38;5;241m.\u001B[39mwarn(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtorchaudio C++ extension is not available.\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 86\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m\n\u001B[0;32m---> 88\u001B[0m \u001B[43m_load_lib\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mlibtorchaudio\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m 89\u001B[0m \u001B[38;5;66;03m# This import is for initializing the methods registered via PyBind11\u001B[39;00m\n\u001B[1;32m 90\u001B[0m \u001B[38;5;66;03m# This has to happen after the base library is loaded\u001B[39;00m\n\u001B[1;32m 91\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _torchaudio \u001B[38;5;66;03m# noqa\u001B[39;00m\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/_extension.py:51\u001B[0m, in \u001B[0;36m_load_lib\u001B[0;34m(lib)\u001B[0m\n\u001B[1;32m 49\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m path\u001B[38;5;241m.\u001B[39mexists():\n\u001B[1;32m 50\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[0;32m---> 51\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mops\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload_library\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpath\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 52\u001B[0m torch\u001B[38;5;241m.\u001B[39mclasses\u001B[38;5;241m.\u001B[39mload_library(path)\n\u001B[1;32m 53\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mTrue\u001B[39;00m\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torch/_ops.py:643\u001B[0m, in \u001B[0;36m_Ops.load_library\u001B[0;34m(self, path)\u001B[0m\n\u001B[1;32m 638\u001B[0m path \u001B[38;5;241m=\u001B[39m _utils_internal\u001B[38;5;241m.\u001B[39mresolve_library_path(path)\n\u001B[1;32m 639\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m dl_open_guard():\n\u001B[1;32m 640\u001B[0m \u001B[38;5;66;03m# Import the shared library into the process, thus running its\u001B[39;00m\n\u001B[1;32m 641\u001B[0m \u001B[38;5;66;03m# static (global) initialization code in order to register custom\u001B[39;00m\n\u001B[1;32m 642\u001B[0m \u001B[38;5;66;03m# operators with the JIT.\u001B[39;00m\n\u001B[0;32m--> 643\u001B[0m \u001B[43mctypes\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mCDLL\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpath\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 644\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mloaded_libraries\u001B[38;5;241m.\u001B[39madd(path)\n", - "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/ctypes/__init__.py:382\u001B[0m, in \u001B[0;36mCDLL.__init__\u001B[0;34m(self, name, mode, handle, use_errno, use_last_error, winmode)\u001B[0m\n\u001B[1;32m 379\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_FuncPtr \u001B[38;5;241m=\u001B[39m _FuncPtr\n\u001B[1;32m 381\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m handle \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m--> 382\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_handle \u001B[38;5;241m=\u001B[39m \u001B[43m_dlopen\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_name\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmode\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 383\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 384\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_handle \u001B[38;5;241m=\u001B[39m handle\n", - "\u001B[0;31mOSError\u001B[0m: dlopen(/Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so, 0x0006): Symbol not found: __ZN2at4_ops15sum_dim_IntList4callERKNS_6TensorEN3c108ArrayRefIxEEbNS5_8optionalINS5_10ScalarTypeEEE\n Referenced from: <34C7FCDA-98E6-3DB6-B57D-478635DE1F58> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so\n Expected in: <89972BE7-3028-34DA-B561-E66870D59767> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib" - ] - } - ], - "source": [ - "from speechbrain.pretrained import SpectralMaskEnhancement\n", - "model = SpectralMaskEnhancement.from_hparams(\"speechbrain/mtl-mimic-voicebank\")" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-03-11T18:38:33.342704Z", - "start_time": "2024-03-11T18:38:09.584639Z" - } - } - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false - } - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/scripts/eval/adjust.py b/scripts/eval/adjust.py new file mode 100644 index 0000000..e3e8438 --- /dev/null +++ b/scripts/eval/adjust.py @@ -0,0 +1,58 @@ +import argparse +import glob +import os + +from aria.data.midi import MidiDict + + +def get_matched_paths(orig_dir: str, adj_dir: str): + # Assume that the files have the same path relative to their directory + res = [] + orig_paths = glob.glob(os.path.join(orig_dir, "**/*.mid"), recursive=True) + print(f"found {len(orig_paths)} mid files") + + for mid_path in orig_paths: + orig_rel_path = os.path.relpath(mid_path, orig_dir) + adj_path = os.path.join(adj_dir, orig_rel_path) + orig_path = os.path.join(orig_dir, orig_rel_path) + + res.append((os.path.abspath(orig_path), os.path.abspath(adj_path))) + + print(f"found {len(res)} matched mp3-midi pairs") + assert len(orig_paths) == len(res) + + return res + + +def adjust_mid(orig_path: str, adj_path: str): + assert os.path.isfile(adj_path) is False + mid_dict = MidiDict.from_midi(orig_path) + mid_dict.resolve_pedal() + mid = mid_dict.to_midi() + + os.makedirs(os.path.dirname(adj_path), exist_ok=True) + mid.save(adj_path) + + +def main(): + parser = argparse.ArgumentParser( + description="Remove duplicate MP3 files based on audio content." + ) + parser.add_argument( + "orig_dir", type=str, help="Directory to scan for duplicate MP3 files." + ) + parser.add_argument( + "adj_dir", type=str, help="Directory to scan for duplicate MP3 files." + ) + args = parser.parse_args() + + matched_paths = get_matched_paths( + orig_dir=args.orig_dir, adj_dir=args.adj_dir + ) + + for orig_path, adj_path in matched_paths: + adjust_mid(orig_path, adj_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/eval/dtw.sh b/scripts/eval/dtw.sh deleted file mode 100644 index 1d25b71..0000000 --- a/scripts/eval/dtw.sh +++ /dev/null @@ -1,5 +0,0 @@ -python /home/loubb/work/aria-amt/scripts/eval/dtw.py \ - -audio_dir /mnt/ssd1/data/mp3/raw/aria-mp3 \ - -mid_dir /mnt/ssd1/amt/transcribed_data/0/aria-mid \ - -output_file /mnt/ssd1/amt/transcribed_data/0/aria-mid.csv - diff --git a/amt/evaluate.py b/scripts/eval/mir.py similarity index 92% rename from amt/evaluate.py rename to scripts/eval/mir.py index 282fc64..0bfc520 100644 --- a/amt/evaluate.py +++ b/scripts/eval/mir.py @@ -1,11 +1,14 @@ import glob from tqdm.auto import tqdm +from collections import defaultdict import pretty_midi import numpy as np import mir_eval import json import os +pretty_midi.pretty_midi.MAX_TICK = 1e10 + def midi_to_intervals_and_pitches(midi_file_path): """ @@ -73,6 +76,7 @@ def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0): open(output_stats_file, "w") if output_stats_file is not None else None ) + res = defaultdict(list) for est_file, ref_file in tqdm(est_ref_pairs): ref_intervals, ref_pitches = midi_to_intervals_and_pitches(ref_file) est_intervals, est_pitches = midi_to_intervals_and_pitches(est_file) @@ -84,8 +88,15 @@ def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0): if output_fhandle is not None: output_fhandle.write(json.dumps(scores)) output_fhandle.write("\n") + for k, v in scores.items(): + res[k].append(v) else: print(json.dumps(scores, indent=4)) + for k, v in scores.items(): + res[k].append(v) + + for k, v in res.items(): + print(k, sum(v) / len(v)) if __name__ == "__main__": diff --git a/scripts/eval/mir.sh b/scripts/eval/mir.sh deleted file mode 100644 index 0364d4c..0000000 --- a/scripts/eval/mir.sh +++ /dev/null @@ -1,5 +0,0 @@ -python /home/loubb/work/aria-amt/amt/evaluate.py \ - --est-dir /home/loubb/work/aria-amt/maestro-ft \ - --ref-dir /mnt/ssd1/data/mp3/raw/maestro-mp3 \ - --output-stats-file out.json - \ No newline at end of file diff --git a/scripts/eval/prune.sh b/scripts/eval/prune.sh deleted file mode 100644 index 9b3f89a..0000000 --- a/scripts/eval/prune.sh +++ /dev/null @@ -1,6 +0,0 @@ -python /home/loubb/work/aria-amt/scripts/eval/prune.py \ - -mid_dir /mnt/ssd1/amt/transcribed_data/0/pijama-mid \ - -output_dir /mnt/ssd1/amt/transcribed_data/0/pijama-mid-pruned \ - -score_file /mnt/ssd1/amt/transcribed_data/0/pijama-mid.csv \ - -max_score 0.42 \ - # -dry \ No newline at end of file diff --git a/scripts/eval/split.sh b/scripts/eval/split.sh deleted file mode 100644 index 05faece..0000000 --- a/scripts/eval/split.sh +++ /dev/null @@ -1,4 +0,0 @@ -python /home/loubb/work/aria-amt/scripts/eval/split.py \ - -mid_dir /mnt/ssd1/amt/transcribed_data/0/aria-mid-pruned \ - -audio_dir /mnt/ssd1/data/mp3/raw/aria-mp3 \ - -csv_path aria-pruned-split.csv diff --git a/tests/test_data.py b/tests/test_data.py index 8e1010c..9c566ad 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -9,16 +9,18 @@ from amt.data import get_wav_mid_segments, AmtDataset from amt.tokenizer import AmtTokenizer -from amt.audio import AudioTransform, log_mel_spectrogram +from amt.audio import AudioTransform from amt.train import get_dataloaders from aria.data.midi import MidiDict +from torch.utils.data import DataLoader + logging.basicConfig(level=logging.INFO) if os.path.isdir("tests/test_results") is False: os.mkdir("tests/test_results") -MAESTRO_PATH = "/mnt/ssd1/amt/training_data/train.txt" +MAESTRO_PATH = "/mnt/ssd1/amt/training_data/maestro/train-s15.txt" def plot_spec(mel: torch.Tensor, name: str | int): @@ -68,13 +70,43 @@ def test_build(self): ).to_midi() mid.save(f"tests/test_results/trunc_{idx}.mid") + def test_build_multiple(self): + matched_paths = [ + ("tests/test_data/maestro.wav", "tests/test_data/maestro1.mid") + for _ in range(2) + ] + if os.path.isfile("tests/test_results/dataset_1.jsonl"): + os.remove("tests/test_results/dataset_1.jsonl") + if os.path.isfile("tests/test_results/dataset_2.jsonl"): + os.remove("tests/test_results/dataset_2.jsonl") + + AmtDataset.build( + load_paths=matched_paths, + save_path="tests/test_results/dataset_1.jsonl", + ) + + AmtDataset.build( + load_paths=matched_paths, + save_path="tests/test_results/dataset_2.jsonl", + ) + + dataset = AmtDataset( + [ + "tests/test_results/dataset_1.jsonl", + "tests/test_results/dataset_2.jsonl", + ] + ) + + for idx, (wav, src, tgt, idx) in enumerate(dataset): + print(wav.shape, src.shape, tgt.shape) + def test_maestro(self): if not os.path.isfile(MAESTRO_PATH): return tokenizer = AmtTokenizer() audio_transform = AudioTransform() - dataset = AmtDataset(load_path=MAESTRO_PATH) + dataset = AmtDataset(load_paths=MAESTRO_PATH) print(f"Dataset length: {len(dataset)}") for idx, (wav, src, tgt, __idx) in enumerate(dataset): src_dec, tgt_dec = tokenizer.decode(src), tokenizer.decode(tgt) @@ -105,6 +137,34 @@ def test_maestro(self): for src_tok, tgt_tok in zip(src_dec[1:], tgt_dec): self.assertEqual(src_tok, tgt_tok) + def test_tensor_pitch_aug(self): + tokenizer = AmtTokenizer() + audio_transform = AudioTransform() + dataset = AmtDataset(load_paths=MAESTRO_PATH) + tensor_pitch_aug = AmtTokenizer().export_tensor_pitch_aug() + + dataloader = DataLoader( + dataset, + batch_size=4, + num_workers=1, + shuffle=False, + ) + + for batch in dataloader: + wav, src, tgt, idxs = batch + + src_p = tensor_pitch_aug(seq=src.clone(), shift=1)[0] + src_p_dec = tokenizer.decode(src_p) + + src_np = src.clone()[0] + src_np_dec = tokenizer.decode(src_np) + + for x, y in zip(src_p_dec, src_np_dec): + if x == "

": + break + else: + print(x, y) + class TestAug(unittest.TestCase): def test_spec(self): @@ -128,13 +188,6 @@ def test_spec(self): torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE) torchaudio.save("tests/test_results/shift.wav", shift_wav, SAMPLE_RATE) - log_mel = log_mel_spectrogram(wav) - plot_spec(log_mel.squeeze(0), "orig") - - _mel = audio_transform.mel_transform(spec) - _log_mel = audio_transform.norm_mel(_mel) - plot_spec(_log_mel.squeeze(0), "new") - def test_pitch_aug(self): tokenizer = AmtTokenizer(return_tensors=True) tensor_pitch_aug_fn = tokenizer.export_tensor_pitch_aug() @@ -184,13 +237,6 @@ def test_detune(self): torchaudio.save("tests/test_results/orig_gl.wav", gl_wav, SAMPLE_RATE) torchaudio.save("tests/test_results/detune.wav", shift_wav, SAMPLE_RATE) - log_mel = log_mel_spectrogram(wav) - plot_spec(log_mel.squeeze(0), "orig") - - _mel = audio_transform.mel_transform(spec) - _log_mel = audio_transform.norm_mel(_mel) - plot_spec(_log_mel.squeeze(0), "new") - def test_mels(self): SAMPLE_RATE, CHUNK_LEN = 16000, 30 audio_transform = AudioTransform() @@ -278,7 +324,7 @@ def load_data(self, dataloader, num_batches=100): def test_profile_dl(self): train_dataloader, val_dataloader = get_dataloaders( - train_data_path="/weka/proj-aria/aria-amt/data/train.jsonl", + train_data_paths="/weka/proj-aria/aria-amt/data/train.jsonl", val_data_path="/weka/proj-aria/aria-amt/data/train.jsonl", batch_size=16, num_workers=0,