diff --git a/aria/config.py b/aria/config.py index 8295bca..7037994 100644 --- a/aria/config.py +++ b/aria/config.py @@ -3,10 +3,13 @@ import os import json +from functools import lru_cache + CONFIG_DIR = os.path.join(os.path.dirname(__file__), "..", "config") +@lru_cache(maxsize=1) def load_config(): """Returns a dictionary loaded from the config.json file.""" with open(os.path.join(CONFIG_DIR, "config.json")) as f: diff --git a/aria/datasets.py b/aria/datasets.py index 8575a86..210c9da 100644 --- a/aria/datasets.py +++ b/aria/datasets.py @@ -21,7 +21,7 @@ from collections import defaultdict from aria.config import load_config -from aria.tokenizer import SeparatedAbsTokenizer +from aria.tokenizer import InferenceAbsTokenizer from ariautils.tokenizer import Tokenizer from ariautils.midi import ( MidiDict, @@ -473,7 +473,7 @@ def __init__(self, tokenizer: Tokenizer): def build(**kwargs): raise NotImplementedError - def get_loss_mask(self, tokenized_seq: list): + def get_loss_mask(self, src_seq: list, tgt_seq: list): # Should returns a bool Tensor with False indicating a masked loss raise NotImplementedError @@ -602,7 +602,7 @@ def _format(tok): src = seq tgt = seq[1:] + [self.tokenizer.pad_tok] - mask = self.get_loss_mask(tgt) + mask = self.get_loss_mask(src_seq=src, tgt_seq=tgt) return ( torch.tensor(self.tokenizer.encode(src)), @@ -692,7 +692,11 @@ def _new_transform(x): raise ValueError("Must provide function or list of functions.") -def _get_seqs(_entry: MidiDict | dict, _tokenizer: Tokenizer): +def _get_seqs( + _entry: MidiDict | dict, + _tokenizer: Tokenizer, + _tokenize_fn: Callable | None = None, +): logger = setup_logger() if isinstance(_entry, str): @@ -705,7 +709,10 @@ def _get_seqs(_entry: MidiDict | dict, _tokenizer: Tokenizer): raise Exception try: - _tokenized_seq = _tokenizer.tokenize(_midi_dict) + if _tokenize_fn is not None: + _tokenized_seq = _tokenize_fn(_midi_dict) + else: + _tokenized_seq = _tokenizer.tokenize(_midi_dict) except Exception as e: print(e) logger.info(f"Skipping midi_dict: {e}") @@ -719,6 +726,7 @@ def _get_seqs(_entry: MidiDict | dict, _tokenizer: Tokenizer): def get_seqs( tokenizer: Tokenizer, midi_dict_iter: Iterable, + tokenize_fn: Callable | None = None, ): # Can't pickle geneator object when start method is spawn if multiprocessing.get_start_method() == "spawn": @@ -729,7 +737,10 @@ def get_seqs( with multiprocessing.Pool() as pool: results = pool.imap_unordered( - functools.partial(_get_seqs, _tokenizer=tokenizer), midi_dict_iter + functools.partial( + _get_seqs, _tokenizer=tokenizer, _tokenize_fn=tokenize_fn + ), + midi_dict_iter, ) yield from results @@ -782,9 +793,9 @@ def __init__(self, dir_paths: List[str] | str, tokenizer: Tokenizer): def __len__(self): return len(self.index) - def get_loss_mask(self, tokenized_seq: list): + def get_loss_mask(self, src_seq: list, tgt_seq: list): return torch.tensor( - [tok != self.tokenizer.pad_tok for tok in tokenized_seq], + [tok != self.tokenizer.pad_tok for tok in tgt_seq], dtype=torch.bool, ) @@ -876,10 +887,8 @@ def _build_epoch(_save_path, _midi_dataset): f"Finished building, saved PretrainingDataset to {save_dir}" ) - return cls(dir_paths=save_dir, tokenizer=tokenizer) - -# TODO: Improve this logic so it supports MIDI files with multiple tempo_msgs +# TODO: Refactor for readability def _get_combined_mididict( clean_midi_dict: MidiDict, noisy_midi_dict: MidiDict, @@ -887,7 +896,7 @@ def _get_combined_mididict( max_noisy_ms: int, min_clean_ms: int, max_clean_ms: int, -): +) -> MidiDict: # NOTE: We adopt the tempo/ticks_per_beat of the clean_midi_dict, and # adjust the noisy note messages accordingly. assert len(clean_midi_dict.tempo_msgs) == 1, "Unsupported tempo msgs" @@ -904,24 +913,30 @@ def _get_combined_mididict( noisy_intervals = [] clean_intervals = [] prev_ms = -1 + add_noisy_next = random.choice([True, False]) while True: - # Add noise interval - noisy_end_ms = random.randint( - prev_ms + min_noisy_ms, prev_ms + max_noisy_ms - ) - noisy_intervals.append([prev_ms + 1, noisy_end_ms]) - prev_ms = noisy_end_ms - if prev_ms > total_length_ms: - break - - # Add clean interval - clean_end_ms = random.randint( - prev_ms + min_clean_ms, prev_ms + max_clean_ms - ) - clean_intervals.append([prev_ms + 1, clean_end_ms]) - prev_ms = clean_end_ms - if prev_ms > total_length_ms: - break + if add_noisy_next is True: + # Add noisy interval + noisy_end_ms = random.randint( + prev_ms + min_noisy_ms, prev_ms + max_noisy_ms + ) + noisy_intervals.append([prev_ms + 1, noisy_end_ms]) + prev_ms = noisy_end_ms + if prev_ms > total_length_ms: + break + else: + add_noisy_next = False + else: + # Add clean interval + clean_end_ms = random.randint( + prev_ms + min_clean_ms, prev_ms + max_clean_ms + ) + clean_intervals.append([prev_ms + 1, clean_end_ms]) + prev_ms = clean_end_ms + if prev_ms > total_length_ms: + break + else: + add_noisy_next = True # Merge note_msgs clean_ms_to_tick = (clean_midi_dict.ticks_per_beat * 1e3) / ( @@ -930,20 +945,12 @@ def _get_combined_mididict( comb_note_msgs = [] for _note_msg in noisy_midi_dict.note_msgs: - onset_time_ms = get_duration_ms( - start_tick=0, - end_tick=_note_msg["data"]["start"], - tempo_msgs=noisy_midi_dict.tempo_msgs, - ticks_per_beat=noisy_midi_dict.ticks_per_beat, - ) + onset_time_ms = noisy_midi_dict.tick_to_ms(_note_msg["data"]["start"]) for _interval_start_ms, _interval_end_ms in noisy_intervals: if _interval_start_ms < onset_time_ms < _interval_end_ms: - offset_time_ms = get_duration_ms( - start_tick=0, - end_tick=_note_msg["data"]["end"], - tempo_msgs=noisy_midi_dict.tempo_msgs, - ticks_per_beat=noisy_midi_dict.ticks_per_beat, + offset_time_ms = noisy_midi_dict.tick_to_ms( + _note_msg["data"]["end"] ) _adj_note_msg = copy.deepcopy(_note_msg) _adj_onset_tick = int(onset_time_ms * clean_ms_to_tick) @@ -956,21 +963,13 @@ def _get_combined_mididict( break for _note_msg in clean_midi_dict.note_msgs: - onset_time_ms = get_duration_ms( - start_tick=0, - end_tick=_note_msg["data"]["start"], - tempo_msgs=clean_midi_dict.tempo_msgs, - ticks_per_beat=clean_midi_dict.ticks_per_beat, - ) + onset_time_ms = clean_midi_dict.tick_to_ms(_note_msg["data"]["start"]) for _interval_start_ms, _interval_end_ms in clean_intervals: if _interval_start_ms < onset_time_ms < _interval_end_ms: comb_note_msgs.append(_note_msg) break - # Redundant sort - comb_note_msgs = sorted(comb_note_msgs, key=lambda msg: msg["tick"]) - comb_metadata = deepcopy(clean_midi_dict.metadata) comb_metadata["noisy_intervals"] = noisy_intervals @@ -986,7 +985,7 @@ def _get_combined_mididict( ) -# TODO: Move hyperparams into config.json (and TEST) +# TODO: Refactor this function for readability def _noise_midi_dict(midi_dict: MidiDict, config: dict): def _get_velocity_adjusted_msg( __note_msg: dict, @@ -1149,51 +1148,58 @@ def _get_onset_adjusted_msg( ) -def _get_mixed_dataset( - _clean_dataset: Iterable, - _noisy_datasets: list[Iterable], +def export_inference_abs_build_tokenize_fn( + midi_dict: MidiDict, tokenizer: InferenceAbsTokenizer ): finetuning_config = load_config()["data"]["finetuning"] - ACTIVATION_PROB = finetuning_config["noising"]["activation_prob"] + GUIDANCE_PROB = finetuning_config["guidance_prob"] + NOISING_PROB = finetuning_config["noising"]["activation_prob"] MIN_NOISY_MS = finetuning_config["min_noisy_interval_ms"] MAX_NOISY_MS = finetuning_config["max_noisy_interval_ms"] MIN_CLEAN_MS = finetuning_config["min_clean_interval_ms"] MAX_CLEAN_MS = finetuning_config["max_clean_interval_ms"] - comb_midi_dicts = [] - _noisy_dataset_itt = random_selection_itt(_noisy_datasets) - for clean, noisy in zip(_clean_dataset, _noisy_dataset_itt): - assert ( - os.path.splitext(os.path.basename(clean.metadata["abs_path"]))[0] - == os.path.splitext(os.path.basename(noisy.metadata["abs_path"]))[0] - ), f"file order mismatch: {clean.metadata['abs_path']}; {noisy.metadata['abs_path']}" - - if random.random() < ACTIVATION_PROB: - noisy = _noise_midi_dict(noisy, config=finetuning_config["noising"]) - - comb_midi_dicts.append( - _get_combined_mididict( - clean, - noisy, - min_noisy_ms=MIN_NOISY_MS, - max_noisy_ms=MAX_NOISY_MS, - min_clean_ms=MIN_CLEAN_MS, - max_clean_ms=MAX_CLEAN_MS, - ) + if random.random() <= NOISING_PROB: + noisy_midi_dict = _noise_midi_dict( + midi_dict, config=finetuning_config["noising"] + ) + midi_dict_for_tokenization = _get_combined_mididict( + clean_midi_dict=midi_dict, + noisy_midi_dict=noisy_midi_dict, + min_noisy_ms=MIN_NOISY_MS, + max_noisy_ms=MAX_NOISY_MS, + min_clean_ms=MIN_CLEAN_MS, + max_clean_ms=MAX_CLEAN_MS, ) + else: + midi_dict_for_tokenization = midi_dict - return MidiDataset(comb_midi_dicts) + if random.random() <= GUIDANCE_PROB: + return tokenizer.tokenize( + midi_dict=midi_dict_for_tokenization, + prompt_intervals_ms=midi_dict_for_tokenization.metadata.get( + "noisy_intervals", [] + ), + guidance_midi_dict=midi_dict, + ) + else: + return tokenizer.tokenize( + midi_dict=midi_dict_for_tokenization, + prompt_intervals_ms=midi_dict_for_tokenization.metadata.get( + "noisy_intervals", [] + ), + ) class FinetuningDataset(TrainingDataset): """Torch dataset object yielding sequences formatted for fine-tuning.""" def __init__( - self, dir_paths: List[str] | str, tokenizer: SeparatedAbsTokenizer + self, dir_paths: List[str] | str, tokenizer: InferenceAbsTokenizer ): super().__init__(tokenizer=tokenizer) - assert tokenizer.name == "separated_abs", "invalid tokenizer" + assert tokenizer.name == "inference_abs", "invalid tokenizer" if isinstance(dir_paths, str): dir_paths = [dir_paths] @@ -1205,31 +1211,33 @@ def __init__( def __len__(self): return len(self.index) - def get_loss_mask(self, tokenized_seq: list): - mask = [True] * len(tokenized_seq) - inside_inst = False + def get_loss_mask(self, src_seq: list, tgt_seq: list): + mask = [False] * len(tgt_seq) + inside_target = True - for idx, token in enumerate(tokenized_seq): - if token == self.tokenizer.inst_start_tok: - mask[idx] = False - inside_inst = True - elif token == self.tokenizer.inst_end_tok: - mask[idx] = False - inside_inst = False - elif inside_inst: - mask[idx] = False + for idx, (src_tok, tgt_tok) in enumerate(zip(src_seq, tgt_seq)): + if src_tok == self.tokenizer.guidance_start_tok: + inside_target = False + elif src_tok == self.tokenizer.guidance_end_tok: + inside_target = True + elif tgt_tok == self.tokenizer.prompt_start_tok: + inside_target = False + elif src_tok == self.tokenizer.prompt_end_tok: + inside_target = True + + if inside_target is True and tgt_tok != self.tokenizer.pad_tok: + mask[idx] = True return torch.tensor(mask, dtype=torch.bool) @classmethod def build( cls, - tokenizer: Tokenizer, + tokenizer: InferenceAbsTokenizer, save_dir: str, max_seq_len: int, num_epochs: int, - clean_dataset_path: str, - noisy_dataset_paths: str, + midi_dataset_path: str, ): def _build_epoch(_save_path, _midi_dataset): @@ -1244,7 +1252,17 @@ def _build_epoch(_save_path, _midi_dataset): ) _idx = 0 - for entry in reservoir(get_seqs(tokenizer, _midi_dataset), 10): + for entry in reservoir( + get_seqs( + tokenizer, + _midi_dataset, + tokenize_fn=functools.partial( + export_inference_abs_build_tokenize_fn, + tokenizer=tokenizer, + ), + ), + 10, + ): for _entry in tokenizer.split(entry, max_seq_len): writer.write(_entry) @@ -1252,12 +1270,14 @@ def _build_epoch(_save_path, _midi_dataset): if _idx % 250 == 0: logger.info(f"Finished processing {_idx}") + # DEBUG + if _idx == 1000: + break + logger = setup_logger() assert max_seq_len > 0, "max_seq_len must be greater than 0" assert num_epochs > 0, "num_epochs must be greater than 0" - assert os.path.isfile(clean_dataset_path), "file not found" - for __path in noisy_dataset_paths: - assert os.path.isfile(__path), "file not found" + assert os.path.isfile(midi_dataset_path), "file not found" if multiprocessing.get_start_method() == "spawn": logger.warning( 'The current multiprocessing start method is "spawn", this ' @@ -1284,21 +1304,14 @@ def _build_epoch(_save_path, _midi_dataset): f"tokenizer_name={tokenizer.name}" ) - clean_dataset = MidiDataset.load(clean_dataset_path) - noisy_datasets = [ - MidiDataset.load(_path) for _path in noisy_dataset_paths - ] - for idx in range(num_epochs): logger.info(f"Building epoch {idx}/{num_epochs - 1}...") # Reload the combined dataset for each epoch - combined_dataset = _get_mixed_dataset(clean_dataset, noisy_datasets) + midi_dataset = MidiDataset.get_generator(midi_dataset_path) _build_epoch( _save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"), - _midi_dataset=combined_dataset, + _midi_dataset=midi_dataset, ) logger.info(f"Finished building, saved FinetuningDataset to {save_dir}") - - return cls(dir_paths=save_dir, tokenizer=tokenizer) diff --git a/aria/inference/model.py b/aria/inference/model.py index 94d4d3e..d4bf439 100644 --- a/aria/inference/model.py +++ b/aria/inference/model.py @@ -44,8 +44,17 @@ def __init__(self, model_config: ModelConfig): model_config.d_model, model_config.vocab_size, bias=False ) - def forward(self, idxs: torch.Tensor, input_pos: torch.Tensor): - hidden_states = self.model(idxs=idxs, input_pos=input_pos) + def forward( + self, + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, + ): + hidden_states = self.model( + idxs=idxs, + input_pos=input_pos, + pad_idxs=pad_idxs, + ) logits = self.lm_head(hidden_states) return logits @@ -98,10 +107,15 @@ def forward( self, idxs: torch.Tensor, input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, ): assert self.freqs_cis is not None, "Caches must be initialized first" mask = self.causal_mask[None, None, input_pos] + + if pad_idxs is not None: + mask = mask & ~(pad_idxs.unsqueeze(1).unsqueeze(1)) + freqs_cis = self.freqs_cis[input_pos] x = self.tok_embeddings(idxs) diff --git a/aria/run.py b/aria/run.py index 054d052..cfb4bc5 100644 --- a/aria/run.py +++ b/aria/run.py @@ -11,23 +11,26 @@ def _parse_sample_args(): argp.add_argument("-m", help="name of model config file") argp.add_argument("-c", help="path to model checkpoint") argp.add_argument("-p", help="path to midi file") - argp.add_argument( - "-pt", help="sample using the pretrained model", action="store_true" - ) argp.add_argument( "-temp", - help="change temp value", + help="sampling temperature value", type=float, required=False, default=0.95, ) argp.add_argument( "-top_p", - help="change top_p value", + help="sampling top_p value", type=float, required=False, default=0.95, ) + argp.add_argument( + "-cfg", + help="sampling cfg gamma value", + type=float, + required=False, + ) argp.add_argument( "-metadata", nargs=2, @@ -49,38 +52,26 @@ def _parse_sample_args(): ) argp.add_argument("-e", action="store_true", help="enable force end") argp.add_argument("-l", type=int, help="generation length", default=1024) - argp.add_argument("-noise", action="store_true", help="add noise to prompt") + argp.add_argument( + "-guidance_path", type=str, help="path to guidance MIDI", required=False + ) + argp.add_argument( + "-guidance_start_ms", + help="guidance interval start (ms)", + type=int, + required=False, + ) + argp.add_argument( + "-guidance_end_ms", + help="guidance interval end (ms)", + type=int, + required=False, + ) argp.add_argument("-compile", action="store_true", help="compile cudagraph") return argp.parse_args(sys.argv[2:]) -def _get_model_name(name: str | None, state: dict): - if name is not None: - return name - - print("Model name is not provided. Trying to infer from checkpoint...") - _defaults = { - 16: "small", - 32: "medium", - 64: "large", - } - try: - pattern = re.compile(r"encode_layers\.(\d+)\.") - layer_keys = [pattern.search(k) for k in state.keys()] - layer_keys = set(p.group(1) for p in layer_keys if p is not None) - for i in range(len(layer_keys)): - assert str(i) in layer_keys - - if len(layer_keys) in _defaults: - print(f"Selecting model name: {_defaults[len(layer_keys)]}") - return _defaults[len(layer_keys)] - assert False - except: - raise ValueError("Model name is not provided and cannot be inferred.") - - -# TODO: Add support for sampling from the pretrained model def sample(args): """Entrypoint for sampling""" @@ -88,9 +79,12 @@ def sample(args): from aria.inference import TransformerLM from aria.model import ModelConfig from aria.config import load_model_config, load_config - from ariautils.tokenizer import AbsTokenizer - from aria.tokenizer import SeparatedAbsTokenizer - from aria.sample import greedy_sample, get_pt_prompt, get_inst_prompt + from aria.tokenizer import InferenceAbsTokenizer + from aria.sample import ( + sample_batch_cfg, + sample_batch, + get_inference_prompt, + ) from ariautils.midi import MidiDict from aria.utils import _load_weight @@ -116,11 +110,7 @@ def sample(args): force_end = args.e model_name = args.m - if args.pt == True: - tokenizer = AbsTokenizer() - else: - tokenizer = SeparatedAbsTokenizer() - + tokenizer = InferenceAbsTokenizer() model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) model_config.grad_checkpoint = False @@ -133,17 +123,6 @@ def sample(args): "Failed to load model_state. This is likely due to an incompatibility " "between the checkpoint file (-c) and model name/config (-m)." ) - if args.pt: - print( - "When using the -pt flag make sure you provide a checkpoint for " - "the pretrained model." - ) - else: - print( - "When not using the -pt flag make sure you provide a checkpoint " - " for the instuct-finetuned (inst) model." - ) - raise e assert args.l > 0, "Generation length must be positive." @@ -151,6 +130,10 @@ def sample(args): # Load and format prompts and metadata midi_dict = MidiDict.from_midi(mid_path=args.p) + if args.guidance_path: + guidance_midi_dict = MidiDict.from_midi(mid_path=args.guidance_path) + else: + guidance_midi_dict = None for k, v in manual_metadata.items(): midi_dict.metadata[k] = v @@ -160,42 +143,48 @@ def sample(args): f"Instruments: {set([MidiDict.get_program_to_instrument()[msg['data']] for msg in midi_dict.instrument_msgs])}" ) - if args.pt: - if args.noise: - print("Noising not supported with pretrained model") + prompt_seq, guidance_seq = get_inference_prompt( + tokenizer=tokenizer, + midi_dict=midi_dict, + truncate_len=truncate_len, + guidance_start_ms=args.guidance_start_ms, + guidance_end_ms=args.guidance_end_ms, + guidance_midi_dict=guidance_midi_dict, + ) - prompt_seq = get_pt_prompt( - tokenizer=tokenizer, - midi_dict=midi_dict, - truncate_len=truncate_len, - ) - else: - prompt_seq = get_inst_prompt( - tokenizer=tokenizer, - midi_dict=midi_dict, - truncate_len=truncate_len, - noise=args.noise, + if guidance_seq: + tokenizer.detokenize(guidance_seq).to_midi().save( + os.path.join(samples_dir, f"guidance.mid") ) - - prompts = [prompt_seq for _ in range(num_variations)] if len(prompt_seq) + args.l > model_config.max_seq_len: print( "WARNING: Required context exceeds max_seq_len supported by model" ) + prompts = [prompt_seq for _ in range(num_variations)] - print(prompt_seq) - - results = greedy_sample( - model=model, - tokenizer=tokenizer, - prompts=prompts, - max_new_tokens=max_new_tokens, - force_end=force_end, - # cfg_gamma=args.cfg, - temperature=args.temp, - top_p=args.top_p, - compile=args.compile, - ) + if args.cfg is not None: + results = sample_batch_cfg( + model=model, + tokenizer=tokenizer, + prompts=prompts, + max_new_tokens=max_new_tokens, + cfg_gamma=args.cfg, + force_end=force_end, + temperature=args.temp, + top_p=args.top_p, + compile=args.compile, + ) + else: + results = sample_batch( + model=model, + tokenizer=tokenizer, + prompts=prompts, + max_new_tokens=max_new_tokens, + force_end=force_end, + temperature=args.temp, + top_p=args.top_p, + compile=args.compile, + ) samples_dir = os.path.join(os.path.dirname(__file__), "..", "samples") if os.path.isdir(samples_dir) is False: @@ -204,7 +193,7 @@ def sample(args): for idx, tokenized_seq in enumerate(results): res_midi_dict = tokenizer.detokenize(tokenized_seq) res_midi = res_midi_dict.to_midi() - res_midi.save(f"samples/res_{idx + 1}.mid") + res_midi.save(os.path.join(samples_dir, f"res_{idx + 1}.mid")) print("Results saved to samples/") @@ -277,7 +266,7 @@ def build_pretraining_dataset(args): elif args.tokenizer_name == "rel": tokenizer = RelTokenizer() - dataset = PretrainingDataset.build( + PretrainingDataset.build( tokenizer=tokenizer, save_dir=args.save_dir, max_seq_len=args.l, @@ -289,13 +278,8 @@ def build_pretraining_dataset(args): def _parse_finetune_dataset_args(): argp = argparse.ArgumentParser(prog="aria finetune-dataset") argp.add_argument( - "-clean_load_path", - help="path to the clean midi_dict dataset", - ) - argp.add_argument( - "-noisy_load_paths", - nargs="+", - help="one or more paths to noisy midi_dict datasets", + "-midi_dataset_path", + help="path to midi_dict dataset", ) argp.add_argument("-save_dir", help="path to save dataset") argp.add_argument("-l", help="max sequence length", type=int, default=4096) @@ -305,17 +289,16 @@ def _parse_finetune_dataset_args(): def build_finetune_dataset(args): - from aria.tokenizer import SeparatedAbsTokenizer + from aria.tokenizer import InferenceAbsTokenizer from aria.datasets import FinetuningDataset - tokenizer = SeparatedAbsTokenizer() + tokenizer = InferenceAbsTokenizer() FinetuningDataset.build( tokenizer=tokenizer, save_dir=args.save_dir, max_seq_len=args.l, num_epochs=args.e, - clean_dataset_path=args.clean_load_path, - noisy_dataset_paths=args.noisy_load_paths, + midi_dataset_path=args.midi_dataset_path, ) diff --git a/aria/sample.py b/aria/sample.py index 8c70af4..4546283 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -1,5 +1,6 @@ """Contains generation/sampling code""" +import copy import torch import torch._dynamo.config import torch._inductor.config @@ -8,7 +9,8 @@ from tqdm import tqdm from aria.inference import TransformerLM -from ariautils.tokenizer import Tokenizer +from aria.tokenizer import InferenceAbsTokenizer +from ariautils.tokenizer import Tokenizer, AbsTokenizer from ariautils.midi import MidiDict torch._inductor.config.coordinate_descent_tuning = True @@ -16,16 +18,45 @@ torch._inductor.config.fx_graph_cache = True +def get_cfg_prompt(prompts: list, pad_tok: str, guidance_end_tok: str): + cfg_prompts = [] + for prompt in prompts: + prompt_no_guidance = prompt[prompt.index(guidance_end_tok) + 1 :] + prompt_no_guidance = [pad_tok] * ( + len(prompt) - len(prompt_no_guidance) + ) + prompt_no_guidance + cfg_prompts.append(prompt) + cfg_prompts.append(prompt_no_guidance) + + return cfg_prompts + + @torch.inference_mode() -def prefill(model, idxs: torch.Tensor, input_pos: torch.Tensor): - logits = model.forward(idxs=idxs, input_pos=input_pos)[:, -1] +def decode_one( + model: TransformerLM, + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, +): + logits = model.forward( + idxs=idxs, + input_pos=input_pos, + pad_idxs=pad_idxs, + )[:, -1] return logits @torch.inference_mode() -def decode_one(model, idxs: torch.Tensor, input_pos: torch.Tensor): - logits = model.forward(idxs=idxs, input_pos=input_pos)[:, -1] +def prefill( + model: TransformerLM, + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, +): + logits = model.forward(idxs=idxs, input_pos=input_pos, pad_idxs=pad_idxs)[ + :, -1 + ] return logits @@ -62,33 +93,32 @@ def update_seq_ids_( seq[:, idx] = next_token_ids -# TODO: Add CFG back into this when working -# TODO: Check that unexpected instrument warnings still working +# TODO: Not working @torch.autocast( "cuda", dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, ) @torch.inference_mode() -def greedy_sample( +def sample_batch( model: TransformerLM, tokenizer: Tokenizer, prompts: List[list], max_new_tokens: int, force_end=False, - cfg_gamma: float | None = 1.05, temperature: float = 0.95, top_p: float = 0.95, - compile: bool = True, + compile: bool = False, ): - """Performs greedy (top_p) auto-regressive sampling on a batch of prompts.""" - if force_end: assert max_new_tokens > 130, "prompt too long to use force_end=True" _prompt_len = len(prompts[0]) _num_prompts = len(prompts) + assert all([len(p) == _prompt_len for p in prompts]) model.eval() + dim_tok_inserted = [False for _ in range(_num_prompts)] + eos_tok_seen = [False for _ in range(_num_prompts)] total_len = _prompt_len + max_new_tokens seq = torch.stack( [ @@ -98,15 +128,12 @@ def greedy_sample( for p in prompts ] ).cuda() - dim_tok_inserted = [False for _ in range(_num_prompts)] - eos_tok_seen = [False for _ in range(_num_prompts)] if compile is True: global decode_one decode_one = torch.compile( decode_one, mode="reduce-overhead", - # mode="max-autotune", fullgraph=True, ) @@ -119,7 +146,7 @@ def greedy_sample( ) print( - f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}" + f"Using hyperparams: temp={temperature}, top_p={top_p}, gen_len={max_new_tokens}" ) for idx in ( @@ -145,8 +172,8 @@ def greedy_sample( ), ) - if tokenizer.name == "separated_abs": - logits[:, tokenizer.tok_to_id[tokenizer.inst_start_tok]] = float( + if tokenizer.name == "inference_abs": + logits[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( "-inf" ) @@ -183,6 +210,134 @@ def greedy_sample( return decoded_results +@torch.autocast( + "cuda", + dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, +) +@torch.inference_mode() +def sample_batch_cfg( + model: TransformerLM, + tokenizer: InferenceAbsTokenizer, + prompts: List[list], + max_new_tokens: int, + cfg_gamma: float, + force_end=False, + temperature: float = 0.95, + top_p: float = 0.95, + compile: bool = False, +): + assert 0.0 <= cfg_gamma <= 2.0 + assert 0.0 <= temperature <= 2.0 + assert 0.5 <= top_p <= 1.0 + assert tokenizer.name == "inference_abs" + if force_end: + assert max_new_tokens > 130, "prompt too long to use force_end=True" + + prompts = get_cfg_prompt( + prompts, tokenizer.pad_tok, tokenizer.guidance_end_tok + ) + + _prompt_len = len(prompts[0]) + _num_prompts = len(prompts) + assert all([len(p) == _prompt_len for p in prompts]) + + model.eval() + total_len = _prompt_len + max_new_tokens + seq = torch.stack( + [ + torch.tensor( + tokenizer.encode(p + [tokenizer.pad_tok] * (total_len - len(p))) + ) + for p in prompts + ] + ).cuda() + dim_tok_inserted = [False for _ in range(_num_prompts)] + eos_tok_seen = [False for _ in range(_num_prompts)] + + if compile is True: + global decode_one + decode_one = torch.compile( + decode_one, + mode="reduce-overhead", + fullgraph=True, + ) + + model.setup_cache( + batch_size=_num_prompts, + max_seq_len=total_len, + dtype=( + torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + ), + ) + + print( + f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}" + ) + + for idx in ( + pbar := tqdm( + range(_prompt_len, total_len), + total=total_len - _prompt_len, + leave=False, + ) + ): + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + if idx == _prompt_len: + logits = prefill( + model, + idxs=seq[:, :idx], + input_pos=torch.arange(0, idx, device=seq.device), + pad_idxs=(seq == tokenizer.pad_id), + ) + else: + logits = decode_one( + model, + idxs=seq[:, idx - 1 : idx], + input_pos=torch.tensor( + [idx - 1], device=seq.device, dtype=torch.int + ), + pad_idxs=(seq == tokenizer.pad_id), + ) + + logits_cfg = cfg_gamma * logits[::2] + (1 - cfg_gamma) * logits[1::2] + logits_cfg[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( + "-inf" + ) + + if temperature > 0.0: + probs = torch.softmax(logits_cfg / temperature, dim=-1) + next_token_ids = sample_top_p(probs, top_p).flatten() + else: + next_token_ids = torch.argmax(logits_cfg, dim=-1).flatten() + + next_token_ids = next_token_ids.repeat_interleave(2) + update_seq_ids_( + seq=seq, + idx=idx, + next_token_ids=next_token_ids, + dim_tok_inserted=dim_tok_inserted, + eos_tok_seen=eos_tok_seen, + max_len=total_len, + force_end=force_end, + tokenizer=tokenizer, + ) + + if all(seen_eos is True for seen_eos in eos_tok_seen): + break + + decoded_results = [tokenizer.decode(s) for s in seq.tolist()][::2] + decoded_results = [ + ( + res[: res.index(tokenizer.eos_tok) + 1] + if tokenizer.eos_tok in res + else res + ) + for res in decoded_results + ] + + return decoded_results + + def sample_top_p(probs, p): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) @@ -194,32 +349,39 @@ def sample_top_p(probs, p): return next_token -# TODO: Clean up a bit and get rid of footguns -def get_inst_prompt( - tokenizer: Tokenizer, +def get_inference_prompt( + tokenizer: InferenceAbsTokenizer, midi_dict: MidiDict, truncate_len: int, - noise: bool, + guidance_start_ms: int, + guidance_end_ms: int, + guidance_midi_dict: MidiDict | None = None, ): - from aria.datasets import _noise_midi_dict - from aria.config import load_config - - midi_dict.metadata["noisy_intervals"] = [[0, truncate_len * 1e3]] - - if noise == True: - midi_dict = _noise_midi_dict( - midi_dict, load_config()["data"]["finetuning"]["noising"] + assert tokenizer.name == "inference_abs" + + if guidance_midi_dict is not None: + assert guidance_start_ms is not None and guidance_start_ms >= 0 + assert guidance_end_ms is not None and guidance_end_ms >= 0 + assert ( + tokenizer._config["guidance"]["min_ms"] + <= guidance_end_ms - guidance_start_ms + <= tokenizer._config["guidance"]["max_ms"] ) - prompt_seq = tokenizer.tokenize(midi_dict=midi_dict) + prompt_seq = tokenizer.tokenize( + midi_dict=midi_dict, + prompt_intervals_ms=( + [[0, truncate_len * 1e3]] if truncate_len > 0 else [] + ), + guidance_midi_dict=guidance_midi_dict, + guidance_start_ms=guidance_start_ms, + guidance_end_ms=guidance_end_ms, + ) - if tokenizer.inst_end_tok in prompt_seq: - prompt_seq = prompt_seq[: prompt_seq.index(tokenizer.inst_end_tok) + 1] - elif tokenizer.eos_tok in prompt_seq: - # TODO: This is a workaround for a bug where is not inserted if - # the pieces ends. - assert prompt_seq[-1] == tokenizer.eos_tok - prompt_seq[-1] = tokenizer.inst_end_tok + if tokenizer.prompt_end_tok in prompt_seq: + prompt_seq = prompt_seq[ + : prompt_seq.index(tokenizer.prompt_end_tok) + 1 + ] else: print("No notes found in prompt region") prompt_seq = prompt_seq[: prompt_seq.index(tokenizer.bos_tok) + 1] @@ -227,21 +389,13 @@ def get_inst_prompt( if tokenizer.dim_tok in prompt_seq: prompt_seq.remove(tokenizer.dim_tok) - return prompt_seq - - -def get_pt_prompt( - tokenizer: Tokenizer, - midi_dict: MidiDict, - truncate_len: int, -): - prompt_seq = tokenizer.tokenize(midi_dict=midi_dict) - prompt_seq = tokenizer.truncate_by_time( - tokenized_seq=prompt_seq, - trunc_time_ms=truncate_len * 1e3, - ) - - if tokenizer.dim_tok in prompt_seq: - prompt_seq.remove(tokenizer.dim_tok) + if guidance_midi_dict is not None: + guidance_seq = copy.deepcopy(prompt_seq) + guidance_seq = guidance_seq[ + : guidance_seq.index(tokenizer.guidance_end_tok) + ] + guidance_seq[0] = ("prefix", "instrument", "piano") + else: + guidance_seq = None - return prompt_seq + return prompt_seq, guidance_seq diff --git a/aria/tokenizer.py b/aria/tokenizer.py index 2b0ce6a..818fb27 100644 --- a/aria/tokenizer.py +++ b/aria/tokenizer.py @@ -1,90 +1,248 @@ """Tokenizer for MIDI conditioned completions""" import copy +import random +import functools -from ariautils.midi import MidiDict, get_duration_ms +from typing import Callable + +from aria.config import load_config +from ariautils.midi import MidiDict from ariautils.tokenizer import AbsTokenizer as _AbsTokenizer -class SeparatedAbsTokenizer(_AbsTokenizer): +class InferenceAbsTokenizer(_AbsTokenizer): def __init__(self): super().__init__() - self.name = "separated_abs" - self.inst_start_tok = "" - self.inst_end_tok = "" - self.add_tokens_to_vocab([self.inst_start_tok, self.inst_end_tok]) - self.special_tokens.append(self.inst_start_tok) - self.special_tokens.append(self.inst_end_tok) - - def tokenize(self, midi_dict: MidiDict, **kwargs): - def _add_inst_toks(_seq: list, _start_ms: int, _end_ms: int): - res_seq = copy.deepcopy(_seq) - - inst_inserted = False - time_tok_cnt = 0 - curr_time_ms = 0 - for idx, (tok_1, tok_2) in enumerate(zip(_seq, _seq[1:])): - if tok_1 == self.time_tok: - time_tok_cnt += 1 - elif ( - isinstance(tok_1, tuple) and tok_1[0] in self.instruments_wd - ): - assert isinstance(tok_2, tuple) and tok_2[0] == "onset" + self.name = "inference_abs" + self._config = load_config()["tokenizer"]["inference_abs"] + + self.prompt_start_tok = "" + self.prompt_end_tok = "" + self.guidance_start_tok = "" + self.guidance_end_tok = "" + + self.add_tokens_to_vocab( + [ + self.prompt_start_tok, + self.prompt_end_tok, + self.guidance_start_tok, + self.guidance_end_tok, + ] + ) + self.special_tokens.append(self.prompt_start_tok) + self.special_tokens.append(self.prompt_end_tok) + self.special_tokens.append(self.guidance_start_tok) + self.special_tokens.append(self.guidance_end_tok) - # Adjust time - _curr_time = ( - self.config["abs_time_step_ms"] * time_tok_cnt - ) + tok_2[1] + def _get_guidance_interval_ms(self, guidance_midi_dict: MidiDict): + first_note_onset_ms = guidance_midi_dict.tick_to_ms( + guidance_midi_dict.note_msgs[0]["tick"] + ) + last_note_onset_ms = guidance_midi_dict.tick_to_ms( + guidance_midi_dict.note_msgs[-1]["tick"] + ) + guidance_segment_length_ms = random.randint( + self._config["guidance"]["min_ms"], + min(self._config["guidance"]["max_ms"], last_note_onset_ms), + ) + guidance_start_ms = random.randint( + first_note_onset_ms, + last_note_onset_ms - guidance_segment_length_ms, + ) + guidance_end_ms = guidance_start_ms + guidance_segment_length_ms + + return guidance_start_ms, guidance_end_ms - assert _curr_time >= curr_time_ms - curr_time_ms = _curr_time + def _get_guidance_seq( + self, + guidance_midi_dict: MidiDict, + guidance_start_ms: int | None = None, + guidance_end_ms: int | None = None, + ): + assert guidance_midi_dict.note_msgs is not None - if curr_time_ms >= _start_ms and inst_inserted == False: - res_seq.insert(idx, self.inst_start_tok) - inst_inserted = True - if curr_time_ms > _end_ms and inst_inserted == True: - res_seq.insert(idx + 1, self.inst_end_tok) - break + # Need to validate these numbers + if guidance_start_ms is None: + assert guidance_end_ms is None + guidance_start_ms, guidance_end_ms = self._get_guidance_interval_ms( + guidance_midi_dict=guidance_midi_dict + ) - return res_seq + slice_note_msgs = [] + for note_msg in guidance_midi_dict.note_msgs: + start_ms = guidance_midi_dict.tick_to_ms(note_msg["data"]["start"]) + if guidance_start_ms <= start_ms <= guidance_end_ms: + slice_note_msgs.append(note_msg) - if midi_dict.metadata.get("noisy_intervals") is None: - print("noisy_intervals metadata not present") - return super().tokenize(midi_dict, **kwargs) + slice_midi_dict = copy.deepcopy(guidance_midi_dict) + slice_midi_dict.note_msgs = slice_note_msgs - seq = super().tokenize(midi_dict, **kwargs) + if len(slice_midi_dict.note_msgs) == 0: + # Catches not note in interval + return [] - # This logic is required as the tokenizer removes proceeding silence - first_note_ms = get_duration_ms( - start_tick=0, - end_tick=midi_dict.note_msgs[0]["data"]["start"], - tempo_msgs=midi_dict.tempo_msgs, - ticks_per_beat=midi_dict.ticks_per_beat, + guidance_seq = self._tokenize_midi_dict( + midi_dict=slice_midi_dict, + remove_preceding_silence=True, ) - noisy_intervals = [ - [ - ival[0] - first_note_ms, - ival[1] - first_note_ms, - ] - for ival in midi_dict.metadata.get("noisy_intervals") - if ival[1] >= first_note_ms + + if self.dim_tok in guidance_seq: + guidance_seq.remove(self.dim_tok) + + guidance_seq = guidance_seq[ + guidance_seq.index(self.bos_tok) + + 1 : guidance_seq.index(self.eos_tok) ] - for start_ms, end_ms in noisy_intervals: - seq = _add_inst_toks(seq, start_ms, end_ms) + return ( + [self.guidance_start_tok] + guidance_seq + [self.guidance_end_tok] + ) + + def _add_prompt_tokens( + self, seq: list, prompt_start_ms: int, prompt_end_ms: int + ): + res = copy.deepcopy(seq) + prompt_tok_inserted = False + time_tok_cnt = 0 + curr_time_ms = 0 + for idx, (tok_1, tok_2) in enumerate(zip(seq, seq[1:])): + if tok_1 == self.time_tok: + time_tok_cnt += 1 + elif isinstance(tok_1, tuple) and tok_1[0] in self.instruments_wd: + assert isinstance(tok_2, tuple) and tok_2[0] == "onset" + + # Adjust time + curr_time_ms = (self.abs_time_step_ms * time_tok_cnt) + tok_2[1] + + if ( + curr_time_ms >= prompt_start_ms + and prompt_tok_inserted == False + ): + res.insert(idx, self.prompt_start_tok) + prompt_tok_inserted = True + elif ( + curr_time_ms > prompt_end_ms and prompt_tok_inserted == True + ): + res.insert(idx + 1, self.prompt_end_tok) + break - return seq + return res - def detokenize(self, midi_dict: MidiDict, **kwargs): - return super().detokenize(midi_dict, **kwargs) + def tokenize( + self, + midi_dict: MidiDict, + prompt_intervals_ms: list[tuple[int, int]], + guidance_midi_dict: MidiDict | None = None, + guidance_start_ms: int | None = None, + guidance_end_ms: int | None = None, + ): + seq = self._tokenize_midi_dict( + midi_dict=midi_dict, remove_preceding_silence=True + ) + first_note_ms = midi_dict.tick_to_ms( + midi_dict.note_msgs[0]["data"]["start"] + ) + + for prompt_start_ms, prompt_end_ms in prompt_intervals_ms: + if prompt_end_ms > first_note_ms: + seq = self._add_prompt_tokens( + seq, + prompt_start_ms=prompt_start_ms - first_note_ms, + prompt_end_ms=prompt_end_ms - first_note_ms, + ) + + if guidance_midi_dict is not None: + guidance_seq = self._get_guidance_seq( + guidance_midi_dict=guidance_midi_dict, + guidance_start_ms=guidance_start_ms, + guidance_end_ms=guidance_end_ms, + ) + else: + guidance_seq = [] + + return guidance_seq + seq + + def detokenize(self, tokenized_seq: list, **kwargs): + if self.guidance_end_tok in tokenized_seq: + seq = tokenized_seq[tokenized_seq.index(self.guidance_end_tok) :] + else: + seq = tokenized_seq + + return super()._detokenize_midi_dict(seq, **kwargs) def export_data_aug(self): return [ - self.export_pitch_aug(5), - self.export_velocity_aug(1), + self.export_guidance_tempo_aug(max_tempo_aug=0.25, mixup=True), + self.export_guidance_pitch_aug(4), + self.export_guidance_velocity_aug(2), ] + def export_guidance_aug_fn(self, aug_fn): + """Transforms augmentation function to only apply to guidance seq""" + + def _guidance_seq_aug_fn( + src: list, + _aug_fn: Callable, + pad_tok: str, + **kwargs, + ) -> list: + + initial_seq_len = len(src) + if self.guidance_start_tok in src and self.guidance_end_tok in src: + guidance_seq = src[ + src.index(self.guidance_start_tok) + + 1 : src.index(self.guidance_end_tok) + ] + seq = src[src.index(self.guidance_end_tok) + 1 :] + + if len(guidance_seq) == 0: + return src + else: + return src + + augmented_guidance_seq = _aug_fn(guidance_seq) + res = ( + [self.guidance_start_tok] + + augmented_guidance_seq + + [self.guidance_end_tok] + + seq + ) + + # Pad or truncate to original sequence length as necessary + res = res[:initial_seq_len] + res += [pad_tok] * (initial_seq_len - len(res)) + + return res + + return functools.partial( + _guidance_seq_aug_fn, + _aug_fn=aug_fn, + pad_tok=self.pad_tok, + ) + + def export_guidance_pitch_aug(self, max_pitch_aug: int): + """Apply pitch augmentation to the guidance sequence""" + + return self.export_guidance_aug_fn( + self.export_pitch_aug(max_pitch_aug=max_pitch_aug) + ) + + def export_guidance_velocity_aug(self, max_num_aug_steps: int): + """Apply velocity augmentation to the guidance sequence""" + + return self.export_guidance_aug_fn( + self.export_velocity_aug(max_num_aug_steps=max_num_aug_steps) + ) + + def export_guidance_tempo_aug(self, max_tempo_aug: int, mixup: bool): + """Apply tempo augmentation to the guidance sequence""" + + return self.export_guidance_aug_fn( + self.export_tempo_aug(max_tempo_aug=max_tempo_aug, mixup=mixup) + ) + def split(self, seq: list, seq_len: int): def _process_chunk(_chunk: list): # Ensure first token is note token @@ -99,21 +257,28 @@ def _process_chunk(_chunk: list): else: _chunk.pop(0) - # Insert inst_start_tok if it is missing (but required) + # Insert prompt_start_tok if it is missing (but required) for idx in range(len(_chunk)): tok = _chunk[idx] - if tok == self.inst_start_tok: + if tok == self.prompt_start_tok: break - elif tok == self.inst_end_tok: + elif tok == self.prompt_end_tok: if _chunk[0] == self.bos_tok: - _chunk.insert(1, self.inst_start_tok) + _chunk.insert(1, self.prompt_start_tok) else: - _chunk.insert(0, self.inst_start_tok) + _chunk.insert(0, self.prompt_start_tok) break return _chunk + guidance = [] + if self.guidance_start_tok in seq: + guidance_start = seq.index(self.guidance_start_tok) + guidance_end = seq.index(self.guidance_end_tok) + guidance = seq[guidance_start : guidance_end + 1] + seq = seq[guidance_end + 1 :] + prefix = [] while seq: tok = seq[0] @@ -122,7 +287,6 @@ def _process_chunk(_chunk: list): else: break - # Generate chunks chunks = [ _process_chunk(seq[idx : idx + seq_len]) for idx in range(0, len(seq) - 100, seq_len) @@ -130,12 +294,10 @@ def _process_chunk(_chunk: list): res = [] for chunk in chunks: - if self.inst_start_tok not in chunk: - continue - - sub_seq = prefix + chunk + sub_seq = guidance + prefix + chunk sub_seq = sub_seq[:seq_len] sub_seq += [self.pad_tok] * (seq_len - len(sub_seq)) + res.append(sub_seq) return res diff --git a/aria/train.py b/aria/train.py index 6d08371..6a2c74b 100644 --- a/aria/train.py +++ b/aria/train.py @@ -20,7 +20,7 @@ from aria.config import load_model_config from aria.model import ModelConfig, TransformerLM from ariautils.tokenizer import Tokenizer, AbsTokenizer, RelTokenizer -from aria.tokenizer import SeparatedAbsTokenizer +from aria.tokenizer import InferenceAbsTokenizer from aria.datasets import ( TrainingDataset, PretrainingDataset, @@ -196,7 +196,7 @@ def get_optim( num_epochs: int, steps_per_epoch: int, ): - LR = 3e-4 + LR = 3e-5 END_RATIO = 0.1 WARMUP_STEPS = 200 @@ -363,8 +363,11 @@ def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0): ) # Transpose for CrossEntropyLoss loss = loss_fn(logits, tgt) - loss = loss * mask - loss = loss[loss != 0.0].mean() # != 0.0 here is important + if mask.sum() == 0: + loss = (loss * 0).sum() + else: + loss = loss * mask + loss = loss[loss != 0.0].mean() # Calculate statistics loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) @@ -423,8 +426,11 @@ def val_loop(dataloader, _epoch: int): logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss loss = loss_fn(logits, tgt) - loss = loss * mask - loss = loss[loss != 0.0].mean() + if mask.sum() == 0: + loss = (loss * 0).sum() + else: + loss = loss * mask + loss = loss[loss != 0.0].mean() # Logging loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) @@ -530,8 +536,8 @@ def resume_train( tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path) if tokenizer_name == "abs": tokenizer = AbsTokenizer() - elif tokenizer_name == "separated_abs": - tokenizer = SeparatedAbsTokenizer() + elif tokenizer_name == "inference_abs": + tokenizer = InferenceAbsTokenizer() elif tokenizer_name == "rel": tokenizer = RelTokenizer() else: @@ -656,8 +662,8 @@ def train( tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path) if tokenizer_name == "abs": tokenizer = AbsTokenizer() - elif tokenizer_name == "separated_abs": - tokenizer = SeparatedAbsTokenizer() + elif tokenizer_name == "inference_abs": + tokenizer = InferenceAbsTokenizer() elif tokenizer_name == "rel": tokenizer = RelTokenizer() else: diff --git a/config/config.json b/config/config.json index dbbea32..7e0a736 100644 --- a/config/config.json +++ b/config/config.json @@ -2,7 +2,7 @@ "data": { "tests": { "note_density_in_interval":{ - "run": true, + "run": false, "args": { "test_params_list": [ @@ -30,20 +30,20 @@ } }, "note_timing_entropy":{ - "run": true, + "run": false, "args": { "min_length_entropy": 2.5, "min_onset_delta_entropy": 0.0 } }, "note_pitch_entropy":{ - "run": true, + "run": false, "args": { "min_entropy": 3.0 } }, "unique_pitch_count_in_interval":{ - "run": true, + "run": false, "args": { "test_params_list": [ @@ -54,58 +54,59 @@ } }, "unique_pitch_count":{ - "run": true, + "run": false, "args": { "min_num_unique_pitches": 12 } }, "silent_interval":{ - "run": true, + "run": false, "args": { "max_silence_s": 20 } }, "mean_note_velocity":{ - "run": true, + "run": false, "args": { "min_mean_velocity": 20, "max_mean_velocity": 105 } }, "max_programs":{ - "run": true, + "run": false, "args": { "max": 12 } }, "max_instruments":{ - "run": true, + "run": false, "args": { "max": 7 } }, "total_note_frequency":{ - "run": true, + "run": false, "args": { "min_per_second": 1.5, "max_per_second": 30 } }, "note_frequency_per_instrument":{ - "run": true, + "run": false, "args": { "min_per_second": 1.0, "max_per_second": 25 } }, - "min_length":{ - "run": true, + "length":{ + "run": false, "args": { - "min_seconds": 30 + "min_length_s": 30, + "max_length_s": 7200 } }, "repetitive_content":{ - "run": true, + "run": false, "args": { "min_length_m": 20, "num_chunks": 5, @@ -138,6 +139,10 @@ }, "metadata": { "functions": { + "aria_midi_json": { + "run": true, + "args": {} + }, "composer_filename": { "run": false, "args": { @@ -171,16 +176,17 @@ } }, "finetuning": { + "guidance_prob": 0.5, "min_noisy_interval_ms": 5000, "max_noisy_interval_ms": 60000, "min_clean_interval_ms": 60000, "max_clean_interval_ms": 200000, "noising": { - "activation_prob": 0.95, + "activation_prob": 0.8, "remove_notes": { - "activation_prob": 0.75, - "min_ratio": 0.1, - "max_ratio": 0.4 + "activation_prob": 0.5, + "min_ratio": 0.0, + "max_ratio": 0.3 }, "adjust_velocity": { "activation_prob": 0.3, @@ -190,11 +196,11 @@ "min_ratio": 0.30 }, "adjust_onsets": { - "activation_prob": 0.5, - "min_adjust_s": 0.03, + "activation_prob": 0.25, + "min_adjust_s": 0.01, "max_adjust_s": 0.07, "max_ratio": 0.15, - "min_ratio": 0.5 + "min_ratio": 0.3 }, "quantize_onsets": { "activation_prob": 0.15, @@ -203,6 +209,17 @@ "max_vel_delta": 45 } } + } + }, + "tokenizer": { + "inference_abs": { + "guidance": { + "min_ms": 5000, + "max_ms": 30000 + } + + } } + }