From d25a7faafba6cb94923e61752ffe62f20f027b91 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 9 Mar 2024 19:38:55 +0000 Subject: [PATCH 1/4] Fix distortion and dataset indexing (#16) * working * format * fix distortion bottlekneck * format * adj --- amt/audio.py | 28 ++++++++++------ amt/data.py | 55 ++++++++++++++++++++++++++++++- amt/infer.py | 81 +++++++++++++++++++++++++++++----------------- amt/run.py | 30 ++++++++--------- amt/train.py | 6 ++++ tests/test_data.py | 6 +++- 6 files changed, 148 insertions(+), 58 deletions(-) diff --git a/amt/audio.py b/amt/audio.py index ed90f65..a67bce9 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -193,10 +193,10 @@ def __init__( min_dist_gain: int = 0, noise_ratio: float = 0.95, reverb_ratio: float = 0.95, - applause_ratio: float = 0.01, # CHANGE + applause_ratio: float = 0.01, distort_ratio: float = 0.15, reduce_ratio: float = 0.01, - spec_aug_ratio: float = 0.25, + spec_aug_ratio: float = 0.5, ): super().__init__() self.tokenizer = AmtTokenizer() @@ -257,7 +257,7 @@ def __init__( ) self.spec_aug = torch.nn.Sequential( torchaudio.transforms.FrequencyMasking( - freq_mask_param=10, iid_masks=True + freq_mask_param=15, iid_masks=True ), torchaudio.transforms.TimeMasking( time_mask_param=1000, iid_masks=True @@ -374,6 +374,17 @@ def apply_distortion(self, wav: torch.tensor): return AF.overdrive(wav, gain=gain, colour=colour) + def distortion_aug_cpu(self, wav: torch.Tensor): + # This function should run on the cpu (i.e. in the dataloader collate + # function) in order to not be a bottlekneck + + if random.random() < self.reduce_ratio: + wav = self.apply_reduction(wav) + if random.random() < self.distort_ratio: + wav = self.apply_distortion(wav) + + return wav + def shift_spec(self, specs: torch.Tensor, shift: int): if shift == 0: return specs @@ -400,18 +411,15 @@ def shift_spec(self, specs: torch.Tensor, shift: int): return shifted_specs def aug_wav(self, wav: torch.Tensor): + # This function doesn't apply distortion. If distortion is desired it + # should be run before hand on the cpu with distortion_aug_cpu. + # Noise if random.random() < self.noise_ratio: wav = self.apply_noise(wav) if random.random() < self.applause_ratio: wav = self.apply_applause(wav) - # Distortion - if random.random() < self.reduce_ratio: - wav = self.apply_reduction(wav) - elif random.random() < self.distort_ratio: - wav = self.apply_distortion(wav) - # Reverb if random.random() < self.reverb_ratio: return self.apply_reverb(wav) @@ -439,7 +447,7 @@ def log_mel(self, wav: torch.Tensor, shift: int | None = None): return log_spec def forward(self, wav: torch.Tensor, shift: int = 0): - # Noise, distortion, and reverb + # Noise, and reverb wav = self.aug_wav(wav) # Spec & pitch shift diff --git a/amt/data.py b/amt/data.py index 71982c5..c5ddfdc 100644 --- a/amt/data.py +++ b/amt/data.py @@ -113,7 +113,17 @@ def __init__(self, load_path: str): self.file_mmap = mmap.mmap( self.file_buff.fileno(), 0, access=mmap.ACCESS_READ ) - self.index = self._build_index() + + index_path = AmtDataset._get_index_path(load_path=load_path) + if os.path.isfile(index_path) is True: + self.index = self._load_index(load_path=index_path) + else: + print("Calculating index...") + self.index = self._build_index() + print( + f"Index of length {len(self.index)} calculated, saving to {index_path}" + ) + self._save_index(index=self.index, save_path=index_path) def close(self): if self.file_buff: @@ -167,6 +177,21 @@ def _build_index(self): return index + def _save_index(self, index: list[int], save_path: str): + with open(save_path, "w") as file: + for idx in index: + file.write(f"{idx}\n") + + def _load_index(self, load_path: str): + with open(load_path, "r") as file: + return [int(line.strip()) for line in file] + + @staticmethod + def _get_index_path(load_path: str): + return ( + f"{load_path.rsplit('.', 1)[0]}_index.{load_path.rsplit('.', 1)[1]}" + ) + @classmethod def build( cls, @@ -175,6 +200,12 @@ def build( num_processes: int = 1, ): assert os.path.isfile(save_path) is False, f"{save_path} already exists" + + index_path = AmtDataset._get_index_path(load_path=save_path) + if os.path.isfile(index_path): + print(f"Removing existing index file at {index_path}") + os.remove(AmtDataset._get_index_path(load_path=save_path)) + num_paths = len(matched_load_paths) with Pool(processes=num_processes) as pool: sharded_save_paths = [] @@ -202,3 +233,25 @@ def build( os.system(shell_cmd) for _path in sharded_save_paths: os.remove(_path) + + # Create index by loading object + AmtDataset(load_path=save_path) + + def _build_index(self): + self.file_mmap.seek(0) + index = [] + pos = 0 + while True: + pos_buff = pos + + pos = self.file_mmap.find(b"\n", pos) + if pos == -1: + break + pos = self.file_mmap.find(b"\n", pos + 1) + if pos == -1: + break + + index.append(pos_buff) + pos += 1 + + return index diff --git a/amt/infer.py b/amt/infer.py index 114d75e..8437de8 100644 --- a/amt/infer.py +++ b/amt/infer.py @@ -1,6 +1,7 @@ import os import time import random +import logging import torch import torch.multiprocessing as multiprocessing @@ -21,7 +22,23 @@ VEL_TOLERANCE = 50 -# TODO: Profile and fix gpu util +def _setup_logger(): + logger = logging.getLogger(__name__) + for h in logger.handlers[:]: + logger.removeHandler(h) + + logger.propagate = False + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "[%(asctime)s] %(process)d: [%(levelname)s] %(message)s", + ) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return logging.getLogger(__name__) def calculate_vel( @@ -101,7 +118,7 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) else: # Call the function with float16 if bfloat16 is not supported - with torch.autocast("cuda", dtype=torch.float16): + with torch.autocast("cuda", dtype=torch.float32): return func(*args, **kwargs) return wrapper @@ -114,6 +131,7 @@ def process_segments( audio_transform: AudioTransform, tokenizer: AmtTokenizer, ): + logger = logging.getLogger(__name__) audio_segs = torch.stack( [audio_seg for (audio_seg, prefix), _ in tasks] ).cuda() @@ -131,14 +149,14 @@ def process_segments( kv_cache = model.get_empty_cache() - for idx in ( - pbar := tqdm( - range(min_prefix_len, MAX_SEQ_LEN - 1), - total=MAX_SEQ_LEN - (min_prefix_len + 1), - leave=False, - ) - ): - # for idx in range(min_prefix_len, MAX_SEQ_LEN - 1): + # for idx in ( + # pbar := tqdm( + # range(min_prefix_len, MAX_SEQ_LEN - 1), + # total=MAX_SEQ_LEN - (min_prefix_len + 1), + # leave=False, + # ) + # ): + for idx in range(min_prefix_len, MAX_SEQ_LEN - 1): if idx == min_prefix_len: logits = model.decoder( xa=audio_features, @@ -181,7 +199,7 @@ def process_segments( break if not all(eos_seen): - print("WARNING: OVERFLOW") + logger.warning("Context length overflow when transcribing segment") for _idx in range(seq.shape[0]): if eos_seen[_idx] == False: eos_seen[_idx] = MAX_SEQ_LEN @@ -201,9 +219,9 @@ def gpu_manager( batch_size: int, ): # model.compile() + logger = _setup_logger() audio_transform = AudioTransform().cuda() tokenizer = AmtTokenizer(return_tensors=True) - process_pid = multiprocessing.current_process().pid wait_for_batch = True batch = [] @@ -211,9 +229,9 @@ def gpu_manager( try: task, pid = gpu_task_queue.get(timeout=5) except: - print(f"{process_pid}: GPU task timeout") + logger.info(f"GPU task timeout") if len(batch) == 0: - print(f"{process_pid}: Finished GPU tasks") + logger.info(f"Finished GPU tasks") return else: wait_for_batch = False @@ -274,8 +292,10 @@ def process_file( result_queue: Queue, tokenizer: AmtTokenizer = AmtTokenizer(), ): - process_pid = multiprocessing.current_process().pid - print(f"{process_pid}: Getting wav segments") + logger = logging.getLogger(__name__) + pid = multiprocessing.current_process().pid + + logger.info(f"Getting wav segments") audio_segments = [ f for f, _ in get_wav_mid_segments( @@ -288,10 +308,10 @@ def process_file( init_idx = len(seq) # Add to gpu queue and wait for results - gpu_task_queue.put(((audio_seg, seq), process_pid)) + gpu_task_queue.put(((audio_seg, seq), pid)) while True: gpu_result = result_queue.get() - if gpu_result["pid"] == process_pid: + if gpu_result["pid"] == pid: seq = gpu_result["result"] break else: @@ -307,7 +327,7 @@ def process_file( else: seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS) if len(seq) == 1: - print(f"{process_pid}: exiting early") + logger.info(f"Exiting early") return res return res @@ -336,19 +356,19 @@ def _get_save_path(_file_path: str): return save_path - pid = multiprocessing.current_process().pid + logger = _setup_logger() tokenizer = AmtTokenizer() files_processed = 0 while not file_queue.empty(): file_path = file_queue.get() save_path = _get_save_path(file_path) if os.path.exists(save_path): - print(f"{pid}: {save_path} already exists, overwriting") + logger.info(f"{save_path} already exists, overwriting") try: res = process_file(file_path, gpu_task_queue, result_queue) except Exception as e: - print(f"{pid}: Failed to transcribe {file_path}") + logger.error(f"Failed to transcribe {file_path}") continue files_processed += 1 @@ -365,14 +385,14 @@ def _get_save_path(_file_path: str): mid = mid_dict.to_midi() mid.save(save_path) except Exception as e: - print(f"{pid}: Failed to detokenize with error {e}") + logger.error(f"Failed to detokenize with error {e}") else: - print(f"{pid}: Finished file {files_processed} - {file_path}") - print(f"{pid}: {file_queue.qsize()} file(s) remaining in queue") + logger.info(f"Finished file {files_processed} - {file_path}") + logger.info(f"{file_queue.qsize()} file(s) remaining in queue") def batch_transcribe( - file_paths: list, + file_paths, # Queue | list, model: AmtEncoderDecoder, save_dir: str, batch_size: int = 16, @@ -384,9 +404,12 @@ def batch_transcribe( model.cuda() model.eval() - file_queue = Queue() - for file_path in file_paths: - file_queue.put(file_path) + if isinstance(file_paths, list): + file_queue = Queue() + for file_path in file_paths: + file_queue.put(file_path) + else: + file_queue = file_paths gpu_task_queue = Queue() result_queue = Queue() diff --git a/amt/run.py b/amt/run.py index b0815fd..7e29ae0 100644 --- a/amt/run.py +++ b/amt/run.py @@ -140,6 +140,7 @@ def transcribe( """ import torch from torch.cuda import is_available as cuda_is_available + from torch.multiprocessing import Queue from amt.tokenizer import AmtTokenizer from amt.infer import batch_transcribe from amt.config import load_model_config @@ -188,35 +189,30 @@ def transcribe( file_paths = found_mp3 + found_wav else: file_paths = [load_path] + batch_size = 1 if multi_gpu: - # Generate chunks gpu_ids = [ int(id) for id in os.getenv("CUDA_VISIBLE_DEVICES").split(",") ] - num_gpus = len(gpu_ids) print(f"Visible gpu_ids: {gpu_ids}") - chunk_size = (len(file_paths) // num_gpus) + 1 - chunks = [ - file_paths[i : i + chunk_size] - for i in range(0, len(file_paths), chunk_size) - ] - print(f"Split {len(file_paths)} files into {len(chunks)} chunks") + # Use shared file queue between gpu processes + file_queue = torch.multiprocessing.Queue() + for file_path in file_paths: + file_queue.put(file_path) processes = [] - for idx, chunk in enumerate(chunks): - print( - f"Starting process on cuda-{idx}: {len(chunk)} files to process" - ) + for gpu_id in gpu_ids: + print(f"Starting process on cuda-{gpu_id}") process = torch.multiprocessing.Process( target=batch_transcribe, args=( - chunk, + file_queue, model, save_dir, batch_size, - gpu_ids[idx], + gpu_id, load_dir, ), ) @@ -237,9 +233,9 @@ def transcribe( def main(): - # Nested argparse inspired by - https://shorturl.at/kuKW0 - parser = argparse.ArgumentParser(usage="amt []") - subparsers = parser.add_subparsers(help="sub-command help") + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(help="sub-command help", dest="command") + # add maestro and transcribe subparsers subparser_maestro = subparsers.add_parser( "maestro", help="Commands to build the maestro dataset." diff --git a/amt/train.py b/amt/train.py index fcb7e39..10a3952 100644 --- a/amt/train.py +++ b/amt/train.py @@ -223,7 +223,10 @@ def get_dataloaders( # Pitch aug (to the sequence tensors) must be applied in the train # dataloader as it needs to be done to every element in the batch equally. # Having this code running on the main process was causing a bottlekneck. + # Furthermore, distortion runs very slowly on the gpu, so we do it in + # the dataloader instead. tensor_pitch_aug = AmtTokenizer().export_tensor_pitch_aug() + audio_transform = AudioTransform() def _collate_fn(seqs, max_pitch_shift: int): wav, src, tgt = torch.utils.data.default_collate(seqs) @@ -233,6 +236,9 @@ def _collate_fn(seqs, max_pitch_shift: int): src = tensor_pitch_aug(seq=src, shift=pitch_shift) tgt = tensor_pitch_aug(seq=tgt, shift=pitch_shift) + # Distortion + wav = audio_transform.distortion_aug_cpu(wav) + return wav, src, tgt, pitch_shift train_dataloader = DataLoader( diff --git a/tests/test_data.py b/tests/test_data.py index 18c92f5..e1770e4 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,6 +1,7 @@ import unittest import logging import os +import time import torch import torchaudio import matplotlib.pyplot as plt @@ -61,6 +62,7 @@ def test_maestro(self): tokenizer = AmtTokenizer() dataset = AmtDataset(load_path=MAESTRO_PATH) + print(f"Dataset length: {len(dataset)}") for idx, (wav, src, tgt) in enumerate(dataset): src_dec, tgt_dec = tokenizer.decode(src), tokenizer.decode(tgt) if (idx + 1) % 100 == 0: @@ -152,7 +154,9 @@ def test_mels(self): wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE).mean( 0, keepdim=True )[:, : SAMPLE_RATE * CHUNK_LEN] - wav_aug = audio_transform.aug_wav(wav) + wav_aug = audio_transform.aug_wav( + audio_transform.distortion_aug_cpu(wav) + ) torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE) torchaudio.save("tests/test_results/aug.wav", wav_aug, SAMPLE_RATE) From 12d249b19789998988c9159b5e9a8b86d04c6068 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 9 Mar 2024 20:16:16 +0000 Subject: [PATCH 2/4] Fix colab error (#17) --- amt/audio.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/amt/audio.py b/amt/audio.py index a67bce9..8b3a08b 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -265,6 +265,8 @@ def __init__( ) def _get_paths(self, dir_path): + os.makedirs(dir_path, exist_ok=True) + return [ os.path.join(dir_path, f) for f in os.listdir(dir_path) From d56e8e5eaf84836653e3d4ef1315fd422e5c761f Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 11 Mar 2024 21:05:42 +0000 Subject: [PATCH 3/4] Add pedal msgs to tokenizer (#18) * add pedal msgs to tokenizer * fix eos token * format * improve inference --- amt/infer.py | 77 ++++++++++++++++++--------- amt/tokenizer.py | 115 +++++++++++++++++++++++++++++++--------- tests/test_tokenizer.py | 18 ++++++- 3 files changed, 158 insertions(+), 52 deletions(-) diff --git a/amt/infer.py b/amt/infer.py index 8437de8..382fb65 100644 --- a/amt/infer.py +++ b/amt/infer.py @@ -7,19 +7,22 @@ from torch.multiprocessing import Queue from tqdm import tqdm +from functools import wraps +from torch.cuda import is_bf16_supported from amt.model import AmtEncoderDecoder from amt.tokenizer import AmtTokenizer -from amt.audio import AudioTransform +from amt.audio import AudioTransform, pad_or_trim from amt.data import get_wav_mid_segments + MAX_SEQ_LEN = 4096 LEN_MS = 30000 STRIDE_FACTOR = 3 CHUNK_LEN_MS = LEN_MS // STRIDE_FACTOR -BEAM = 3 -ONSET_TOLERANCE = 50 -VEL_TOLERANCE = 50 +BEAM = 5 +ONSET_TOLERANCE = 61 +VEL_TOLERANCE = 100 def _setup_logger(): @@ -105,10 +108,6 @@ def calculate_onset( return tokenizer.tok_to_id[("onset", new_onset)] -from functools import wraps -from torch.cuda import is_bf16_supported - - def optional_bf16_autocast(func): @wraps(func) def wrapper(*args, **kwargs): @@ -145,7 +144,7 @@ def process_segments( tokenizer.trunc_seq(prefix, MAX_SEQ_LEN) for prefix in raw_prefixes ] seq = torch.stack([tokenizer.encode(prefix) for prefix in prefixes]).cuda() - eos_seen = [False for _ in prefixes] + end_idxs = [MAX_SEQ_LEN for _ in prefixes] kv_cache = model.get_empty_cache() @@ -173,7 +172,7 @@ def process_segments( next_tok_ids = torch.argmax(logits[:, -1], dim=-1) for batch_idx in range(logits.shape[0]): - if eos_seen[batch_idx] is not False: + if idx > end_idxs[batch_idx]: # End already seen, add pad token tok_id = tokenizer.pad_id elif idx >= prefix_lens[batch_idx]: @@ -192,20 +191,24 @@ def process_segments( tok_id = tokenizer.tok_to_id[prefixes[batch_idx][idx]] seq[batch_idx, idx] = tok_id - if tokenizer.id_to_tok[tok_id] == tokenizer.eos_tok: - eos_seen[batch_idx] = idx - - if all(eos_seen): + tok = tokenizer.id_to_tok[tok_id] + if tok == tokenizer.eos_tok: + end_idxs[batch_idx] = idx + elif ( + type(tok) is tuple + and tok[0] == "onset" + and tok[1] >= LEN_MS - CHUNK_LEN_MS + ): + end_idxs[batch_idx] = idx - 2 + + if all(_idx <= idx for _idx in end_idxs): break - if not all(eos_seen): + if not all(_idx <= idx for _idx in end_idxs): logger.warning("Context length overflow when transcribing segment") - for _idx in range(seq.shape[0]): - if eos_seen[_idx] == False: - eos_seen[_idx] = MAX_SEQ_LEN results = [ - tokenizer.decode(seq[_idx, : eos_seen[_idx] + 1]) + tokenizer.decode(seq[_idx, : end_idxs[_idx] + 1]) for _idx in range(seq.shape[0]) ] @@ -218,7 +221,7 @@ def gpu_manager( model: AmtEncoderDecoder, batch_size: int, ): - # model.compile() + model.compile() logger = _setup_logger() audio_transform = AudioTransform().cuda() tokenizer = AmtTokenizer(return_tensors=True) @@ -283,7 +286,7 @@ def _truncate_seq( except: return [""] else: - return res[: res.index(tokenizer.eos_tok)] + return res[: res.index(tokenizer.eos_tok)] # Needs to change def process_file( @@ -302,8 +305,15 @@ def process_file( audio_path=file_path, stride_factor=STRIDE_FACTOR ) ] - seq = [""] - res = [""] + + # Add addtional (padded) final audio segment + _last_seg = audio_segments[-1] + audio_segments.append( + pad_or_trim(_last_seg[len(_last_seg) // STRIDE_FACTOR :]) + ) + + seq = [tokenizer.bos_tok] + res = [tokenizer.bos_tok] for idx, audio_seg in enumerate(audio_segments): init_idx = len(seq) @@ -318,15 +328,18 @@ def process_file( result_queue.put(gpu_result) res += _shift_onset( - seq[init_idx : seq.index(tokenizer.eos_tok)], + seq[init_idx:], idx * CHUNK_LEN_MS, ) if idx == len(audio_segments) - 1: break + elif res[-1] == tokenizer.eos_tok: + logger.info(f"Exiting early") + break else: - seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS) - if len(seq) == 1: + seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS - CHUNK_LEN_MS) + if len(seq) <= 2: logger.info(f"Exiting early") return res @@ -441,3 +454,15 @@ def batch_transcribe( p.join() gpu_manager_process.join() + + +def sample_top_p(probs, p): + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + + return next_token diff --git a/amt/tokenizer.py b/amt/tokenizer.py index 67d6072..762a10a 100644 --- a/amt/tokenizer.py +++ b/amt/tokenizer.py @@ -46,6 +46,7 @@ def __init__(self, return_tensors: bool = False): self.prev_tokens = [("prev", i) for i in range(128)] self.note_on_tokens = [("on", i) for i in range(128)] self.note_off_tokens = [("off", i) for i in range(128)] + self.pedal_tokens = [("pedal", 0), (("pedal", 1))] self.velocity_tokens = [("vel", i) for i in self.velocity_quantizations] self.onset_tokens = [ ("onset", i) for i in self.onset_time_quantizations @@ -56,6 +57,7 @@ def __init__(self, return_tensors: bool = False): + self.prev_tokens + self.note_on_tokens + self.note_off_tokens + + self.pedal_tokens + self.velocity_tokens + self.onset_tokens ) @@ -76,7 +78,10 @@ def _quantize_velocity(self, velocity: int): else: return velocity_quantized - # This method needs to be cleaned up completely, variables renamed + # TODO: + # - I need to make this method more robust, as it will have to handle + # an arbitrary MIDI file + # - Decide whether to put pedal messages as prev tokens def _tokenize_midi_dict( self, midi_dict: MidiDict, @@ -88,6 +93,12 @@ def _tokenize_midi_dict( ), "Invalid values for start_ms, end_ms" midi_dict.resolve_pedal() # Important !! + pedal_intervals = midi_dict._build_pedal_intervals() + if len(pedal_intervals.keys()) > 1: + print("Warning: midi_dict has more than one pedal channel") + pedal_intervals = pedal_intervals[0] + + last_msg_ms = -1 on_off_notes = [] prev_notes = [] for msg in midi_dict.note_msgs: @@ -109,6 +120,9 @@ def _tokenize_midi_dict( ticks_per_beat=midi_dict.ticks_per_beat, ) + if note_end_ms > last_msg_ms: + last_msg_ms = note_end_ms + rel_note_start_ms_q = self._quantize_onset(note_start_ms - start_ms) rel_note_end_ms_q = self._quantize_onset(note_end_ms - start_ms) velocity_q = self._quantize_velocity(_velocity) @@ -149,35 +163,70 @@ def _tokenize_midi_dict( ("off", _pitch, rel_note_end_ms_q, None) ) - on_off_notes.sort(key=lambda x: (x[2], x[0] == "on")) + on_off_pedal = [] + for pedal_on_tick, pedal_off_tick in pedal_intervals: + pedal_on_ms = get_duration_ms( + start_tick=0, + end_tick=pedal_on_tick, + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=midi_dict.ticks_per_beat, + ) + pedal_off_ms = get_duration_ms( + start_tick=0, + end_tick=pedal_off_tick, + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=midi_dict.ticks_per_beat, + ) + + rel_on_ms_q = self._quantize_onset(pedal_on_ms - start_ms) + rel_off_ms_q = self._quantize_onset(pedal_off_ms - start_ms) + + # On message + if pedal_on_ms <= start_ms or pedal_on_ms >= end_ms: + continue + else: + on_off_pedal.append(("pedal", 1, rel_on_ms_q, None)) + + # Off message + if pedal_off_ms <= start_ms or pedal_off_ms >= end_ms: + continue + else: + on_off_pedal.append(("pedal", 0, rel_off_ms_q, None)) + + on_off_combined = on_off_notes + on_off_pedal + on_off_combined.sort( + key=lambda x: ( + x[2], + (0 if x[0] == "pedal" else 1 if x[0] == "off" else 2), + ) + ) random.shuffle(prev_notes) tokenized_seq = [] - note_status = {} - for pitch in prev_notes: - note_status[pitch] = True - for note in on_off_notes: - _type, _pitch, _onset, _velocity = note + for tok in on_off_combined: + _type, _val, _onset, _velocity = tok if _type == "on": - if note_status.get(_pitch) == True: - # Place holder - we can remove note_status logic now - raise Exception - - tokenized_seq.append(("on", _pitch)) + tokenized_seq.append(("on", _val)) tokenized_seq.append(("onset", _onset)) tokenized_seq.append(("vel", _velocity)) - note_status[_pitch] = True elif _type == "off": - if note_status.get(_pitch) == False: - # Place holder - we can remove note_status logic now - raise Exception - else: - tokenized_seq.append(("off", _pitch)) + tokenized_seq.append(("off", _val)) + tokenized_seq.append(("onset", _onset)) + elif _type == "pedal": + if _val == 0: + tokenized_seq.append(("pedal", _val)) + tokenized_seq.append(("onset", _onset)) + elif _val: + tokenized_seq.append(("pedal", _val)) tokenized_seq.append(("onset", _onset)) - note_status[_pitch] = False prefix = [("prev", p) for p in prev_notes] - return prefix + [self.bos_tok] + tokenized_seq + [self.eos_tok] + + # Add eos_tok only if segment includes end of midi_dict + if last_msg_ms < end_ms: + return prefix + [self.bos_tok] + tokenized_seq + [self.eos_tok] + else: + return prefix + [self.bos_tok] + tokenized_seq def _detokenize_midi_dict( self, @@ -243,16 +292,29 @@ def _detokenize_midi_dict( print("Unexpected token order: 'prev' seen after ''") if DEBUG: raise Exception + elif tok_1_type == "pedal": + # Pedal information contained in note-off messages, so we don't + # need to manually processes them + _pedal_data = tok_1_data + _tick = tok_2_data + note_msgs.append( + { + "type": "pedal", + "data": _pedal_data, + "tick": _tick, + "channel": 0, + } + ) elif tok_1_type == "on": if (tok_2_type, tok_3_type) != ("onset", "vel"): - print("Unexpected token order") + print("Unexpected token order:", tok_1, tok_2, tok_3) if DEBUG: raise Exception else: notes_to_close[tok_1_data] = (tok_2_data, tok_3_data) elif tok_1_type == "off": if tok_2_type != "onset": - print("Unexpected token order") + print("Unexpected token order:", tok_1, tok_2, tok_3) if DEBUG: raise Exception else: @@ -336,9 +398,6 @@ def export_data_aug(self): def export_msg_mixup(self): def msg_mixup(src: list): - def round_to_base(n, base=150): - return base * round(n / base) - # Process bos, eos, and pad tokens orig_len = len(src) seen_pad_tok = False @@ -387,6 +446,9 @@ def round_to_base(n, base=150): elif tok_1_type == "off": _onset = tok_2_data buffer[_onset]["off"].append((tok_1, tok_2)) + elif tok_1_type == "pedal": + _onset = tok_2_data + buffer[_onset]["pedal"].append((tok_1, tok_2)) else: pass @@ -394,6 +456,9 @@ def round_to_base(n, base=150): for k, v in sorted(buffer.items()): random.shuffle(v["on"]) random.shuffle(v["off"]) + for item in v["pedal"]: + res.append(item[0]) # Pedal + res.append(item[1]) # Onset for item in v["off"]: res.append(item[0]) # Pitch res.append(item[1]) # Onset diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 64c1a36..1148c0c 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -37,8 +37,24 @@ def _tokenize_detokenize(mid_name: str, start: int, end: int): _tokenize_detokenize("maestro2.mid", start=START, end=END) _tokenize_detokenize("maestro3.mid", start=START, end=END) + def test_eos_tok(self): + tokenizer = AmtTokenizer() + midi_dict = MidiDict.from_midi(f"tests/test_data/maestro1.mid") + + cnt = 0 + while True: + seq = tokenizer._tokenize_midi_dict( + midi_dict, start_ms=cnt * 10000, end_ms=(cnt * 10000) + 30000 + ) + if len(seq) <= 2: + self.assertEqual(seq[-1], tokenizer.eos_tok) + break + else: + cnt += 1 + def test_pitch_aug(self): tokenizer = AmtTokenizer(return_tensors=True) + tensor_pitch_aug = tokenizer.export_tensor_pitch_aug() midi_dict_1 = MidiDict.from_midi("tests/test_data/maestro1.mid") midi_dict_2 = MidiDict.from_midi("tests/test_data/maestro2.mid") @@ -61,7 +77,7 @@ def test_pitch_aug(self): tokenizer.encode(seq_3), ) ) - aug_seqs = tokenizer.pitch_aug(seqs, shift=2) + aug_seqs = tensor_pitch_aug(seqs, shift=2) midi_dict_1_aug = tokenizer._detokenize_midi_dict( tokenizer.decode(aug_seqs[0]), 30000 From d6fea7f6457c5604864cb870142b5c2703f82997 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 12 Mar 2024 17:24:53 +0000 Subject: [PATCH 4/4] Fix inference, pedal, and add EQ aug (#20) * fix inference and add prev pedal token * add bandpass eq --- amt/audio.py | 20 ++++++++-- amt/infer.py | 91 +++++++++++++++++++++++++--------------------- amt/tokenizer.py | 49 ++++++++++++++++--------- tests/test_data.py | 12 ++++++ 4 files changed, 109 insertions(+), 63 deletions(-) diff --git a/amt/audio.py b/amt/audio.py index 8b3a08b..18224f5 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -194,6 +194,7 @@ def __init__( noise_ratio: float = 0.95, reverb_ratio: float = 0.95, applause_ratio: float = 0.01, + bandpass_ratio: float = 0.1, distort_ratio: float = 0.15, reduce_ratio: float = 0.01, spec_aug_ratio: float = 0.5, @@ -214,6 +215,7 @@ def __init__( self.noise_ratio = noise_ratio self.reverb_ratio = reverb_ratio self.applause_ratio = applause_ratio + self.bandpass_ratio = bandpass_ratio self.distort_ratio = distort_ratio self.reduce_ratio = reduce_ratio self.spec_aug_ratio = spec_aug_ratio @@ -350,6 +352,14 @@ def apply_applause(self, wav: torch.tensor): return AF.add_noise(waveform=wav, noise=applause, snr=snr_dbs) + def apply_bandpass(self, wav: torch.tensor): + central_freq = random.randint(1000, 3500) + Q = random.uniform(0.707, 1.41) + + return torchaudio.functional.bandpass_biquad( + wav, self.sample_rate, central_freq, Q + ) + def apply_reduction(self, wav: torch.tensor): """ Limit the high-band pass filter, the low-band pass filter and the sample rate @@ -424,9 +434,13 @@ def aug_wav(self, wav: torch.Tensor): # Reverb if random.random() < self.reverb_ratio: - return self.apply_reverb(wav) - else: - return wav + wav = self.apply_reverb(wav) + + # EQ + if random.random() < self.bandpass_ratio: + wav = self.apply_bandpass(wav) + + return wav def norm_mel(self, mel_spec: torch.Tensor): log_spec = torch.clamp(mel_spec, min=1e-10).log10() diff --git a/amt/infer.py b/amt/infer.py index 382fb65..289d499 100644 --- a/amt/infer.py +++ b/amt/infer.py @@ -283,10 +283,13 @@ def _truncate_seq( _mid_dict = tokenizer._detokenize_midi_dict(seq, LEN_MS) try: res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1) - except: + except Exception: + print("Truncate failed") return [""] else: - return res[: res.index(tokenizer.eos_tok)] # Needs to change + if res[-1] == tokenizer.eos_tok: + res.pop() + return res def process_file( @@ -306,14 +309,9 @@ def process_file( ) ] - # Add addtional (padded) final audio segment - _last_seg = audio_segments[-1] - audio_segments.append( - pad_or_trim(_last_seg[len(_last_seg) // STRIDE_FACTOR :]) - ) - + res = [] seq = [tokenizer.bos_tok] - res = [tokenizer.bos_tok] + concat_seq = [tokenizer.bos_tok] for idx, audio_seg in enumerate(audio_segments): init_idx = len(seq) @@ -327,21 +325,25 @@ def process_file( else: result_queue.put(gpu_result) - res += _shift_onset( + concat_seq += _shift_onset( seq[init_idx:], idx * CHUNK_LEN_MS, ) if idx == len(audio_segments) - 1: - break - elif res[-1] == tokenizer.eos_tok: - logger.info(f"Exiting early") - break + res.append(concat_seq) + elif concat_seq[-1] == tokenizer.eos_tok: + res.append(concat_seq) + seq = [tokenizer.bos_tok] + concat_seq = [tokenizer.bos_tok] + logger.info(f"Finished segment - eos_tok seen") else: seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS - CHUNK_LEN_MS) - if len(seq) <= 2: - logger.info(f"Exiting early") - return res + if len(seq) == 1: + res.append(concat_seq) + seq = [tokenizer.bos_tok] + concat_seq = [tokenizer.bos_tok] + logger.info(f"Exiting early - silence") return res @@ -353,16 +355,35 @@ def worker( save_dir: str, input_dir: str | None = None, ): - def _get_save_path(_file_path: str): + def _save_seq(_seq: list, _save_path: str): + if os.path.exists(_save_path): + logger.info(f"Already exists {_save_path} - overwriting") + + for tok in _seq[::-1]: + if type(tok) is tuple and tok[0] == "onset": + last_onset = tok[1] + break + + try: + mid_dict = tokenizer._detokenize_midi_dict( + tokenized_seq=_seq, len_ms=last_onset + ) + mid = mid_dict.to_midi() + mid.save(_save_path) + except Exception as e: + logger.error(f"Failed to save {_save_path}") + + def _get_save_path(_file_path: str, _idx: int | str = ""): if input_dir is None: save_path = os.path.join( save_dir, - os.path.splitext(os.path.basename(file_path))[0] + ".mid", + os.path.splitext(os.path.basename(file_path))[0] + + f"{_idx}.mid", ) else: input_rel_path = os.path.relpath(_file_path, input_dir) save_path = os.path.join( - save_dir, os.path.splitext(input_rel_path)[0] + ".mid" + save_dir, os.path.splitext(input_rel_path)[0] + f"{_idx}.mid" ) if not os.path.isdir(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path), exist_ok=True) @@ -374,34 +395,20 @@ def _get_save_path(_file_path: str): files_processed = 0 while not file_queue.empty(): file_path = file_queue.get() - save_path = _get_save_path(file_path) - if os.path.exists(save_path): - logger.info(f"{save_path} already exists, overwriting") try: - res = process_file(file_path, gpu_task_queue, result_queue) + seqs = process_file(file_path, gpu_task_queue, result_queue) except Exception as e: - logger.error(f"Failed to transcribe {file_path}") + logger.error(f"Failed to process {file_path}") continue - files_processed += 1 - - for tok in res[::-1]: - if type(tok) is tuple and tok[0] == "onset": - last_onset = tok[1] - break + logger.info(f"Transcribed into {len(seqs)} segment(s)") + for _idx, seq in enumerate(seqs): + _save_seq(seq, _get_save_path(file_path, _idx)) - try: - mid_dict = tokenizer._detokenize_midi_dict( - tokenized_seq=res, len_ms=last_onset - ) - mid = mid_dict.to_midi() - mid.save(save_path) - except Exception as e: - logger.error(f"Failed to detokenize with error {e}") - else: - logger.info(f"Finished file {files_processed} - {file_path}") - logger.info(f"{file_queue.qsize()} file(s) remaining in queue") + files_processed += 1 + logger.info(f"Finished file {files_processed} - {file_path}") + logger.info(f"{file_queue.qsize()} file(s) remaining in queue") def batch_transcribe( diff --git a/amt/tokenizer.py b/amt/tokenizer.py index 762a10a..d5416a7 100644 --- a/amt/tokenizer.py +++ b/amt/tokenizer.py @@ -46,7 +46,7 @@ def __init__(self, return_tensors: bool = False): self.prev_tokens = [("prev", i) for i in range(128)] self.note_on_tokens = [("on", i) for i in range(128)] self.note_off_tokens = [("off", i) for i in range(128)] - self.pedal_tokens = [("pedal", 0), (("pedal", 1))] + self.pedal_tokens = [("pedal", 0), ("pedal", 1), ("prev", "pedal")] self.velocity_tokens = [("vel", i) for i in self.velocity_quantizations] self.onset_tokens = [ ("onset", i) for i in self.onset_time_quantizations @@ -81,7 +81,6 @@ def _quantize_velocity(self, velocity: int): # TODO: # - I need to make this method more robust, as it will have to handle # an arbitrary MIDI file - # - Decide whether to put pedal messages as prev tokens def _tokenize_midi_dict( self, midi_dict: MidiDict, @@ -96,11 +95,13 @@ def _tokenize_midi_dict( pedal_intervals = midi_dict._build_pedal_intervals() if len(pedal_intervals.keys()) > 1: print("Warning: midi_dict has more than one pedal channel") + if len(midi_dict.instrument_msgs) > 1: + print("Warning: midi_dict has more than one instrument msg") pedal_intervals = pedal_intervals[0] last_msg_ms = -1 on_off_notes = [] - prev_notes = [] + prev_toks = [] for msg in midi_dict.note_msgs: _pitch = msg["data"]["pitch"] _velocity = msg["data"]["velocity"] @@ -137,9 +138,9 @@ def _tokenize_midi_dict( if note_end_ms <= start_ms or note_start_ms >= end_ms: # Skip continue elif ( - note_start_ms < start_ms and _pitch not in prev_notes + note_start_ms < start_ms and _pitch not in prev_toks ): # Add to prev notes - prev_notes.append(_pitch) + prev_toks.append(_pitch) if note_end_ms < end_ms: on_off_notes.append( ("off", _pitch, rel_note_end_ms_q, None) @@ -182,8 +183,10 @@ def _tokenize_midi_dict( rel_off_ms_q = self._quantize_onset(pedal_off_ms - start_ms) # On message - if pedal_on_ms <= start_ms or pedal_on_ms >= end_ms: + if pedal_off_ms <= start_ms or pedal_on_ms >= end_ms: continue + elif pedal_on_ms < start_ms and pedal_off_ms >= start_ms: + prev_toks.append("pedal") else: on_off_pedal.append(("pedal", 1, rel_on_ms_q, None)) @@ -200,7 +203,7 @@ def _tokenize_midi_dict( (0 if x[0] == "pedal" else 1 if x[0] == "off" else 2), ) ) - random.shuffle(prev_notes) + random.shuffle(prev_toks) tokenized_seq = [] for tok in on_off_combined: @@ -220,7 +223,7 @@ def _tokenize_midi_dict( tokenized_seq.append(("pedal", _val)) tokenized_seq.append(("onset", _onset)) - prefix = [("prev", p) for p in prev_notes] + prefix = [("prev", p) for p in prev_toks] # Add eos_tok only if segment includes end of midi_dict if last_msg_ms < end_ms: @@ -271,7 +274,21 @@ def _detokenize_midi_dict( if DEBUG: raise Exception - notes_to_close[tok[1]] = (0, self.default_velocity) + if tok[1] == "pedal": + pedal_msgs.append( + { + "type": "pedal", + "data": 1, + "tick": 0, + "channel": 0, + } + ) + elif isinstance(tok[1], int): + notes_to_close[tok[1]] = (0, self.default_velocity) + else: + print(f"Invalid 'prev' token: {tok}") + if DEBUG: + raise Exception else: raise Exception( f"Invalid note sequence at position {idx}: {tok, tokenized_seq[:idx]}" @@ -293,11 +310,9 @@ def _detokenize_midi_dict( if DEBUG: raise Exception elif tok_1_type == "pedal": - # Pedal information contained in note-off messages, so we don't - # need to manually processes them _pedal_data = tok_1_data _tick = tok_2_data - note_msgs.append( + pedal_msgs.append( { "type": "pedal", "data": _pedal_data, @@ -454,13 +469,11 @@ def msg_mixup(src: list): # Shuffle order and re-append to result for k, v in sorted(buffer.items()): + off_pedal_combined = v["off"] + v["pedal"] + random.shuffle(off_pedal_combined) random.shuffle(v["on"]) - random.shuffle(v["off"]) - for item in v["pedal"]: - res.append(item[0]) # Pedal - res.append(item[1]) # Onset - for item in v["off"]: - res.append(item[0]) # Pitch + for item in off_pedal_combined: + res.append(item[0]) # Off or pedal res.append(item[1]) # Onset for item in v["on"]: res.append(item[0]) # Pitch diff --git a/tests/test_data.py b/tests/test_data.py index e1770e4..1437472 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -177,6 +177,18 @@ def test_distortion(self): res = audio_transform.apply_distortion(wav) torchaudio.save("tests/test_results/dist.wav", res, SAMPLE_RATE) + def test_bandpass(self): + SAMPLE_RATE, CHUNK_LEN = 16000, 30 + audio_transform = AudioTransform() + wav, sr = torchaudio.load("tests/test_data/147.wav") + wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE).mean( + 0, keepdim=True + )[:, : SAMPLE_RATE * CHUNK_LEN] + + torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE) + res = audio_transform.apply_bandpass(wav) + torchaudio.save("tests/test_results/bandpass.wav", res, SAMPLE_RATE) + def test_applause(self): SAMPLE_RATE, CHUNK_LEN = 16000, 30 audio_transform = AudioTransform()