diff --git a/amt/audio.py b/amt/audio.py index a0d5d4c..6c37f3f 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -194,10 +194,11 @@ def __init__( noise_ratio: float = 0.95, reverb_ratio: float = 0.95, applause_ratio: float = 0.01, - bandpass_ratio: float = 0.1, + bandpass_ratio: float = 0.15, distort_ratio: float = 0.15, reduce_ratio: float = 0.01, - codecs_ratio: float = 0.01, + detune_ratio: float = 0.1, + detune_max_shift: float = 0.15, spec_aug_ratio: float = 0.5, ): super().__init__() @@ -219,8 +220,9 @@ def __init__( self.bandpass_ratio = bandpass_ratio self.distort_ratio = distort_ratio self.reduce_ratio = reduce_ratio + self.detune_ratio = detune_ratio + self.detune_max_shift = detune_max_shift self.spec_aug_ratio = spec_aug_ratio - self.codecs_ratio = codecs_ratio self.reduction_resample_rate = 6000 # Hardcoded? # Audio aug @@ -268,6 +270,19 @@ def __init__( ), ) + def get_params(self): + return { + "noise_ratio": self.noise_ratio, + "reverb_ratio": self.reverb_ratio, + "applause_ratio": self.applause_ratio, + "bandpass_ratio": self.bandpass_ratio, + "distort_ratio": self.distort_ratio, + "reduce_ratio": self.reduce_ratio, + "detune_ratio": self.detune_ratio, + "detune_max_shift": self.detune_max_shift, + "spec_aug_ratio": self.spec_aug_ratio, + } + def _get_paths(self, dir_path): os.makedirs(dir_path, exist_ok=True) @@ -399,21 +414,7 @@ def distortion_aug_cpu(self, wav: torch.Tensor): return wav - def apply_codec(self, wav: torch.tensor): - """ - Apply different audio codecs to the audio. - """ - format_encoder_pairs = [ - ("wav", "pcm_mulaw"), - ("g722", None), - ("ogg", "vorbis") - ] - for format, encoder in format_encoder_pairs: - encoder = torchaudio.io.AudioEffector(format=format, encoder=encoder) - if random.random() < self.codecs_ratio: - wav = encoder.apply(wav, self.sample_rate) - - def shift_spec(self, specs: torch.Tensor, shift: int): + def shift_spec(self, specs: torch.Tensor, shift: int | float): if shift == 0: return specs @@ -438,9 +439,21 @@ def shift_spec(self, specs: torch.Tensor, shift: int): return shifted_specs + def detune_spec(self, specs: torch.Tensor): + if random.random() < self.detune_ratio: + detune_shift = random.uniform( + -self.detune_max_shift, self.detune_max_shift + ) + detuned_specs = self.shift_spec(specs, shift=detune_shift) + + return (specs + detuned_specs) / 2 + else: + return specs + def aug_wav(self, wav: torch.Tensor): # This function doesn't apply distortion. If distortion is desired it - # should be run before hand on the cpu with distortion_aug_cpu. + # should be run beforehand on the cpu with distortion_aug_cpu. Note + # also that detuning is done to the spectrogram in log_mel, not the wav. # Noise if random.random() < self.noise_ratio: @@ -468,10 +481,17 @@ def norm_mel(self, mel_spec: torch.Tensor): return log_spec - def log_mel(self, wav: torch.Tensor, shift: int | None = None): + def log_mel( + self, wav: torch.Tensor, shift: int | None = None, detune: bool = False + ): spec = self.spec_transform(wav)[..., :-1] - if shift and shift != 0: + + if shift is not None and shift != 0: spec = self.shift_spec(spec, shift) + elif detune is True: + # Don't detune and spec shift at the same time + spec = self.detune_spec(spec) + mel_spec = self.mel_transform(spec) # Norm @@ -483,8 +503,8 @@ def forward(self, wav: torch.Tensor, shift: int = 0): # Noise, and reverb wav = self.aug_wav(wav) - # Spec & pitch shift - log_mel = self.log_mel(wav, shift) + # Spec, detuning & pitch shift + log_mel = self.log_mel(wav, shift, detune=True) # Spec aug if random.random() < self.spec_aug_ratio: diff --git a/amt/data.py b/amt/data.py index 6a4810e..377b760 100644 --- a/amt/data.py +++ b/amt/data.py @@ -1,91 +1,22 @@ import mmap import os +import io +import base64 import shutil import orjson import torch import torchaudio -from multiprocessing import Pool +from multiprocessing import Pool, Queue, Process +from typing import Callable from aria.data.midi import MidiDict from amt.tokenizer import AmtTokenizer from amt.config import load_config from amt.audio import pad_or_trim -from midi2audio import FluidSynth -import random - - -class SyntheticMidiHandler: - def __init__(self, soundfont_path: str, soundfont_prob_dict: dict = None, num_wavs_per_midi: int = 1): - """ - File to load MIDI files and convert them to audio. - - Parameters - ---------- - soundfont_path : str - Path to the directory containing soundfont files. - soundfont_prob_dict : dict, optional - Dictionary containing the probability of using a soundfont file. - The keys are the soundfont file names and the values are the - probability of using the soundfont file. If none is given, then - a uniform distribution is used. - num_wavs_per_midi : int, optional - Number of audio files to generate per MIDI file. - """ - - self.soundfont_path = soundfont_path - self.soundfont_prob_dict = soundfont_prob_dict - self.num_wavs_per_midi = num_wavs_per_midi - - self.fs_objs = self._load_soundfonts() - self.soundfont_cumul_prob_dict = self._get_cumulative_prob_dict() - - def _load_soundfonts(self): - """Loads the soundfonts into fluidsynth objects.""" - fs_files = os.listdir(self.soundfont_path) - fs_objs = {} - for fs_file in fs_files: - fs_objs[fs_file] = FluidSynth(fs_file) - return fs_objs - - def _get_cumulative_prob_dict(self): - """Returns a dictionary with the cumulative probabilities of the soundfonts. - Used for sampling the soundfonts. - """ - if self.soundfont_prob_dict is None: - self.soundfont_prob_dict = {k: 1 / len(self.fs_objs) for k in self.fs_objs.keys()} - self.soundfont_prob_dict = {k: v / sum(self.soundfont_prob_dict.values()) - for k, v in self.soundfont_prob_dict.items()} - cumul_prob_dict = {} - cumul_prob = 0 - for k, v in self.soundfont_prob_dict.items(): - cumul_prob_dict[k] = (cumul_prob, cumul_prob + v) - cumul_prob += v - return cumul_prob_dict - - def _sample_soundfont(self): - """Samples a soundfont file.""" - rand_num = random.random() - for k, (v_s, v_e) in self.soundfont_cumul_prob_dict.items(): - if (rand_num >= v_s) and (rand_num < v_e): - return self.fs_objs[k] - - def get_wav(self, midi_path: str, save_path: str): - """ - Converts a MIDI file to audio. - - Parameters - ---------- - midi_path : str - Path to the MIDI file. - save_path : str - Path to save the audio file. - """ - for i in range(self.num_wavs_per_midi): - soundfont = self._sample_soundfont() - if self.num_wavs_per_midi > 1: - save_path = save_path[:-4] + f"_{i}.wav" - soundfont.midi_to_audio(midi_path, save_path) + + +# Occasionally the worker util goes to 0 for some reason, debug this def get_wav_mid_segments( @@ -133,7 +64,7 @@ def get_wav_mid_segments( res = [] for idx in range( 0, - total_samples - (num_samples - (num_samples // stride_factor)), + total_samples - (num_samples - num_samples // stride_factor), num_samples // stride_factor, ): audio_feature = pad_or_trim(wav[idx:], length=num_samples) @@ -142,6 +73,7 @@ def get_wav_mid_segments( midi_dict=midi_dict, start_ms=idx // samples_per_ms, end_ms=(idx + num_samples) / samples_per_ms, + max_pedal_len_ms=10000, ) else: mid_feature = [] @@ -154,29 +86,97 @@ def get_wav_mid_segments( return res -def write_features(args): - audio_path, mid_path, save_path = args +def write_features(audio_path: str, mid_path: str, save_path: str): features = get_wav_mid_segments( audio_path=audio_path, mid_path=mid_path, return_json=False, ) - dirname, basename = os.path.split(save_path) - proc_save_path = os.path.join(dirname, str(os.getpid()) + basename) - with open(proc_save_path, mode="ab") as file: + # Father forgive me for I have sinned + with open(save_path, mode="a") as file: for wav, seq in features: - file.write( - orjson.dumps( - wav.numpy(), - option=orjson.OPT_SERIALIZE_NUMPY, - ) - ) - file.write(b"\n") - file.write(orjson.dumps(seq)) - file.write(b"\n") + # Encode wav using b64 to avoid newlines + wav_buffer = io.BytesIO() + torch.save(wav, wav_buffer) + wav_buffer.seek(0) + wav_bytes = wav_buffer.read() + wav_str = base64.b64encode(wav_bytes).decode("utf-8") + file.write(wav_str) + file.write("\n") + + seq_bytes = orjson.dumps(seq) + seq_str = base64.b64encode(seq_bytes).decode("utf-8") + file.write(seq_str) + file.write("\n") + + +def get_synth_audio(cli_cmd_fn: str, mid_path: str, wav_path: str): + _cmd = cli_cmd_fn(mid_path, wav_path) + os.system(_cmd) + + +def write_synth_features(cli_cmd_fn: Callable, mid_path: str, save_path: str): + audio_path_temp = f"{os.getpid()}_temp.wav" + + try: + get_synth_audio( + cli_cmd=cli_cmd_fn, mid_path=mid_path, wav_path=audio_path_temp + ) + except: + if os.path.isfile(audio_path_temp): + os.remove(audio_path_temp) + return + else: + features = get_wav_mid_segments( + audio_path=audio_path_temp, + mid_path=mid_path, + return_json=False, + ) + os.remove(audio_path_temp) + + with open(save_path, mode="a") as file: + for wav, seq in features: + wav_buffer = io.BytesIO() + torch.save(wav, wav_buffer) + wav_buffer.seek(0) + wav_bytes = wav_buffer.read() + wav_str = base64.b64encode(wav_bytes).decode("utf-8") + file.write(wav_str) + file.write("\n") + + seq_bytes = orjson.dumps(seq) + seq_str = base64.b64encode(seq_bytes).decode("utf-8") + file.write(seq_str) + file.write("\n") + - return proc_save_path +def build_worker_fn(load_path_queue, save_path_queue, _save_path: str): + dirname, basename = os.path.split(_save_path) + worker_save_path = os.path.join(dirname, str(os.getpid()) + basename) + + while not load_path_queue.empty(): + audio_path, mid_path = load_path_queue.get() + write_features(audio_path, mid_path, worker_save_path) + + print("Worker", os.getpid(), "finished") + save_path_queue.put(worker_save_path) + + +def build_synth_worker_fn( + cli_cmd: Callable, + load_path_queue, + save_path_queue, + _save_path: str, +): + dirname, basename = os.path.split(_save_path) + worker_save_path = os.path.join(dirname, str(os.getpid()) + basename) + + while not load_path_queue.empty(): + mid_path = load_path_queue.get() + write_synth_features(cli_cmd, mid_path, worker_save_path) + + save_path_queue.put(worker_save_path) class AmtDataset(torch.utils.data.Dataset): @@ -222,8 +222,10 @@ def _format(tok): self.file_mmap.seek(self.index[idx]) # Load data from line - wav = torch.tensor(orjson.loads(self.file_mmap.readline())) - _seq = orjson.loads(self.file_mmap.readline()) + wav = torch.load( + io.BytesIO(base64.b64decode(self.file_mmap.readline())) + ) + _seq = orjson.loads(base64.b64decode(self.file_mmap.readline())) _seq = [_format(tok) for tok in _seq] # Format seq _seq = self.mixup_fn(_seq) # Data augmentation @@ -267,11 +269,31 @@ def _get_index_path(load_path: str): f"{load_path.rsplit('.', 1)[0]}_index.{load_path.rsplit('.', 1)[1]}" ) + def _build_index(self): + self.file_mmap.seek(0) + index = [] + pos = 0 + while True: + pos_buff = pos + + pos = self.file_mmap.find(b"\n", pos) + if pos == -1: + break + pos = self.file_mmap.find(b"\n", pos + 1) + if pos == -1: + break + + index.append(pos_buff) + pos += 1 + + return index + @classmethod def build( cls, - matched_load_paths: list[tuple[str, str]], + load_paths: list, save_path: str, + cli_cmd_fn: Callable | None = None, num_processes: int = 1, ): assert os.path.isfile(save_path) is False, f"{save_path} already exists" @@ -281,18 +303,55 @@ def build( print(f"Removing existing index file at {index_path}") os.remove(AmtDataset._get_index_path(load_path=save_path)) - num_paths = len(matched_load_paths) - with Pool(processes=num_processes) as pool: - sharded_save_paths = [] - res = pool.imap_unordered( - write_features, - ((ap, mp, save_path) for ap, mp in matched_load_paths), - ) - for idx, proc_save_path in enumerate(res): - if idx % 10 == 0 and idx != 0: - print(f"Finished {idx}/{num_paths}") - if proc_save_path not in sharded_save_paths: - sharded_save_paths.append(proc_save_path) + save_path_queue = Queue() + load_path_queue = Queue() + for entry in load_paths: + load_path_queue.put(entry) + + if cli_cmd_fn is None: + # Build matched audio-midi dataset + assert len(load_paths[0]) == 2, "Invalid load paths" + print("Building matched audio-midi dataset") + worker_processes = [ + Process( + target=build_worker_fn, + args=( + load_path_queue, + save_path_queue, + save_path, + ), + ) + for _ in range(num_processes) + ] + else: + # Build synthetic dataset + assert len(load_paths[0]) == 1, "Invalid load paths" + print("Building synthetic dataset") + worker_processes = [ + Process( + target=build_synth_worker_fn, + args=( + cli_cmd_fn, + load_path_queue, + save_path_queue, + save_path, + ), + ) + for _ in range(num_processes) + ] + + for p in worker_processes: + p.start() + for p in worker_processes: + p.join() + + sharded_save_paths = [] + while not save_path_queue.empty(): + try: + _path = save_path_queue.get_nowait() + sharded_save_paths.append(_path) + except Queue.Empty: + break # This is bad, however cat is fast if shutil.which("cat") is None: @@ -311,22 +370,3 @@ def build( # Create index by loading object AmtDataset(load_path=save_path) - - def _build_index(self): - self.file_mmap.seek(0) - index = [] - pos = 0 - while True: - pos_buff = pos - - pos = self.file_mmap.find(b"\n", pos) - if pos == -1: - break - pos = self.file_mmap.find(b"\n", pos + 1) - if pos == -1: - break - - index.append(pos_buff) - pos += 1 - - return index diff --git a/amt/infer.py b/amt/infer.py deleted file mode 100644 index 289d499..0000000 --- a/amt/infer.py +++ /dev/null @@ -1,475 +0,0 @@ -import os -import time -import random -import logging -import torch -import torch.multiprocessing as multiprocessing - -from torch.multiprocessing import Queue -from tqdm import tqdm -from functools import wraps -from torch.cuda import is_bf16_supported - -from amt.model import AmtEncoderDecoder -from amt.tokenizer import AmtTokenizer -from amt.audio import AudioTransform, pad_or_trim -from amt.data import get_wav_mid_segments - - -MAX_SEQ_LEN = 4096 -LEN_MS = 30000 -STRIDE_FACTOR = 3 -CHUNK_LEN_MS = LEN_MS // STRIDE_FACTOR -BEAM = 5 -ONSET_TOLERANCE = 61 -VEL_TOLERANCE = 100 - - -def _setup_logger(): - logger = logging.getLogger(__name__) - for h in logger.handlers[:]: - logger.removeHandler(h) - - logger.propagate = False - logger.setLevel(logging.DEBUG) - formatter = logging.Formatter( - "[%(asctime)s] %(process)d: [%(levelname)s] %(message)s", - ) - - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - ch.setFormatter(formatter) - logger.addHandler(ch) - - return logging.getLogger(__name__) - - -def calculate_vel( - logits: torch.Tensor, - init_vel: int, - tokenizer: AmtTokenizer = AmtTokenizer(), -): - probs, idxs = torch.topk(torch.softmax(logits, dim=-1), BEAM) - vels = [tokenizer.id_to_tok[idx.item()] for idx in idxs] - - # Get rid of outliers - for idx in range(BEAM): - vel = vels[idx] - if type(vel) is not tuple: - vels[idx] = 0 - probs[idx] = 0.0 - elif vel[0] != "vel": - vels[idx] = 0 - probs[idx] = 0.0 - elif (vel[1] < init_vel - VEL_TOLERANCE / 2) or ( - vel[1] > init_vel + VEL_TOLERANCE / 2 - ): - vels[idx] = vels[idx][1] - probs[idx] = 0.0 - else: - vels[idx] = vels[idx][1] - - vels = torch.tensor(vels).to(probs.device) - new_vel = torch.sum(vels * probs) / torch.sum(probs) - new_vel = round(new_vel.item() / 5) * 5 - - return tokenizer.tok_to_id[("vel", new_vel)] - - -def calculate_onset( - logits: torch.Tensor, - init_onset: int, - tokenizer: AmtTokenizer = AmtTokenizer(), -): - probs, idxs = torch.topk(torch.softmax(logits, dim=-1), BEAM) - onsets = [tokenizer.id_to_tok[idx.item()] for idx in idxs] - - # Get rid of outliers - for idx in range(BEAM): - onset = onsets[idx] - if type(onset) is not tuple: - onsets[idx] = 0 - probs[idx] = 0.0 - elif onset[0] != "onset": - onsets[idx] = 0 - probs[idx] = 0.0 - elif (onset[1] < init_onset - ONSET_TOLERANCE / 2) or ( - onset[1] > init_onset + ONSET_TOLERANCE / 2 - ): - onsets[idx] = onsets[idx][1] - probs[idx] = 0.0 - else: - onsets[idx] = onsets[idx][1] - - onsets = torch.tensor(onsets).to(probs.device) - new_onset = torch.sum(onsets * probs) / torch.sum(probs) - new_onset = round(new_onset.item() / 10) * 10 - - return tokenizer.tok_to_id[("onset", new_onset)] - - -def optional_bf16_autocast(func): - @wraps(func) - def wrapper(*args, **kwargs): - # Assuming 'check_bfloat16_support()' returns True if bfloat16 is supported - if is_bf16_supported(): - with torch.autocast("cuda", dtype=torch.bfloat16): - return func(*args, **kwargs) - else: - # Call the function with float16 if bfloat16 is not supported - with torch.autocast("cuda", dtype=torch.float32): - return func(*args, **kwargs) - - return wrapper - - -@optional_bf16_autocast -def process_segments( - tasks: list, - model: AmtEncoderDecoder, - audio_transform: AudioTransform, - tokenizer: AmtTokenizer, -): - logger = logging.getLogger(__name__) - audio_segs = torch.stack( - [audio_seg for (audio_seg, prefix), _ in tasks] - ).cuda() - log_mels = audio_transform.log_mel(audio_segs) - audio_features = model.embed_audio(mel=log_mels) - - raw_prefixes = [prefix for (audio_seg, prefix), _ in tasks] - prefix_lens = [len(prefix) for prefix in raw_prefixes] - min_prefix_len = min(prefix_lens) - prefixes = [ - tokenizer.trunc_seq(prefix, MAX_SEQ_LEN) for prefix in raw_prefixes - ] - seq = torch.stack([tokenizer.encode(prefix) for prefix in prefixes]).cuda() - end_idxs = [MAX_SEQ_LEN for _ in prefixes] - - kv_cache = model.get_empty_cache() - - # for idx in ( - # pbar := tqdm( - # range(min_prefix_len, MAX_SEQ_LEN - 1), - # total=MAX_SEQ_LEN - (min_prefix_len + 1), - # leave=False, - # ) - # ): - for idx in range(min_prefix_len, MAX_SEQ_LEN - 1): - if idx == min_prefix_len: - logits = model.decoder( - xa=audio_features, - x=seq[:, :idx], - kv_cache=kv_cache, - ) - else: - logits = model.decoder( - xa=audio_features, - x=seq[:, idx - 1 : idx], - kv_cache=kv_cache, - ) - - next_tok_ids = torch.argmax(logits[:, -1], dim=-1) - - for batch_idx in range(logits.shape[0]): - if idx > end_idxs[batch_idx]: - # End already seen, add pad token - tok_id = tokenizer.pad_id - elif idx >= prefix_lens[batch_idx]: - # New token required, recalculated if needed - tok_id = next_tok_ids[batch_idx].item() - tok = tokenizer.id_to_tok[tok_id] - if type(tok) is tuple and tok[0] == "onset": - # If onset token, recalculate - tok_id = calculate_onset(logits[batch_idx, -1], tok[1]) - elif type(tok) is tuple and tok[0] == "vel": - # If velocity token, recalculate - tok_id = calculate_vel(logits[batch_idx, -1], tok[1]) - - else: - # Still in prefix tokens, do nothing - tok_id = tokenizer.tok_to_id[prefixes[batch_idx][idx]] - - seq[batch_idx, idx] = tok_id - tok = tokenizer.id_to_tok[tok_id] - if tok == tokenizer.eos_tok: - end_idxs[batch_idx] = idx - elif ( - type(tok) is tuple - and tok[0] == "onset" - and tok[1] >= LEN_MS - CHUNK_LEN_MS - ): - end_idxs[batch_idx] = idx - 2 - - if all(_idx <= idx for _idx in end_idxs): - break - - if not all(_idx <= idx for _idx in end_idxs): - logger.warning("Context length overflow when transcribing segment") - - results = [ - tokenizer.decode(seq[_idx, : end_idxs[_idx] + 1]) - for _idx in range(seq.shape[0]) - ] - - return results - - -def gpu_manager( - gpu_task_queue: Queue, - result_queue: Queue, - model: AmtEncoderDecoder, - batch_size: int, -): - model.compile() - logger = _setup_logger() - audio_transform = AudioTransform().cuda() - tokenizer = AmtTokenizer(return_tensors=True) - - wait_for_batch = True - batch = [] - while True: - try: - task, pid = gpu_task_queue.get(timeout=5) - except: - logger.info(f"GPU task timeout") - if len(batch) == 0: - logger.info(f"Finished GPU tasks") - return - else: - wait_for_batch = False - else: - batch.append((task, pid)) - - if len(batch) == batch_size or ( - len(batch) > 0 and wait_for_batch is False - ): - # Process batch on GPU - results = process_segments( - tasks=[task for task in batch], - model=model, - audio_transform=audio_transform, - tokenizer=tokenizer, - ) - for result, (_, pid) in zip(results, batch): - result_queue.put({"result": result, "pid": pid}) - batch.clear() - - -def _shift_onset(seq: list, shift_ms: int): - res = [] - for tok in seq: - if type(tok) is tuple and tok[0] == "onset": - res.append(("onset", tok[1] + shift_ms)) - else: - res.append(tok) - - return res - - -def _truncate_seq( - seq: list, - start_ms: int, - end_ms: int, - tokenizer: AmtTokenizer = AmtTokenizer(), -): - if start_ms == end_ms: - _mid_dict, unclosed_notes = tokenizer._detokenize_midi_dict( - seq, start_ms, return_unclosed_notes=True - ) - random.shuffle(unclosed_notes) - return [("prev", p) for p in unclosed_notes] + [tokenizer.bos_tok] - else: - _mid_dict = tokenizer._detokenize_midi_dict(seq, LEN_MS) - try: - res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1) - except Exception: - print("Truncate failed") - return [""] - else: - if res[-1] == tokenizer.eos_tok: - res.pop() - return res - - -def process_file( - file_path, - gpu_task_queue: Queue, - result_queue: Queue, - tokenizer: AmtTokenizer = AmtTokenizer(), -): - logger = logging.getLogger(__name__) - pid = multiprocessing.current_process().pid - - logger.info(f"Getting wav segments") - audio_segments = [ - f - for f, _ in get_wav_mid_segments( - audio_path=file_path, stride_factor=STRIDE_FACTOR - ) - ] - - res = [] - seq = [tokenizer.bos_tok] - concat_seq = [tokenizer.bos_tok] - for idx, audio_seg in enumerate(audio_segments): - init_idx = len(seq) - - # Add to gpu queue and wait for results - gpu_task_queue.put(((audio_seg, seq), pid)) - while True: - gpu_result = result_queue.get() - if gpu_result["pid"] == pid: - seq = gpu_result["result"] - break - else: - result_queue.put(gpu_result) - - concat_seq += _shift_onset( - seq[init_idx:], - idx * CHUNK_LEN_MS, - ) - - if idx == len(audio_segments) - 1: - res.append(concat_seq) - elif concat_seq[-1] == tokenizer.eos_tok: - res.append(concat_seq) - seq = [tokenizer.bos_tok] - concat_seq = [tokenizer.bos_tok] - logger.info(f"Finished segment - eos_tok seen") - else: - seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS - CHUNK_LEN_MS) - if len(seq) == 1: - res.append(concat_seq) - seq = [tokenizer.bos_tok] - concat_seq = [tokenizer.bos_tok] - logger.info(f"Exiting early - silence") - - return res - - -def worker( - file_queue: Queue, - gpu_task_queue: Queue, - result_queue: Queue, - save_dir: str, - input_dir: str | None = None, -): - def _save_seq(_seq: list, _save_path: str): - if os.path.exists(_save_path): - logger.info(f"Already exists {_save_path} - overwriting") - - for tok in _seq[::-1]: - if type(tok) is tuple and tok[0] == "onset": - last_onset = tok[1] - break - - try: - mid_dict = tokenizer._detokenize_midi_dict( - tokenized_seq=_seq, len_ms=last_onset - ) - mid = mid_dict.to_midi() - mid.save(_save_path) - except Exception as e: - logger.error(f"Failed to save {_save_path}") - - def _get_save_path(_file_path: str, _idx: int | str = ""): - if input_dir is None: - save_path = os.path.join( - save_dir, - os.path.splitext(os.path.basename(file_path))[0] - + f"{_idx}.mid", - ) - else: - input_rel_path = os.path.relpath(_file_path, input_dir) - save_path = os.path.join( - save_dir, os.path.splitext(input_rel_path)[0] + f"{_idx}.mid" - ) - if not os.path.isdir(os.path.dirname(save_path)): - os.makedirs(os.path.dirname(save_path), exist_ok=True) - - return save_path - - logger = _setup_logger() - tokenizer = AmtTokenizer() - files_processed = 0 - while not file_queue.empty(): - file_path = file_queue.get() - - try: - seqs = process_file(file_path, gpu_task_queue, result_queue) - except Exception as e: - logger.error(f"Failed to process {file_path}") - continue - - logger.info(f"Transcribed into {len(seqs)} segment(s)") - for _idx, seq in enumerate(seqs): - _save_seq(seq, _get_save_path(file_path, _idx)) - - files_processed += 1 - logger.info(f"Finished file {files_processed} - {file_path}") - logger.info(f"{file_queue.qsize()} file(s) remaining in queue") - - -def batch_transcribe( - file_paths, # Queue | list, - model: AmtEncoderDecoder, - save_dir: str, - batch_size: int = 16, - gpu_id: int | None = None, - input_dir: str | None = None, -): - if gpu_id is not None: - os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) - - model.cuda() - model.eval() - if isinstance(file_paths, list): - file_queue = Queue() - for file_path in file_paths: - file_queue.put(file_path) - else: - file_queue = file_paths - - gpu_task_queue = Queue() - result_queue = Queue() - - worker_processes = [ - multiprocessing.Process( - target=worker, - args=( - file_queue, - gpu_task_queue, - result_queue, - save_dir, - input_dir, - ), - ) - for _ in range(batch_size + 1) - ] - for p in worker_processes: - p.start() - - time.sleep(10) - gpu_manager_process = multiprocessing.Process( - target=gpu_manager, - args=(gpu_task_queue, result_queue, model, batch_size), - ) - gpu_manager_process.start() - - for p in worker_processes: - p.join() - - gpu_manager_process.join() - - -def sample_top_p(probs, p): - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - - return next_token diff --git a/amt/inference/__init__.py b/amt/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/amt/inference/model.py b/amt/inference/model.py new file mode 100644 index 0000000..c302614 --- /dev/null +++ b/amt/inference/model.py @@ -0,0 +1,435 @@ +"""Contains code modified from https://github.com/openai/whisper""" + +import math +import torch +import torch.nn.functional as F + +from torch import Tensor, nn +from dataclasses import dataclass +from typing import Dict, Iterable, Optional + + +@dataclass +class ModelConfig: + n_mels: int + n_audio_ctx: int + n_audio_state: int + n_audio_head: int + n_audio_layer: int + n_text_ctx: int + n_text_state: int + n_text_head: int + n_text_layer: int + n_vocab: Optional[int] = None + + def set_vocab_size(self, vocab_size: int): + self.n_vocab = vocab_size + + +class KVCache(nn.Module): + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + head_dim: int, + dtype=torch.bfloat16, + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val, v_val: [B, H, L, D] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +def sinusoids( + length: int, channels: int, max_timescale: float = 10000 +) -> torch.Tensor: + """Returns sinusoids for positional embedding""" + if channels % 2 != 0: + raise ValueError( + f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." + ) + log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp( + -log_timescale_increment * torch.arange(channels // 2) + ) + scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) + return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) + + +class EncoderAttention(nn.Module): + def __init__(self, n_state: int, n_head: int): + super().__init__() + assert n_state % n_head == 0, "n_head does not evenly devide n_state" + + self.n_head = n_head + self.d_head = n_state // n_head + self.query = nn.Linear(n_state, n_state, bias=False) + self.key = nn.Linear(n_state, n_state, bias=False) + self.value = nn.Linear(n_state, n_state, bias=False) + self.out = nn.Linear(n_state, n_state, bias=False) + + def forward( + self, + xa: Tensor, + ): + q = self.query(xa) + k = self.key(xa) + v = self.value(xa) + + # Reshape for correct format + batch_size, source_seq_len, _ = k.shape + batch_size, target_seq_len, _ = q.shape + q = q.view( + batch_size, target_seq_len, self.n_head, self.d_head + ).transpose(1, 2) + k = k.view( + batch_size, source_seq_len, self.n_head, self.d_head + ).transpose(1, 2) + v = v.view( + batch_size, source_seq_len, self.n_head, self.d_head + ).transpose(1, 2) + + wv = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, + is_causal=False, + ) + wv = wv.transpose(1, 2).reshape( + batch_size, + target_seq_len, + self.n_head * self.d_head, + ) + + return self.out(wv) + + +class CrossAttention(nn.Module): + def __init__(self, n_state: int, n_head: int): + super().__init__() + assert n_state % n_head == 0, "n_head does not evenly devide n_state" + + self.n_head = n_head + self.d_head = n_state // n_head + self.query = nn.Linear(n_state, n_state, bias=False) + self.key = nn.Linear(n_state, n_state, bias=False) + self.value = nn.Linear(n_state, n_state, bias=False) + self.out = nn.Linear(n_state, n_state, bias=False) + self.kv_cache: KVCache | None = None + + def get_kv(self, xa: torch.Tensor, xa_input_pos: Tensor): + assert self.kv_cache is not None, "No kv_cache" + k = self.key(xa[:, xa_input_pos]) + v = self.value(xa[:, xa_input_pos]) + + # Reshape for correct format + batch_size, source_seq_len, _ = k.shape + k = k.view( + batch_size, source_seq_len, self.n_head, self.d_head + ).transpose(1, 2) + v = v.view( + batch_size, source_seq_len, self.n_head, self.d_head + ).transpose(1, 2) + + k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=xa_input_pos) + + return k, v + + def forward( + self, + x: Tensor, + xa: Tensor, + xa_input_pos: Tensor, + ): + q = self.query(x) + batch_size, target_seq_len, _ = q.shape + q = q.view( + batch_size, target_seq_len, self.n_head, self.d_head + ).transpose(1, 2) + + k, v = self.get_kv(xa, xa_input_pos) + wv = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, + is_causal=False, + ) + wv = wv.transpose(1, 2).reshape( + batch_size, + target_seq_len, + self.n_head * self.d_head, + ) + + return self.out(wv) + + +class CausalSelfAttention(nn.Module): + def __init__(self, n_state: int, n_head: int): + super().__init__() + assert n_state % n_head == 0, "n_head does not evenly devide n_state" + + self.n_state = n_state + self.n_head = n_head + self.d_head = n_state // n_head + self.out = nn.Linear(n_state, n_state, bias=False) + self.kv_cache: KVCache | None = None + + # Add this back after + self.combined_qkv = nn.Linear(n_state, 3 * n_state, bias=False) + self._register_load_state_dict_pre_hook(self.combined_qkv_hook) + + def get_kv(self, k: Tensor, v: Tensor, input_pos: Tensor): + k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=input_pos) + + return k, v + + def combined_qkv_hook(self, state_dict, prefix, *args): + if prefix + "query.weight" in state_dict: + wq = state_dict.pop(prefix + "query.weight") + wk = state_dict.pop(prefix + "key.weight") + wv = state_dict.pop(prefix + "value.weight") + state_dict[prefix + "combined_qkv.weight"] = torch.cat([wq, wk, wv]) + + def forward( + self, + x: Tensor, + mask: Optional[Tensor] = None, + input_pos: Optional[Tensor] = None, + ): + q, k, v = self.combined_qkv(x).split( + [self.n_state, self.n_state, self.n_state], dim=-1 + ) + + batch_size, target_seq_len, _ = q.shape + q = q.view( + batch_size, target_seq_len, self.n_head, self.d_head + ).transpose(1, 2) + + batch_size, source_seq_len, _ = k.shape + k = k.view( + batch_size, source_seq_len, self.n_head, self.d_head + ).transpose(1, 2) + v = v.view( + batch_size, source_seq_len, self.n_head, self.d_head + ).transpose(1, 2) + + k, v = self.get_kv(k, v, input_pos=input_pos) + wv = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, + attn_mask=mask, + ) + + # (bz, nh, L, dh) -> (bz, L, nh, dh) -> (bz, L, d) + wv = wv.transpose(1, 2).reshape( + batch_size, target_seq_len, self.n_head * self.d_head + ) + + return self.out(wv) + + +class EncoderAttentionBlock(nn.Module): + def __init__( + self, n_state: int, n_head: int, cross_attention: bool = False + ): + super().__init__() + self.attn = EncoderAttention(n_state, n_head) + self.attn_ln = nn.LayerNorm(n_state) + n_mlp = n_state * 4 + self.mlp = nn.Sequential( + nn.Linear(n_state, n_mlp, bias=False), + nn.GELU(), + nn.Linear(n_mlp, n_state, bias=False), + ) + self.mlp_ln = nn.LayerNorm(n_state) + + def forward( + self, + xa: Tensor, + ): + xa = xa + self.attn( + self.attn_ln(xa), + ) + xa = xa + self.mlp(self.mlp_ln(xa)) + + return xa + + +class DecoderAttentionBlock(nn.Module): + def __init__( + self, n_state: int, n_head: int, cross_attention: bool = False + ): + super().__init__() + self.attn = CausalSelfAttention(n_state, n_head) + self.attn_ln = nn.LayerNorm(n_state) + self.cross_attn = ( + CrossAttention(n_state, n_head) if cross_attention else None + ) + self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp = nn.Sequential( + nn.Linear(n_state, n_mlp, bias=False), + nn.GELU(), + nn.Linear(n_mlp, n_state, bias=False), + ) + self.mlp_ln = nn.LayerNorm(n_state) + + def forward( + self, + x: Tensor, + xa: Tensor, + mask: Optional[Tensor] = None, + x_input_pos: Optional[Tensor] = None, + xa_input_pos: Optional[Tensor] = None, + ): + x = x + self.attn( + self.attn_ln(x), + mask=mask, + input_pos=x_input_pos, + ) + x = x + self.cross_attn(self.cross_attn_ln(x), xa, xa_input_pos) + x = x + self.mlp(self.mlp_ln(x)) + + return x + + +class AudioEncoder(nn.Module): + def __init__( + self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + ): + super().__init__() + self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d( + n_state, n_state, kernel_size=3, stride=2, padding=1 + ) + self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) + + self.blocks: Iterable[EncoderAttentionBlock] = nn.ModuleList( + [EncoderAttentionBlock(n_state, n_head) for _ in range(n_layer)] + ) + self.ln_post = nn.LayerNorm(n_state) + + def forward(self, xa: Tensor): + xa = F.gelu(self.conv1(xa)) + xa = F.gelu(self.conv2(xa)) + xa = xa.permute(0, 2, 1) + + assert ( + xa.shape[1:] == self.positional_embedding.shape + ), f"incorrect audio shape: {xa.shape[1:]} != {self.positional_embedding.shape}" + xa = (xa + self.positional_embedding).to(xa.dtype) + + for block in self.blocks: + xa = block(xa) + + xa = self.ln_post(xa) + return xa + + +class TextDecoder(nn.Module): + def __init__( + self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + ): + super().__init__() + self.token_embedding = nn.Embedding(n_vocab, n_state) + self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) + + self.blocks: Iterable[DecoderAttentionBlock] = nn.ModuleList( + [ + DecoderAttentionBlock(n_state, n_head, cross_attention=True) + for _ in range(n_layer) + ] + ) + self.ln = nn.LayerNorm(n_state) + self.register_buffer("causal_mask", None, persistent=False) + + def forward( + self, + x: Tensor, + xa: Tensor, + x_input_pos: Tensor, + xa_input_pos: Tensor, + ): + mask = self.causal_mask[None, None, x_input_pos] + x = self.token_embedding(x) + self.positional_embedding[x_input_pos] + + for block in self.blocks: + x = block( + x=x, + xa=xa, + mask=mask, + x_input_pos=x_input_pos, + xa_input_pos=xa_input_pos, + ) + + x = self.ln(x) + logits = ( + x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + + return logits + + def setup_cache( + self, + batch_size, + max_seq_len=4096, + max_audio_len=1500, + ): + self.causal_mask = torch.tril( + torch.ones(max_seq_len, max_seq_len, dtype=torch.bool) + ) + # Init cache + for b in self.blocks: + b.attn.kv_cache = KVCache( + max_batch_size=batch_size, + max_seq_length=max_seq_len, + n_heads=8, + head_dim=64, + ).cuda() + b.cross_attn.kv_cache = KVCache( + max_batch_size=batch_size, + max_seq_length=max_audio_len, + n_heads=8, + head_dim=64, + ).cuda() + + +class AmtEncoderDecoder(nn.Module): + def __init__(self, dims: ModelConfig): + super().__init__() + self.dims = dims + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + ) + self.decoder = TextDecoder( + self.dims.n_vocab, + self.dims.n_text_ctx, + self.dims.n_text_state, + self.dims.n_text_head, + self.dims.n_text_layer, + ) + + def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor: + _buff = self.encoder(mel) + return self.decoder(tokens, _buff) + + @property + def device(self): + return next(self.parameters()).device diff --git a/amt/inference/quantize.py b/amt/inference/quantize.py new file mode 100644 index 0000000..a54b4f7 --- /dev/null +++ b/amt/inference/quantize.py @@ -0,0 +1,153 @@ +"""Contains code modified from https://github.com/pytorch-labs/gpt-fast""" + +import torch + +from torch import nn as nn +from torch.nn import functional as F + + +def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): + # assumes symmetric quantization + # assumes axis == 0 + # assumes dense memory format + # TODO(future): relax ^ as needed + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scales and zero_points based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scales = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scales is the same dtype as the original tensor + scales = torch.clamp(scales, min=eps).to(x.dtype) + zero_points = torch.zeros( + min_val_neg.size(), dtype=torch.int64, device=device + ) + + # quantize based on qmin/qmax/scales/zp + # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 + x_div = x / scales.unsqueeze(-1) + x_round = torch.round(x_div) + x_zp = x_round + zero_points.unsqueeze(-1) + quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) + + return quant, scales, zero_points + + +def replace_linear_weight_only_int8_per_channel(module): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + if child.bias is not None: + setattr( + module, + name, + WeightOnlyInt8LinearBias( + child.in_features, child.out_features + ), + ) + else: + setattr( + module, + name, + WeightOnlyInt8Linear(child.in_features, child.out_features), + ) + else: + replace_linear_weight_only_int8_per_channel(child) + + +class WeightOnlyInt8QuantHandler: + def __init__(self, mod: torch.nn.Module): + self.mod = mod + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + int8_weight, scales, _ = dynamically_quantize_per_channel( + mod.weight.float(), -128, 127, torch.int8 + ) + cur_state_dict[f"{fqn}.weight"] = int8_weight.to("cpu") + cur_state_dict[f"{fqn}.scales"] = scales.to( + mod.weight.dtype + ).to("cpu") + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_weight_only_int8_per_channel(self.mod) + return self.mod + + +class WeightOnlyInt8Linear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer( + "weight", torch.empty((out_features, in_features), dtype=torch.int8) + ) + self.register_buffer( + "scales", torch.ones(out_features, dtype=torch.bfloat16) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + + +# Kinda gross workaround - might not be fused by the compiler +class WeightOnlyInt8LinearBias(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer( + "weight", torch.empty((out_features, in_features), dtype=torch.int8) + ) + self.register_buffer( + "bias", torch.empty(out_features, dtype=torch.bfloat16) + ) + self.register_buffer( + "scales", torch.ones(out_features, dtype=torch.bfloat16) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return ( + F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + + self.bias + ) diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py new file mode 100644 index 0000000..9109e6a --- /dev/null +++ b/amt/inference/transcribe.py @@ -0,0 +1,716 @@ +import os +import time +import random +import logging +import traceback +import threading +import torch +import torch.multiprocessing as multiprocessing +import torch._dynamo.config +import torch._inductor.config + +from torch.multiprocessing import Queue +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +from functools import wraps +from torch.cuda import is_bf16_supported + +from amt.inference.model import AmtEncoderDecoder +from amt.tokenizer import AmtTokenizer +from amt.audio import AudioTransform +from amt.data import get_wav_mid_segments + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True + +MAX_SEQ_LEN = 4096 +MAX_BLOCK_LEN = 4096 +LEN_MS = 30000 +STRIDE_FACTOR = 3 +CHUNK_LEN_MS = LEN_MS // STRIDE_FACTOR + + +def _setup_logger(): + logger = logging.getLogger(__name__) + for h in logger.handlers[:]: + logger.removeHandler(h) + + logger.propagate = False + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "[%(asctime)s] %(process)d: [%(levelname)s] %(message)s", + ) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + fh = logging.FileHandler("transcribe.log") + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logging.getLogger(__name__) + + +@torch.jit.script +def get_static_mask(): + # The values are hardcoded here for the pytorch jit - manually update + col_indices = torch.arange(3419, device="cuda").unsqueeze(0) + mask_a = col_indices >= 392 + mask_b = col_indices <= 3418 + return col_indices, mask_a & mask_b + + +@torch.jit.script +def recalculate_tok_ids( + logits: torch.Tensor, + tok_ids: torch.Tensor, +): + probs = torch.softmax(logits, dim=-1) + + # Mask out all non-onset/vel tok_ids + col_indices, interval_mask = get_static_mask() + + # Mask out tok_ids larger than 30ms from original tok_id + tok_ids_expanded = tok_ids.unsqueeze(1) + mask_c = col_indices <= tok_ids_expanded + 3 + mask_d = col_indices >= tok_ids_expanded - 3 + beam_mask = mask_c & mask_d + + # Don't mask out the original tok_id (required for non-onset/vel toks) + tok_id_mask = torch.zeros_like(probs, dtype=torch.bool) + tok_id_mask.scatter_(1, tok_ids_expanded, 1) + + # Combine and calculate probs + combined_mask = (interval_mask & beam_mask) | tok_id_mask + probs[~combined_mask] = 0 + + # Calculate expected value + weighted_idxs = probs * torch.arange( + probs.size(1), device=probs.device + ).float().unsqueeze(0) + idx_evs = ( + (weighted_idxs.sum(dim=1) / (probs.sum(dim=1) + 1e-9)) + .round() + .to(torch.long) + ) + + return idx_evs + + +# Changes seq and eos_idxs in place - tok_ids hardcoded +@torch.jit.script +def update_seq_end_idxs_( + next_tok_ids: torch.Tensor, + seq: torch.Tensor, + eos_idxs: torch.Tensor, + prefix_lens: torch.Tensor, + idx: int, +): + # Update eos_idxs if next tok is eos_tok + eos_mask = next_tok_ids == 1 + eos_idxs[eos_mask] = idx + + # Update eos_idxs if next tok in onset > 20000 + offset_mask = next_tok_ids >= 2418 + eos_idxs[offset_mask] = idx - 2 + + # Don't update toks in prefix or after eos_idx + insert_mask = (prefix_lens <= idx) & (eos_idxs >= idx) + seq[insert_mask, idx] = next_tok_ids[insert_mask] + + +def optional_bf16_autocast(func): + @wraps(func) + def wrapper(*args, **kwargs): + if is_bf16_supported(): + with torch.autocast("cuda", dtype=torch.bfloat16): + return func(*args, **kwargs) + else: + with torch.autocast("cuda", dtype=torch.float32): + return func(*args, **kwargs) + + return wrapper + + +def decode_token( + model: AmtEncoderDecoder, + x: torch.Tensor, + xa: torch.Tensor, + x_input_pos: torch.Tensor, + xa_input_pos: torch.Tensor, +): + logits = model.decoder.forward( + x=x, + xa=xa, + x_input_pos=x_input_pos, + xa_input_pos=xa_input_pos, + )[:, -1] + next_tok_ids = torch.argmax(logits, dim=-1) + + return logits, next_tok_ids + + +@optional_bf16_autocast +@torch.no_grad() +def process_segments( + tasks: list, + model: AmtEncoderDecoder, + audio_transform: AudioTransform, + tokenizer: AmtTokenizer, + logger: logging.Logger, +): + audio_segs = torch.stack( + [audio_seg for (audio_seg, prefix), _ in tasks] + ).cuda() + log_mels = audio_transform.log_mel(audio_segs) + audio_features = model.encoder(xa=log_mels) + + raw_prefixes = [prefix for (audio_seg, prefix), _ in tasks] + prefix_lens = torch.tensor( + [len(prefix) for prefix in raw_prefixes], dtype=torch.int + ) + min_prefix_len = min(prefix_lens).item() + prefixes = [ + tokenizer.trunc_seq(prefix, MAX_BLOCK_LEN) for prefix in raw_prefixes + ] + seq = torch.stack([tokenizer.encode(prefix) for prefix in prefixes]).cuda() + eos_idxs = torch.tensor([MAX_BLOCK_LEN for _ in prefixes], dtype=torch.int) + + # for idx in ( + # pbar := tqdm( + # range(min_prefix_len, MAX_BLOCK_LEN - 1), + # total=MAX_BLOCK_LEN - (min_prefix_len + 1), + # leave=False, + # ) + # ): + for idx in range(min_prefix_len, MAX_BLOCK_LEN - 1): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): + if idx == min_prefix_len: + logits, next_tok_ids = decode_token( + model, + x=seq[:, :idx], + xa=audio_features, + x_input_pos=torch.arange(0, idx, device=seq.device), + xa_input_pos=torch.arange( + 0, audio_features.shape[1], device=seq.device + ), + ) + else: + logits, next_tok_ids = decode_token( + model, + x=seq[:, idx - 1 : idx], + xa=audio_features, + x_input_pos=torch.tensor( + [idx - 1], device=seq.device, dtype=torch.int + ), + xa_input_pos=torch.tensor( + [], device=seq.device, dtype=torch.int + ), + ) + + next_tok_ids = recalculate_tok_ids( + logits=logits, + tok_ids=next_tok_ids, + ) + update_seq_end_idxs_( + next_tok_ids=next_tok_ids, + seq=seq, + eos_idxs=eos_idxs, + prefix_lens=prefix_lens, + idx=idx, + ) + + if all(_idx <= idx for _idx in eos_idxs): + break + + # If there is a context length overflow, we need to have some special logic + # to make sure that a sequence of the correct format is returned. Right now + # it messes things up somehow + if not all(_idx <= idx for _idx in eos_idxs): + logger.warning("Context length overflow when transcribing segment") + + results = [ + tokenizer.decode(seq[_idx, : eos_idxs[_idx] + 1]) + for _idx in range(seq.shape[0]) + ] + + return results + + +# There is a memory leak in here somewhere +def gpu_manager( + gpu_batch_queue: Queue, + result_queue: Queue, + model: AmtEncoderDecoder, + batch_size: int, + gpu_id: int | None = None, +): + logger = _setup_logger() + logger.info("Started GPU manager") + + if gpu_id is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + + global decode_token, recalculate_tok_ids + model.decoder.setup_cache(batch_size=batch_size, max_seq_len=MAX_BLOCK_LEN) + model.cuda() + model.eval() + if batch_size == 1: + recalculate_tok_ids = torch.compile( + recalculate_tok_ids, mode="max-autotune-no-cudagraphs" + ) + decode_token = torch.compile( + decode_token, + # mode="reduce-overhead", + mode="max-autotune", + fullgraph=True, + ) + + audio_transform = AudioTransform().cuda() + tokenizer = AmtTokenizer(return_tensors=True) + + try: + while True: + try: + batch = gpu_batch_queue.get(timeout=10) + except Exception as e: + logger.info(f"GPU timedout waiting for batch") + break + else: + try: + results = process_segments( + tasks=batch, + model=model, + audio_transform=audio_transform, + tokenizer=tokenizer, + logger=logger, + ) + except Exception as e: + logger.error( + f"Failed to process batch: {traceback.format_exc()}" + ) + raise e + else: + # pid = -1 when its a pad sequence + for result, (_, pid) in zip(results, batch): + if pid != -1: + result_queue.put({"result": result, "pid": pid}) + + except Exception as e: + logger.error(f"GPU manager failed with exception: {e}") + finally: + logger.info(f"GPU manager terminated") + + +def _find_min_diff_batch(tasks: list, batch_size: int): + prefix_lens = [ + (len(prefix), idx) for idx, ((audio_seg, prefix), _) in enumerate(tasks) + ] + prefix_lens.sort(key=lambda x: x[0]) + + min_diff = float("inf") + start_idx = 0 + + # Iterate through the array to find the batch with the min difference + for _idx in range(len(prefix_lens) - batch_size + 1): + current_diff = ( + prefix_lens[_idx + batch_size - 1][0] - prefix_lens[_idx][0] + ) + if current_diff < min_diff: + min_diff = current_diff + start_idx = _idx + + return [ + orig_idx + for prefix_lens, orig_idx in prefix_lens[ + start_idx : start_idx + batch_size + ] + ] + + +def gpu_batch_manager( + gpu_task_queue: Queue, + gpu_batch_queue: Queue, + batch_size: int, +): + logger = _setup_logger() + logger.info("Started batch manager") + try: + tasks = [] + while True: + try: + task, pid = gpu_task_queue.get(timeout=0.2) + except Exception as e: + pass + else: + tasks.append((task, pid)) + continue + + # No tasks in queue -> check gpu batch queue + if gpu_batch_queue.empty() is False: + continue + elif len(tasks) == 0: + continue + + # Get new batch and add to batch queue + if len(tasks) < batch_size: + logger.info("Not enough tasks - padding batch") + while len(tasks) < batch_size: + _pad_task, _pid = tasks[0] + tasks.append((_pad_task, -1)) + + assert len(tasks) >= batch_size, "batch error" + new_batch_idxs = _find_min_diff_batch( + tasks, + batch_size=batch_size, + ) + gpu_batch_queue.put([tasks[_idx] for _idx in new_batch_idxs]) + tasks = [ + task + for _idx, task in enumerate(tasks) + if _idx not in new_batch_idxs + ] + except Exception as e: + logger.error(f"GPU batch manager failed with exception: {e}") + finally: + logger.info(f"GPU batch manager terminated") + + +def _shift_onset(seq: list, shift_ms: int): + res = [] + for tok in seq: + if type(tok) is tuple and tok[0] == "onset": + res.append(("onset", tok[1] + shift_ms)) + else: + res.append(tok) + + return res + + +def _truncate_seq( + seq: list, + start_ms: int, + end_ms: int, + logger: logging.Logger, + tokenizer: AmtTokenizer = AmtTokenizer(), +): + # Truncates and shifts a sequence by retokenizing the underlying midi_dict + if start_ms == end_ms: + _mid_dict, unclosed_notes = tokenizer._detokenize_midi_dict( + seq, start_ms, return_unclosed_notes=True + ) + random.shuffle(unclosed_notes) + return [("prev", p) for p in unclosed_notes] + [tokenizer.bos_tok] + else: + try: + _mid_dict = tokenizer._detokenize_midi_dict(seq, LEN_MS) + res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1) + except Exception as e: + logger.error(f"Truncate segment failed: {e}") + return [tokenizer.bos_tok] + else: + if res[-1] == tokenizer.eos_tok: + res.pop() + return res + + +def transcribe_file( + file_path, + gpu_task_queue: Queue, + result_queue: Queue, + pid: int, + tokenizer: AmtTokenizer = AmtTokenizer(), +): + logger = logging.getLogger(__name__) + + logger.info(f"Getting wav segments: {file_path}") + audio_segments = [ + f + for f, _ in get_wav_mid_segments( + audio_path=file_path, stride_factor=STRIDE_FACTOR + ) + ] + + res = [] + seq = [tokenizer.bos_tok] + concat_seq = [tokenizer.bos_tok] + for idx, audio_seg in enumerate(audio_segments): + init_idx = len(seq) + + # Add to gpu queue and wait for results + gpu_task_queue.put(((audio_seg, seq), pid)) + while True: + # Issue with this logic perhaps + gpu_result = result_queue.get(timeout=300) + if gpu_result["pid"] == pid: + seq = gpu_result["result"] + break + else: + result_queue.put(gpu_result) + + concat_seq += _shift_onset( + seq[init_idx:], + idx * CHUNK_LEN_MS, + ) + + if idx == len(audio_segments) - 1: + res.append(concat_seq) + elif concat_seq[-1] == tokenizer.eos_tok: + res.append(concat_seq) + seq = [tokenizer.bos_tok] + concat_seq = [tokenizer.bos_tok] + logger.info(f"Finished segment (eos_tok): {file_path}") + else: + # This might need it's logic adjusted + + seq = _truncate_seq( + seq, + CHUNK_LEN_MS, + LEN_MS - CHUNK_LEN_MS, + logger=logger, + ) + + if len(seq) == 1: + logger.error(f"Failed to transcribe segment: {file_path}") + if len(concat_seq) > 500: + res.append(concat_seq) + else: + pass + # logger.info(f"Sequence too short ({len(concat_seq)})") + + seq = [tokenizer.bos_tok] + concat_seq = [tokenizer.bos_tok] + + return res + + +def process_file( + file_path: str, + file_queue: Queue, + gpu_task_queue: Queue, + result_queue: Queue, + tokenizer: AmtTokenizer, + save_dir: str, + input_dir: str, + logger: logging.Logger, +): + def _save_seq(_seq: list, _save_path: str): + if os.path.exists(_save_path): + logger.info(f"Already exists {_save_path} - overwriting") + + for tok in _seq[::-1]: + if type(tok) is tuple and tok[0] == "onset": + last_onset = tok[1] + break + + try: + mid_dict = tokenizer._detokenize_midi_dict( + tokenized_seq=_seq, len_ms=last_onset + ) + mid = mid_dict.to_midi() + mid.save(_save_path) + except Exception as e: + logger.error(f"Failed to save {_save_path}") + + def _get_save_path(_file_path: str, _idx: int | str = ""): + if input_dir is None: + save_path = os.path.join( + save_dir, + os.path.splitext(os.path.basename(file_path))[0] + + f"{_idx}.mid", + ) + else: + input_rel_path = os.path.relpath(_file_path, input_dir) + save_path = os.path.join( + save_dir, os.path.splitext(input_rel_path)[0] + f"{_idx}.mid" + ) + if not os.path.isdir(os.path.dirname(save_path)): + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + return save_path + + def remove_failures_from_queue_(_queue: Queue, _pid: int): + _buff = [] + while True: + try: + _buff.append(_queue(timout=5)) + except Exception: + break + + num_removed = 0 + for _task, __pid in _buff: + if _pid != __pid: + _queue.put((_task, __pid)) + else: + num_removed += 1 + + return num_removed + + pid = threading.get_ident() + try: + seqs = transcribe_file(file_path, gpu_task_queue, result_queue, pid=pid) + except Exception as e: + logger.error(f"Failed to process {file_path}: {traceback.format_exc()}") + task_rmv_cnt = remove_failures_from_queue_(gpu_task_queue, pid) + res_rmv_cnt = remove_failures_from_queue_(result_queue, pid) + logger.info(f"Removed {task_rmv_cnt} from task queue") + logger.info(f"Removed {res_rmv_cnt} from result queue") + return + + logger.info(f"Finished file: {file_path}") + _idx = 0 + for seq in seqs: + if len(seq) < 1000: + logger.info("Skipping seq - too short") + continue + _save_seq(seq, _get_save_path(file_path, _idx)) + _idx += 1 + + logger.info(f"Transcribed into {_idx} segment(s)") + logger.info(f"{file_queue.qsize()} file(s) remaining in queue") + + +def worker( + file_queue: Queue, + gpu_task_queue: Queue, + result_queue: Queue, + save_dir: str, + input_dir: str | None = None, + tasks_per_worker: int = 1, +): + logger = _setup_logger() + tokenizer = AmtTokenizer() + threads = [] + try: + while not file_queue.empty() or any(t.is_alive() for t in threads): + while len(threads) < tasks_per_worker and not file_queue.empty(): + logging.info("Starting worker") + file_path = file_queue.get() + t = threading.Thread( + target=process_file, + args=( + file_path, + file_queue, + gpu_task_queue, + result_queue, + tokenizer, + save_dir, + input_dir, + logger, + ), + ) + t.start() + threads.append(t) + + threads = [t for t in threads if t.is_alive()] + + time.sleep(0.1) + + for t in threads: + t.join() + + except Exception as e: + logger.error(f"File worker failed with exception: {e}") + finally: + logger.info(f"File worker terminated") + + +# Needs to test this for multi-gpu +def batch_transcribe( + file_paths: list, + model: AmtEncoderDecoder, + save_dir: str, + batch_size: int = 16, + input_dir: str | None = None, + gpu_ids: int | None = None, + quantize: bool = True, +): + torch.multiprocessing.set_start_method("spawn") + num_gpus = len(gpu_ids) if gpu_ids is not None else 1 + logger = _setup_logger() + + if os.path.isfile("transcribe.log"): + os.remove("transcribe.log") + + if quantize is True: + logger.info("Quantising weights to int8") + model = quantize_int8(model) + + gpu_task_queue = Queue() + gpu_batch_queue = Queue() + result_queue = Queue() + file_queue = Queue() + for file_path in file_paths: + file_queue.put(file_path) + + num_workers = min(batch_size * num_gpus, len(file_paths)) + logger.info(f"Creating {num_workers} file worker(s)") + worker_processes = [ + multiprocessing.Process( + target=worker, + args=( + file_queue, + gpu_task_queue, + result_queue, + save_dir, + input_dir, + # Wait for all threads to finish + 4, + ), + ) + for _ in range(num_workers) + ] + gpu_batch_manager_process = multiprocessing.Process( + target=gpu_batch_manager, + args=(gpu_task_queue, gpu_batch_queue, batch_size), + ) + + start_time = time.time() + if num_gpus == 1: + gpu_manager_processes = [ + multiprocessing.Process( + target=gpu_manager, + args=(gpu_batch_queue, result_queue, model, batch_size), + ) + ] + else: + gpu_manager_processes = [ + multiprocessing.Process( + target=gpu_manager, + args=(gpu_batch_queue, result_queue, model, batch_size, gpu_id), + ) + for gpu_id in gpu_ids + ] + + for p in worker_processes: + p.start() + time.sleep(5) + gpu_batch_manager_process.start() + for p in gpu_manager_processes: + p.start() + + # Watch for file workers to finish + for p in worker_processes: + p.join() + for p in gpu_manager_processes: + p.join() + gpu_batch_manager_process.terminate() + + print("Took", (time.time() - start_time) / 60, "mins to transcribe files") + + +def quantize_int8(model: torch.nn.Module): + from amt.inference.quantize import WeightOnlyInt8QuantHandler + + quantizer = WeightOnlyInt8QuantHandler(model) + int8_state_dict = quantizer.create_quantized_state_dict() + _model = quantizer.convert_for_runtime() + _model.load_state_dict(int8_state_dict) + + return _model diff --git a/amt/model.py b/amt/model.py index 9e8ccb2..1b60a46 100644 --- a/amt/model.py +++ b/amt/model.py @@ -50,67 +50,27 @@ def __init__(self, n_state: int, n_head: int): self.n_head = n_head self.d_head = n_state // n_head - self.query = nn.Linear(n_state, n_state) + self.query = nn.Linear(n_state, n_state, bias=False) self.key = nn.Linear(n_state, n_state, bias=False) - self.value = nn.Linear(n_state, n_state) - self.out = nn.Linear(n_state, n_state) + self.value = nn.Linear(n_state, n_state, bias=False) + self.out = nn.Linear(n_state, n_state, bias=False) def forward( self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, - kv_cache: Optional[dict] = None, ): q = self.query(x) - if kv_cache is None: - # Normal forward - if xa is not None: - # Cross att - k = self.key(xa) - v = self.value(xa) - else: - # Self att in encoder/decoder - k = self.key(x) - v = self.value(x) + if xa is not None: + # Cross att + k = self.key(xa) + v = self.value(xa) else: - # Using cache - k_id = f"{id(self)}_k" - v_id = f"{id(self)}_v" - - if xa is not None: - # Cross att - calculate once and reuse - if kv_cache.get(k_id) is None: - # Not recorded yet, calculate and store - k = self.key(xa) - v = self.value(xa) - kv_cache[k_id] = k - kv_cache[v_id] = v - else: - # Already recorded, get - k = kv_cache[k_id] - v = kv_cache[v_id] - else: - # Decoder self att, append each time - if kv_cache.get(k_id) is None: - # Not recorded yet, calculate and store - k = self.key(x) - v = self.value(x) - kv_cache[k_id] = k - kv_cache[v_id] = v - else: - # Already recorded, get and append - k = torch.cat((kv_cache[k_id], self.key(x)), dim=1).detach() - v = torch.cat( - (kv_cache[v_id], self.value(x)), dim=1 - ).detach() - kv_cache[k_id] = k - kv_cache[v_id] = v - - # When using kv_cache for decoder self attention, we don't - # want to use a mask in the self attention calculation - mask = None + # Self att in encoder/decoder + k = self.key(x) + v = self.value(x) # Reshape and transpose for attention calculation batch_size, target_seq_len, _ = q.shape @@ -157,7 +117,9 @@ def __init__( n_mlp = n_state * 4 self.mlp = nn.Sequential( - nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state) + nn.Linear(n_state, n_mlp, bias=False), + nn.GELU(), + nn.Linear(n_mlp, n_state, bias=False), ) self.mlp_ln = nn.LayerNorm(n_state) @@ -166,16 +128,10 @@ def forward( x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, - kv_cache: Optional[dict] = None, ): - x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] + x = x + self.attn(self.attn_ln(x), mask=mask)[0] if self.cross_attn: - x = ( - x - + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[ - 0 - ] - ) + x = x + self.cross_attn(self.cross_attn_ln(x), xa)[0] x = x + self.mlp(self.mlp_ln(x)) return x @@ -236,22 +192,18 @@ def __init__( mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) self.register_buffer("mask", mask, persistent=False) - def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): + def forward(self, x: Tensor, xa: Tensor): """ x : torch.LongTensor, shape = (batch_size, <= n_ctx) the text tokens xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) the encoded audio features to be attended on """ - offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 - x = ( - self.token_embedding(x) - + self.positional_embedding[offset : offset + x.shape[-1]] - ) + x = self.token_embedding(x) + self.positional_embedding[: x.shape[-1]] x = x.to(xa.dtype) for block in self.blocks: - x = block(x, xa, mask=self.mask, kv_cache=kv_cache) + x = block(x, xa, mask=self.mask) x = self.ln(x) logits = ( diff --git a/amt/run.py b/amt/run.py index 7e29ae0..1b2bcc5 100644 --- a/amt/run.py +++ b/amt/run.py @@ -38,6 +38,12 @@ def _add_transcribe_args(subparser): subparser.add_argument( "-multi_gpu", help="use all GPUs", action="store_true", default=False ) + subparser.add_argument( + "-q8", + help="apply int8 quantization on weights", + action="store_true", + default=False, + ) subparser.add_argument("-bs", help="batch size", type=int, default=16) @@ -88,19 +94,19 @@ def build_maestro( print(f"Building {train_file}") AmtDataset.build( - matched_load_paths=matched_paths_train, + load_paths=matched_paths_train, save_path=train_file, num_processes=num_procs, ) print(f"Building {val_file}") AmtDataset.build( - matched_load_paths=matched_paths_val, + load_paths=matched_paths_val, save_path=val_file, num_processes=num_procs, ) print(f"Building {test_file}") AmtDataset.build( - matched_load_paths=matched_paths_test, + load_paths=matched_paths_test, save_path=test_file, num_processes=num_procs, ) @@ -114,7 +120,6 @@ def transcribe( load_dir=None, batch_size=16, multi_gpu=False, - augment=None, ): """ Transcribe audio files to midi using the given model and checkpoint. @@ -138,13 +143,11 @@ def transcribe( augment : str Augment the audio files before transcribing. This is used for evaluation. This tests the robustness of the model. """ - import torch from torch.cuda import is_available as cuda_is_available - from torch.multiprocessing import Queue from amt.tokenizer import AmtTokenizer - from amt.infer import batch_transcribe + from amt.inference.transcribe import batch_transcribe from amt.config import load_model_config - from amt.model import ModelConfig, AmtEncoderDecoder + from amt.inference.model import ModelConfig, AmtEncoderDecoder from aria.utils import _load_weight assert cuda_is_available(), "CUDA device not found" @@ -176,7 +179,6 @@ def transcribe( _model_state[k] = v model_state = _model_state model.load_state_dict(model_state) - torch.multiprocessing.set_start_method("spawn") if trans_mode == "batch": found_wav = glob.glob( @@ -196,31 +198,14 @@ def transcribe( int(id) for id in os.getenv("CUDA_VISIBLE_DEVICES").split(",") ] print(f"Visible gpu_ids: {gpu_ids}") - - # Use shared file queue between gpu processes - file_queue = torch.multiprocessing.Queue() - for file_path in file_paths: - file_queue.put(file_path) - - processes = [] - for gpu_id in gpu_ids: - print(f"Starting process on cuda-{gpu_id}") - process = torch.multiprocessing.Process( - target=batch_transcribe, - args=( - file_queue, - model, - save_dir, - batch_size, - gpu_id, - load_dir, - ), - ) - process.start() - processes.append(process) - - for process in processes: - process.join() + batch_transcribe( + file_paths=file_paths, + model=model, + save_dir=save_dir, + batch_size=batch_size, + input_dir=load_dir, + gpu_ids=gpu_ids, + ) else: batch_transcribe( diff --git a/amt/tokenizer.py b/amt/tokenizer.py index d5416a7..c368673 100644 --- a/amt/tokenizer.py +++ b/amt/tokenizer.py @@ -11,10 +11,6 @@ from amt.config import load_config -# Instead of doing this, we could calculate beams at inference time, selecting -# the note with the first onset so that we don't miss notes. - - DEBUG = os.getenv("DEBUG") @@ -63,6 +59,12 @@ def __init__(self, return_tensors: bool = False): ) self.pad_id = self.tok_to_id[self.pad_tok] + def _get_inference_ids(self): + return [ + self.tok_to_id[tok] + for tok in self.velocity_tokens + self.onset_tokens + ] + def _quantize_onset(self, time: int): # This function will return values res >= 0 (inc. 0) return self._find_closest_int(time, self.onset_time_quantizations) @@ -86,13 +88,16 @@ def _tokenize_midi_dict( midi_dict: MidiDict, start_ms: int, end_ms: int, + max_pedal_len_ms: int | None = None, ): assert ( end_ms - start_ms <= self.max_onset ), "Invalid values for start_ms, end_ms" - midi_dict.resolve_pedal() # Important !! + if midi_dict.pedal_resolved is False: + midi_dict.resolve_pedal() # Important !! pedal_intervals = midi_dict._build_pedal_intervals() + if len(pedal_intervals.keys()) > 1: print("Warning: midi_dict has more than one pedal channel") if len(midi_dict.instrument_msgs) > 1: @@ -179,6 +184,9 @@ def _tokenize_midi_dict( ticks_per_beat=midi_dict.ticks_per_beat, ) + if max_pedal_len_ms is not None: + pedal_off_ms = min(pedal_off_ms, pedal_on_ms + max_pedal_len_ms) + rel_on_ms_q = self._quantize_onset(pedal_on_ms - start_ms) rel_off_ms_q = self._quantize_onset(pedal_off_ms - start_ms) @@ -307,8 +315,7 @@ def _detokenize_midi_dict( if tok_1_type == "prev": notes_to_close[tok_1_data] = (0, self.default_velocity) print("Unexpected token order: 'prev' seen after ''") - if DEBUG: - raise Exception + raise ValueError elif tok_1_type == "pedal": _pedal_data = tok_1_data _tick = tok_2_data @@ -323,8 +330,7 @@ def _detokenize_midi_dict( elif tok_1_type == "on": if (tok_2_type, tok_3_type) != ("onset", "vel"): print("Unexpected token order:", tok_1, tok_2, tok_3) - if DEBUG: - raise Exception + raise ValueError else: notes_to_close[tok_1_data] = (tok_2_data, tok_3_data) elif tok_1_type == "off": diff --git a/amt/train.py b/amt/train.py index 10a3952..eee1b8e 100644 --- a/amt/train.py +++ b/amt/train.py @@ -266,7 +266,6 @@ def rolling_average(prev_avg: float, x_n: float, n: int): return ((prev_avg * (n - 1)) / n) + (x_n / n) -# TODO: Test that loss/backprop is working correctly (look at shapes) def _train( epochs: int, accelerator: accelerate.Accelerator, @@ -281,34 +280,6 @@ def _train( resume_epoch: int | None = None, project_dir: str | None = None, ): - def profile_flops(dataloader: DataLoader): - def _bench(): - for batch in dataloader: - wav, src, tgt, pitch_shift = batch - with torch.no_grad(): - mel = audio_transform.forward(wav, shift=pitch_shift) - logits = model(mel, src) # (b_sz, s_len, v_sz) - logits = logits.transpose(1, 2) - loss = loss_fn(logits, tgt) - - # Backwards step - omit optimizer.step() - accelerator.backward(loss) - optimizer.zero_grad() - break - - logger.info( - f"Model has " - f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} " - "parameters" - ) - logger.info("Compiling model...") - _bench() - - # with flop_counter: - # _bench() - # total_flop = sum(flop_counter.get_flop_counts()["Global"].values()) - # logger.info(f"Forwards & backwards FLOP: {total_flop / 1e12} TF") - def make_checkpoint(_accelerator, _epoch: int, _step: int): checkpoint_dir = os.path.join( project_dir, @@ -327,7 +298,6 @@ def train_loop( dataloader: DataLoader, _epoch: int, _resume_step: int = 0, - overfit: bool = False, ): avg_train_loss = 0 trailing_loss = 0 @@ -409,7 +379,7 @@ def train_loop( return avg_train_loss - def val_loop(dataloader, _epoch: int): + def val_loop(dataloader, _epoch: int, aug: bool): avg_val_loss = 0 model.eval() for step, batch in ( @@ -421,10 +391,21 @@ def val_loop(dataloader, _epoch: int): ): wav, src, tgt = batch with torch.no_grad(): - mel = audio_transform.log_mel(wav) + if aug == False: + mel = audio_transform.log_mel(wav) + elif aug == True: + # Apply aug without distortion or spec-augment + mel = audio_transform.log_mel( + audio_transform.aug_wav(wav), detune=True + ) + else: + raise TypeError + logits = model(mel, src) - logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss - loss = loss_fn(logits, tgt) + logits = logits.transpose( + 1, 2 + ) # Transpose for CrossEntropyLoss + loss = loss_fn(logits, tgt) # Logging avg_val_loss = rolling_average(avg_val_loss, loss.item(), step) @@ -432,7 +413,8 @@ def val_loop(dataloader, _epoch: int): # EPOCH logger.info( - f"EPOCH {_epoch}/{epochs + start_epoch}: Finished evaluation - " + f"EPOCH {_epoch}/{epochs + start_epoch}: Finished evaluation " + f"{'(aug)' if aug is True else ''} - " f"average_loss={round(avg_val_loss, 4)}" ) @@ -447,7 +429,11 @@ def val_loop(dataloader, _epoch: int): PAD_ID = train_dataloader.dataset.tokenizer.pad_id logger = get_logger(__name__) # Accelerate logger loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID) - profile_flops(dataloader=train_dataloader) + logger.info( + f"Model has " + f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} " + "parameters" + ) if accelerator.is_main_process: loss_csv = open(os.path.join(project_dir, "loss.csv"), "w") @@ -455,7 +441,9 @@ def val_loop(dataloader, _epoch: int): loss_writer.writerow(["epoch", "step", "loss"]) epoch_csv = open(os.path.join(project_dir, "epoch.csv"), "w") epoch_writer = csv.writer(epoch_csv) - epoch_writer.writerow(["epoch", "avg_train_loss", "avg_val_loss"]) + epoch_writer.writerow( + ["epoch", "avg_train_loss", "avg_val_loss", "avg_val_loss_aug"] + ) if resume_epoch is not None: start_epoch = resume_epoch + 1 @@ -477,9 +465,16 @@ def val_loop(dataloader, _epoch: int): _epoch=resume_epoch, _resume_step=resume_step, ) - avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=resume_epoch) + avg_val_loss = val_loop( + dataloader=val_dataloader, _epoch=resume_epoch, aug=False + ) + avg_val_loss_aug = val_loop( + dataloader=val_dataloader, _epoch=resume_epoch, aug=True + ) if accelerator.is_main_process: - epoch_writer.writerow([resume_epoch, avg_train_loss, avg_val_loss]) + epoch_writer.writerow( + [resume_epoch, avg_train_loss, avg_val_loss, avg_val_loss_aug] + ) epoch_csv.flush() make_checkpoint( _accelerator=accelerator, _epoch=start_epoch, _step=0 @@ -487,9 +482,16 @@ def val_loop(dataloader, _epoch: int): for epoch in range(start_epoch, epochs + start_epoch): avg_train_loss = train_loop(dataloader=train_dataloader, _epoch=epoch) - avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=epoch) + avg_val_loss = val_loop( + dataloader=val_dataloader, _epoch=epoch, aug=False + ) + avg_val_loss_aug = val_loop( + dataloader=val_dataloader, _epoch=epoch, aug=True + ) if accelerator.is_main_process: - epoch_writer.writerow([epoch, avg_train_loss, avg_val_loss]) + epoch_writer.writerow( + [epoch, avg_train_loss, avg_val_loss, avg_val_loss_aug] + ) epoch_csv.flush() make_checkpoint(_accelerator=accelerator, _epoch=epoch + 1, _step=0) @@ -565,6 +567,7 @@ def resume_train( model = torch.compile(model) audio_transform = AudioTransform().to(accelerator.device) logger.info(f"Loaded model with config: {load_model_config(model_name)}") + logger.info(f"Loaded transform with config: {audio_transform.get_params()}") train_dataloader, val_dataloader = get_dataloaders( train_data_path=train_data_path, @@ -682,6 +685,7 @@ def train( model = torch.compile(model) audio_transform = AudioTransform().to(accelerator.device) logger.info(f"Loaded model with config: {load_model_config(model_name)}") + logger.info(f"Loaded transform with config: {audio_transform.get_params()}") if mode == "finetune": try: model.load_state_dict(_load_weight(finetune_cp_path)) diff --git a/config/config.json b/config/config.json index 2fa9fd4..9da2e4e 100644 --- a/config/config.json +++ b/config/config.json @@ -17,7 +17,7 @@ "n_mels": 256 }, "data": { - "stride_factor": 6, + "stride_factor": 12, "max_seq_len": 4096 } } \ No newline at end of file diff --git a/config/models/small.json b/config/models/small.json deleted file mode 100644 index 1c87733..0000000 --- a/config/models/small.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "n_mels": 256, - "n_audio_ctx": 1500, - "n_audio_state": 384, - "n_audio_head": 6, - "n_audio_layer": 8, - "n_text_ctx": 4096, - "n_text_state": 384, - "n_text_head": 6, - "n_text_layer": 8 -} \ No newline at end of file diff --git a/config/models/test.json b/config/models/test.json deleted file mode 100644 index 93c0f16..0000000 --- a/config/models/test.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "n_mels": 256, - "n_audio_ctx": 1500, - "n_audio_state": 64, - "n_audio_head": 4, - "n_audio_layer": 4, - "n_text_ctx": 4096, - "n_text_state": 64, - "n_text_head": 4, - "n_text_layer": 4 -} \ No newline at end of file diff --git a/tests/test_data.py b/tests/test_data.py index 1437472..f69117e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,7 +1,8 @@ import unittest import logging import os -import time +import cProfile +import pstats import torch import torchaudio import matplotlib.pyplot as plt @@ -9,6 +10,7 @@ from amt.data import get_wav_mid_segments, AmtDataset from amt.tokenizer import AmtTokenizer from amt.audio import AudioTransform, log_mel_spectrogram +from amt.train import get_dataloaders from aria.data.midi import MidiDict @@ -16,7 +18,17 @@ if os.path.isdir("tests/test_results") is False: os.mkdir("tests/test_results") -MAESTRO_PATH = "/weka/proj-aria/aria-amt/data/maestro/val.jsonl" +MAESTRO_PATH = "/weka/proj-aria/aria-amt/data/train.jsonl" + + +def plot_spec(mel: torch.Tensor, name: str | int): + plt.figure(figsize=(10, 4)) + plt.imshow(mel, aspect="auto", origin="lower", cmap="viridis") + plt.colorbar(format="%+2.0f dB") + plt.title("(mel)-Spectrogram") + plt.tight_layout() + plt.savefig(f"tests/test_results/{name}.png") + plt.close() # Need to test this properly, have issues turning mel_spec back into audio @@ -32,14 +44,14 @@ def test_wav_mid_segments(self): class TestAmtDataset(unittest.TestCase): def test_build(self): matched_paths = [ - ("tests/test_data/147.wav", "tests/test_data/147.mid") + ("tests/test_data/maestro.wav", "tests/test_data/maestro1.mid") for _ in range(3) ] if os.path.isfile("tests/test_results/dataset.jsonl"): os.remove("tests/test_results/dataset.jsonl") AmtDataset.build( - matched_load_paths=matched_paths, + load_paths=matched_paths, save_path="tests/test_results/dataset.jsonl", ) @@ -61,6 +73,7 @@ def test_maestro(self): return tokenizer = AmtTokenizer() + audio_transform = AudioTransform() dataset = AmtDataset(load_path=MAESTRO_PATH) print(f"Dataset length: {len(dataset)}") for idx, (wav, src, tgt) in enumerate(dataset): @@ -74,8 +87,13 @@ def test_maestro(self): ) src_mid = src_mid_dict.to_midi() - if idx % 10 == 0: - src_mid.save(f"tests/test_results/dataset_{idx}.mid") + src_mid.save(f"tests/test_results/dataset_{idx}.mid") + torchaudio.save( + f"tests/test_results/wav_{idx}.wav", wav.unsqueeze(0), 16000 + ) + plot_spec( + audio_transform(wav.unsqueeze(0)).squeeze(0), f"mel_{idx}" + ) self.assertTrue(tokenizer.unk_tok not in src_dec) self.assertTrue(tokenizer.unk_tok not in tgt_dec) @@ -84,15 +102,6 @@ def test_maestro(self): class TestAug(unittest.TestCase): - def plot_spec(self, mel: torch.Tensor, name: str | int): - plt.figure(figsize=(10, 4)) - plt.imshow(mel, aspect="auto", origin="lower", cmap="viridis") - plt.colorbar(format="%+2.0f dB") - plt.title("(mel)-Spectrogram") - plt.tight_layout() - plt.savefig(f"tests/test_results/{name}.png") - plt.close() - def test_spec(self): SAMPLE_RATE, CHUNK_LEN = 16000, 30 audio_transform = AudioTransform() @@ -115,11 +124,11 @@ def test_spec(self): torchaudio.save("tests/test_results/shift.wav", shift_wav, SAMPLE_RATE) log_mel = log_mel_spectrogram(wav) - self.plot_spec(log_mel.squeeze(0), "orig") + plot_spec(log_mel.squeeze(0), "orig") _mel = audio_transform.mel_transform(spec) _log_mel = audio_transform.norm_mel(_mel) - self.plot_spec(_log_mel.squeeze(0), "new") + plot_spec(_log_mel.squeeze(0), "new") def test_pitch_aug(self): tokenizer = AmtTokenizer(return_tensors=True) @@ -147,6 +156,36 @@ def test_pitch_aug(self): for src_tok, tgt_tok in zip(src_aug_dec[1:], tgt_aug_dec): self.assertEqual(src_tok, tgt_tok) + def test_detune(self): + SAMPLE_RATE, CHUNK_LEN = 16000, 30 + audio_transform = AudioTransform() + wav, sr = torchaudio.load("tests/test_data/maestro.wav") + wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE).mean( + 0, keepdim=True + )[:, : SAMPLE_RATE * CHUNK_LEN] + + griffin_lim = torchaudio.transforms.GriffinLim( + n_fft=2048, + hop_length=160, + power=1, + n_iter=64, + ) + + spec = audio_transform.spec_transform(wav) + shift_spec = audio_transform.detune_spec(spec) + shift_wav = griffin_lim(shift_spec) + gl_wav = griffin_lim(spec) + torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE) + torchaudio.save("tests/test_results/orig_gl.wav", gl_wav, SAMPLE_RATE) + torchaudio.save("tests/test_results/detune.wav", shift_wav, SAMPLE_RATE) + + log_mel = log_mel_spectrogram(wav) + plot_spec(log_mel.squeeze(0), "orig") + + _mel = audio_transform.mel_transform(spec) + _log_mel = audio_transform.norm_mel(_mel) + plot_spec(_log_mel.squeeze(0), "new") + def test_mels(self): SAMPLE_RATE, CHUNK_LEN = 16000, 30 audio_transform = AudioTransform() @@ -163,7 +202,7 @@ def test_mels(self): wavs = torch.stack((wav[0], wav[0], wav[0])) mels = audio_transform(wavs) for idx in range(mels.shape[0]): - self.plot_spec(mels[idx], idx) + plot_spec(mels[idx], idx) def test_distortion(self): SAMPLE_RATE, CHUNK_LEN = 16000, 30 @@ -226,5 +265,27 @@ def test_noise(self): torchaudio.save("tests/test_results/noise.wav", res, SAMPLE_RATE) +class TestDataLoader(unittest.TestCase): + def load_data(self, dataloader, num_batches=100): + for idx, data in enumerate(dataloader): + if idx >= num_batches: + break + + def test_profile_dl(self): + train_dataloader, val_dataloader = get_dataloaders( + train_data_path="/weka/proj-aria/aria-amt/data/train.jsonl", + val_data_path="/weka/proj-aria/aria-amt/data/train.jsonl", + batch_size=16, + num_workers=0, + ) + + profiler = cProfile.Profile() + profiler.enable() + self.load_data(train_dataloader, num_batches=10) + profiler.disable() + stats = pstats.Stats(profiler).sort_stats("cumulative") + stats.print_stats() + + if __name__ == "__main__": unittest.main()