From 7fc0908b929dcb7f76c8f677e869aca7dc724eb7 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 14 Dec 2023 17:42:31 +0000 Subject: [PATCH] Update tokenizers (#80) * not working * fix dict_to_midi * aug not working * add mp spawn warnings * abs soundfont path * add data aug and tests * update .gitignore * update make format * update dataset and tokenizers to use rel/abs * fix sampling cli * upgrade train.py to use both tokenizers * update req * fix acc_convert * fix entrypoint * add comment * fix test * small fixes * rmv compile * update profile flops * typo --------- Co-authored-by: Louis --- .gitignore | 6 +- Makefile | 1 + aria/data/datasets.py | 87 ++- aria/data/midi.py | 10 +- aria/run.py | 55 +- aria/sample.py | 2 +- aria/tokenizer/__init__.py | 2 +- aria/tokenizer/tokenizer.py | 958 ++++++++++++++++++++++++----- aria/train.py | 133 ++-- aria/utils.py | 6 +- config/config.json | 50 +- requirements.txt | 1 + scripts/midi_to_audio.py | 13 + tests/reference_implementations.py | 2 +- tests/test_data.py | 143 +++-- tests/test_models.py | 24 +- tests/test_tokenizers.py | 283 ++++++++- tests/test_training.py | 13 +- 18 files changed, 1448 insertions(+), 341 deletions(-) create mode 100644 scripts/midi_to_audio.py diff --git a/.gitignore b/.gitignore index 63ef896..dbc6464 100644 --- a/.gitignore +++ b/.gitignore @@ -161,7 +161,9 @@ cython_debug/ # Project specific tools/ -data/ +./data/ +fluidsynth/ +*.DS_Store tests/test_results lightning_logs/ -.vscode/ \ No newline at end of file +.vscode/ diff --git a/Makefile b/Makefile index a8a98fe..774c3fc 100644 --- a/Makefile +++ b/Makefile @@ -6,3 +6,4 @@ test: .PHONY: format format: black --line-length 80 ./aria + black --line-length 80 ./tests diff --git a/aria/data/datasets.py b/aria/data/datasets.py index 0d5a649..508aca0 100644 --- a/aria/data/datasets.py +++ b/aria/data/datasets.py @@ -14,10 +14,10 @@ from pathlib import Path from typing import Callable, Iterable from collections import defaultdict -from multiprocessing import Pool, Process, Queue +from multiprocessing import Pool, Process, Queue, get_start_method from aria.config import load_config -from aria.tokenizer import Tokenizer, TokenizerLazy +from aria.tokenizer import Tokenizer from aria.data.midi import MidiDict, get_test_fn @@ -277,6 +277,11 @@ def _get_mididicts_mp(_paths): yield mid_dict logger = setup_logger() + if get_start_method() == "spawn": + logger.warning( + 'The current multiprocessing start method is "spawn", this ' + "will slow down dataset building" + ) paths = [] if recur is True: @@ -332,6 +337,27 @@ def init_epoch(self, epoch_num: int | None = None): def build(**kwargs): raise NotImplementedError + @classmethod + def get_config_from_path(cls, path: str): + """Returns config dict from dataset file/directory. + + If a directory provided, it is assumed t""" + + def _get_config_from_fp(_path): + # Finetuning Dataset + return FinetuningDataset.get_config_from_path(path=_path) + + def _get_config_from_dir(_path): + # Pretraining Dataset + return PretrainingDataset.get_config_from_path(path=_path) + + if os.path.isfile(path): + return _get_config_from_fp(path) + elif os.path.isdir(path): + return _get_config_from_dir(path) + else: + raise FileNotFoundError("Invalid path provided") + def close(self): if self.file_buff: self.file_buff.close() @@ -362,6 +388,7 @@ def _format(tok): src = seq tgt = seq[1:] + [self.tokenizer.pad_tok] + # Fine till here return self.tokenizer.encode(src), self.tokenizer.encode(tgt) def check_config(self): @@ -475,10 +502,6 @@ def get_seqs( tokenizer: Tokenizer, midi_dict_iter: Iterable, ): - # TokenizerLazy is the only supported tokenizer due to the truncate - # and stride logic in _get_tokenized_seqs - assert isinstance(tokenizer, TokenizerLazy), "Unsupported tokenizer" - iq = Queue() oq = Queue() @@ -520,6 +543,19 @@ def __init__(self, dir_path: str, tokenizer: Tokenizer): def __len__(self): return len(self.index) + @classmethod + def get_config_from_path(cls, path: str): + """Returns config dict from dataset directory. + + Note that this will return the config corresponding to epoch0.jsonl. + """ + assert os.path.isdir(path), "directory not found" + assert os.path.isfile( + epoch0_path := os.path.join(path, "epoch0.jsonl") + ), "epoch file not found" + with open(epoch0_path) as f: + return json.loads(f.readline()) + def init_epoch(self, idx: int | None = None): if idx is None: idx = self.curr_epoch + 1 @@ -551,7 +587,6 @@ def _get_epoch_files(self): os.path.join(self.dir_path, file_name) for file_name in file_names ] - # Check correct formatting present_epochs = [] for file_name in file_names: if not re.match(r"^epoch\d+\.jsonl$", file_name): @@ -606,6 +641,7 @@ def _build_epoch(_save_path, _midi_dataset): ) buffer = [] + # TODO: Profile why mp takes a while to spit up for entry in get_seqs(tokenizer, _midi_dataset): if entry is not None: buffer += entry @@ -617,6 +653,11 @@ def _build_epoch(_save_path, _midi_dataset): logger = setup_logger() assert max_seq_len > 0, "max_seq_len must be greater than 0" assert num_epochs > 0, "num_epochs must be greater than 0" + if get_start_method() == "spawn": + logger.warning( + 'The current multiprocessing start method is "spawn", this ' + "will slow down dataset building" + ) if os.path.isdir(save_dir) and os.listdir(save_dir): print( @@ -632,6 +673,7 @@ def _build_epoch(_save_path, _midi_dataset): if not os.path.exists(save_dir): os.mkdir(save_dir) + # TODO: This is very slow right now if not midi_dataset: midi_dataset = MidiDataset.load(midi_dataset_path) else: @@ -639,7 +681,7 @@ def _build_epoch(_save_path, _midi_dataset): logger.info( f"Building PretrainingDataset with config: " - f"max_seq_len={max_seq_len} " + f"max_seq_len={max_seq_len}, " f"tokenizer_name={tokenizer.name}" ) _num_proc = os.cpu_count() @@ -647,7 +689,7 @@ def _build_epoch(_save_path, _midi_dataset): logger.warning( "Number of processes is close to the number of MidiDicts " "in the dataset. This can result in shuffling not working " - "as intended when building different epochs." + "as intended when building different epochs" ) for idx in range(num_epochs): logger.info(f"Building epoch {idx}/{num_epochs - 1}...") @@ -655,6 +697,7 @@ def _build_epoch(_save_path, _midi_dataset): _save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"), _midi_dataset=midi_dataset, ) + # TODO: This is very slow for large datasets midi_dataset.shuffle() logger.info( @@ -679,6 +722,13 @@ def __init__(self, file_path: str, tokenizer: Tokenizer): def __len__(self): return len(self.index) + @classmethod + def get_config_from_path(cls, path: str): + """Returns config dict from dataset file""" + assert os.path.isfile(path), "dataset file not found" + with open(path) as f: + return json.loads(f.readline()) + # Do nothing in this case def init_epoch(self, idx: int | None = None): self.logger.info(f"Successful initiated epoch {idx}") @@ -693,8 +743,9 @@ def build( midi_dataset: MidiDataset = None, midi_dataset_path: str = None, ): - """Builds and returns PretrainingDataset.""" + """Builds and returns FinetuningDataset.""" + # This function should be made more robust in the future def _truncate_and_stride(_tokenized_seq: list): prefix = [] @@ -720,13 +771,12 @@ def _truncate_and_stride(_tokenized_seq: list): # Checks that next start note will not be cutoff midway while idx < seq_len: - # Break loop when a non 'wait' or 'dur' is seen if _tokenized_seq[idx] in tokenizer.special_tokens: break - elif _tokenized_seq[idx][0] in {"wait", "dur"}: - idx += 1 - else: + elif _tokenized_seq[idx][0] in tokenizer.instruments_wd: break + else: + idx += 1 # Add the last sequence _seq = prefix + _tokenized_seq[idx : idx + max_seq_len - prefix_len] @@ -748,8 +798,8 @@ def _build(_midi_dataset): ) logger.info( f"Building FinetuningDataset with config: " - f"tokenizer_name=tokenizer.name" - f"max_seq_len={max_seq_len} " + f"tokenizer_name={tokenizer.name}, " + f"max_seq_len={max_seq_len}, " f"stride_len={stride_len}" ) @@ -760,6 +810,11 @@ def _build(_midi_dataset): logger = setup_logger() assert max_seq_len > 0, "max_seq_len must be greater than 0" + if get_start_method() == "spawn": + logger.warning( + 'The current multiprocessing start method is "spawn", this ' + "will slow down dataset building" + ) if os.path.isfile(save_path): print( diff --git a/aria/data/midi.py b/aria/data/midi.py index 8720a56..863da6d 100644 --- a/aria/data/midi.py +++ b/aria/data/midi.py @@ -393,6 +393,14 @@ def dict_to_midi(mid_data: dict): Returns: mido.MidiFile: The MIDI parsed from the input data. """ + + # Magic sorting function + def _sort_fn(msg): + if hasattr(msg, "velocity"): + return (msg.time, msg.velocity) + else: + return (msg.time, 1000) + assert mid_data.keys() == { "meta_msgs", "tempo_msgs", @@ -475,7 +483,7 @@ def dict_to_midi(mid_data: dict): ) # Sort and convert from abs_time -> delta_time - track = sorted(track, key=lambda msg: msg.time) + track = sorted(track, key=_sort_fn) tick = 0 for msg in track: msg.time -= tick diff --git a/aria/run.py b/aria/run.py index 8bcb857..a19f19b 100644 --- a/aria/run.py +++ b/aria/run.py @@ -8,8 +8,15 @@ import warnings +# TODO: Implement a way of inferring the tokenizer name automatically def _parse_sample_args(): argp = argparse.ArgumentParser(prog="aria sample") + argp.add_argument( + "-tok", + help="name of tokenizer", + choices=["abs", "rel"], + required=True, + ) argp.add_argument("-m", help="name of model config file") argp.add_argument("-c", help="path to model checkpoint") argp.add_argument("-p", help="path to midi file") @@ -89,6 +96,8 @@ def _get_midi_path(midi_path: str | None) -> str: return midi_path +# TODO: Add arg for supressing the audio conversion, and commands for changing +# the sampling params from the cli def sample(args): """Entrypoint for sampling""" @@ -96,7 +105,7 @@ def sample(args): from torch.cuda import is_available as cuda_is_available from aria.model import TransformerLM, ModelConfig from aria.config import load_model_config - from aria.tokenizer import TokenizerLazy + from aria.tokenizer import RelTokenizer, AbsTokenizer from aria.sample import greedy_sample from aria.data.midi import MidiDict from aria.utils import midi_to_audio @@ -121,11 +130,21 @@ def sample(args): truncate_len = args.trunc force_end = args.e - tokenizer = TokenizerLazy(return_tensors=True) + if args.tok == "abs": + tokenizer = AbsTokenizer(return_tensors=True) + elif args.tok == "rel": + tokenizer = RelTokenizer(return_tensors=True) + model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) model = TransformerLM(model_config).to(device) - model.load_state_dict(model_state) + try: + model.load_state_dict(model_state) + except: + print( + "Failed to load state_dict, this could be because the wrong " + "tokenizer was selected" + ) if args.q: if device.type != "cpu": warnings.warn( @@ -244,6 +263,9 @@ def _parse_pretrain_dataset_args(): argp = argparse.ArgumentParser(prog="aria pretrain-dataset") argp.add_argument("load_path", help="path midi_dict dataset") argp.add_argument("save_dir", help="path to save dataset") + argp.add_argument( + "tokenizer_name", help="tokenizer name", choices=["abs", "rel"] + ) argp.add_argument("-l", help="max sequence length", type=int, default=2048) argp.add_argument("-e", help="num epochs", type=int, default=1) @@ -251,10 +273,14 @@ def _parse_pretrain_dataset_args(): def build_pretraining_dataset(args): - from aria.tokenizer import TokenizerLazy + from aria.tokenizer import AbsTokenizer, RelTokenizer from aria.data.datasets import PretrainingDataset - tokenizer = TokenizerLazy() + if args.tokenizer_name == "abs": + tokenizer = AbsTokenizer() + elif args.tokenizer_name == "rel": + tokenizer = RelTokenizer() + dataset = PretrainingDataset.build( tokenizer=tokenizer, save_dir=args.save_dir, @@ -268,18 +294,24 @@ def _parse_finetune_dataset_args(): argp = argparse.ArgumentParser(prog="aria finetune-dataset") argp.add_argument("load_path", help="path midi_dict dataset") argp.add_argument("save_path", help="path to save dataset") + argp.add_argument( + "tokenizer_name", help="tokenizer name", choices=["abs", "rel"] + ) argp.add_argument("-l", help="max sequence length", type=int, default=2048) argp.add_argument("-s", help="stride length", type=int, default=512) return argp.parse_args(sys.argv[2:]) -# This might not be correct - double check def build_finetune_dataset(args): - from aria.tokenizer import TokenizerLazy + from aria.tokenizer import AbsTokenizer, RelTokenizer from aria.data.datasets import FinetuningDataset - tokenizer = TokenizerLazy() + if args.tokenizer_name == "abs": + tokenizer = AbsTokenizer() + elif args.tokenizer_name == "rel": + tokenizer = RelTokenizer() + dataset = FinetuningDataset.build( tokenizer=tokenizer, save_path=args.save_path, @@ -295,7 +327,12 @@ def main(): parser.add_argument( "command", help="command to run", - choices=("sample", "midi-dataset", "pretrain-dataset"), + choices=( + "sample", + "midi-dataset", + "pretrain-dataset", + "finetune-dataset", + ), ) # parse_args defaults to [1:] for args, but you need to diff --git a/aria/sample.py b/aria/sample.py index 26a70cd..e1fb713 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -1,6 +1,6 @@ """Contains generation/sampling code""" # This file contains code from https://github.com/facebookresearch/llama which -# is available under the following licence: +# is available under the following license: # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the GNU diff --git a/aria/tokenizer/__init__.py b/aria/tokenizer/__init__.py index 759f215..e2078a6 100644 --- a/aria/tokenizer/__init__.py +++ b/aria/tokenizer/__init__.py @@ -1 +1 @@ -from .tokenizer import Tokenizer, TokenizerLazy +from .tokenizer import Tokenizer, RelTokenizer, AbsTokenizer diff --git a/aria/tokenizer/tokenizer.py b/aria/tokenizer/tokenizer.py index 6f56fdb..990b264 100644 --- a/aria/tokenizer/tokenizer.py +++ b/aria/tokenizer/tokenizer.py @@ -8,7 +8,6 @@ from collections import defaultdict from typing import Callable -from mido.midifiles.units import second2tick from aria.data.midi import MidiDict from aria.config import load_config @@ -46,6 +45,8 @@ def __init__( # These must be implemented in child class (abstract params) self.vocab = () + self.instruments_wd = [] + self.instruments_nd = [] self.config = {} self.tok_to_id = {} self.id_to_tok = {} @@ -57,6 +58,7 @@ def tokenize_midi_dict(self, midi_dict: MidiDict): tokens.""" raise NotImplementedError + # REMEMBER TO USE THIS API IN THE TRAIN SCRIPT def tokenize(self, midi_dict: MidiDict, **kwargs): """Tokenizes a MidiDict object. @@ -79,62 +81,807 @@ def detokenize(self, tokenized_seq: list): """ return self.detokenize_midi_dict(tokenized_seq) + def export_data_aug(cls): + """Abstract method for exporting a list of all data augmentation + functions. + + This function is used when setting data transformation functions in + TrainingDatase, e.g. + + PretrainingDataset.set_transform(Tokenizer.export_data_aug()) + """ + raise NotImplementedError + def encode(self, unencoded_seq: list): """Converts tokenized sequence into a list/torch.Tensor of ids.""" - def _enc_fn(tok): - return self.tok_to_id.get(tok, self.tok_to_id[self.unk_tok]) + def _enc_fn(tok): + return self.tok_to_id.get(tok, self.tok_to_id[self.unk_tok]) + + if self.tok_to_id is None: + raise NotImplementedError("tok_to_id") + + if self.return_tensors is True: + encoded_seq = torch.tensor([_enc_fn(tok) for tok in unencoded_seq]) + else: + encoded_seq = [_enc_fn(tok) for tok in unencoded_seq] + + return encoded_seq + + def decode(self, encoded_seq: list | torch.Tensor): + """Converts sequence of ids into the corresponding list of tokens.""" + + def _dec_fn(id): + return self.id_to_tok.get(id, self.unk_tok) + + if self.id_to_tok is None: + raise NotImplementedError("id_to_tok") + + if isinstance(encoded_seq, torch.Tensor): + decoded_seq = [_dec_fn(idx) for idx in encoded_seq.tolist()] + else: + decoded_seq = [_dec_fn(idx) for idx in encoded_seq] + + return decoded_seq + + @classmethod + def _find_closest_int(cls, n: int, sorted_list: list): + # Selects closest integer to n from sorted_list + # Time ~ Log(n) + + left, right = 0, len(sorted_list) - 1 + closest = float("inf") + + while left <= right: + mid = (left + right) // 2 + diff = abs(sorted_list[mid] - n) + + if diff < abs(closest - n): + closest = sorted_list[mid] + + if sorted_list[mid] < n: + left = mid + 1 + else: + right = mid - 1 + + return closest + + def _build_pedal_intervals(self, midi_dict: MidiDict): + """Returns pedal-on intervals for each channel.""" + channel_to_pedal_intervals = defaultdict(list) + pedal_status = {} + + for pedal_msg in midi_dict.pedal_msgs: + tick = pedal_msg["tick"] + channel = pedal_msg["channel"] + data = pedal_msg["data"] + + if data == 1 and pedal_status.get(channel, None) is None: + pedal_status[channel] = tick + elif data == 0 and pedal_status.get(channel, None) is not None: + # Close pedal interval + _start_tick = pedal_status[channel] + _end_tick = tick + channel_to_pedal_intervals[channel].append( + [_start_tick, _end_tick] + ) + del pedal_status[channel] + + # Close all unclosed pedals at end of track + final_tick = midi_dict.note_msgs[-1]["data"]["end"] + for channel, start_tick in pedal_status.items(): + channel_to_pedal_intervals[channel].append([start_tick, final_tick]) + + return channel_to_pedal_intervals + + def add_tokens_to_vocab(self, tokens: list | tuple): + for token in tokens: + assert token not in self.vocab + + self.vocab = self.vocab + tuple(tokens) + self.tok_to_id = {tok: idx for idx, tok in enumerate(self.vocab)} + self.id_to_tok = {v: k for k, v in self.tok_to_id.items()} + self.vocab_size = len(self.vocab) + + def export_aug_fn_concat(self, aug_fn: Callable): + """Exports a function that splits src before augmenting. + + This is useful for augmentation functions that expect pure sequences + instead of concatenated ones (like those given by PretrainedDataset). + """ + + def _aug_fn_concat( + src: list, + _aug_fn: Callable, + pad_tok: str, + eos_tok: str, + **kwargs, + ): + # Split list on '' + initial_seq_len = len(src) + src_sep = [] + prev_idx = 0 + for curr_idx, tok in enumerate(src, start=1): + if tok == eos_tok: + src_sep.append(src[prev_idx:curr_idx]) + prev_idx = curr_idx + + # Last sequence + if prev_idx != curr_idx: + src_sep.append(src[prev_idx:]) + + # Augment + src_sep = [ + _aug_fn( + _src, + **kwargs, + ) + for _src in src_sep + ] + + # Concatenate + src_aug_concat = [tok for src_aug in src_sep for tok in src_aug] + + # Pad or truncate to original sequence length as necessary + src_aug_concat = src_aug_concat[:initial_seq_len] + src_aug_concat += [pad_tok] * ( + initial_seq_len - len(src_aug_concat) + ) + + return src_aug_concat + + return functools.partial( + _aug_fn_concat, + _aug_fn=aug_fn, + pad_tok=self.pad_tok, + eos_tok=self.eos_tok, + ) + + +class AbsTokenizer(Tokenizer): + """MidiDict tokenizer implemented with absolute onset timings""" + + def __init__(self, return_tensors: bool = False): + super().__init__(return_tensors) + self.config = load_config()["tokenizer"]["abs"] + self.name = "abs" + + # Calculate time quantizations (in ms) + self.abs_time_step = self.config["abs_time_step_ms"] + self.max_dur = self.config["max_dur_ms"] + self.time_step = self.config["time_step_ms"] + + self.dur_time_quantizations = [ + self.time_step * i + for i in range((self.max_dur // self.time_step) + 1) + ] + self.onset_time_quantizations = [ + self.time_step * i for i in range((self.max_dur // self.time_step)) + ] + + # Calculate velocity quantizations + self.velocity_step = self.config["velocity_quantization"]["step"] + self.velocity_quantizations = [ + i * self.velocity_step + for i in range(int(127 / self.velocity_step) + 1) + ] + self.max_velocity = self.velocity_quantizations[-1] + + # _nd = no drum; _wd = with drum + self.instruments_nd = [ + k + for k, v in self.config["ignore_instruments"].items() + if v is False + ] + self.instruments_wd = self.instruments_nd + ["drum"] + + # Prefix tokens + self.prefix_tokens = [ + ("prefix", "instrument", x) for x in self.instruments_wd + ] + self.composer_names = self.config["composer_names"] + self.prefix_tokens += [ + ("prefix", "composer", x) for x in self.composer_names + ] + + # Build vocab + self.time_tok = "" + self.onset_tokens = [ + ("onset", i) for i in self.onset_time_quantizations + ] + self.dur_tokens = [("dur", i) for i in self.dur_time_quantizations] + self.drum_tokens = [("drum", i) for i in range(35, 82)] + + self.note_tokens = list( + itertools.product( + self.instruments_nd, + [i for i in range(128)], + self.velocity_quantizations, + ) + ) + + self.special_tokens.append(self.time_tok) + self.add_tokens_to_vocab( + self.special_tokens + + self.prefix_tokens + + self.note_tokens + + self.drum_tokens + + self.dur_tokens + + self.onset_tokens + ) + self.pad_id = self.tok_to_id[self.pad_tok] + + def export_data_aug(self): + return [ + self.export_tempo_aug(tempo_aug_range=0.2, mixup=True), + self.export_pitch_aug(5), + self.export_velocity_aug(1), + ] + + def _quantize_dur(self, time: int): + # This function will return values res >= 0 (inc. 0) + return self._find_closest_int(time, self.dur_time_quantizations) + + def _quantize_onset(self, time: int): + # This function will return values res >= 0 (inc. 0) + return self._find_closest_int(time, self.onset_time_quantizations) + + def _quantize_velocity(self, velocity: int): + # This function will return values in the range 0 < res =< 127 + velocity_quantized = self._find_closest_int( + velocity, self.velocity_quantizations + ) + + if velocity_quantized == 0 and velocity != 0: + return self.velocity_step + else: + return velocity_quantized + + def _format(self, prefix: list, unformatted_seq: list): + # If unformatted_seq is longer than 150 tokens insert diminish tok + idx = -100 + random.randint(-10, 10) + if len(unformatted_seq) > 150: + if ( + unformatted_seq[idx][0] == "onset" + ): # Don't want: note, , onset, due + unformatted_seq.insert(idx - 1, self.dim_tok) + elif ( + unformatted_seq[idx][0] == "dur" + ): # Don't want: note, onset, , dur + unformatted_seq.insert(idx - 2, self.dim_tok) + else: + unformatted_seq.insert(idx, self.dim_tok) + + res = prefix + [self.bos_tok] + unformatted_seq + [self.eos_tok] + + return res + + def tokenize_midi_dict(self, midi_dict: MidiDict): + ticks_per_beat = midi_dict.ticks_per_beat + midi_dict.remove_instruments(self.config["ignore_instruments"]) + + if len(midi_dict.note_msgs) == 0: + raise Exception("note_msgs is empty after ignoring instruments") + + channel_to_pedal_intervals = self._build_pedal_intervals(midi_dict) + + channels_used = {msg["channel"] for msg in midi_dict.note_msgs} + + channel_to_instrument = { + msg["channel"]: midi_dict.program_to_instrument[msg["data"]] + for msg in midi_dict.instrument_msgs + if msg["channel"] != 9 # Exclude drums + } + # If non-drum channel is missing from instrument_msgs, default to piano + for c in channels_used: + if channel_to_instrument.get(c) is None and c != 9: + channel_to_instrument[c] = "piano" + + # Add non-drums to present_instruments (prefix) + prefix = [ + ("prefix", "instrument", x) + for x in set(channel_to_instrument.values()) + ] + if 9 in channels_used: + prefix.append(("prefix", "instrument", "drum")) + + composer = midi_dict.metadata.get("composer") + if composer and (composer in self.composer_names): + prefix.insert(0, ("prefix", "composer", composer)) + + # NOTE: Any preceding silence is removed implicitly + tokenized_seq = [] + initial_onset_tick = midi_dict.note_msgs[0]["data"]["start"] + curr_time_since_onset = 0 + for _, msg in enumerate(midi_dict.note_msgs): + # Extract msg data + _channel = msg["channel"] + _pitch = msg["data"]["pitch"] + _velocity = msg["data"]["velocity"] + _start_tick = msg["data"]["start"] + _end_tick = msg["data"]["end"] + + # Calculate time data + prev_time_since_onset = curr_time_since_onset + curr_time_since_onset = get_duration_ms( + start_tick=initial_onset_tick, + end_tick=_start_tick, + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=ticks_per_beat, + ) + + # Add abs time token if necessary + time_toks_to_append = ( + curr_time_since_onset // self.abs_time_step + ) - (prev_time_since_onset // self.abs_time_step) + if time_toks_to_append > 0: + for _ in range(time_toks_to_append): + tokenized_seq.append(self.time_tok) + + # Special case instrument is a drum. This occurs exclusively when + # MIDI channel is 9 when 0 indexing + if _channel == 9: + _note_onset = self._quantize_onset( + curr_time_since_onset % self.abs_time_step + ) + tokenized_seq.append(("drum", _pitch)) + tokenized_seq.append(("onset", _note_onset)) + + else: # Non drum case (i.e. an instrument note) + _instrument = channel_to_instrument[_channel] + + # Update _end_tick if affected by pedal + for pedal_interval in channel_to_pedal_intervals[_channel]: + pedal_start, pedal_end = ( + pedal_interval[0], + pedal_interval[1], + ) + if ( + pedal_start <= _start_tick < pedal_end + and _end_tick < pedal_end + ): + _end_tick = pedal_end + + _note_duration = get_duration_ms( + start_tick=_start_tick, + end_tick=_end_tick, + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=ticks_per_beat, + ) + + # Quantize + _velocity = self._quantize_velocity(_velocity) + _note_onset = self._quantize_onset( + curr_time_since_onset % self.abs_time_step + ) + _note_duration = self._quantize_dur(_note_duration) + if _note_duration == 0: + _note_duration = self.time_step + + tokenized_seq.append((_instrument, _pitch, _velocity)) + tokenized_seq.append(("onset", _note_onset)) + tokenized_seq.append(("dur", _note_duration)) + + return self._format( + prefix=prefix, + unformatted_seq=tokenized_seq, + ) + + def detokenize_midi_dict(self, tokenized_seq: list): + instrument_programs = self.config["instrument_programs"] + # NOTE: These values chosen so that 1000 ticks = 1000ms, allowing us to + # skip converting between ticks and ms + TICKS_PER_BEAT = 500 + TEMPO = 500000 + + # Set message tempos + tempo_msgs = [{"type": "tempo", "data": TEMPO, "tick": 0}] + meta_msgs = [] + pedal_msgs = [] + instrument_msgs = [] + + instrument_to_channel = {} + + # Add non-drum instrument_msgs, breaks at first note token + channel_idx = 0 + for idx, tok in enumerate(tokenized_seq): + if channel_idx == 9: # Skip channel reserved for drums + channel_idx += 1 + + if tok in self.special_tokens: + # Skip special tokens + continue + elif ( + tok[0] == "prefix" + and tok[1] == "instrument" + and tok[2] in self.instruments_wd + ): + # Process instrument prefix tokens + if tok[2] in instrument_to_channel.keys(): + logging.warning(f"Duplicate prefix {tok[2]}") + continue + elif tok[2] == "drum": + instrument_msgs.append( + { + "type": "instrument", + "data": 0, + "tick": 0, + "channel": 9, + } + ) + instrument_to_channel["drum"] = 9 + else: + instrument_msgs.append( + { + "type": "instrument", + "data": instrument_programs[tok[2]], + "tick": 0, + "channel": channel_idx, + } + ) + instrument_to_channel[tok[2]] = channel_idx + channel_idx += 1 + elif tok[0] == "prefix": + # Skip all other prefix tokens + continue + else: + # Note, wait, or duration token + start = idx + break + + # Note messages + note_msgs = [] + curr_tick = 0 + for tok_1, tok_2, tok_3 in zip( + tokenized_seq[start:], + tokenized_seq[start + 1 :], + tokenized_seq[start + 2 :], + ): + if tok_1 in self.special_tokens: + _tok_type_1 = "special" + else: + _tok_type_1 = tok_1[0] + if tok_2 in self.special_tokens: + _tok_type_2 = "special" + else: + _tok_type_2 = tok_2[0] + if tok_3 in self.special_tokens: + _tok_type_3 = "special" + else: + _tok_type_3 = tok_3[0] + + if tok_1 == self.time_tok: + curr_tick += self.abs_time_step + + elif ( + _tok_type_1 == "special" + or _tok_type_1 == "prefix" + or _tok_type_1 == "onset" + or _tok_type_1 == "dur" + ): + continue + elif _tok_type_1 == "drum" and _tok_type_2 == "onset": + _start_tick = curr_tick + tok_2[1] + _end_tick = _start_tick + self.time_step + _pitch = tok_1[1] + _channel = instrument_to_channel["drum"] + _velocity = self.config["drum_velocity"] + + note_msgs.append( + { + "type": "note", + "data": { + "pitch": _pitch, + "start": _start_tick, + "end": _end_tick, + "velocity": _velocity, + }, + "tick": _start_tick, + "channel": _channel, + } + ) + + elif ( + _tok_type_1 in self.instruments_nd + and _tok_type_2 == "onset" + and _tok_type_3 == "dur" + ): + _pitch = tok_1[1] + _channel = instrument_to_channel.get(tok_1[0], None) + _velocity = tok_1[2] + _start_tick = curr_tick + tok_2[1] + _end_tick = _start_tick + tok_3[1] + + if _channel is None: + logging.warning( + "Tried to decode note message for unexpected instrument" + ) + else: + note_msgs.append( + { + "type": "note", + "data": { + "pitch": _pitch, + "start": _start_tick, + "end": _end_tick, + "velocity": _velocity, + }, + "tick": _start_tick, + "channel": _channel, + } + ) + + else: + logging.warning( + f"Unexpected token sequence: {tok_1}, {tok_2}, {tok_3}" + ) + + return MidiDict( + meta_msgs=meta_msgs, + tempo_msgs=tempo_msgs, + pedal_msgs=pedal_msgs, + instrument_msgs=instrument_msgs, + note_msgs=note_msgs, + ticks_per_beat=TICKS_PER_BEAT, + metadata={}, + ) + + def export_pitch_aug(self, aug_range: int): + """Exports a function that augments the pitch of all note tokens. + + Note that notes which fall out of the range (0, 127) will be replaced + with the unknown token ''. + + Args: + aug_range (int): Returned function will randomly augment the pitch + from a value in the range (-aug_range, aug_range). + + Returns: + Callable[list]: Exported function. + """ + + def pitch_aug_seq( + src: list, + unk_tok: str, + _aug_range: float, + ): + def pitch_aug_tok(tok, _pitch_aug): + if isinstance(tok, str): # Stand in for special tokens + _tok_type = "special" + else: + _tok_type = tok[0] + + if ( + _tok_type == "special" + or _tok_type == "prefix" + or _tok_type == "dur" + or _tok_type == "drum" + or _tok_type == "onset" + ): + # Return without changing + return tok + else: + # Return augmented tok + (_instrument, _pitch, _velocity) = tok - if self.tok_to_id is None: - raise NotImplementedError("tok_to_id") + if 0 <= _pitch + _pitch_aug <= 127: + return (_instrument, _pitch + _pitch_aug, _velocity) + else: + return unk_tok - if self.return_tensors is True: - encoded_seq = torch.tensor([_enc_fn(tok) for tok in unencoded_seq]) - else: - encoded_seq = [_enc_fn(tok) for tok in unencoded_seq] + pitch_aug = random.randint(-_aug_range, _aug_range) + return [pitch_aug_tok(x, pitch_aug) for x in src] - return encoded_seq + # See functools.partial docs + return functools.partial( + self.export_aug_fn_concat(aug_fn=pitch_aug_seq), + unk_tok=self.unk_tok, + _aug_range=aug_range, + ) - def decode(self, encoded_seq: list | torch.Tensor): - """Converts sequence of ids into the corresponding list of tokens.""" + def export_velocity_aug(self, aug_steps_range: int): + """Exports a function which augments the velocity of all pitch tokens. - def _dec_fn(id): - return self.id_to_tok.get(id, self.unk_tok) + This augmentation truncated such that it returns a valid note token. - if self.id_to_tok is None: - raise NotImplementedError("id_to_tok") + Args: + aug_steps_range (int): Returned function will randomly augment + velocity in the range aug_steps_range * (-self.velocity_step, + self.velocity step). - if isinstance(encoded_seq, torch.Tensor): - decoded_seq = [_dec_fn(idx) for idx in encoded_seq.tolist()] - else: - decoded_seq = [_dec_fn(idx) for idx in encoded_seq] + Returns: + Callable[str]: Exported function. + """ - return decoded_seq + def velocity_aug_seq( + src: list, + velocity_step: int, + max_velocity: int, + _aug_steps_range: int, + ): + def velocity_aug_tok(tok, _velocity_aug): + if isinstance(tok, str): # Stand in for special tokens + _tok_type = "special" + else: + _tok_type = tok[0] - def add_tokens_to_vocab(self, tokens: list | tuple): - for token in tokens: - assert token not in self.vocab + if ( + _tok_type == "special" + or _tok_type == "prefix" + or _tok_type == "dur" + or _tok_type == "drum" + or _tok_type == "onset" + ): + # Return without changing + return tok + else: + # Return augmented tok + (_instrument, _pitch, _velocity) = tok - self.vocab = self.vocab + tuple(tokens) - self.tok_to_id = {tok: idx for idx, tok in enumerate(self.vocab)} - self.id_to_tok = {v: k for k, v in self.tok_to_id.items()} - self.vocab_size = len(self.vocab) + # Check it doesn't go out of bounds + if _velocity + _velocity_aug >= max_velocity: + return (_instrument, _pitch, max_velocity) + elif _velocity + _velocity_aug <= velocity_step: + return (_instrument, _pitch, velocity_step) + + return (_instrument, _pitch, _velocity + _velocity_aug) + velocity_aug = velocity_step * random.randint( + -_aug_steps_range, _aug_steps_range + ) + return [velocity_aug_tok(x, velocity_aug) for x in src] -class TokenizerLazy(Tokenizer): - """Lazy MidiDict Tokenizer""" + # See functools.partial docs + return functools.partial( + self.export_aug_fn_concat(aug_fn=velocity_aug_seq), + velocity_step=self.velocity_step, + max_velocity=self.max_velocity, + _aug_steps_range=aug_steps_range, + ) - def __init__( - self, - return_tensors: bool = False, - ): + def export_tempo_aug(self, tempo_aug_range, mixup: bool): + # Chord mix up will randomly reorder concurrent notes. A concurrent + # notes are those which occur at the onset. + def tempo_aug( + src: list, + abs_time_step: int, + max_dur: int, + time_step: int, + unk_tok: str, + time_tok: str, + dim_tok: str, + start_tok: str, + end_tok: str, + instruments_wd: list, + _tempo_aug_range: float, + _mixup: bool, + ): + """This must be used with export_aug_fn_concat in order to work + properly for concatenated sequences.""" + + def _quantize_time(_n: int): + return round(_n / time_step) * time_step + + tempo_aug = random.uniform( + 1 - _tempo_aug_range, 1 + _tempo_aug_range + ) + + src_time_tok_cnt = 0 + dim_tok_seen = None + res = [] + note_buffer = None + buffer = defaultdict(lambda: defaultdict(list)) + for tok_1, tok_2, tok_3 in zip(src, src[1:], src[2:]): + if tok_1 == time_tok: + _tok_type = "time" + elif tok_1 == unk_tok: + _tok_type = "unk" + elif tok_1 == start_tok: + res.append(tok_1) + continue + elif tok_1 == dim_tok and note_buffer: + dim_tok_seen = (src_time_tok_cnt, note_buffer["onset"][1]) + continue + elif tok_1[0] == "prefix": + res.append(tok_1) + continue + elif tok_1[0] in instruments_wd: + _tok_type = tok_1[0] + else: + # This only triggers for incomplete notes at the beginning, + # e.g. an onset token before a note token is seen + continue + + if _tok_type == "time": + src_time_tok_cnt += 1 + elif _tok_type == "drum": + note_buffer = { + "note": tok_1, + "onset": tok_2, + "dur": None, + } + buffer[src_time_tok_cnt][tok_2[1]].append(note_buffer) + else: # unk or in instruments_wd + note_buffer = { + "note": tok_1, + "onset": tok_2, + "dur": tok_3, + } + buffer[src_time_tok_cnt][tok_2[1]].append(note_buffer) + + prev_tgt_time_tok_cnt = 0 + for src_time_tok_cnt, interval_notes in sorted(buffer.items()): + for src_onset, notes_by_onset in sorted(interval_notes.items()): + src_time = src_time_tok_cnt * abs_time_step + src_onset + tgt_time = round(src_time * tempo_aug) + curr_tgt_time_tok_cnt = tgt_time // abs_time_step + curr_tgt_onset = _quantize_time(tgt_time % abs_time_step) + + for _ in range( + curr_tgt_time_tok_cnt - prev_tgt_time_tok_cnt + ): + res.append(time_tok) + prev_tgt_time_tok_cnt = curr_tgt_time_tok_cnt + + if _mixup == True: + random.shuffle(notes_by_onset) + + for note in notes_by_onset: + _src_note_tok = note["note"] + _src_dur_tok = note["dur"] + + if _src_dur_tok is not None: + tgt_dur = _quantize_time( + round(_src_dur_tok[1] * tempo_aug) + ) + tgt_dur = min(tgt_dur, max_dur) + else: + tgt_dur = None + + res.append(_src_note_tok) + res.append(("onset", curr_tgt_onset)) + if tgt_dur: + res.append(("dur", tgt_dur)) + + if dim_tok_seen is not None and dim_tok_seen == ( + src_time_tok_cnt, + src_onset, + ): + res.append(dim_tok) + dim_tok_seen = None + + if src[-1] == end_tok: + res.append(end_tok) + + return res + + return functools.partial( + self.export_aug_fn_concat(aug_fn=tempo_aug), + abs_time_step=self.abs_time_step, + max_dur=self.max_dur, + time_step=self.time_step, + unk_tok=self.unk_tok, + time_tok=self.time_tok, + dim_tok=self.dim_tok, + end_tok=self.eos_tok, + start_tok=self.bos_tok, + instruments_wd=self.instruments_wd, + _tempo_aug_range=tempo_aug_range, + _mixup=mixup, + ) + + +class RelTokenizer(Tokenizer): + """MidiDict tokenizer implemented with relative onset timings""" + + def __init__(self, return_tensors: bool = False): super().__init__(return_tensors) - self.config = load_config()["tokenizer"]["lazy"] - self.name = "lazy" + self.config = load_config()["tokenizer"]["rel"] + self.name = "rel" # Calculate time quantizations (in ms) self.num_time_step = self.config["time_quantization"]["num_steps"] - self.min_time_step = self.config["time_quantization"]["min_step"] + self.min_time_step = self.config["time_quantization"]["step"] self.time_step_quantizations = [ self.min_time_step * i for i in range(self.num_time_step) ] @@ -186,68 +933,23 @@ def __init__( + self.dur_tokens + self.wait_tokens ) - self.pad_id = self.tok_to_id[self.pad_tok] - def _build_pedal_intervals(self, midi_dict: MidiDict): - """Returns pedal-on intervals for each channel.""" - channel_to_pedal_intervals = defaultdict(list) - pedal_status = {} - - for pedal_msg in midi_dict.pedal_msgs: - tick = pedal_msg["tick"] - channel = pedal_msg["channel"] - data = pedal_msg["data"] - - if data == 1 and pedal_status.get(channel, None) is None: - pedal_status[channel] = tick - elif data == 0 and pedal_status.get(channel, None) is not None: - # Close pedal interval - _start_tick = pedal_status[channel] - _end_tick = tick - channel_to_pedal_intervals[channel].append( - [_start_tick, _end_tick] - ) - del pedal_status[channel] - - # Close all unclosed pedals at end of track - final_tick = midi_dict.note_msgs[-1]["data"]["end"] - for channel, start_tick in pedal_status.items(): - channel_to_pedal_intervals[channel].append([start_tick, final_tick]) - - return channel_to_pedal_intervals - - @classmethod - def _find_closest_int(cls, n: int, sorted_list: list): - # Selects closest integer to n from sorted_list - # Time ~ Log(n) - - left, right = 0, len(sorted_list) - 1 - closest = float("inf") - - while left <= right: - mid = (left + right) // 2 - diff = abs(sorted_list[mid] - n) - - if diff < abs(closest - n): - closest = sorted_list[mid] - - if sorted_list[mid] < n: - left = mid + 1 - else: - right = mid - 1 - - return closest + def export_data_aug(self): + return [ + self.export_chord_mixup(), + self.export_tempo_aug(tempo_aug_range=0.2), + self.export_pitch_aug(5), + self.export_velocity_aug(1), + ] def _quantize_time(self, time: int): # This function will return values res >= 0 (inc. 0) - return TokenizerLazy._find_closest_int( - time, self.time_step_quantizations - ) + return self._find_closest_int(time, self.time_step_quantizations) def _quantize_velocity(self, velocity: int): # This function will return values in the range 0 < res =< 127 - velocity_quantized = TokenizerLazy._find_closest_int( + velocity_quantized = self._find_closest_int( velocity, self.velocity_quantizations ) @@ -376,7 +1078,9 @@ def tokenize_midi_dict(self, midi_dict: MidiDict): def detokenize_midi_dict(self, tokenized_seq: list): instrument_programs = self.config["instrument_programs"] - TICKS_PER_BEAT = 480 + # NOTE: These values chosen so that 1000 ticks = 1000ms, allowing us to + # skip converting between ticks and ms + TICKS_PER_BEAT = 500 TEMPO = 500000 # Set message tempos @@ -457,13 +1161,7 @@ def detokenize_midi_dict(self, tokenized_seq: list): ): continue elif _curr_tok_type == "wait": - curr_tick += int( - second2tick( - second=1e-3 * curr_tok[1], - ticks_per_beat=TICKS_PER_BEAT, - tempo=TEMPO, - ) - ) + curr_tick += curr_tok[1] elif _curr_tok_type == "drum": _tick = curr_tick _pitch = curr_tok[1] @@ -502,13 +1200,7 @@ def detokenize_midi_dict(self, tokenized_seq: list): _pitch = curr_tok[1] _velocity = curr_tok[2] _start_tick = curr_tick - _end_tick = curr_tick + int( - second2tick( - second=1e-3 * duration, - ticks_per_beat=TICKS_PER_BEAT, - tempo=TEMPO, - ) - ) + _end_tick = curr_tick + duration note_msgs.append( { @@ -539,60 +1231,6 @@ def detokenize_midi_dict(self, tokenized_seq: list): metadata={}, ) - def export_aug_fn_concat(self, aug_fn: Callable): - """Exports a function that splits src before augmenting. - - This is useful for augmentation functions that expect pure sequences - instead of concatenated ones (like those given by PretrainedDataset). - """ - - def _aug_fn_concat( - src: list, - _aug_fn: Callable, - pad_tok: str, - eos_tok: str, - **kwargs, - ): - # Split list on '' - initial_seq_len = len(src) - src_sep = [] - prev_idx = 0 - for curr_idx, tok in enumerate(src, start=1): - if tok == eos_tok: - src_sep.append(src[prev_idx:curr_idx]) - prev_idx = curr_idx - - # Last sequence - if prev_idx != curr_idx: - src_sep.append(src[prev_idx:]) - - # Augment - src_sep = [ - _aug_fn( - _src, - **kwargs, - ) - for _src in src_sep - ] - - # Concatenate - src_aug_concat = [tok for src_aug in src_sep for tok in src_aug] - - # Pad or truncate to original sequence length as necessary - src_aug_concat = src_aug_concat[:initial_seq_len] - src_aug_concat += [pad_tok] * ( - initial_seq_len - len(src_aug_concat) - ) - - return src_aug_concat - - return functools.partial( - _aug_fn_concat, - _aug_fn=aug_fn, - pad_tok=self.pad_tok, - eos_tok=self.eos_tok, - ) - def export_pitch_aug(self, aug_range: int): """Exports a function that augments the pitch of all note tokens. diff --git a/aria/train.py b/aria/train.py index a676bd8..618ae81 100644 --- a/aria/train.py +++ b/aria/train.py @@ -17,8 +17,12 @@ from aria.config import load_model_config from aria.model import ModelConfig, TransformerLM -from aria.tokenizer import TokenizerLazy -from aria.data.datasets import PretrainingDataset, FinetuningDataset +from aria.tokenizer import Tokenizer, AbsTokenizer, RelTokenizer +from aria.data.datasets import ( + TrainingDataset, + PretrainingDataset, + FinetuningDataset, +) # ----- USAGE ----- @@ -83,6 +87,21 @@ def setup_logger(project_dir: str): return get_logger(__name__) # using accelerate.logging.get_logger() +def get_tokenizer_name( + train_data_path: str, + val_data_path: str, +): + """This will throw an error if there is a tokenizer mismatch""" + train_config = TrainingDataset.get_config_from_path(train_data_path) + val_config = TrainingDataset.get_config_from_path(val_data_path) + + assert ( + train_config["tokenizer_name"] == val_config["tokenizer_name"] + ), "Dataset tokenizers don't match" + + return train_config["tokenizer_name"] + + def setup_project_dir(project_dir: str | None): if not project_dir: # Create project directory @@ -208,7 +227,7 @@ def get_finetune_optim( def get_pretrain_dataloaders( train_data_dir: str, val_data_dir: str, - tokenizer: TokenizerLazy, + tokenizer: Tokenizer, batch_size: int, num_workers: int, init_epoch: int | None = None, @@ -238,14 +257,7 @@ def get_pretrain_dataloaders( ), "val-data directory should only contain one epoch" if apply_aug: - train_dataset.set_transform( - [ - tokenizer.export_chord_mixup(), - tokenizer.export_velocity_aug(1), - tokenizer.export_pitch_aug(5), - tokenizer.export_tempo_aug(0.2), - ] - ) + train_dataset.set_transform(tokenizer.export_data_aug()) train_dataloader = DataLoader( train_dataset, @@ -266,7 +278,7 @@ def get_pretrain_dataloaders( def get_finetune_dataloaders( train_data_path: str, val_data_path: str, - tokenizer: TokenizerLazy, + tokenizer: Tokenizer, batch_size: int, num_workers: int, apply_aug: bool = True, @@ -282,14 +294,7 @@ def get_finetune_dataloaders( ) if apply_aug: - train_dataset.set_transform( - [ - tokenizer.export_chord_mixup(), - tokenizer.export_velocity_aug(1), - tokenizer.export_pitch_aug(5), - tokenizer.export_tempo_aug(0.2), - ] - ) + train_dataset.set_transform(tokenizer.export_data_aug()) train_dataloader = DataLoader( train_dataset, @@ -360,19 +365,13 @@ def _bench(): f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} " "parameters" ) - logger.info("Profiling FLOP/s") + logger.info("Profiling FLOP") _bench() with flop_counter: _bench() - total_flops = sum(flop_counter.get_flop_counts()["Global"].values()) - ms_per_iter = do_bench(_bench) - iters_per_second = 1e3 / ms_per_iter - - logger.info( - f"{total_flops / 1e12} TF, " - f"{iters_per_second * total_flops / 1e12} TF/s (not warm)" - ) + total_flop = sum(flop_counter.get_flop_counts()["Global"].values()) + logger.info(f"Forwards & backwards FLOP: {total_flop / 1e12} TF") def make_checkpoint(_accelerator, _epoch: int, _step: int): checkpoint_dir = os.path.join( @@ -586,6 +585,14 @@ def resume_train( else: raise Exception + tokenizer_name = get_tokenizer_name(train_data_path, val_data_path) + if tokenizer_name == "abs": + tokenizer = AbsTokenizer(return_tensors=True) + elif tokenizer_name == "rel": + tokenizer = RelTokenizer(return_tensors=True) + else: + raise Exception("Invalid tokenizer name") + # TODO: Add support for verifying the resume_step and epoch, keep these # save these variables as part of the state during checkpointing project_dir = setup_project_dir(project_dir) @@ -614,10 +621,9 @@ def resume_train( logger.info(f"Creating checkpoints every {steps_per_checkpoint}") # Init model - tokenizer = TokenizerLazy(return_tensors=True) model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) - model = torch.compile(TransformerLM(model_config), mode="default") + model = TransformerLM(model_config) if mode == "pretrain": train_dataloader, val_dataloader = get_pretrain_dataloaders( @@ -672,10 +678,17 @@ def resume_train( scheduler, ) - accelerator.load_state(checkpoint_dir) + try: + accelerator.load_state(checkpoint_dir) + except Exception as e: + raise Exception( + f"Failed to load checkpoint: {e}\n" + "This could be due to a mismatch between the tokenizer used " + "to build the pre-training and fine-tuning datasets" + ) logger.info(f"Loaded checkpoint at {checkpoint_dir}") - logger.info("Starting train job") + _train( epochs=epochs, accelerator=accelerator, @@ -724,6 +737,14 @@ def train( else: raise Exception + tokenizer_name = get_tokenizer_name(train_data_path, val_data_path) + if tokenizer_name == "abs": + tokenizer = AbsTokenizer(return_tensors=True) + elif tokenizer_name == "rel": + tokenizer = RelTokenizer(return_tensors=True) + else: + raise Exception("Invalid tokenizer name") + project_dir = setup_project_dir(project_dir) accelerator = accelerate.Accelerator(project_dir=project_dir) logger = setup_logger(project_dir) @@ -738,17 +759,22 @@ def train( ) # Init model - tokenizer = TokenizerLazy(return_tensors=True) model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) model = TransformerLM(model_config) logger.info(f"Loaded model with config: {load_model_config(model_name)}") if mode == "finetune": - model.load_state_dict(torch.load(finetune_cp_path)) + try: + model.load_state_dict(torch.load(finetune_cp_path)) + except Exception as e: + raise Exception( + f"Failed to load checkpoint: {e}\n" + "This could be due to a mismatch between the tokenizer used " + "to build the pre-training and fine-tuning datasets" + ) logger.info( f"Loaded finetune checkpoint located at: {finetune_cp_path}" ) - model = torch.compile(model, mode="default") if mode == "pretrain": train_dataloader, val_dataloader = get_pretrain_dataloaders( @@ -828,23 +854,22 @@ def convert_cp_from_safetensors(checkpoint_path: str, save_path: str): def convert_cp_from_accelerate( model_name: str, checkpoint_dir: str, save_path: str ): - # Converts a compiled model checkpoint into one that can be loaded directly - logger = get_logger(__name__) + def _load_state_dict(_tokenizer: Tokenizer): + model_config = ModelConfig(**load_model_config(model_name)) + model_config.set_vocab_size(_tokenizer.vocab_size) + model = TransformerLM(model_config) + model = accelerator.prepare(model) + accelerator.load_state(checkpoint_dir) + + return model.state_dict() + accelerator = accelerate.Accelerator() - tokenizer = TokenizerLazy(return_tensors=True) - model_config = ModelConfig(**load_model_config(model_name)) - model_config.set_vocab_size(tokenizer.vocab_size) - model = torch.compile(TransformerLM(model_config), mode="default") - - model = accelerator.prepare(model) - accelerator.load_state(checkpoint_dir) - state_dict = model.state_dict() - for key in list(state_dict.keys()): - if key.startswith("_orig_mod."): - new_key = key[len("_orig_mod.") :] - state_dict[new_key] = state_dict.pop(key) - else: - logger.warning(f"Found unexpected key: {key}") + + # Try both tokenizers + try: + state_dict = _load_state_dict(_tokenizer=AbsTokenizer()) + except: + state_dict = _load_state_dict(_tokenizer=RelTokenizer()) torch.save(state_dict, save_path) @@ -872,8 +897,8 @@ def parse_resume_args(): def parse_pretrain_args(): argp = argparse.ArgumentParser(prog="python aria/train.py pretrain") argp.add_argument("model", help="name of model config file") - argp.add_argument("train_dir", help="path to train dir") - argp.add_argument("val_data", help="path to val data") + argp.add_argument("train_data", help="path to train dir") + argp.add_argument("val_data", help="path to val dir") argp.add_argument("-epochs", help="train epochs", type=int, required=True) argp.add_argument("-bs", help="batch size", type=int, default=32) argp.add_argument("-workers", help="number workers", type=int, default=1) diff --git a/aria/utils.py b/aria/utils.py index a689c87..acf1ef6 100644 --- a/aria/utils.py +++ b/aria/utils.py @@ -8,7 +8,11 @@ def midi_to_audio(mid_path: str, soundfont_path: str | None = None): - SOUNDFONT_PATH = "fluidsynth/DoreMarkYamahaS6-v1.6.sf2" + SOUNDFONT_PATH = os.path.join( + os.path.dirname(__file__), + "..", + "fluidsynth/DoreMarkYamahaS6-v1.6.sf2", + ) DOWNLOAD_URL = "https://www.dropbox.com/scl/fi/t8gou8stesm42sc559nzu/DoreMarkYamahaS6-v1.6.sf2?rlkey=28ecl63kkjjmwxrkd6hnzsq8f&dl=1" if os.name != "posix": diff --git a/config/config.json b/config/config.json index b4f827d..340f8c1 100644 --- a/config/config.json +++ b/config/config.json @@ -74,7 +74,7 @@ }, "tokenizer": { - "lazy": { + "rel": { "ignore_instruments": { "piano": false, "chromatic": true, @@ -117,9 +117,55 @@ }, "time_quantization": { "num_steps": 500, - "min_step": 10 + "step": 10 }, "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"] + }, + "abs": { + "ignore_instruments": { + "piano": false, + "chromatic": true, + "organ": false, + "guitar": false, + "bass": false, + "strings": false, + "ensemble": false, + "brass": false, + "reed": false, + "pipe": false, + "synth_lead": false, + "synth_pad": true, + "synth_effect": true, + "ethnic": true, + "percussive": true, + "sfx": true + }, + "instrument_programs": { + "piano": 0, + "chromatic": 13, + "organ": 16, + "guitar": 24, + "bass": 32, + "strings": 40, + "ensemble": 48, + "brass": 56, + "reed": 64, + "pipe": 73, + "synth_lead": 80, + "synth_pad": 88, + "synth_effect": 96, + "ethnic": 104, + "percussive": 112, + "sfx": 120 + }, + "drum_velocity": 60, + "velocity_quantization": { + "step": 15 + }, + "abs_time_step_ms": 5000, + "max_dur_ms": 5000, + "time_step_ms": 10, + "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"] } } } diff --git a/requirements.txt b/requirements.txt index 1464769..c70ea51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,6 @@ accelerate mido jsonlines pydub +tqdm einops safetensors \ No newline at end of file diff --git a/scripts/midi_to_audio.py b/scripts/midi_to_audio.py new file mode 100644 index 0000000..e63b709 --- /dev/null +++ b/scripts/midi_to_audio.py @@ -0,0 +1,13 @@ +import glob + +from aria.utils import midi_to_audio + + +def main(): + paths = glob.glob("samples/*.mid") + for path in paths: + midi_to_audio(path) + + +if __name__ == "__main__": + main() diff --git a/tests/reference_implementations.py b/tests/reference_implementations.py index 3d3f90f..d653781 100644 --- a/tests/reference_implementations.py +++ b/tests/reference_implementations.py @@ -43,4 +43,4 @@ def apply_rotary_pos_emb_reference(x, cos, sin, interleaved=False): x[..., ro_dim:], ], dim=-1, - ) \ No newline at end of file + ) diff --git a/tests/test_data.py b/tests/test_data.py index 487402e..70facba 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -8,10 +8,27 @@ from aria.data.midi import MidiDict from aria.data import jsonl_zst +TEST_TOKENIZER = "abs" +logger = logging.getLogger(__name__) if not os.path.isdir("tests/test_results"): os.makedirs("tests/test_results") +def setup_logger(): + logger = logging.getLogger(__name__) + for h in logger.handlers[:]: + logger.removeHandler(h) + logger.propagate = False + logger.setLevel(logging.INFO) + formatter = logging.Formatter( + "[%(asctime)s] tests.test_data: [%(levelname)s] %(message)s" + ) + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + def get_short_seq(): return [ ("prefix", "instrument", "piano"), @@ -95,15 +112,15 @@ def test_data_hash(self): self.assertEqual(mid_1.calculate_hash(), mid_2.calculate_hash()) -# TODO: Fix failing tests and fix the del thing not working correctly. - - class TestPretrainingDataset(unittest.TestCase): def test_build(self): MAX_SEQ_LEN = 512 - tknzr = tokenizer.TokenizerLazy( - return_tensors=False, - ) + if TEST_TOKENIZER == "abs": + tknzr = tokenizer.AbsTokenizer(return_tensors=False) + elif TEST_TOKENIZER == "rel": + tknzr = tokenizer.RelTokenizer(return_tensors=False) + else: + raise KeyError mididict_dataset = datasets.MidiDataset.build( dir="tests/test_data", recur=True, @@ -132,9 +149,12 @@ def test_build(self): def test_mmap(self): MAX_SEQ_LEN = 512 - tknzr = tokenizer.TokenizerLazy( - return_tensors=False, - ) + if TEST_TOKENIZER == "abs": + tknzr = tokenizer.AbsTokenizer(return_tensors=False) + elif TEST_TOKENIZER == "rel": + tknzr = tokenizer.RelTokenizer(return_tensors=False) + else: + raise KeyError mididict_dataset = datasets.MidiDataset.build( dir="tests/test_data", recur=True, @@ -153,14 +173,17 @@ def test_mmap(self): self.assertEqual(len({len(_) for _ in raw_entries}), 1) src, tgt = pretrain_dataset[0] - logging.info(f"src: {tknzr.decode(src)[:50]}") - logging.info(f"tgt: {tknzr.decode(tgt)[:50]}") + logger.info(f"src: {tknzr.decode(src)[:50]}") + logger.info(f"tgt: {tknzr.decode(tgt)[:50]}") - def test_augmentation(self): + def test_aug(self): MAX_SEQ_LEN = 512 - tknzr = tokenizer.TokenizerLazy( - return_tensors=False, - ) + if TEST_TOKENIZER == "abs": + tknzr = tokenizer.AbsTokenizer(return_tensors=False) + elif TEST_TOKENIZER == "rel": + tknzr = tokenizer.RelTokenizer(return_tensors=False) + else: + raise KeyError mididict_dataset = datasets.MidiDataset.build( dir="tests/test_data", recur=True, @@ -174,27 +197,14 @@ def test_augmentation(self): num_epochs=1, midi_dataset=mididict_dataset, ) - pretrain_dataset.set_transform( - [ - tknzr.export_chord_mixup(), - tknzr.export_pitch_aug(5), - tknzr.export_velocity_aug(2), - tknzr.export_tempo_aug(0.5), - ] - ) - - seq = get_short_seq() - seq_augmented = pretrain_dataset._transform(seq) + pretrain_dataset.set_transform(tknzr.export_data_aug()) + for idx, seq in enumerate(tknzr.decode(pretrain_dataset[0][0])): + for _idx, tok in enumerate(seq): + if tok == tknzr.unk_tok: + logger.warning(f"unk_tok seen at seq={idx}, idx={_idx}") - logging.info(f"aug:\n{seq} ->\n{seq_augmented}") - self.assertEqual( - seq_augmented[4][1] - seq[4][1], - seq_augmented[8][1] - seq[8][1], - ) - self.assertEqual( - seq_augmented[4][2] - seq[4][2], - seq_augmented[8][2] - seq[8][2], - ) + logger.info(f"data_aug_1: {tknzr.decode(pretrain_dataset[0][0][:50])}") + logger.info(f"data_aug_2: {tknzr.decode(pretrain_dataset[0][0][:50])}") class TestFinetuningDataset(unittest.TestCase): @@ -202,9 +212,12 @@ class TestFinetuningDataset(unittest.TestCase): def test_build(self): MAX_SEQ_LEN = 512 STRIDE_LEN = 256 - tknzr = tokenizer.TokenizerLazy( - return_tensors=False, - ) + if TEST_TOKENIZER == "abs": + tknzr = tokenizer.AbsTokenizer(return_tensors=False) + elif TEST_TOKENIZER == "rel": + tknzr = tokenizer.RelTokenizer(return_tensors=False) + else: + raise KeyError mididict_dataset = datasets.MidiDataset.build( dir="tests/test_data", recur=True, @@ -237,28 +250,39 @@ def test_build(self): self.assertEqual(len({len(_) for _ in raw_entries}), 1) src, tgt = finetune_dataset_from_file[0] - logging.info(f"src: {tknzr.decode(src)[:50]}") - logging.info(f"tgt: {tknzr.decode(tgt)[:50]}") - - finetune_dataset_from_file.set_transform( - [ - tknzr.export_pitch_aug(5), - tknzr.export_velocity_aug(2), - tknzr.export_tempo_aug(0.5), - ] - ) - seq = get_short_seq() - seq_augmented = finetune_dataset_from_file._transform(seq) + logger.info(f"src: {tknzr.decode(src)[:50]}") + logger.info(f"tgt: {tknzr.decode(tgt)[:50]}") - logging.info(f"aug:\n{seq} ->\n{seq_augmented}") - self.assertEqual( - seq_augmented[4][1] - seq[4][1], - seq_augmented[8][1] - seq[8][1], + def test_aug(self): + MAX_SEQ_LEN = 512 + STRIDE_LEN = 256 + if TEST_TOKENIZER == "abs": + tknzr = tokenizer.AbsTokenizer(return_tensors=False) + elif TEST_TOKENIZER == "rel": + tknzr = tokenizer.RelTokenizer(return_tensors=False) + else: + raise KeyError + mididict_dataset = datasets.MidiDataset.build( + dir="tests/test_data", + recur=True, ) - self.assertEqual( - seq_augmented[4][2] - seq[4][2], - seq_augmented[8][2] - seq[8][2], + if os.path.isfile("tests/test_results/finetune_dataset_buff.jsonl"): + os.remove("tests/test_results/finetune_dataset_buff.jsonl") + finetune_dataset = datasets.FinetuningDataset.build( + tokenizer=tknzr, + save_path="tests/test_results/finetune_dataset_buff.jsonl", + max_seq_len=MAX_SEQ_LEN, + stride_len=STRIDE_LEN, + midi_dataset=mididict_dataset, ) + finetune_dataset.set_transform(tknzr.export_data_aug()) + for idx, seq in enumerate(tknzr.decode(finetune_dataset[0][0])): + for _idx, tok in enumerate(seq): + if tok == tknzr.unk_tok: + logger.warning(f"unk_tok seen at seq={idx}, idx={_idx}") + + logger.info(f"data_aug_1: {tknzr.decode(finetune_dataset[0][0][:50])}") + logger.info(f"data_aug_2: {tknzr.decode(finetune_dataset[0][0][:50])}") class TestReaderWriter(unittest.TestCase): @@ -278,9 +302,6 @@ def test_jsonl_zst(self): os.remove(filename) +setup_logger() if __name__ == "__main__": - if os.path.isdir("tests/test_results") is False: - os.mkdir("tests/test_results") - - logging.basicConfig(level=logging.INFO) unittest.main() diff --git a/tests/test_models.py b/tests/test_models.py index 916e65f..1e2ed98 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -7,14 +7,14 @@ from aria.config import load_model_config from aria.sample import greedy_sample from aria.model.model import YaRNConfig +from aria.tokenizer import AbsTokenizer from aria.model.utils import apply_rotary_pos_emb from .reference_implementations import apply_rotary_pos_emb_reference -from aria.tokenizer import TokenizerLazy class TestModel(unittest.TestCase): def test_yarn_config(self): - tokenizer = TokenizerLazy(return_tensors=True) + tokenizer = AbsTokenizer(return_tensors=True) model_config = ModelConfig(**load_model_config("test")) model_config.set_vocab_size(tokenizer.vocab_size) model = TransformerLM(model_config) @@ -30,7 +30,9 @@ def test_yarn_config(self): def test_rope_util_fns(self): q = torch.rand(4, 8, 12, 64) - inv_freq = 1 / (10000 ** (torch.arange(0, 64, 2, dtype=torch.float32) / 64)) + inv_freq = 1 / ( + 10000 ** (torch.arange(0, 64, 2, dtype=torch.float32) / 64) + ) t = torch.arange(8, dtype=inv_freq.dtype) freqs = torch.outer(t, inv_freq) cos = torch.cos(freqs) @@ -40,7 +42,7 @@ def test_rope_util_fns(self): assert torch.allclose(q, q_ref, atol=1e-5) def test_attn_mask(self): - tokenizer = TokenizerLazy(return_tensors=True) + tokenizer = AbsTokenizer(return_tensors=True) model_config = ModelConfig(**load_model_config("test")) model_config.set_vocab_size(tokenizer.vocab_size) model = TransformerLM(model_config) @@ -50,13 +52,19 @@ def test_attn_mask(self): model = TransformerLM(model_config).eval() inp = torch.randint(0, 10000, (1, 10)) - attn_mask = torch.concat([torch.zeros((1, 5), dtype=torch.bool), torch.ones((1, 5), dtype=torch.bool)], dim=-1) + attn_mask = torch.concat( + [ + torch.zeros((1, 5), dtype=torch.bool), + torch.ones((1, 5), dtype=torch.bool), + ], + dim=-1, + ) out = model(inp, attn_mask=attn_mask) out2 = model(inp[:, -5:]) assert torch.allclose(out[:, -5:], out2, atol=1e-5) def test_generation(self): - tokenizer = TokenizerLazy(return_tensors=True) + tokenizer = AbsTokenizer(return_tensors=True) model_config = ModelConfig(**load_model_config("test")) model_config.set_vocab_size(tokenizer.vocab_size) model = TransformerLM(model_config) @@ -74,7 +82,9 @@ def test_generation(self): device=torch.device("cpu"), max_new_tokens=50, ) - prompts = [[tokenizer.pad_tok] + tokenizer.tokenize(midi_dict=midi_dict)[:50]] * 3 + prompts = [ + [tokenizer.pad_tok] + tokenizer.tokenize(midi_dict=midi_dict)[:50] + ] * 3 out2 = greedy_sample( model, tokenizer, diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py index a090ef8..01e3f67 100644 --- a/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -3,15 +3,116 @@ import os import time +from typing import Callable + from aria import tokenizer from aria.data.midi import MidiDict +from aria.utils import midi_to_audio if not os.path.isdir("tests/test_results"): os.makedirs("tests/test_results") -def get_short_seq(tknzr: tokenizer.TokenizerLazy): +# TODO: Implement with tokenizer functions +def get_short_seq_abs(tknzr: tokenizer.AbsTokenizer): + return [ + ("prefix", "instrument", "piano"), + ("prefix", "instrument", "drum"), + "", + ("piano", 62, tknzr._quantize_velocity(45)), + ("onset", tknzr._quantize_onset(0)), + ("dur", tknzr._quantize_dur(50)), + ("drum", 50), + ("onset", tknzr._quantize_onset(100)), + ("piano", 64, tknzr._quantize_velocity(75)), + ("onset", tknzr._quantize_onset(100)), + ("dur", tknzr._quantize_dur(5000)), + "", + "", + "", + ("piano", 65, tknzr._quantize_velocity(75)), + ("onset", tknzr._quantize_onset(170)), + ("dur", tknzr._quantize_dur(100)), + "", + ("piano", 60, tknzr._quantize_velocity(45)), + ("onset", tknzr._quantize_onset(270)), + ("dur", tknzr._quantize_dur(60)), + "", + ("onset", tknzr._quantize_onset(270)), + ("dur", tknzr._quantize_dur(70)), + ("drum", 50), + ("onset", tknzr._quantize_onset(270)), + "", + ("piano", 80, tknzr._quantize_velocity(45)), + ("onset", tknzr._quantize_onset(270)), + ("dur", tknzr._quantize_dur(80)), + "", + ] + + +def get_concat_seq_abs(tknzr: tokenizer.AbsTokenizer): + return [ + ("onset", tknzr._quantize_onset(270)), + ("dur", tknzr._quantize_dur(60)), + "", + ("onset", tknzr._quantize_onset(270)), + ("dur", tknzr._quantize_dur(70)), + ("drum", 50), + ("onset", tknzr._quantize_onset(270)), + "", + ("piano", 80, tknzr._quantize_velocity(45)), + ("onset", tknzr._quantize_onset(270)), + ("dur", tknzr._quantize_dur(80)), + "", + ("prefix", "instrument", "piano"), + ("prefix", "instrument", "drum"), + "", + ("piano", 62, tknzr._quantize_velocity(45)), + ("onset", tknzr._quantize_onset(0)), + ("dur", tknzr._quantize_dur(50)), + ("drum", 50), + ("onset", tknzr._quantize_onset(100)), + ("piano", 64, tknzr._quantize_velocity(75)), + ("onset", tknzr._quantize_onset(100)), + ("dur", tknzr._quantize_dur(5000)), + "", + "", + "", + ("piano", 65, tknzr._quantize_velocity(75)), + ("onset", tknzr._quantize_onset(170)), + ("dur", tknzr._quantize_dur(100)), + "", + ("piano", 60, tknzr._quantize_velocity(45)), + ("onset", tknzr._quantize_onset(270)), + ("dur", tknzr._quantize_dur(60)), + "", + ("onset", tknzr._quantize_onset(270)), + ("dur", tknzr._quantize_dur(70)), + ("drum", 50), + ("onset", tknzr._quantize_onset(270)), + "", + ("piano", 80, tknzr._quantize_velocity(45)), + ("onset", tknzr._quantize_onset(270)), + ("dur", tknzr._quantize_dur(80)), + "", + ("prefix", "instrument", "piano"), + ("prefix", "instrument", "drum"), + "", + ("piano", 62, tknzr._quantize_velocity(45)), + ("onset", tknzr._quantize_onset(0)), + ("dur", tknzr._quantize_dur(50)), + ("drum", 50), + ("onset", tknzr._quantize_onset(100)), + ("piano", 64, tknzr._quantize_velocity(75)), + ("onset", tknzr._quantize_onset(100)), + ("dur", tknzr._quantize_dur(5000)), + "", + "", + ] + + +def get_short_seq_rel(tknzr: tokenizer.RelTokenizer): return [ ("prefix", "instrument", "piano"), ("prefix", "instrument", "drum"), @@ -20,7 +121,7 @@ def get_short_seq(tknzr: tokenizer.TokenizerLazy): ("piano", 62, tknzr._quantize_velocity(50)), ("dur", tknzr._quantize_time(50)), ("wait", tknzr._quantize_time(100)), - ("drum", tknzr._quantize_time(50)), + ("drum", 50), ("piano", 64, tknzr._quantize_velocity(70)), ("dur", tknzr._quantize_time(1000000)), ("wait", tknzr._quantize_time(1000000)), @@ -42,7 +143,7 @@ def get_short_seq(tknzr: tokenizer.TokenizerLazy): ] -def get_concat_seq(tknzr: tokenizer.TokenizerLazy): +def get_concat_seq_rel(tknzr: tokenizer.RelTokenizer): return [ ("dur", tknzr._quantize_time(1000000)), ("wait", tknzr._quantize_time(1000000)), @@ -99,7 +200,145 @@ def get_concat_seq(tknzr: tokenizer.TokenizerLazy): ] -class TestLazyTokenizer(unittest.TestCase): +class TestAbsTokenizer(unittest.TestCase): + def test_tokenize_detokenize_mididict(self): + def tokenize_detokenize(file_name: str): + mid_path = f"tests/test_data/{file_name}" + midi_dict = MidiDict.from_midi(mid_path=mid_path) + tokenized_seq = tknzr.tokenize(midi_dict) + detokenized_midi_dict = tknzr.detokenize(tokenized_seq) + res = detokenized_midi_dict.to_midi() + res.save(f"tests/test_results/{file_name}") + + tknzr = tokenizer.AbsTokenizer(return_tensors=False) + tokenize_detokenize("basic.mid") + tokenize_detokenize("arabesque.mid") + tokenize_detokenize("beethoven.mid") + tokenize_detokenize("bach.mid") + tokenize_detokenize("expressive.mid") + tokenize_detokenize("pop.mid") + tokenize_detokenize("beethoven_moonlight.mid") + + def test_aug(self): + def tokenize_aug_detokenize( + file_name: str, + aug_fn: Callable, + aug_name: str, + audio=False, + ): + mid_path = f"tests/test_data/{file_name}" + midi_dict = MidiDict.from_midi(mid_path=mid_path) + tokenized_seq = tknzr.tokenize(midi_dict) + tokenized_seq_aug = aug_fn(tokenized_seq) + detokenized_midi_dict = tknzr.detokenize(tokenized_seq_aug) + res = detokenized_midi_dict.to_midi() + save_path = f"tests/test_results/abs_{aug_name}_{file_name}" + res.save(save_path) + if audio is True: + midi_to_audio(save_path) + + tknzr = tokenizer.AbsTokenizer(return_tensors=False) + seq = get_short_seq_abs(tknzr) + seq_concat = get_concat_seq_abs(tknzr) + pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) + velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) + tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.5, mixup=True) + + # Pitch augmentation + seq_pitch_augmented = pitch_aug_fn(get_short_seq_abs(tknzr)) + logging.info(f"pitch_aug_fn:\n{seq} ->\n\n{seq_pitch_augmented}\n") + tokenize_aug_detokenize("basic.mid", pitch_aug_fn, "pitch") + tokenize_aug_detokenize("arabesque.mid", pitch_aug_fn, "pitch") + tokenize_aug_detokenize("beethoven.mid", pitch_aug_fn, "pitch") + tokenize_aug_detokenize("bach.mid", pitch_aug_fn, "pitch") + tokenize_aug_detokenize("expressive.mid", pitch_aug_fn, "pitch") + tokenize_aug_detokenize("pop.mid", pitch_aug_fn, "pitch") + tokenize_aug_detokenize( + "beethoven_moonlight.mid", pitch_aug_fn, "pitch" + ) + + # Velocity augmentation + seq_velocity_augmented = velocity_aug_fn(get_short_seq_abs(tknzr)) + logging.info( + f"velocity_aug_fn:\n{seq} ->\n\n{seq_velocity_augmented}\n" + ) + tokenize_aug_detokenize("basic.mid", velocity_aug_fn, "velocity") + tokenize_aug_detokenize("arabesque.mid", velocity_aug_fn, "velocity") + tokenize_aug_detokenize("beethoven.mid", velocity_aug_fn, "velocity") + tokenize_aug_detokenize("bach.mid", velocity_aug_fn, "velocity") + tokenize_aug_detokenize("expressive.mid", velocity_aug_fn, "velocity") + tokenize_aug_detokenize("pop.mid", velocity_aug_fn, "velocity") + tokenize_aug_detokenize( + "beethoven_moonlight.mid", velocity_aug_fn, "velocity" + ) + + # Tempo augmentation + seq_tempo_augmented = tempo_aug_fn(get_short_seq_abs(tknzr)) + logging.info(f"tempo_aug_fn:\n{seq} ->\n\n{seq_tempo_augmented}\n") + + seq_concat_tempo_augmented = tempo_aug_fn(get_concat_seq_abs(tknzr)) + logging.info( + f"tempo_aug_fn:\n{seq_concat} ->\n\n{seq_concat_tempo_augmented}\n" + ) + + tokenize_aug_detokenize("basic.mid", tempo_aug_fn, "tempo") + tokenize_aug_detokenize("arabesque.mid", tempo_aug_fn, "tempo") + tokenize_aug_detokenize("beethoven.mid", tempo_aug_fn, "tempo") + tokenize_aug_detokenize("bach.mid", tempo_aug_fn, "tempo") + tokenize_aug_detokenize("expressive.mid", tempo_aug_fn, "tempo") + tokenize_aug_detokenize("pop.mid", tempo_aug_fn, "tempo") + tokenize_aug_detokenize( + "beethoven_moonlight.mid", tempo_aug_fn, "tempo" + ) + + def test_aug_time(self): + tknzr = tokenizer.AbsTokenizer() + mid_dict = MidiDict.from_midi("tests/test_data/beethoven.mid") + tokenized_seq = tknzr.tokenize(mid_dict)[:4096] + pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) + velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) + tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.5, mixup=True) + + # Pitch augmentation + t_start = time.perf_counter() + pitch_aug_fn(tokenized_seq) + t_pitch_aug = (time.perf_counter() - t_start) * 1e3 + logging.info(f"pitch_aug_fn took {int(t_pitch_aug)}ms") + self.assertLessEqual(t_pitch_aug, 50) + + # Velocity augmentation + t_start = time.perf_counter() + velocity_aug_fn(tokenized_seq) + t_vel_aug = (time.perf_counter() - t_start) * 1e3 + logging.info(f"velocity_aug_fn took {int(t_vel_aug)}ms") + self.assertLessEqual(t_vel_aug, 50) + + # Tempo augmentation + t_start = time.perf_counter() + tempo_aug_fn(tokenized_seq) + t_tempo_aug = (time.perf_counter() - t_start) * 1e3 + logging.info(f"tempo_aug_fn took {int(t_tempo_aug)}ms") + self.assertLessEqual(t_tempo_aug, 50) + + def test_no_unk_token(self): + def _test_no_unk_token(file_name: str): + mid_path = f"tests/test_data/{file_name}" + midi_dict = MidiDict.from_midi(mid_path=mid_path) + seq = tknzr.tokenize(midi_dict) + enc_dec_seq = tknzr.decode(tknzr.encode(seq)) + for tok in enc_dec_seq: + self.assertTrue(tok != tknzr.unk_tok) + + tknzr = tokenizer.AbsTokenizer() + _test_no_unk_token("basic.mid") + _test_no_unk_token("arabesque.mid") + _test_no_unk_token("bach.mid") + _test_no_unk_token("expressive.mid") + _test_no_unk_token("pop.mid") + _test_no_unk_token("beethoven_moonlight.mid") + + +class TestRelTokenizer(unittest.TestCase): def test_tokenize_detokenize_mididict(self): def tokenize_detokenize(file_name: str): mid_path = f"tests/test_data/{file_name}" @@ -109,7 +348,7 @@ def tokenize_detokenize(file_name: str): res = detokenized_midi_dict.to_midi() res.save(f"tests/test_results/{file_name}") - tknzr = tokenizer.TokenizerLazy(return_tensors=False) + tknzr = tokenizer.RelTokenizer(return_tensors=False) tokenize_detokenize("basic.mid") tokenize_detokenize("arabesque.mid") @@ -120,16 +359,16 @@ def tokenize_detokenize(file_name: str): tokenize_detokenize("beethoven_moonlight.mid") def test_aug(self): - tknzr = tokenizer.TokenizerLazy(return_tensors=False) - seq = get_short_seq(tknzr) - seq_concat = get_concat_seq(tknzr) + tknzr = tokenizer.RelTokenizer(return_tensors=False) + seq = get_short_seq_rel(tknzr) + seq_concat = get_concat_seq_rel(tknzr) pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) - tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.5) + tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.8) chord_mixup_fn = tknzr.export_chord_mixup() # Pitch augmentation - seq_pitch_augmented = pitch_aug_fn(get_short_seq(tknzr)) + seq_pitch_augmented = pitch_aug_fn(get_short_seq_rel(tknzr)) logging.info(f"pitch_aug_fn:\n{seq} ->\n\n{seq_pitch_augmented}\n") self.assertEqual( seq_pitch_augmented[4][1] - seq[4][1], @@ -137,7 +376,7 @@ def test_aug(self): ) # Velocity augmentation - seq_velocity_augmented = velocity_aug_fn(get_short_seq(tknzr)) + seq_velocity_augmented = velocity_aug_fn(get_short_seq_rel(tknzr)) logging.info( f"velocity_aug_fn:\n{seq} ->\n\n{seq_velocity_augmented}\n" ) @@ -147,25 +386,25 @@ def test_aug(self): ) # Tempo augmentation - seq_tempo_augmented = tempo_aug_fn(get_short_seq(tknzr)) + seq_tempo_augmented = tempo_aug_fn(get_short_seq_rel(tknzr)) logging.info(f"tempo_aug_fn:\n{seq} ->\n\n{seq_tempo_augmented}\n") - seq_concat_tempo_augmented = tempo_aug_fn(get_concat_seq(tknzr)) + seq_concat_tempo_augmented = tempo_aug_fn(get_concat_seq_rel(tknzr)) logging.info( f"tempo_aug_fn:\n{seq_concat} ->\n\n{seq_concat_tempo_augmented}\n" ) # Chord mix-up augmentation - seq_mixup_augmented = chord_mixup_fn(get_short_seq(tknzr)) + seq_mixup_augmented = chord_mixup_fn(get_short_seq_rel(tknzr)) logging.info(f"chord_mixup_fn:\n{seq} ->\n\n{seq_mixup_augmented}\n") - seq_concat_tempo_augmented = chord_mixup_fn(get_concat_seq(tknzr)) + seq_concat_tempo_augmented = chord_mixup_fn(get_concat_seq_rel(tknzr)) logging.info( f"chord_mixup_fn:\n{seq_concat} ->\n\n{seq_concat_tempo_augmented}\n" ) def test_aug_time(self): - tknzr = tokenizer.TokenizerLazy() + tknzr = tokenizer.RelTokenizer() mid_dict = MidiDict.from_midi("tests/test_data/beethoven.mid") tokenized_seq = tknzr.tokenize(mid_dict)[:4096] @@ -203,21 +442,21 @@ def test_aug_time(self): self.assertLessEqual(t_mixup_aug, 50) def test_encode_decode(self): - tknzr = tokenizer.TokenizerLazy(return_tensors=True) - seq = get_short_seq(tknzr) + tknzr = tokenizer.RelTokenizer(return_tensors=True) + seq = get_short_seq_rel(tknzr) enc_dec_seq = tknzr.decode(tknzr.encode(seq)) for x, y in zip(seq, enc_dec_seq): self.assertEqual(x, y) - tknzr = tokenizer.TokenizerLazy(return_tensors=False) - seq = get_short_seq(tknzr) + tknzr = tokenizer.RelTokenizer(return_tensors=False) + seq = get_short_seq_rel(tknzr) enc_dec_seq = tknzr.decode(tknzr.encode(seq)) for x, y in zip(seq, enc_dec_seq): self.assertEqual(x, y) def test_no_unk_token(self): - tknzr = tokenizer.TokenizerLazy() - seq = get_short_seq(tknzr) + tknzr = tokenizer.RelTokenizer() + seq = get_short_seq_rel(tknzr) enc_dec_seq = tknzr.decode(tknzr.encode(seq)) for tok in enc_dec_seq: self.assertTrue(tok != tknzr.unk_tok) diff --git a/tests/test_training.py b/tests/test_training.py index d39140c..bbe7802 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -4,7 +4,7 @@ import logging from aria.train import train, resume_train, convert_cp_from_accelerate -from aria.tokenizer import TokenizerLazy +from aria.tokenizer import RelTokenizer, AbsTokenizer from aria.data.midi import MidiDict from aria.data.datasets import ( MidiDataset, @@ -12,6 +12,7 @@ FinetuningDataset, ) +TEST_TOKENIZER = "abs" PT_TRAIN_DATA_PATH = "tests/test_results/pretrain_dataset_train" PT_VAL_DATA_PATH = "tests/test_results/pretrain_dataset_val" FT_TRAIN_DATA_PATH = "tests/test_results/finetune_dataset_train.jsonl" @@ -36,7 +37,13 @@ def test_training(self): val_mididict = MidiDict.from_midi("tests/test_data/arabesque.mid") train_midi_dataset = MidiDataset([train_mididict]) val_midi_dataset = MidiDataset([val_mididict]) - tokenizer = TokenizerLazy(return_tensors=True) + + if TEST_TOKENIZER == "abs": + tokenizer = AbsTokenizer(return_tensors=False) + elif TEST_TOKENIZER == "rel": + tokenizer = RelTokenizer(return_tensors=False) + else: + raise KeyError # PRETRAINING if os.path.exists(PT_TRAIN_DATA_PATH): @@ -151,6 +158,6 @@ def test_training(self): ) +logging.basicConfig(level=logging.INFO) if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) unittest.main()