From 1c9d66637bd558b3f51c22fc3cc4199ff6065b8b Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 9 Apr 2024 21:02:36 +0100 Subject: [PATCH] Fix training and inference (#24) * fix * fix batched * adjust * transfer * add cleanup * working * move_to_node * fix config * fix msg * add synth dataset gen * remote changes * local changes * add scripts * fix audio params --- .gitignore | 1 - amt/audio.py | 21 +- amt/data.py | 92 ++++- amt/evaluate.py | 78 ++--- amt/inference/transcribe.py | 327 +++++++++++------- amt/run.py | 155 ++++++++- amt/tokenizer.py | 11 +- amt/train.py | 189 +++++----- config/models/medium-final.json | 11 + .../models/{medium.json => small-final.json} | 4 +- requirements.txt | 5 +- scripts/eval/dedupe.py | 72 ++++ scripts/eval/dtw.py | 211 +++++++++++ scripts/eval/dtw.sh | 5 + scripts/eval/mir.sh | 5 + scripts/eval/prune.py | 91 +++++ scripts/eval/prune.sh | 6 + scripts/eval/req-eval.txt | 4 + scripts/eval/split.py | 57 +++ scripts/eval/split.sh | 4 + tests/test_data.py | 17 +- 21 files changed, 1064 insertions(+), 302 deletions(-) create mode 100644 config/models/medium-final.json rename config/models/{medium.json => small-final.json} (77%) create mode 100644 scripts/eval/dedupe.py create mode 100644 scripts/eval/dtw.py create mode 100644 scripts/eval/dtw.sh create mode 100644 scripts/eval/mir.sh create mode 100644 scripts/eval/prune.py create mode 100644 scripts/eval/prune.sh create mode 100644 scripts/eval/req-eval.txt create mode 100644 scripts/eval/split.py create mode 100644 scripts/eval/split.sh diff --git a/.gitignore b/.gitignore index 026b7af..1d36432 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,5 @@ # data files *.csv -*.json *.xls *.xlsx *.pkl diff --git a/amt/audio.py b/amt/audio.py index 6c37f3f..447822e 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -191,15 +191,15 @@ def __init__( max_snr: int = 50, max_dist_gain: int = 25, min_dist_gain: int = 0, - noise_ratio: float = 0.95, - reverb_ratio: float = 0.95, + noise_ratio: float = 0.75, + reverb_ratio: float = 0.75, applause_ratio: float = 0.01, bandpass_ratio: float = 0.15, distort_ratio: float = 0.15, reduce_ratio: float = 0.01, - detune_ratio: float = 0.1, - detune_max_shift: float = 0.15, - spec_aug_ratio: float = 0.5, + detune_ratio: float = 0.0, + detune_max_shift: float = 0.0, + spec_aug_ratio: float = 0.9, ): super().__init__() self.tokenizer = AmtTokenizer() @@ -223,7 +223,10 @@ def __init__( self.detune_ratio = detune_ratio self.detune_max_shift = detune_max_shift self.spec_aug_ratio = spec_aug_ratio - self.reduction_resample_rate = 6000 # Hardcoded? + + self.time_mask_param = 2500 + self.freq_mask_param = 15 + self.reduction_resample_rate = 6000 # Audio aug impulse_paths = self._get_paths( @@ -263,10 +266,10 @@ def __init__( ) self.spec_aug = torch.nn.Sequential( torchaudio.transforms.FrequencyMasking( - freq_mask_param=15, iid_masks=True + freq_mask_param=self.freq_mask_param, iid_masks=True ), torchaudio.transforms.TimeMasking( - time_mask_param=1000, iid_masks=True + time_mask_param=self.time_mask_param, iid_masks=True ), ) @@ -281,6 +284,8 @@ def get_params(self): "detune_ratio": self.detune_ratio, "detune_max_shift": self.detune_max_shift, "spec_aug_ratio": self.spec_aug_ratio, + "time_mask_param": self.time_mask_param, + "freq_mask_param": self.freq_mask_param, } def _get_paths(self, dir_path): diff --git a/amt/data.py b/amt/data.py index 377b760..ad8660f 100644 --- a/amt/data.py +++ b/amt/data.py @@ -1,6 +1,8 @@ import mmap import os import io +import random +import shlex import base64 import shutil import orjson @@ -16,7 +18,14 @@ from amt.audio import pad_or_trim -# Occasionally the worker util goes to 0 for some reason, debug this +def _check_onset_threshold(seq: list, onset: int): + for tok_1, tok_2 in zip(seq, seq[1:]): + if isinstance(tok_1, tuple) and tok_1[0] in ("on", "off"): + _onset = tok_2[1] + if _onset > onset: + return True + + return False def get_wav_mid_segments( @@ -24,6 +33,7 @@ def get_wav_mid_segments( mid_path: str = "", return_json: bool = False, stride_factor: int | None = None, + pad_last=False, ): """This function yields tuples of matched log mel spectrograms and tokenized sequences (np.array, list). If it is given only an audio path @@ -61,10 +71,12 @@ def get_wav_mid_segments( # Create features total_samples = wav.shape[-1] + pad_factor = 2 if pad_last is True else 1 res = [] for idx in range( 0, - total_samples - (num_samples - num_samples // stride_factor), + total_samples + - (num_samples - pad_factor * (num_samples // stride_factor)), num_samples // stride_factor, ): audio_feature = pad_or_trim(wav[idx:], length=num_samples) @@ -75,6 +87,12 @@ def get_wav_mid_segments( end_ms=(idx + num_samples) / samples_per_ms, max_pedal_len_ms=10000, ) + + # Hardcoded to 2.5s + if _check_onset_threshold(mid_feature, 2500) is False: + print("No note messages after 2.5s - skipping") + continue + else: mid_feature = [] @@ -86,6 +104,56 @@ def get_wav_mid_segments( return res +def pianoteq_cmd_fn(mid_path: str, wav_path: str): + presets = [ + "C. Bechstein", + "C. Bechstein Close Mic", + "C. Bechstein Under Lid", + "C. Bechstein 440", + "C. Bechstein Recording", + "C. Bechstein Werckmeister III", + "C. Bechstein Neidhardt III", + "C. Bechstein mesotonic", + "C. Bechstein well tempered", + "HB Steinway D Blues", + "HB Steinway D Pop", + "HB Steinway D New Age", + "HB Steinway D Prelude", + "HB Steinway D Felt I", + "HB Steinway D Felt II", + "HB Steinway Model D", + "HB Steinway D Classical Recording", + "HB Steinway D Jazz Recording", + "HB Steinway D Chamber Recording", + "HB Steinway D Studio Recording", + "HB Steinway D Intimate", + "HB Steinway D Cinematic", + "HB Steinway D Close Mic Classical", + "HB Steinway D Close Mic Jazz", + "HB Steinway D Player Wide", + "HB Steinway D Player Clean", + "HB Steinway D Trio", + "HB Steinway D Duo", + "HB Steinway D Cabaret", + "HB Steinway D Bright", + "HB Steinway D Hyper Bright", + "HB Steinway D Prepared", + "HB Steinway D Honky Tonk", + ] + + preset = random.choice(presets) + + # Safely quote the preset name, MIDI path, and WAV path + safe_preset = shlex.quote(preset) + safe_mid_path = shlex.quote(mid_path) + safe_wav_path = shlex.quote(wav_path) + + # Construct the command + command = f"/home/mchorse/pianoteq/x86-64bit/Pianoteq\\ 8\\ STAGE --preset {safe_preset} --midi {safe_mid_path} --wav {safe_wav_path}" + + return command + + def write_features(audio_path: str, mid_path: str, save_path: str): features = get_wav_mid_segments( audio_path=audio_path, @@ -121,7 +189,7 @@ def write_synth_features(cli_cmd_fn: Callable, mid_path: str, save_path: str): try: get_synth_audio( - cli_cmd=cli_cmd_fn, mid_path=mid_path, wav_path=audio_path_temp + cli_cmd_fn=cli_cmd_fn, mid_path=mid_path, wav_path=audio_path_temp ) except: if os.path.isfile(audio_path_temp): @@ -133,7 +201,11 @@ def write_synth_features(cli_cmd_fn: Callable, mid_path: str, save_path: str): mid_path=mid_path, return_json=False, ) - os.remove(audio_path_temp) + + if os.path.isfile(audio_path_temp): + os.remove(audio_path_temp) + + print(f"Found {len(features)}") with open(save_path, mode="a") as file: for wav, seq in features: @@ -174,7 +246,11 @@ def build_synth_worker_fn( while not load_path_queue.empty(): mid_path = load_path_queue.get() - write_synth_features(cli_cmd, mid_path, worker_save_path) + try: + write_synth_features(cli_cmd, mid_path, worker_save_path) + except Exception as e: + print("Failed") + print(e) save_path_queue.put(worker_save_path) @@ -239,7 +315,7 @@ def _format(tok): seq_len=self.config["max_seq_len"], ) - return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt) + return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt), idx def _build_index(self): self.file_mmap.seek(0) @@ -254,7 +330,7 @@ def _build_index(self): return index - def _save_index(self, index: list[int], save_path: str): + def _save_index(self, index: list, save_path: str): with open(save_path, "w") as file: for idx in index: file.write(f"{idx}\n") @@ -325,7 +401,7 @@ def build( ] else: # Build synthetic dataset - assert len(load_paths[0]) == 1, "Invalid load paths" + assert isinstance(load_paths[0], str), "Invalid load paths" print("Building synthetic dataset") worker_processes = [ Process( diff --git a/amt/evaluate.py b/amt/evaluate.py index d79e885..282fc64 100644 --- a/amt/evaluate.py +++ b/amt/evaluate.py @@ -42,27 +42,32 @@ def midi_to_hz(note, shift=0): # return (a / 32) * (2 ** ((note - 9) / 12)) +def get_matched_files(est_dir: str, ref_dir: str): + # We assume that the files have the same path relative to their directory + + res = [] + est_paths = glob.glob(os.path.join(est_dir, "**/*.mid"), recursive=True) + print(f"found {len(est_paths)} est files") + + for est_path in est_paths: + est_rel_path = os.path.relpath(est_path, est_dir) + ref_path = os.path.join( + ref_dir, os.path.splitext(est_rel_path)[0] + ".midi" + ) + if os.path.isfile(ref_path): + res.append((est_path, ref_path)) + + print(f"found {len(res)} matched est-ref pairs") + + return res + + def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0): """ Evaluate the estimated pitches against the reference pitches using mir_eval. """ - # Evaluate the estimated pitches against the reference pitches - ref_midi_files = glob.glob(f"{ref_dir}/*.mid*") - est_midi_files = glob.glob(f"{est_dir}/*.mid*") - - est_ref_pairs = [] - for est_fpath in est_midi_files: - ref_fpath = os.path.join(ref_dir, os.path.basename(est_fpath)) - if ref_fpath in ref_midi_files: - est_ref_pairs.append((est_fpath, ref_fpath)) - if ref_fpath.replace(".mid", ".midi") in ref_midi_files: - est_ref_pairs.append( - (est_fpath, ref_fpath.replace(".mid", ".midi")) - ) - else: - print( - f"Reference file not found for {est_fpath} (ref file: {ref_fpath})" - ) + + est_ref_pairs = get_matched_files(est_dir, ref_dir) output_fhandle = ( open(output_stats_file, "w") if output_stats_file is not None else None @@ -104,38 +109,9 @@ def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0): help="Path to the file to save the evaluation stats", ) - # add mir_eval and dtw subparsers - subparsers = parser.add_subparsers(help="sub-command help") - mir_eval_parse = subparsers.add_parser( - "run_mir_eval", - help="Run standard mir_eval evaluation on MAESTRO test set.", - ) - mir_eval_parse.add_argument( - "--shift", - type=int, - default=0, - help="Shift to apply to the estimated pitches.", - ) - - # to come - dtw_eval_parse = subparsers.add_parser( - "run_dtw", - help="Run dynamic time warping evaluation on a specified dataset.", - ) - args = parser.parse_args() - if not hasattr(args, "command"): - parser.print_help() - print("Unrecognized command") - exit(1) - - # todo: should we add an option to run transcription again every time we wish to evaluate? - # that way, we can run both tests with a range of different audio augmentations right here. - # -> We expect that baseline methods will fall flat on these, while aria-amt will be OK. - - if args.command == "run_mir_eval": - evaluate_mir_eval( - args.est_dir, args.ref_dir, args.output_stats_file, args.shift - ) - elif args.command == "run_dtw": - pass + evaluate_mir_eval( + args.est_dir, + args.ref_dir, + args.output_stats_file, + ) diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index 9109e6a..4984dad 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -1,4 +1,6 @@ import os +import sys +import signal import time import random import logging @@ -10,7 +12,6 @@ import torch._inductor.config from torch.multiprocessing import Queue -from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm from functools import wraps from torch.cuda import is_bf16_supported @@ -31,15 +32,17 @@ CHUNK_LEN_MS = LEN_MS // STRIDE_FACTOR -def _setup_logger(): +def _setup_logger(name: str | None = None): + logger_name = f"[{name}] " if name else "" logger = logging.getLogger(__name__) for h in logger.handlers[:]: logger.removeHandler(h) logger.propagate = False logger.setLevel(logging.DEBUG) + # Adjust the formatter to include the name before the PID if provided formatter = logging.Formatter( - "[%(asctime)s] %(process)d: [%(levelname)s] %(message)s", + f"[%(asctime)s] {logger_name}%(process)d: [%(levelname)s] %(message)s", ) ch = logging.StreamHandler() @@ -110,12 +113,12 @@ def update_seq_end_idxs_( prefix_lens: torch.Tensor, idx: int, ): - # Update eos_idxs if next tok is eos_tok - eos_mask = next_tok_ids == 1 + # Update eos_idxs if next tok is eos_tok and not eos_id < idx + eos_mask = (next_tok_ids == 1) & (eos_idxs > idx) eos_idxs[eos_mask] = idx # Update eos_idxs if next tok in onset > 20000 - offset_mask = next_tok_ids >= 2418 + offset_mask = (next_tok_ids >= 2418) & (eos_idxs > idx) eos_idxs[offset_mask] = idx - 2 # Don't update toks in prefix or after eos_idx @@ -136,6 +139,7 @@ def wrapper(*args, **kwargs): return wrapper +@torch.no_grad() def decode_token( model: AmtEncoderDecoder, x: torch.Tensor, @@ -163,22 +167,23 @@ def process_segments( tokenizer: AmtTokenizer, logger: logging.Logger, ): - audio_segs = torch.stack( - [audio_seg for (audio_seg, prefix), _ in tasks] - ).cuda() - log_mels = audio_transform.log_mel(audio_segs) + log_mels = audio_transform.log_mel( + torch.stack([audio_seg.cuda() for (audio_seg, prefix), _ in tasks]) + ) audio_features = model.encoder(xa=log_mels) raw_prefixes = [prefix for (audio_seg, prefix), _ in tasks] prefix_lens = torch.tensor( [len(prefix) for prefix in raw_prefixes], dtype=torch.int - ) + ).cuda() min_prefix_len = min(prefix_lens).item() prefixes = [ tokenizer.trunc_seq(prefix, MAX_BLOCK_LEN) for prefix in raw_prefixes ] seq = torch.stack([tokenizer.encode(prefix) for prefix in prefixes]).cuda() - eos_idxs = torch.tensor([MAX_BLOCK_LEN for _ in prefixes], dtype=torch.int) + eos_idxs = torch.tensor( + [MAX_BLOCK_LEN for _ in prefixes], dtype=torch.int + ).cuda() # for idx in ( # pbar := tqdm( @@ -243,34 +248,39 @@ def process_segments( return results -# There is a memory leak in here somewhere def gpu_manager( gpu_batch_queue: Queue, result_queue: Queue, model: AmtEncoderDecoder, batch_size: int, + compile: bool = False, gpu_id: int | None = None, ): - logger = _setup_logger() + if gpu_id: + logger = _setup_logger(name=f"GPU-{gpu_id}") + else: + logger = _setup_logger(name=f"GPU") + logger.info("Started GPU manager") if gpu_id is not None: os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) - global decode_token, recalculate_tok_ids model.decoder.setup_cache(batch_size=batch_size, max_seq_len=MAX_BLOCK_LEN) model.cuda() model.eval() - if batch_size == 1: - recalculate_tok_ids = torch.compile( - recalculate_tok_ids, mode="max-autotune-no-cudagraphs" + if compile is True: + global decode_token, recalculate_tok_ids + if batch_size == 1: + recalculate_tok_ids = torch.compile( + recalculate_tok_ids, mode="max-autotune-no-cudagraphs" + ) + decode_token = torch.compile( + decode_token, + # mode="reduce-overhead", + # mode="max-autotune", + fullgraph=True, ) - decode_token = torch.compile( - decode_token, - # mode="reduce-overhead", - mode="max-autotune", - fullgraph=True, - ) audio_transform = AudioTransform().cuda() tokenizer = AmtTokenizer(return_tensors=True) @@ -278,9 +288,9 @@ def gpu_manager( try: while True: try: - batch = gpu_batch_queue.get(timeout=10) + batch = gpu_batch_queue.get(timeout=30) except Exception as e: - logger.info(f"GPU timedout waiting for batch") + logger.info(f"GPU timed out waiting for batch") break else: try: @@ -339,13 +349,13 @@ def gpu_batch_manager( gpu_batch_queue: Queue, batch_size: int, ): - logger = _setup_logger() + logger = _setup_logger(name="B") logger.info("Started batch manager") try: tasks = [] while True: try: - task, pid = gpu_task_queue.get(timeout=0.2) + task, pid = gpu_task_queue.get(timeout=0.05) except Exception as e: pass else: @@ -360,7 +370,9 @@ def gpu_batch_manager( # Get new batch and add to batch queue if len(tasks) < batch_size: - logger.info("Not enough tasks - padding batch") + logger.warning( + f"Not enough tasks ({len(tasks)}) - padding batch" + ) while len(tasks) < batch_size: _pad_task, _pid = tasks[0] tasks.append((_pad_task, -1)) @@ -397,7 +409,6 @@ def _truncate_seq( seq: list, start_ms: int, end_ms: int, - logger: logging.Logger, tokenizer: AmtTokenizer = AmtTokenizer(), ): # Truncates and shifts a sequence by retokenizing the underlying midi_dict @@ -408,16 +419,15 @@ def _truncate_seq( random.shuffle(unclosed_notes) return [("prev", p) for p in unclosed_notes] + [tokenizer.bos_tok] else: - try: - _mid_dict = tokenizer._detokenize_midi_dict(seq, LEN_MS) - res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1) - except Exception as e: - logger.error(f"Truncate segment failed: {e}") + _mid_dict = tokenizer._detokenize_midi_dict(seq, LEN_MS) + if len(_mid_dict.note_msgs) == 0: return [tokenizer.bos_tok] else: - if res[-1] == tokenizer.eos_tok: - res.pop() - return res + res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1) + + if res[-1] == tokenizer.eos_tok: + res.pop() + return res def transcribe_file( @@ -433,63 +443,89 @@ def transcribe_file( audio_segments = [ f for f, _ in get_wav_mid_segments( - audio_path=file_path, stride_factor=STRIDE_FACTOR + audio_path=file_path, + stride_factor=STRIDE_FACTOR, + pad_last=True, ) ] res = [] seq = [tokenizer.bos_tok] concat_seq = [tokenizer.bos_tok] - for idx, audio_seg in enumerate(audio_segments): + idx = 0 + while audio_segments: init_idx = len(seq) # Add to gpu queue and wait for results - gpu_task_queue.put(((audio_seg, seq), pid)) + gpu_task_queue.put(((audio_segments.pop(0), seq), pid)) while True: - # Issue with this logic perhaps - gpu_result = result_queue.get(timeout=300) - if gpu_result["pid"] == pid: - seq = gpu_result["result"] - break + try: + gpu_result = result_queue.get(timeout=0.1) + except Exception as e: + pass else: - result_queue.put(gpu_result) - - concat_seq += _shift_onset( - seq[init_idx:], - idx * CHUNK_LEN_MS, - ) - - if idx == len(audio_segments) - 1: - res.append(concat_seq) - elif concat_seq[-1] == tokenizer.eos_tok: - res.append(concat_seq) - seq = [tokenizer.bos_tok] - concat_seq = [tokenizer.bos_tok] - logger.info(f"Finished segment (eos_tok): {file_path}") - else: - # This might need it's logic adjusted + if gpu_result["pid"] == pid: + seq = gpu_result["result"] + break + else: + result_queue.put(gpu_result) - seq = _truncate_seq( + try: + next_seq = _truncate_seq( seq, CHUNK_LEN_MS, LEN_MS - CHUNK_LEN_MS, - logger=logger, ) + except Exception as e: + logger.info( + f"Skipping segment {idx} (failed to transcribe): {file_path}" + ) + logger.debug(traceback.format_exc()) + seq = [tokenizer.bos_tok] + else: + if seq[-1] == tokenizer.eos_tok: + logger.info(f"Seen eos_tok at segment {idx}: {file_path}") + seq = seq[:-1] - if len(seq) == 1: - logger.error(f"Failed to transcribe segment: {file_path}") - if len(concat_seq) > 500: - res.append(concat_seq) - else: - pass - # logger.info(f"Sequence too short ({len(concat_seq)})") - + if len(next_seq) == 1: + logger.info(f"Skipping segment {idx} (silence): {file_path}") seq = [tokenizer.bos_tok] - concat_seq = [tokenizer.bos_tok] + else: + concat_seq += _shift_onset( + seq[init_idx:], + idx * CHUNK_LEN_MS, + ) + seq = next_seq + + idx += 1 + + res.append(concat_seq) return res +def get_save_path( + file_path: str, + input_dir: str, + save_dir: str, + idx: int | str = "", +): + if input_dir is None: + save_path = os.path.join( + save_dir, + os.path.splitext(os.path.basename(file_path))[0] + f"{idx}.mid", + ) + else: + input_rel_path = os.path.relpath(file_path, input_dir) + save_path = os.path.join( + save_dir, os.path.splitext(input_rel_path)[0] + f"{idx}.mid" + ) + if not os.path.isdir(os.path.dirname(save_path)): + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + return save_path + + def process_file( file_path: str, file_queue: Queue, @@ -517,29 +553,14 @@ def _save_seq(_seq: list, _save_path: str): mid.save(_save_path) except Exception as e: logger.error(f"Failed to save {_save_path}") - - def _get_save_path(_file_path: str, _idx: int | str = ""): - if input_dir is None: - save_path = os.path.join( - save_dir, - os.path.splitext(os.path.basename(file_path))[0] - + f"{_idx}.mid", - ) - else: - input_rel_path = os.path.relpath(_file_path, input_dir) - save_path = os.path.join( - save_dir, os.path.splitext(input_rel_path)[0] + f"{_idx}.mid" - ) - if not os.path.isdir(os.path.dirname(save_path)): - os.makedirs(os.path.dirname(save_path), exist_ok=True) - - return save_path + logger.debug(traceback.format_exc()) + logger.debug(_seq) def remove_failures_from_queue_(_queue: Queue, _pid: int): _buff = [] while True: try: - _buff.append(_queue(timout=5)) + _buff.append(_queue.get(timeout=5)) except Exception: break @@ -564,18 +585,32 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int): return logger.info(f"Finished file: {file_path}") - _idx = 0 for seq in seqs: - if len(seq) < 1000: + if len(seq) < 500: logger.info("Skipping seq - too short") - continue - _save_seq(seq, _get_save_path(file_path, _idx)) - _idx += 1 + else: + logger.debug( + f"Saving seq of length {len(seq)} from file: {file_path}" + ) + + _save_seq(seq, get_save_path(file_path, input_dir, save_dir)) - logger.info(f"Transcribed into {_idx} segment(s)") logger.info(f"{file_queue.qsize()} file(s) remaining in queue") +def watchdog(main_gpu_pid: int, child_pids: list): + while True: + if not os.path.exists(f"/proc/{main_gpu_pid}"): + print("Cleaning up children...") + for pid in child_pids: + try: + os.kill(pid, signal.SIGTERM) + except ProcessLookupError: + pass + + time.sleep(1) + + def worker( file_queue: Queue, gpu_task_queue: Queue, @@ -584,7 +619,7 @@ def worker( input_dir: str | None = None, tasks_per_worker: int = 1, ): - logger = _setup_logger() + logger = _setup_logger(name="F") tokenizer = AmtTokenizer() threads = [] try: @@ -629,8 +664,10 @@ def batch_transcribe( batch_size: int = 16, input_dir: str | None = None, gpu_ids: int | None = None, - quantize: bool = True, + quantize: bool = False, + compile: bool = False, ): + assert os.name == "posix", "UNIX/LINUX is the only supported OS" torch.multiprocessing.set_start_method("spawn") num_gpus = len(gpu_ids) if gpu_ids is not None else 1 logger = _setup_logger() @@ -639,17 +676,31 @@ def batch_transcribe( os.remove("transcribe.log") if quantize is True: - logger.info("Quantising weights to int8") - model = quantize_int8(model) + logger.info("Quantising decoder weights to int8") + model.decoder = quantize_int8(model.decoder) + + file_queue = Queue() + for file_path in file_paths: + if ( + os.path.isfile(get_save_path(file_path, input_dir, save_dir)) + is False + ): + file_queue.put(file_path) + elif len(file_paths) == 1: + file_queue.put(file_path) + + logger.info(f"Files to process: {file_queue.qsize()}/{len(file_paths)}") + + num_workers = min( + min(batch_size * num_gpus, multiprocessing.cpu_count() - num_gpus), + file_queue.qsize(), + ) gpu_task_queue = Queue() gpu_batch_queue = Queue() result_queue = Queue() - file_queue = Queue() - for file_path in file_paths: - file_queue.put(file_path) - num_workers = min(batch_size * num_gpus, len(file_paths)) + child_pids = [] logger.info(f"Creating {num_workers} file worker(s)") worker_processes = [ multiprocessing.Process( @@ -660,47 +711,73 @@ def batch_transcribe( result_queue, save_dir, input_dir, - # Wait for all threads to finish - 4, + 5, ), ) for _ in range(num_workers) ] + + for p in worker_processes: + p.start() + child_pids.append(p.pid) + gpu_batch_manager_process = multiprocessing.Process( target=gpu_batch_manager, args=(gpu_task_queue, gpu_batch_queue, batch_size), ) + gpu_batch_manager_process.start() + child_pids.append(gpu_batch_manager_process.pid) + time.sleep(5) start_time = time.time() - if num_gpus == 1: - gpu_manager_processes = [ - multiprocessing.Process( - target=gpu_manager, - args=(gpu_batch_queue, result_queue, model, batch_size), - ) - ] - else: + + if num_gpus > 1: gpu_manager_processes = [ multiprocessing.Process( target=gpu_manager, - args=(gpu_batch_queue, result_queue, model, batch_size, gpu_id), + args=( + gpu_batch_queue, + result_queue, + model, + batch_size, + compile, + gpu_id, + ), ) for gpu_id in gpu_ids ] + for p in gpu_manager_processes: + p.start() + watchdog_process = multiprocessing.Process( + target=watchdog, args=(gpu_batch_manager_process.pid, child_pids) + ) + watchdog_process.start() + else: + gpu_manager_processes = None + watchdog_process = multiprocessing.Process( + target=watchdog, args=(os.getpid(), child_pids) + ) + watchdog_process.start() + gpu_manager( + gpu_batch_queue, + result_queue, + model, + batch_size, + compile, + ) - for p in worker_processes: - p.start() - time.sleep(5) - gpu_batch_manager_process.start() - for p in gpu_manager_processes: - p.start() + if gpu_manager_processes is not None: + for p in gpu_manager_processes: + p.join() - # Watch for file workers to finish for p in worker_processes: + p.terminate() p.join() - for p in gpu_manager_processes: - p.join() + gpu_batch_manager_process.terminate() + gpu_batch_manager_process.join() + watchdog_process.terminate() + watchdog_process.join() print("Took", (time.time() - start_time) / 60, "mins to transcribe files") diff --git a/amt/run.py b/amt/run.py index 1b2bcc5..1af6ced 100644 --- a/amt/run.py +++ b/amt/run.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import argparse -import sys import os import glob @@ -11,7 +10,6 @@ # TODO: Implement a way of inferring the tokenizer name automatically def _add_maestro_args(subparser): subparser.add_argument("dir", help="MAESTRO directory path") - subparser.add_argument("csv", help="MAESTRO csv path") subparser.add_argument("-train", help="train save path", required=True) subparser.add_argument("-val", help="val save path", required=True) subparser.add_argument("-test", help="test save path", required=True) @@ -19,7 +17,20 @@ def _add_maestro_args(subparser): "-mp", help="number of processes to use", type=int, - required=False, + default=1, + ) + + +def _add_synth_args(subparser): + subparser.add_argument("dir", help="Directory containing MIDIs") + subparser.add_argument("csv", help="Split csv") + subparser.add_argument("-train", help="train save path", required=True) + subparser.add_argument("-test", help="test save path", required=True) + subparser.add_argument( + "-mp", + help="number of processes to use", + type=int, + default=1, ) @@ -32,6 +43,12 @@ def _add_transcribe_args(subparser): subparser.add_argument( "-load_dir", help="dir containing mp3/wav files", required=False ) + subparser.add_argument( + "-maestro", + help="get file paths from maestro val/test sets", + action="store_true", + default=False, + ) subparser.add_argument( "-save_dir", help="dir to save midi files", required=True ) @@ -44,30 +61,86 @@ def _add_transcribe_args(subparser): action="store_true", default=False, ) + subparser.add_argument( + "-compile", + help="use the pytorch compiler to generate a cuda graph", + action="store_true", + default=False, + ) subparser.add_argument("-bs", help="batch size", type=int, default=16) -def build_maestro( - maestro_dir, maestro_csv_file, train_file, val_file, test_file, num_procs +def get_synth_mid_paths(mid_dir: str, csv_path: str): + assert os.path.isdir(mid_dir), "directory doesn't exist" + assert os.path.isfile(csv_path), "csv not found" + + train_paths = [] + test_paths = [] + with open(csv_path, "r") as f: + dict_reader = DictReader(f) + for entry in dict_reader: + mid_path = os.path.normpath( + os.path.join(mid_dir, entry["mid_path"]) + ) + + assert os.path.isfile(mid_path), "file missing" + if entry["split"] == "train": + train_paths.append(mid_path) + elif entry["split"] == "test": + test_paths.append(mid_path) + else: + raise ValueError("Invalid split") + + return train_paths, test_paths + + +def build_synth( + mid_dir: str, + csv_path: str, + train_file: str, + test_file: str, + num_procs: int, ): - from amt.data import AmtDataset + from amt.data import AmtDataset, pianoteq_cmd_fn - assert os.path.isdir(maestro_dir), "MAESTRO directory not found" - assert os.path.isfile(maestro_csv_file), "MAESTRO csv not found" if os.path.isfile(train_file): print(f"Dataset file already exists at {train_file} - removing") os.remove(train_file) - if os.path.isfile(val_file): - print(f"Dataset file already exists at {val_file} - removing") - os.remove(val_file) if os.path.isfile(test_file): print(f"Dataset file already exists at {test_file} - removing") os.remove(test_file) + ( + train_paths, + test_paths, + ) = get_synth_mid_paths(mid_dir, csv_path) + + print(f"Building {train_file}") + AmtDataset.build( + load_paths=train_paths, + save_path=train_file, + num_processes=num_procs, + cli_cmd_fn=pianoteq_cmd_fn, + ) + print(f"Building {test_file}") + AmtDataset.build( + load_paths=test_paths, + save_path=test_file, + num_processes=num_procs, + cli_cmd_fn=pianoteq_cmd_fn, + ) + + +def get_matched_maestro_paths(maestro_dir): + assert os.path.isdir(maestro_dir), "MAESTRO directory not found" + + maestro_csv_path = os.path.join(maestro_dir, "maestro-v3.0.0.csv") + assert os.path.isfile(maestro_csv_path), "MAESTRO csv not found" + matched_paths_train = [] matched_paths_val = [] matched_paths_test = [] - with open(maestro_csv_file, "r") as f: + with open(maestro_csv_path, "r") as f: dict_reader = DictReader(f) for entry in dict_reader: audio_path = os.path.normpath( @@ -92,6 +165,28 @@ def build_maestro( else: print("Invalid split") + return matched_paths_train, matched_paths_val, matched_paths_test + + +def build_maestro(maestro_dir, train_file, val_file, test_file, num_procs): + from amt.data import AmtDataset + + if os.path.isfile(train_file): + print(f"Dataset file already exists at {train_file} - removing") + os.remove(train_file) + if os.path.isfile(val_file): + print(f"Dataset file already exists at {val_file} - removing") + os.remove(val_file) + if os.path.isfile(test_file): + print(f"Dataset file already exists at {test_file} - removing") + os.remove(test_file) + + ( + matched_paths_train, + matched_paths_val, + matched_paths_test, + ) = get_matched_maestro_paths(maestro_dir) + print(f"Building {train_file}") AmtDataset.build( load_paths=matched_paths_train, @@ -118,8 +213,11 @@ def transcribe( save_dir, load_path=None, load_dir=None, + maestro=False, batch_size=16, multi_gpu=False, + quantize=False, + compile=False, ): """ Transcribe audio files to midi using the given model and checkpoint. @@ -158,7 +256,10 @@ def transcribe( trans_mode = "single" if load_dir: assert os.path.isdir(load_dir), "load directory doesn't exist" - trans_mode = "batch" + if maestro is True: + trans_mode = "maestro" + else: + trans_mode = "batch" if not os.path.exists(save_dir): os.makedirs(save_dir) assert os.path.isdir(save_dir), "save dir doesn't exist" @@ -189,6 +290,14 @@ def transcribe( ) print(f"Found {len(found_mp3)} mp3 and {len(found_wav)} wav files") file_paths = found_mp3 + found_wav + elif trans_mode == "maestro": + matched_train_paths, matched_val_paths, matched_test_paths = ( + get_matched_maestro_paths(load_dir) + ) + val_mp3_paths = [ap for ap, mp in matched_val_paths] + test_mp3_paths = [ap for ap, mp in matched_test_paths] + file_paths = test_mp3_paths # val_mp3_paths + test_mp3_paths + assert len(file_paths) == 177, "Invalid maestro files" else: file_paths = [load_path] batch_size = 1 @@ -205,6 +314,8 @@ def transcribe( batch_size=batch_size, input_dir=load_dir, gpu_ids=gpu_ids, + quantize=quantize, + compile=compile, ) else: @@ -214,6 +325,8 @@ def transcribe( save_dir=save_dir, batch_size=batch_size, input_dir=load_dir, + quantize=quantize, + compile=compile, ) @@ -225,10 +338,14 @@ def main(): subparser_maestro = subparsers.add_parser( "maestro", help="Commands to build the maestro dataset." ) + subparser_synth = subparsers.add_parser( + "synth", help="Commands to build the maestro dataset." + ) subparser_transcribe = subparsers.add_parser( "transcribe", help="Commands to run transcription." ) _add_maestro_args(subparser_maestro) + _add_synth_args(subparser_synth) _add_transcribe_args(subparser_transcribe) args = parser.parse_args() @@ -240,21 +357,31 @@ def main(): elif args.command == "maestro": build_maestro( maestro_dir=args.dir, - maestro_csv_file=args.csv, train_file=args.train, val_file=args.val, test_file=args.test, num_procs=args.mp, ) + elif args.command == "synth": + build_synth( + mid_dir=args.dir, + csv_path=args.csv, + train_file=args.train, + test_file=args.test, + num_procs=args.mp, + ) elif args.command == "transcribe": transcribe( model_name=args.model_name, checkpoint_path=args.checkpoint_path, load_path=args.load_path, load_dir=args.load_dir, + maestro=args.maestro, save_dir=args.save_dir, batch_size=args.bs, multi_gpu=args.multi_gpu, + quantize=args.q8, + compile=args.compile, ) else: print("Unrecognized command") diff --git a/amt/tokenizer.py b/amt/tokenizer.py index c368673..8b38cae 100644 --- a/amt/tokenizer.py +++ b/amt/tokenizer.py @@ -329,15 +329,16 @@ def _detokenize_midi_dict( ) elif tok_1_type == "on": if (tok_2_type, tok_3_type) != ("onset", "vel"): - print("Unexpected token order:", tok_1, tok_2, tok_3) - raise ValueError + raise ValueError( + "Unexpected token order:", tok_1, tok_2, tok_3 + ) else: notes_to_close[tok_1_data] = (tok_2_data, tok_3_data) elif tok_1_type == "off": if tok_2_type != "onset": - print("Unexpected token order:", tok_1, tok_2, tok_3) - if DEBUG: - raise Exception + raise ValueError( + "Unexpected token order:", tok_1, tok_2, tok_3 + ) else: # Process note and add to note msgs note_to_close = notes_to_close.pop(tok_1_data, None) diff --git a/amt/train.py b/amt/train.py index eee1b8e..864ed93 100644 --- a/amt/train.py +++ b/amt/train.py @@ -1,11 +1,13 @@ import os import sys import csv +import math import random import functools import argparse import logging import torch +import torchaudio import accelerate from torch import nn as nn @@ -148,7 +150,7 @@ def _get_optim( warmup_lrs = torch.optim.lr_scheduler.LinearLR( optimizer, - start_factor=0.000001, + start_factor=1e-8, end_factor=1, total_iters=warmup, ) @@ -194,7 +196,7 @@ def get_finetune_optim( ): LR = 1e-4 END_RATIO = 0.1 - WARMUP_STEPS = 500 + WARMUP_STEPS = 1000 return _get_optim( lr=LR, @@ -229,7 +231,7 @@ def get_dataloaders( audio_transform = AudioTransform() def _collate_fn(seqs, max_pitch_shift: int): - wav, src, tgt = torch.utils.data.default_collate(seqs) + wav, src, tgt, idxs = torch.utils.data.default_collate(seqs) # Pitch aug pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift) @@ -239,13 +241,13 @@ def _collate_fn(seqs, max_pitch_shift: int): # Distortion wav = audio_transform.distortion_aug_cpu(wav) - return wav, src, tgt, pitch_shift + return wav, src, tgt, pitch_shift, idxs train_dataloader = DataLoader( train_dataset, batch_size=batch_size, num_workers=num_workers, - collate_fn=functools.partial(_collate_fn, max_pitch_shift=5), + collate_fn=functools.partial(_collate_fn, max_pitch_shift=4), shuffle=True, ) val_dataloader = DataLoader( @@ -258,14 +260,6 @@ def _collate_fn(seqs, max_pitch_shift: int): return train_dataloader, val_dataloader -def rolling_average(prev_avg: float, x_n: float, n: int): - # Returns rolling average without needing to recalculate - if n == 0: - return x_n - else: - return ((prev_avg * (n - 1)) / n) + (x_n / n) - - def _train( epochs: int, accelerator: accelerate.Accelerator, @@ -292,6 +286,18 @@ def make_checkpoint(_accelerator, _epoch: int, _step: int): ) _accelerator.save_state(checkpoint_dir) + def get_max_norm(named_parameters): + max_grad_norm = {"val": 0.0} + for name, parameter in named_parameters: + if parameter.grad is not None and parameter.requires_grad is True: + grad_norm = parameter.grad.data.norm(2).item() + # logger.debug(f"{name}: {grad_norm}") + if grad_norm >= max_grad_norm["val"]: + max_grad_norm["name"] = name + max_grad_norm["val"] = grad_norm + + return max_grad_norm + # This is all slightly messy as train_loop and val_loop make use of the # variables in the wider scope. Perhaps refactor this at some point. def train_loop( @@ -319,58 +325,73 @@ def train_loop( leave=False, ) ): - step = __step + _resume_step + 1 - - wav, src, tgt, pitch_shift = batch - with torch.no_grad(): - mel = audio_transform.forward(wav, shift=pitch_shift) - logits = model(mel, src) # (b_sz, s_len, v_sz) - logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss - loss = loss_fn(logits, tgt) + with accelerator.accumulate(model): + step = __step + _resume_step + 1 - # Calculate statistics - loss_buffer.append(loss.item()) - if len(loss_buffer) > TRAILING_LOSS_STEPS: - loss_buffer.pop(0) - trailing_loss = sum(loss_buffer) / len(loss_buffer) - avg_train_loss = rolling_average( - avg_train_loss, loss.item(), __step - ) + wav, src, tgt, pitch_shift, idxs = batch - # Logging - logger.debug( - f"EPOCH {_epoch} STEP {step}: " - f"lr={lr_for_print}, " - f"loss={round(loss.item(), 4)}, " - f"trailing_loss={round(trailing_loss, 4)}, " - f"average_loss={round(avg_train_loss, 4)}" - ) - if accelerator.is_main_process: - loss_writer.writerow([_epoch, step, loss.item()]) - pbar.set_postfix_str( - f"lr={lr_for_print}, " - f"loss={round(loss.item(), 4)}, " - f"trailing={round(trailing_loss, 4)}" - ) + mel = audio_transform.forward(wav, shift=pitch_shift) + logits = model(mel, src) # (b_sz, s_len, v_sz) + logits = logits.transpose( + 1, 2 + ) # Transpose for CrossEntropyLoss + loss = loss_fn(logits, tgt) - # Backwards step - accelerator.backward(loss) - if accelerator.sync_gradients: + # Calculate statistics + loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) + trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( + loss_buffer[-TRAILING_LOSS_STEPS:] + ) + avg_train_loss = sum(loss_buffer) / len(loss_buffer) + + # Logging + logger.debug( + f"EPOCH {_epoch} STEP {step}: " + f"lr={lr_for_print}, " + f"loss={round(loss_buffer[-1], 4)}, " + f"trailing_loss={round(trailing_loss, 4)}, " + f"average_loss={round(avg_train_loss, 4)}" + ) + if accelerator.is_main_process: + loss_writer.writerow([_epoch, step, loss_buffer[-1]]) + + pbar.set_postfix_str( + f"lr={lr_for_print}, " + f"loss={round(loss_buffer[-1], 4)}, " + f"trailing={round(trailing_loss, 4)}" + ) + + # Backwards step + accelerator.backward(loss) + + max_grad_norm_bn = get_max_norm(model.named_parameters()) accelerator.clip_grad_norm_(model.parameters(), 1.0) + max_grad_norm_an = get_max_norm(model.named_parameters()) - optimizer.step() - optimizer.zero_grad() - if scheduler: - scheduler.step() - lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) - - if steps_per_checkpoint: - if step % steps_per_checkpoint == 0: - make_checkpoint( - _accelerator=accelerator, - _epoch=_epoch, - _step=step, + if max_grad_norm_bn["val"] > 1.5: + logger.warning( + f"Seen large grad_norm {max_grad_norm_bn['name']}: {max_grad_norm_bn['val']} -> {max_grad_norm_an['val']}" ) + logger.debug(accelerator.gather(loss)) + logger.debug(accelerator.gather(idxs)) + elif math.isnan(trailing_loss): + logger.error(accelerator.gather(loss)) + logger.error(loss_buffer) + logger.error(accelerator.gather(idxs)) + + optimizer.step() + optimizer.zero_grad() + if scheduler: + scheduler.step() + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + + if steps_per_checkpoint: + if step % steps_per_checkpoint == 0: + make_checkpoint( + _accelerator=accelerator, + _epoch=_epoch, + _step=step, + ) logger.info( f"EPOCH {_epoch}/{epochs + start_epoch}: Finished training - " @@ -379,8 +400,9 @@ def train_loop( return avg_train_loss + @torch.no_grad() def val_loop(dataloader, _epoch: int, aug: bool): - avg_val_loss = 0 + loss_buffer = [] model.eval() for step, batch in ( pbar := tqdm( @@ -389,26 +411,25 @@ def val_loop(dataloader, _epoch: int, aug: bool): leave=False, ) ): - wav, src, tgt = batch - with torch.no_grad(): - if aug == False: - mel = audio_transform.log_mel(wav) - elif aug == True: - # Apply aug without distortion or spec-augment - mel = audio_transform.log_mel( - audio_transform.aug_wav(wav), detune=True - ) - else: - raise TypeError + wav, src, tgt, idxs = batch + + if aug == False: + mel = audio_transform.log_mel(wav) + elif aug == True: + # Apply aug without distortion or spec-augment + mel = audio_transform.log_mel( + audio_transform.aug_wav(wav), detune=True + ) + else: + raise TypeError - logits = model(mel, src) - logits = logits.transpose( - 1, 2 - ) # Transpose for CrossEntropyLoss - loss = loss_fn(logits, tgt) + logits = model(mel, src) + logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss + loss = loss_fn(logits, tgt) # Logging - avg_val_loss = rolling_average(avg_val_loss, loss.item(), step) + loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) + avg_val_loss = sum(loss_buffer) / len(loss_buffer) pbar.set_postfix_str(f"average_loss={round(avg_val_loss, 4)}") # EPOCH @@ -425,7 +446,7 @@ def val_loop(dataloader, _epoch: int, aug: bool): steps_per_checkpoint > 1 ), "Invalid checkpoint mode value (too small)" - TRAILING_LOSS_STEPS = 200 + TRAILING_LOSS_STEPS = 100 PAD_ID = train_dataloader.dataset.tokenizer.pad_id logger = get_logger(__name__) # Accelerate logger loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID) @@ -531,7 +552,9 @@ def resume_train( assert os.path.isfile(val_data_path), f"No file found at {val_data_path}" tokenizer = AmtTokenizer() - accelerator = accelerate.Accelerator(project_dir=project_dir) + accelerator = accelerate.Accelerator( + project_dir=project_dir, gradient_accumulation_steps=4 + ) if accelerator.is_main_process: project_dir = setup_project_dir(project_dir) logger = setup_logger(project_dir) @@ -662,7 +685,9 @@ def train( assert os.path.isfile(finetune_cp_path), "Invalid checkpoint path" tokenizer = AmtTokenizer() - accelerator = accelerate.Accelerator(project_dir=project_dir) + accelerator = accelerate.Accelerator( + project_dir=project_dir, gradient_accumulation_steps=4 + ) if accelerator.is_main_process: project_dir = setup_project_dir(project_dir) logger = setup_logger(project_dir) @@ -801,6 +826,9 @@ def parse_train_args(): argp.add_argument("model", help="name of model config file") argp.add_argument("train_data", help="path to train dir") argp.add_argument("val_data", help="path to val dir") + argp.add_argument( + "-cpath", help="resuming checkpoint", type=str, required=False + ) argp.add_argument("-epochs", help="train epochs", type=int, required=True) argp.add_argument("-bs", help="batch size", type=int, default=32) argp.add_argument("-workers", help="number workers", type=int, default=1) @@ -849,6 +877,7 @@ def parse_train_args(): num_workers=train_args.workers, batch_size=train_args.bs, epochs=train_args.epochs, + finetune_cp_path=train_args.cpath, steps_per_checkpoint=train_args.spc, project_dir=train_args.pdir, ) diff --git a/config/models/medium-final.json b/config/models/medium-final.json new file mode 100644 index 0000000..69b79c7 --- /dev/null +++ b/config/models/medium-final.json @@ -0,0 +1,11 @@ +{ + "n_mels": 256, + "n_audio_ctx": 1500, + "n_audio_state": 768, + "n_audio_head": 12, + "n_audio_layer": 12, + "n_text_ctx": 4096, + "n_text_state": 768, + "n_text_head": 12, + "n_text_layer": 12 +} \ No newline at end of file diff --git a/config/models/medium.json b/config/models/small-final.json similarity index 77% rename from config/models/medium.json rename to config/models/small-final.json index 45c0de6..1d40dd9 100644 --- a/config/models/medium.json +++ b/config/models/small-final.json @@ -3,9 +3,9 @@ "n_audio_ctx": 1500, "n_audio_state": 512, "n_audio_head": 8, - "n_audio_layer": 12, + "n_audio_layer": 6, "n_text_ctx": 4096, "n_text_state": 512, "n_text_head": 8, - "n_text_layer": 12 + "n_text_layer": 6 } \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index a74c9cf..696fb40 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,9 @@ aria @ git+https://github.com/EleutherAI/aria.git -torch >= 2.1 +torch >= 2.2 torchaudio accelerate +psutil mido tqdm orjson -mir_eval \ No newline at end of file +mir_eval diff --git a/scripts/eval/dedupe.py b/scripts/eval/dedupe.py new file mode 100644 index 0000000..8d67d81 --- /dev/null +++ b/scripts/eval/dedupe.py @@ -0,0 +1,72 @@ +import os +import hashlib +import argparse +import multiprocessing + +from pydub import AudioSegment + + +def hash_audio_file(file_path): + """Hash the audio content of an MP3 file.""" + try: + audio = AudioSegment.from_mp3(file_path) + raw_data = audio.raw_data + except Exception as e: + print(e) + return file_path, -1 + else: + return file_path, hashlib.sha256(raw_data).hexdigest() + + +def find_duplicates(root_dir): + """Find and remove duplicate MP3 files in the directory and its subdirectories.""" + duplicates = [] + mp3_paths = [] + for root, _, files in os.walk(root_dir): + for file in files: + if file.endswith(".mp3"): + mp3_paths.append(os.path.join(root, file)) + + with multiprocessing.Pool() as pool: + hashes = pool.map(hash_audio_file, mp3_paths) + + seen_hash = {} + for p, h in hashes: + if seen_hash.get(h, False) is True: + print("Seen dupe") + duplicates.append(p) + else: + print("Seen orig") + seen_hash[h] = True + + return duplicates + + +def remove_duplicates(duplicate_files): + """Remove the duplicate files.""" + for file in duplicate_files: + os.remove(file) + print(f"Removed duplicate file: {file}") + + +def main(): + parser = argparse.ArgumentParser( + description="Remove duplicate MP3 files based on audio content." + ) + parser.add_argument( + "dir", type=str, help="Directory to scan for duplicate MP3 files." + ) + args = parser.parse_args() + + root_directory = args.dir + duplicates = find_duplicates(root_directory) + + if duplicates: + print(f"Found {len(duplicates)} duplicates. Removing...") + remove_duplicates(duplicates) + else: + print("No duplicates found.") + + +if __name__ == "__main__": + main() diff --git a/scripts/eval/dtw.py b/scripts/eval/dtw.py new file mode 100644 index 0000000..41e5754 --- /dev/null +++ b/scripts/eval/dtw.py @@ -0,0 +1,211 @@ +# pip install git+https://github.com/alex2awesome/djitw.git + +import argparse +import csv +import librosa +import djitw +import pretty_midi +import scipy +import random +import multiprocessing +import os +import warnings +import functools +import glob +import numpy as np + +from multiprocessing.dummy import Pool as ThreadPool + +# Audio/CQT parameters +FS = 22050.0 +NOTE_START = 36 +N_NOTES = 48 +HOP_LENGTH = 1024 + +# DTW parameters +GULLY = 0.96 + + +def compute_cqt(audio_data): + """Compute the CQT and frame times for some audio data""" + # Compute CQT + cqt = librosa.cqt( + audio_data, + sr=FS, + fmin=librosa.midi_to_hz(NOTE_START), + n_bins=N_NOTES, + hop_length=HOP_LENGTH, + tuning=0.0, + ) + # Compute the time of each frame + times = librosa.frames_to_time( + np.arange(cqt.shape[1]), sr=FS, hop_length=HOP_LENGTH + ) + # Compute log-amplitude + cqt = librosa.amplitude_to_db(cqt, ref=cqt.max()) + # Normalize and return + return librosa.util.normalize(cqt, norm=2).T, times + + +# Had to change this to average chunks for large audio files for cpu reasons +def load_and_run_dtw(args): + def calc_score(_midi_cqt, _audio_cqt): + # Nearly all high-performing systems used cosine distance + distance_matrix = scipy.spatial.distance.cdist( + _midi_cqt, _audio_cqt, "cosine" + ) + + # Get lowest cost path + p, q, score = djitw.dtw( + distance_matrix, + GULLY, # The gully for all high-performing systems was near 1 + np.median( + distance_matrix + ), # The penalty was also near 1.0*median(distance_matrix) + inplace=False, + ) + # Normalize by path length, normalize by distance matrix submatrix within path + score = score / len(p) + score = ( + score / distance_matrix[p.min() : p.max(), q.min() : q.max()].mean() + ) + + return score + + audio_file, midi_file = args + # Load in the audio data + audio_data, _ = librosa.load(audio_file, sr=FS) + audio_cqt, audio_times = compute_cqt(audio_data) + + midi_object = pretty_midi.PrettyMIDI(midi_file) + midi_audio = midi_object.fluidsynth(fs=FS) + midi_cqt, midi_times = compute_cqt(midi_audio) + + # Truncate to save on compute time for long tracks + MAX_LEN = 10000 + total_len = midi_cqt.shape[0] + if total_len > MAX_LEN: + idx = 0 + scores = [] + while idx < total_len: + scores.append( + calc_score( + _midi_cqt=midi_cqt[idx : idx + MAX_LEN, :], + _audio_cqt=audio_cqt[idx : idx + MAX_LEN, :], + ) + ) + idx += MAX_LEN + + max_score = max(scores) + avg_score = sum(scores) / len(scores) if scores else 1.0 + + else: + avg_score = calc_score(_midi_cqt=midi_cqt, _audio_cqt=audio_cqt) + max_score = avg_score + + return midi_file, avg_score, max_score + + +# I changed wav with mp3 in here :/ +def get_matched_files(audio_dir: str, mid_dir: str): + # We assume that the files have the same path relative to their directory + res = [] + wav_paths = glob.glob(os.path.join(audio_dir, "**/*.mp3"), recursive=True) + print(f"found {len(wav_paths)} mp3 files") + + for wav_path in wav_paths: + input_rel_path = os.path.relpath(wav_path, audio_dir) + mid_path = os.path.join( + mid_dir, os.path.splitext(input_rel_path)[0] + ".mid" + ) + if os.path.isfile(mid_path): + res.append((wav_path, mid_path)) + + print(f"found {len(res)} matched mp3-midi pairs") + + return res + + +def abortable_worker(func, *args, **kwargs): + timeout = kwargs.get("timeout", None) + p = ThreadPool(1) + res = p.apply_async(func, args=args) + try: + out = res.get(timeout) + return out + except multiprocessing.TimeoutError: + return None, None, None + except Exception as e: + print(e) + return None, None, None + finally: + p.close() + p.join() + + +if __name__ == "__main__": + multiprocessing.set_start_method("fork") + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="amplitude_to_db was called on complex input", + ) + parser = argparse.ArgumentParser() + parser.add_argument("-audio_dir", help="dir containing .wav files") + parser.add_argument( + "-mid_dir", help="dir containing .mid files", default=None + ) + parser.add_argument( + "-output_file", help="path to output file", default=None + ) + args = parser.parse_args() + + matched_files = get_matched_files( + audio_dir=args.audio_dir, mid_dir=args.mid_dir + ) + + results = {} + if os.path.exists(args.output_file): + with open(args.output_file, "r") as f: + reader = csv.DictReader(f) + for row in reader: + results[row["mid_path"]] = { + "avg_score": row["avg_score"], + "max_score": row["max_score"], + } + + matched_files = [ + (audio_path, mid_path) + for audio_path, mid_path in matched_files + if mid_path not in results.keys() + ] + random.shuffle(matched_files) + print(f"loaded {len(results)} results") + print(f"calculating scores for {len(matched_files)}") + + score_csv = open(args.output_file, "a") + csv_writer = csv.writer(score_csv) + csv_writer.writerow(["mid_path", "avg_score", "max_score"]) + + with multiprocessing.Pool() as pool: + abortable_func = functools.partial( + abortable_worker, load_and_run_dtw, timeout=15000 + ) + scores = pool.imap_unordered(abortable_func, matched_files) + + skipped = 0 + processed = 0 + for mid_path, avg_score, max_score in scores: + if avg_score is not None and max_score is not None: + csv_writer.writerow([mid_path, avg_score, max_score]) + score_csv.flush() + else: + print(f"timeout") + skipped += 1 + + processed += 1 + if processed % 10 == 0: + print(f"PROCESSED: {processed}/{len(matched_files)}") + print(f"***") + + print(f"skipped: {skipped}") diff --git a/scripts/eval/dtw.sh b/scripts/eval/dtw.sh new file mode 100644 index 0000000..1d25b71 --- /dev/null +++ b/scripts/eval/dtw.sh @@ -0,0 +1,5 @@ +python /home/loubb/work/aria-amt/scripts/eval/dtw.py \ + -audio_dir /mnt/ssd1/data/mp3/raw/aria-mp3 \ + -mid_dir /mnt/ssd1/amt/transcribed_data/0/aria-mid \ + -output_file /mnt/ssd1/amt/transcribed_data/0/aria-mid.csv + diff --git a/scripts/eval/mir.sh b/scripts/eval/mir.sh new file mode 100644 index 0000000..0364d4c --- /dev/null +++ b/scripts/eval/mir.sh @@ -0,0 +1,5 @@ +python /home/loubb/work/aria-amt/amt/evaluate.py \ + --est-dir /home/loubb/work/aria-amt/maestro-ft \ + --ref-dir /mnt/ssd1/data/mp3/raw/maestro-mp3 \ + --output-stats-file out.json + \ No newline at end of file diff --git a/scripts/eval/prune.py b/scripts/eval/prune.py new file mode 100644 index 0000000..1b8f919 --- /dev/null +++ b/scripts/eval/prune.py @@ -0,0 +1,91 @@ +import argparse +import csv +import os +import shutil + + +# Calculate percentiles without using numpy +def calculate_percentiles(data, percentiles): + data_sorted = sorted(data) + n = len(data_sorted) + results = [] + for percentile in percentiles: + k = (n - 1) * percentile / 100 + f = int(k) + c = k - f + if f + 1 < n: + result = data_sorted[f] + c * (data_sorted[f + 1] - data_sorted[f]) + else: + result = data_sorted[f] + results.append(result) + return results + + +def main(mid_dir, output_dir, score_file, max_score, dry): + if os.path.isdir(output_dir) is False: + os.makedirs(output_dir) + + scores = {} + with open(score_file, "r") as f: + reader = csv.DictReader(f) + failures = 0 + for row in reader: + try: + if 0.0 < float(row["avg_score"]) < 1.0: + scores[row["mid_path"]] = float(row["avg_score"]) + except Exception as e: + pass + + print(f"{failures} failures") + print(f"found {len(scores.items())} mid-score pairs") + + print("top 50 by score:") + for k, v in sorted(scores.items(), key=lambda item: item[1], reverse=True)[ + :50 + ]: + print(f"{v}: {k}") + print("bottom 50 by score:") + for k, v in sorted(scores.items(), key=lambda item: item[1])[:50]: + print(f"{v}: {k}") + + # Define the percentiles to calculate + percentiles = [10, 20, 30, 40, 50, 60, 70, 80, 90] + floats = [v for k, v in scores.items()] + + # Calculate the percentiles + print(f"percentiles: {calculate_percentiles(floats, percentiles)}") + + cnt = 0 + for mid_path, score in scores.items(): + mid_rel_path = os.path.relpath(mid_path, mid_dir) + output_path = os.path.join(output_dir, mid_rel_path) + if not os.path.exists(os.path.dirname(output_path)): + os.makedirs(os.path.dirname(output_path)) + + if score < max_score: + if args.dry is not True: + shutil.copyfile(mid_path, output_path) + else: + cnt += 1 + + print(f"excluded {cnt}/{len(scores.items())} files") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-mid_dir", help="dir containing .mid files", default=None + ) + parser.add_argument( + "-output_dir", help="dir containing .mid files", default=None + ) + parser.add_argument("-score_file", help="path to output file", default=None) + parser.add_argument( + "-max_score", type=float, help="path to output file", default=None + ) + parser.add_argument("-dry", action="store_true", help="path to output file") + args = parser.parse_args() + + main( + args.mid_dir, args.output_dir, args.score_file, args.max_score, args.dry + ) diff --git a/scripts/eval/prune.sh b/scripts/eval/prune.sh new file mode 100644 index 0000000..9b3f89a --- /dev/null +++ b/scripts/eval/prune.sh @@ -0,0 +1,6 @@ +python /home/loubb/work/aria-amt/scripts/eval/prune.py \ + -mid_dir /mnt/ssd1/amt/transcribed_data/0/pijama-mid \ + -output_dir /mnt/ssd1/amt/transcribed_data/0/pijama-mid-pruned \ + -score_file /mnt/ssd1/amt/transcribed_data/0/pijama-mid.csv \ + -max_score 0.42 \ + # -dry \ No newline at end of file diff --git a/scripts/eval/req-eval.txt b/scripts/eval/req-eval.txt new file mode 100644 index 0000000..d3ed177 --- /dev/null +++ b/scripts/eval/req-eval.txt @@ -0,0 +1,4 @@ +djitw @ git+https://github.com/alex2awesome/djitw.git +librosa +pretty_midi +pyfluidsynth \ No newline at end of file diff --git a/scripts/eval/split.py b/scripts/eval/split.py new file mode 100644 index 0000000..c912cbc --- /dev/null +++ b/scripts/eval/split.py @@ -0,0 +1,57 @@ +import csv +import random +import glob +import argparse +import os + + +def get_matched_paths(audio_dir: str, mid_dir: str): + # Assume that the files have the same path relative to their directory + res = [] + mid_paths = glob.glob(os.path.join(mid_dir, "**/*.mid"), recursive=True) + print(f"found {len(mid_paths)} mid files") + + audio_dir_last = os.path.basename(audio_dir) + mid_dir_last = os.path.basename(mid_dir) + + for mid_path in mid_paths: + input_rel_path = os.path.relpath(mid_path, mid_dir) + + mp3_rel_path = os.path.splitext(input_rel_path)[0] + ".mp3" + mp3_path = os.path.join(audio_dir, mp3_rel_path) + + # Check if the corresponding .mp3 file exists + if os.path.isfile(mp3_path): + matched_mid_path = os.path.join(mid_dir_last, input_rel_path) + matched_mp3_path = os.path.join(audio_dir_last, mp3_rel_path) + + res.append((matched_mp3_path, matched_mid_path)) + + print(f"found {len(res)} matched mp3-midi pairs") + assert len(mid_paths) == len(res), "audio files missing" + + return res + + +def create_csv(matched_paths, csv_path): + split_csv = open(csv_path, "w") + csv_writer = csv.writer(split_csv) + csv_writer.writerow(["mid_path", "audio_path", "split"]) + + for audio_path, mid_path in matched_paths: + if random.random() < 0.1: + csv_writer.writerow([mid_path, audio_path, "test"]) + else: + csv_writer.writerow([mid_path, audio_path, "train"]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-mid_dir", type=str) + parser.add_argument("-audio_dir", type=str) + parser.add_argument("-csv_path", type=str) + args = parser.parse_args() + + matched_paths = get_matched_paths(args.audio_dir, args.mid_dir) + + create_csv(matched_paths, args.csv_path) diff --git a/scripts/eval/split.sh b/scripts/eval/split.sh new file mode 100644 index 0000000..05faece --- /dev/null +++ b/scripts/eval/split.sh @@ -0,0 +1,4 @@ +python /home/loubb/work/aria-amt/scripts/eval/split.py \ + -mid_dir /mnt/ssd1/amt/transcribed_data/0/aria-mid-pruned \ + -audio_dir /mnt/ssd1/data/mp3/raw/aria-mp3 \ + -csv_path aria-pruned-split.csv diff --git a/tests/test_data.py b/tests/test_data.py index f69117e..8e1010c 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -18,7 +18,7 @@ if os.path.isdir("tests/test_results") is False: os.mkdir("tests/test_results") -MAESTRO_PATH = "/weka/proj-aria/aria-amt/data/train.jsonl" +MAESTRO_PATH = "/mnt/ssd1/amt/training_data/train.txt" def plot_spec(mel: torch.Tensor, name: str | int): @@ -57,7 +57,7 @@ def test_build(self): dataset = AmtDataset("tests/test_results/dataset.jsonl") tokenizer = AmtTokenizer() - for idx, (wav, src, tgt) in enumerate(dataset): + for idx, (wav, src, tgt, idx) in enumerate(dataset): print(wav.shape, src.shape, tgt.shape) src_decoded = tokenizer.decode(src) tgt_decoded = tokenizer.decode(tgt) @@ -76,11 +76,11 @@ def test_maestro(self): audio_transform = AudioTransform() dataset = AmtDataset(load_path=MAESTRO_PATH) print(f"Dataset length: {len(dataset)}") - for idx, (wav, src, tgt) in enumerate(dataset): + for idx, (wav, src, tgt, __idx) in enumerate(dataset): src_dec, tgt_dec = tokenizer.decode(src), tokenizer.decode(tgt) - if (idx + 1) % 100 == 0: - break - if idx % 7 == 0: + + if idx % 7 == 0 and idx < 100: + print(idx) src_mid_dict = tokenizer._detokenize_midi_dict( src_dec, len_ms=30000, @@ -91,6 +91,11 @@ def test_maestro(self): torchaudio.save( f"tests/test_results/wav_{idx}.wav", wav.unsqueeze(0), 16000 ) + torchaudio.save( + f"tests/test_results/wav_aug_{idx}.wav", + audio_transform.aug_wav(wav.unsqueeze(0)), + 16000, + ) plot_spec( audio_transform(wav.unsqueeze(0)).squeeze(0), f"mel_{idx}" )