diff --git a/aria/config.py b/aria/config.py
index 8295bca..7037994 100644
--- a/aria/config.py
+++ b/aria/config.py
@@ -3,10 +3,13 @@
import os
import json
+from functools import lru_cache
+
CONFIG_DIR = os.path.join(os.path.dirname(__file__), "..", "config")
+@lru_cache(maxsize=1)
def load_config():
"""Returns a dictionary loaded from the config.json file."""
with open(os.path.join(CONFIG_DIR, "config.json")) as f:
diff --git a/aria/datasets.py b/aria/datasets.py
index 8575a86..210c9da 100644
--- a/aria/datasets.py
+++ b/aria/datasets.py
@@ -21,7 +21,7 @@
from collections import defaultdict
from aria.config import load_config
-from aria.tokenizer import SeparatedAbsTokenizer
+from aria.tokenizer import InferenceAbsTokenizer
from ariautils.tokenizer import Tokenizer
from ariautils.midi import (
MidiDict,
@@ -473,7 +473,7 @@ def __init__(self, tokenizer: Tokenizer):
def build(**kwargs):
raise NotImplementedError
- def get_loss_mask(self, tokenized_seq: list):
+ def get_loss_mask(self, src_seq: list, tgt_seq: list):
# Should returns a bool Tensor with False indicating a masked loss
raise NotImplementedError
@@ -602,7 +602,7 @@ def _format(tok):
src = seq
tgt = seq[1:] + [self.tokenizer.pad_tok]
- mask = self.get_loss_mask(tgt)
+ mask = self.get_loss_mask(src_seq=src, tgt_seq=tgt)
return (
torch.tensor(self.tokenizer.encode(src)),
@@ -692,7 +692,11 @@ def _new_transform(x):
raise ValueError("Must provide function or list of functions.")
-def _get_seqs(_entry: MidiDict | dict, _tokenizer: Tokenizer):
+def _get_seqs(
+ _entry: MidiDict | dict,
+ _tokenizer: Tokenizer,
+ _tokenize_fn: Callable | None = None,
+):
logger = setup_logger()
if isinstance(_entry, str):
@@ -705,7 +709,10 @@ def _get_seqs(_entry: MidiDict | dict, _tokenizer: Tokenizer):
raise Exception
try:
- _tokenized_seq = _tokenizer.tokenize(_midi_dict)
+ if _tokenize_fn is not None:
+ _tokenized_seq = _tokenize_fn(_midi_dict)
+ else:
+ _tokenized_seq = _tokenizer.tokenize(_midi_dict)
except Exception as e:
print(e)
logger.info(f"Skipping midi_dict: {e}")
@@ -719,6 +726,7 @@ def _get_seqs(_entry: MidiDict | dict, _tokenizer: Tokenizer):
def get_seqs(
tokenizer: Tokenizer,
midi_dict_iter: Iterable,
+ tokenize_fn: Callable | None = None,
):
# Can't pickle geneator object when start method is spawn
if multiprocessing.get_start_method() == "spawn":
@@ -729,7 +737,10 @@ def get_seqs(
with multiprocessing.Pool() as pool:
results = pool.imap_unordered(
- functools.partial(_get_seqs, _tokenizer=tokenizer), midi_dict_iter
+ functools.partial(
+ _get_seqs, _tokenizer=tokenizer, _tokenize_fn=tokenize_fn
+ ),
+ midi_dict_iter,
)
yield from results
@@ -782,9 +793,9 @@ def __init__(self, dir_paths: List[str] | str, tokenizer: Tokenizer):
def __len__(self):
return len(self.index)
- def get_loss_mask(self, tokenized_seq: list):
+ def get_loss_mask(self, src_seq: list, tgt_seq: list):
return torch.tensor(
- [tok != self.tokenizer.pad_tok for tok in tokenized_seq],
+ [tok != self.tokenizer.pad_tok for tok in tgt_seq],
dtype=torch.bool,
)
@@ -876,10 +887,8 @@ def _build_epoch(_save_path, _midi_dataset):
f"Finished building, saved PretrainingDataset to {save_dir}"
)
- return cls(dir_paths=save_dir, tokenizer=tokenizer)
-
-# TODO: Improve this logic so it supports MIDI files with multiple tempo_msgs
+# TODO: Refactor for readability
def _get_combined_mididict(
clean_midi_dict: MidiDict,
noisy_midi_dict: MidiDict,
@@ -887,7 +896,7 @@ def _get_combined_mididict(
max_noisy_ms: int,
min_clean_ms: int,
max_clean_ms: int,
-):
+) -> MidiDict:
# NOTE: We adopt the tempo/ticks_per_beat of the clean_midi_dict, and
# adjust the noisy note messages accordingly.
assert len(clean_midi_dict.tempo_msgs) == 1, "Unsupported tempo msgs"
@@ -904,24 +913,30 @@ def _get_combined_mididict(
noisy_intervals = []
clean_intervals = []
prev_ms = -1
+ add_noisy_next = random.choice([True, False])
while True:
- # Add noise interval
- noisy_end_ms = random.randint(
- prev_ms + min_noisy_ms, prev_ms + max_noisy_ms
- )
- noisy_intervals.append([prev_ms + 1, noisy_end_ms])
- prev_ms = noisy_end_ms
- if prev_ms > total_length_ms:
- break
-
- # Add clean interval
- clean_end_ms = random.randint(
- prev_ms + min_clean_ms, prev_ms + max_clean_ms
- )
- clean_intervals.append([prev_ms + 1, clean_end_ms])
- prev_ms = clean_end_ms
- if prev_ms > total_length_ms:
- break
+ if add_noisy_next is True:
+ # Add noisy interval
+ noisy_end_ms = random.randint(
+ prev_ms + min_noisy_ms, prev_ms + max_noisy_ms
+ )
+ noisy_intervals.append([prev_ms + 1, noisy_end_ms])
+ prev_ms = noisy_end_ms
+ if prev_ms > total_length_ms:
+ break
+ else:
+ add_noisy_next = False
+ else:
+ # Add clean interval
+ clean_end_ms = random.randint(
+ prev_ms + min_clean_ms, prev_ms + max_clean_ms
+ )
+ clean_intervals.append([prev_ms + 1, clean_end_ms])
+ prev_ms = clean_end_ms
+ if prev_ms > total_length_ms:
+ break
+ else:
+ add_noisy_next = True
# Merge note_msgs
clean_ms_to_tick = (clean_midi_dict.ticks_per_beat * 1e3) / (
@@ -930,20 +945,12 @@ def _get_combined_mididict(
comb_note_msgs = []
for _note_msg in noisy_midi_dict.note_msgs:
- onset_time_ms = get_duration_ms(
- start_tick=0,
- end_tick=_note_msg["data"]["start"],
- tempo_msgs=noisy_midi_dict.tempo_msgs,
- ticks_per_beat=noisy_midi_dict.ticks_per_beat,
- )
+ onset_time_ms = noisy_midi_dict.tick_to_ms(_note_msg["data"]["start"])
for _interval_start_ms, _interval_end_ms in noisy_intervals:
if _interval_start_ms < onset_time_ms < _interval_end_ms:
- offset_time_ms = get_duration_ms(
- start_tick=0,
- end_tick=_note_msg["data"]["end"],
- tempo_msgs=noisy_midi_dict.tempo_msgs,
- ticks_per_beat=noisy_midi_dict.ticks_per_beat,
+ offset_time_ms = noisy_midi_dict.tick_to_ms(
+ _note_msg["data"]["end"]
)
_adj_note_msg = copy.deepcopy(_note_msg)
_adj_onset_tick = int(onset_time_ms * clean_ms_to_tick)
@@ -956,21 +963,13 @@ def _get_combined_mididict(
break
for _note_msg in clean_midi_dict.note_msgs:
- onset_time_ms = get_duration_ms(
- start_tick=0,
- end_tick=_note_msg["data"]["start"],
- tempo_msgs=clean_midi_dict.tempo_msgs,
- ticks_per_beat=clean_midi_dict.ticks_per_beat,
- )
+ onset_time_ms = clean_midi_dict.tick_to_ms(_note_msg["data"]["start"])
for _interval_start_ms, _interval_end_ms in clean_intervals:
if _interval_start_ms < onset_time_ms < _interval_end_ms:
comb_note_msgs.append(_note_msg)
break
- # Redundant sort
- comb_note_msgs = sorted(comb_note_msgs, key=lambda msg: msg["tick"])
-
comb_metadata = deepcopy(clean_midi_dict.metadata)
comb_metadata["noisy_intervals"] = noisy_intervals
@@ -986,7 +985,7 @@ def _get_combined_mididict(
)
-# TODO: Move hyperparams into config.json (and TEST)
+# TODO: Refactor this function for readability
def _noise_midi_dict(midi_dict: MidiDict, config: dict):
def _get_velocity_adjusted_msg(
__note_msg: dict,
@@ -1149,51 +1148,58 @@ def _get_onset_adjusted_msg(
)
-def _get_mixed_dataset(
- _clean_dataset: Iterable,
- _noisy_datasets: list[Iterable],
+def export_inference_abs_build_tokenize_fn(
+ midi_dict: MidiDict, tokenizer: InferenceAbsTokenizer
):
finetuning_config = load_config()["data"]["finetuning"]
- ACTIVATION_PROB = finetuning_config["noising"]["activation_prob"]
+ GUIDANCE_PROB = finetuning_config["guidance_prob"]
+ NOISING_PROB = finetuning_config["noising"]["activation_prob"]
MIN_NOISY_MS = finetuning_config["min_noisy_interval_ms"]
MAX_NOISY_MS = finetuning_config["max_noisy_interval_ms"]
MIN_CLEAN_MS = finetuning_config["min_clean_interval_ms"]
MAX_CLEAN_MS = finetuning_config["max_clean_interval_ms"]
- comb_midi_dicts = []
- _noisy_dataset_itt = random_selection_itt(_noisy_datasets)
- for clean, noisy in zip(_clean_dataset, _noisy_dataset_itt):
- assert (
- os.path.splitext(os.path.basename(clean.metadata["abs_path"]))[0]
- == os.path.splitext(os.path.basename(noisy.metadata["abs_path"]))[0]
- ), f"file order mismatch: {clean.metadata['abs_path']}; {noisy.metadata['abs_path']}"
-
- if random.random() < ACTIVATION_PROB:
- noisy = _noise_midi_dict(noisy, config=finetuning_config["noising"])
-
- comb_midi_dicts.append(
- _get_combined_mididict(
- clean,
- noisy,
- min_noisy_ms=MIN_NOISY_MS,
- max_noisy_ms=MAX_NOISY_MS,
- min_clean_ms=MIN_CLEAN_MS,
- max_clean_ms=MAX_CLEAN_MS,
- )
+ if random.random() <= NOISING_PROB:
+ noisy_midi_dict = _noise_midi_dict(
+ midi_dict, config=finetuning_config["noising"]
+ )
+ midi_dict_for_tokenization = _get_combined_mididict(
+ clean_midi_dict=midi_dict,
+ noisy_midi_dict=noisy_midi_dict,
+ min_noisy_ms=MIN_NOISY_MS,
+ max_noisy_ms=MAX_NOISY_MS,
+ min_clean_ms=MIN_CLEAN_MS,
+ max_clean_ms=MAX_CLEAN_MS,
)
+ else:
+ midi_dict_for_tokenization = midi_dict
- return MidiDataset(comb_midi_dicts)
+ if random.random() <= GUIDANCE_PROB:
+ return tokenizer.tokenize(
+ midi_dict=midi_dict_for_tokenization,
+ prompt_intervals_ms=midi_dict_for_tokenization.metadata.get(
+ "noisy_intervals", []
+ ),
+ guidance_midi_dict=midi_dict,
+ )
+ else:
+ return tokenizer.tokenize(
+ midi_dict=midi_dict_for_tokenization,
+ prompt_intervals_ms=midi_dict_for_tokenization.metadata.get(
+ "noisy_intervals", []
+ ),
+ )
class FinetuningDataset(TrainingDataset):
"""Torch dataset object yielding sequences formatted for fine-tuning."""
def __init__(
- self, dir_paths: List[str] | str, tokenizer: SeparatedAbsTokenizer
+ self, dir_paths: List[str] | str, tokenizer: InferenceAbsTokenizer
):
super().__init__(tokenizer=tokenizer)
- assert tokenizer.name == "separated_abs", "invalid tokenizer"
+ assert tokenizer.name == "inference_abs", "invalid tokenizer"
if isinstance(dir_paths, str):
dir_paths = [dir_paths]
@@ -1205,31 +1211,33 @@ def __init__(
def __len__(self):
return len(self.index)
- def get_loss_mask(self, tokenized_seq: list):
- mask = [True] * len(tokenized_seq)
- inside_inst = False
+ def get_loss_mask(self, src_seq: list, tgt_seq: list):
+ mask = [False] * len(tgt_seq)
+ inside_target = True
- for idx, token in enumerate(tokenized_seq):
- if token == self.tokenizer.inst_start_tok:
- mask[idx] = False
- inside_inst = True
- elif token == self.tokenizer.inst_end_tok:
- mask[idx] = False
- inside_inst = False
- elif inside_inst:
- mask[idx] = False
+ for idx, (src_tok, tgt_tok) in enumerate(zip(src_seq, tgt_seq)):
+ if src_tok == self.tokenizer.guidance_start_tok:
+ inside_target = False
+ elif src_tok == self.tokenizer.guidance_end_tok:
+ inside_target = True
+ elif tgt_tok == self.tokenizer.prompt_start_tok:
+ inside_target = False
+ elif src_tok == self.tokenizer.prompt_end_tok:
+ inside_target = True
+
+ if inside_target is True and tgt_tok != self.tokenizer.pad_tok:
+ mask[idx] = True
return torch.tensor(mask, dtype=torch.bool)
@classmethod
def build(
cls,
- tokenizer: Tokenizer,
+ tokenizer: InferenceAbsTokenizer,
save_dir: str,
max_seq_len: int,
num_epochs: int,
- clean_dataset_path: str,
- noisy_dataset_paths: str,
+ midi_dataset_path: str,
):
def _build_epoch(_save_path, _midi_dataset):
@@ -1244,7 +1252,17 @@ def _build_epoch(_save_path, _midi_dataset):
)
_idx = 0
- for entry in reservoir(get_seqs(tokenizer, _midi_dataset), 10):
+ for entry in reservoir(
+ get_seqs(
+ tokenizer,
+ _midi_dataset,
+ tokenize_fn=functools.partial(
+ export_inference_abs_build_tokenize_fn,
+ tokenizer=tokenizer,
+ ),
+ ),
+ 10,
+ ):
for _entry in tokenizer.split(entry, max_seq_len):
writer.write(_entry)
@@ -1252,12 +1270,14 @@ def _build_epoch(_save_path, _midi_dataset):
if _idx % 250 == 0:
logger.info(f"Finished processing {_idx}")
+ # DEBUG
+ if _idx == 1000:
+ break
+
logger = setup_logger()
assert max_seq_len > 0, "max_seq_len must be greater than 0"
assert num_epochs > 0, "num_epochs must be greater than 0"
- assert os.path.isfile(clean_dataset_path), "file not found"
- for __path in noisy_dataset_paths:
- assert os.path.isfile(__path), "file not found"
+ assert os.path.isfile(midi_dataset_path), "file not found"
if multiprocessing.get_start_method() == "spawn":
logger.warning(
'The current multiprocessing start method is "spawn", this '
@@ -1284,21 +1304,14 @@ def _build_epoch(_save_path, _midi_dataset):
f"tokenizer_name={tokenizer.name}"
)
- clean_dataset = MidiDataset.load(clean_dataset_path)
- noisy_datasets = [
- MidiDataset.load(_path) for _path in noisy_dataset_paths
- ]
-
for idx in range(num_epochs):
logger.info(f"Building epoch {idx}/{num_epochs - 1}...")
# Reload the combined dataset for each epoch
- combined_dataset = _get_mixed_dataset(clean_dataset, noisy_datasets)
+ midi_dataset = MidiDataset.get_generator(midi_dataset_path)
_build_epoch(
_save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"),
- _midi_dataset=combined_dataset,
+ _midi_dataset=midi_dataset,
)
logger.info(f"Finished building, saved FinetuningDataset to {save_dir}")
-
- return cls(dir_paths=save_dir, tokenizer=tokenizer)
diff --git a/aria/inference/model.py b/aria/inference/model.py
index 94d4d3e..d4bf439 100644
--- a/aria/inference/model.py
+++ b/aria/inference/model.py
@@ -44,8 +44,17 @@ def __init__(self, model_config: ModelConfig):
model_config.d_model, model_config.vocab_size, bias=False
)
- def forward(self, idxs: torch.Tensor, input_pos: torch.Tensor):
- hidden_states = self.model(idxs=idxs, input_pos=input_pos)
+ def forward(
+ self,
+ idxs: torch.Tensor,
+ input_pos: torch.Tensor,
+ pad_idxs: torch.Tensor | None = None,
+ ):
+ hidden_states = self.model(
+ idxs=idxs,
+ input_pos=input_pos,
+ pad_idxs=pad_idxs,
+ )
logits = self.lm_head(hidden_states)
return logits
@@ -98,10 +107,15 @@ def forward(
self,
idxs: torch.Tensor,
input_pos: torch.Tensor,
+ pad_idxs: torch.Tensor | None = None,
):
assert self.freqs_cis is not None, "Caches must be initialized first"
mask = self.causal_mask[None, None, input_pos]
+
+ if pad_idxs is not None:
+ mask = mask & ~(pad_idxs.unsqueeze(1).unsqueeze(1))
+
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idxs)
diff --git a/aria/run.py b/aria/run.py
index 054d052..cfb4bc5 100644
--- a/aria/run.py
+++ b/aria/run.py
@@ -11,23 +11,26 @@ def _parse_sample_args():
argp.add_argument("-m", help="name of model config file")
argp.add_argument("-c", help="path to model checkpoint")
argp.add_argument("-p", help="path to midi file")
- argp.add_argument(
- "-pt", help="sample using the pretrained model", action="store_true"
- )
argp.add_argument(
"-temp",
- help="change temp value",
+ help="sampling temperature value",
type=float,
required=False,
default=0.95,
)
argp.add_argument(
"-top_p",
- help="change top_p value",
+ help="sampling top_p value",
type=float,
required=False,
default=0.95,
)
+ argp.add_argument(
+ "-cfg",
+ help="sampling cfg gamma value",
+ type=float,
+ required=False,
+ )
argp.add_argument(
"-metadata",
nargs=2,
@@ -49,38 +52,26 @@ def _parse_sample_args():
)
argp.add_argument("-e", action="store_true", help="enable force end")
argp.add_argument("-l", type=int, help="generation length", default=1024)
- argp.add_argument("-noise", action="store_true", help="add noise to prompt")
+ argp.add_argument(
+ "-guidance_path", type=str, help="path to guidance MIDI", required=False
+ )
+ argp.add_argument(
+ "-guidance_start_ms",
+ help="guidance interval start (ms)",
+ type=int,
+ required=False,
+ )
+ argp.add_argument(
+ "-guidance_end_ms",
+ help="guidance interval end (ms)",
+ type=int,
+ required=False,
+ )
argp.add_argument("-compile", action="store_true", help="compile cudagraph")
return argp.parse_args(sys.argv[2:])
-def _get_model_name(name: str | None, state: dict):
- if name is not None:
- return name
-
- print("Model name is not provided. Trying to infer from checkpoint...")
- _defaults = {
- 16: "small",
- 32: "medium",
- 64: "large",
- }
- try:
- pattern = re.compile(r"encode_layers\.(\d+)\.")
- layer_keys = [pattern.search(k) for k in state.keys()]
- layer_keys = set(p.group(1) for p in layer_keys if p is not None)
- for i in range(len(layer_keys)):
- assert str(i) in layer_keys
-
- if len(layer_keys) in _defaults:
- print(f"Selecting model name: {_defaults[len(layer_keys)]}")
- return _defaults[len(layer_keys)]
- assert False
- except:
- raise ValueError("Model name is not provided and cannot be inferred.")
-
-
-# TODO: Add support for sampling from the pretrained model
def sample(args):
"""Entrypoint for sampling"""
@@ -88,9 +79,12 @@ def sample(args):
from aria.inference import TransformerLM
from aria.model import ModelConfig
from aria.config import load_model_config, load_config
- from ariautils.tokenizer import AbsTokenizer
- from aria.tokenizer import SeparatedAbsTokenizer
- from aria.sample import greedy_sample, get_pt_prompt, get_inst_prompt
+ from aria.tokenizer import InferenceAbsTokenizer
+ from aria.sample import (
+ sample_batch_cfg,
+ sample_batch,
+ get_inference_prompt,
+ )
from ariautils.midi import MidiDict
from aria.utils import _load_weight
@@ -116,11 +110,7 @@ def sample(args):
force_end = args.e
model_name = args.m
- if args.pt == True:
- tokenizer = AbsTokenizer()
- else:
- tokenizer = SeparatedAbsTokenizer()
-
+ tokenizer = InferenceAbsTokenizer()
model_config = ModelConfig(**load_model_config(model_name))
model_config.set_vocab_size(tokenizer.vocab_size)
model_config.grad_checkpoint = False
@@ -133,17 +123,6 @@ def sample(args):
"Failed to load model_state. This is likely due to an incompatibility "
"between the checkpoint file (-c) and model name/config (-m)."
)
- if args.pt:
- print(
- "When using the -pt flag make sure you provide a checkpoint for "
- "the pretrained model."
- )
- else:
- print(
- "When not using the -pt flag make sure you provide a checkpoint "
- " for the instuct-finetuned (inst) model."
- )
-
raise e
assert args.l > 0, "Generation length must be positive."
@@ -151,6 +130,10 @@ def sample(args):
# Load and format prompts and metadata
midi_dict = MidiDict.from_midi(mid_path=args.p)
+ if args.guidance_path:
+ guidance_midi_dict = MidiDict.from_midi(mid_path=args.guidance_path)
+ else:
+ guidance_midi_dict = None
for k, v in manual_metadata.items():
midi_dict.metadata[k] = v
@@ -160,42 +143,48 @@ def sample(args):
f"Instruments: {set([MidiDict.get_program_to_instrument()[msg['data']] for msg in midi_dict.instrument_msgs])}"
)
- if args.pt:
- if args.noise:
- print("Noising not supported with pretrained model")
+ prompt_seq, guidance_seq = get_inference_prompt(
+ tokenizer=tokenizer,
+ midi_dict=midi_dict,
+ truncate_len=truncate_len,
+ guidance_start_ms=args.guidance_start_ms,
+ guidance_end_ms=args.guidance_end_ms,
+ guidance_midi_dict=guidance_midi_dict,
+ )
- prompt_seq = get_pt_prompt(
- tokenizer=tokenizer,
- midi_dict=midi_dict,
- truncate_len=truncate_len,
- )
- else:
- prompt_seq = get_inst_prompt(
- tokenizer=tokenizer,
- midi_dict=midi_dict,
- truncate_len=truncate_len,
- noise=args.noise,
+ if guidance_seq:
+ tokenizer.detokenize(guidance_seq).to_midi().save(
+ os.path.join(samples_dir, f"guidance.mid")
)
-
- prompts = [prompt_seq for _ in range(num_variations)]
if len(prompt_seq) + args.l > model_config.max_seq_len:
print(
"WARNING: Required context exceeds max_seq_len supported by model"
)
+ prompts = [prompt_seq for _ in range(num_variations)]
- print(prompt_seq)
-
- results = greedy_sample(
- model=model,
- tokenizer=tokenizer,
- prompts=prompts,
- max_new_tokens=max_new_tokens,
- force_end=force_end,
- # cfg_gamma=args.cfg,
- temperature=args.temp,
- top_p=args.top_p,
- compile=args.compile,
- )
+ if args.cfg is not None:
+ results = sample_batch_cfg(
+ model=model,
+ tokenizer=tokenizer,
+ prompts=prompts,
+ max_new_tokens=max_new_tokens,
+ cfg_gamma=args.cfg,
+ force_end=force_end,
+ temperature=args.temp,
+ top_p=args.top_p,
+ compile=args.compile,
+ )
+ else:
+ results = sample_batch(
+ model=model,
+ tokenizer=tokenizer,
+ prompts=prompts,
+ max_new_tokens=max_new_tokens,
+ force_end=force_end,
+ temperature=args.temp,
+ top_p=args.top_p,
+ compile=args.compile,
+ )
samples_dir = os.path.join(os.path.dirname(__file__), "..", "samples")
if os.path.isdir(samples_dir) is False:
@@ -204,7 +193,7 @@ def sample(args):
for idx, tokenized_seq in enumerate(results):
res_midi_dict = tokenizer.detokenize(tokenized_seq)
res_midi = res_midi_dict.to_midi()
- res_midi.save(f"samples/res_{idx + 1}.mid")
+ res_midi.save(os.path.join(samples_dir, f"res_{idx + 1}.mid"))
print("Results saved to samples/")
@@ -277,7 +266,7 @@ def build_pretraining_dataset(args):
elif args.tokenizer_name == "rel":
tokenizer = RelTokenizer()
- dataset = PretrainingDataset.build(
+ PretrainingDataset.build(
tokenizer=tokenizer,
save_dir=args.save_dir,
max_seq_len=args.l,
@@ -289,13 +278,8 @@ def build_pretraining_dataset(args):
def _parse_finetune_dataset_args():
argp = argparse.ArgumentParser(prog="aria finetune-dataset")
argp.add_argument(
- "-clean_load_path",
- help="path to the clean midi_dict dataset",
- )
- argp.add_argument(
- "-noisy_load_paths",
- nargs="+",
- help="one or more paths to noisy midi_dict datasets",
+ "-midi_dataset_path",
+ help="path to midi_dict dataset",
)
argp.add_argument("-save_dir", help="path to save dataset")
argp.add_argument("-l", help="max sequence length", type=int, default=4096)
@@ -305,17 +289,16 @@ def _parse_finetune_dataset_args():
def build_finetune_dataset(args):
- from aria.tokenizer import SeparatedAbsTokenizer
+ from aria.tokenizer import InferenceAbsTokenizer
from aria.datasets import FinetuningDataset
- tokenizer = SeparatedAbsTokenizer()
+ tokenizer = InferenceAbsTokenizer()
FinetuningDataset.build(
tokenizer=tokenizer,
save_dir=args.save_dir,
max_seq_len=args.l,
num_epochs=args.e,
- clean_dataset_path=args.clean_load_path,
- noisy_dataset_paths=args.noisy_load_paths,
+ midi_dataset_path=args.midi_dataset_path,
)
diff --git a/aria/sample.py b/aria/sample.py
index 8c70af4..4546283 100644
--- a/aria/sample.py
+++ b/aria/sample.py
@@ -1,5 +1,6 @@
"""Contains generation/sampling code"""
+import copy
import torch
import torch._dynamo.config
import torch._inductor.config
@@ -8,7 +9,8 @@
from tqdm import tqdm
from aria.inference import TransformerLM
-from ariautils.tokenizer import Tokenizer
+from aria.tokenizer import InferenceAbsTokenizer
+from ariautils.tokenizer import Tokenizer, AbsTokenizer
from ariautils.midi import MidiDict
torch._inductor.config.coordinate_descent_tuning = True
@@ -16,16 +18,45 @@
torch._inductor.config.fx_graph_cache = True
+def get_cfg_prompt(prompts: list, pad_tok: str, guidance_end_tok: str):
+ cfg_prompts = []
+ for prompt in prompts:
+ prompt_no_guidance = prompt[prompt.index(guidance_end_tok) + 1 :]
+ prompt_no_guidance = [pad_tok] * (
+ len(prompt) - len(prompt_no_guidance)
+ ) + prompt_no_guidance
+ cfg_prompts.append(prompt)
+ cfg_prompts.append(prompt_no_guidance)
+
+ return cfg_prompts
+
+
@torch.inference_mode()
-def prefill(model, idxs: torch.Tensor, input_pos: torch.Tensor):
- logits = model.forward(idxs=idxs, input_pos=input_pos)[:, -1]
+def decode_one(
+ model: TransformerLM,
+ idxs: torch.Tensor,
+ input_pos: torch.Tensor,
+ pad_idxs: torch.Tensor | None = None,
+):
+ logits = model.forward(
+ idxs=idxs,
+ input_pos=input_pos,
+ pad_idxs=pad_idxs,
+ )[:, -1]
return logits
@torch.inference_mode()
-def decode_one(model, idxs: torch.Tensor, input_pos: torch.Tensor):
- logits = model.forward(idxs=idxs, input_pos=input_pos)[:, -1]
+def prefill(
+ model: TransformerLM,
+ idxs: torch.Tensor,
+ input_pos: torch.Tensor,
+ pad_idxs: torch.Tensor | None = None,
+):
+ logits = model.forward(idxs=idxs, input_pos=input_pos, pad_idxs=pad_idxs)[
+ :, -1
+ ]
return logits
@@ -62,33 +93,32 @@ def update_seq_ids_(
seq[:, idx] = next_token_ids
-# TODO: Add CFG back into this when working
-# TODO: Check that unexpected instrument warnings still working
+# TODO: Not working
@torch.autocast(
"cuda",
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
)
@torch.inference_mode()
-def greedy_sample(
+def sample_batch(
model: TransformerLM,
tokenizer: Tokenizer,
prompts: List[list],
max_new_tokens: int,
force_end=False,
- cfg_gamma: float | None = 1.05,
temperature: float = 0.95,
top_p: float = 0.95,
- compile: bool = True,
+ compile: bool = False,
):
- """Performs greedy (top_p) auto-regressive sampling on a batch of prompts."""
-
if force_end:
assert max_new_tokens > 130, "prompt too long to use force_end=True"
_prompt_len = len(prompts[0])
_num_prompts = len(prompts)
+ assert all([len(p) == _prompt_len for p in prompts])
model.eval()
+ dim_tok_inserted = [False for _ in range(_num_prompts)]
+ eos_tok_seen = [False for _ in range(_num_prompts)]
total_len = _prompt_len + max_new_tokens
seq = torch.stack(
[
@@ -98,15 +128,12 @@ def greedy_sample(
for p in prompts
]
).cuda()
- dim_tok_inserted = [False for _ in range(_num_prompts)]
- eos_tok_seen = [False for _ in range(_num_prompts)]
if compile is True:
global decode_one
decode_one = torch.compile(
decode_one,
mode="reduce-overhead",
- # mode="max-autotune",
fullgraph=True,
)
@@ -119,7 +146,7 @@ def greedy_sample(
)
print(
- f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}"
+ f"Using hyperparams: temp={temperature}, top_p={top_p}, gen_len={max_new_tokens}"
)
for idx in (
@@ -145,8 +172,8 @@ def greedy_sample(
),
)
- if tokenizer.name == "separated_abs":
- logits[:, tokenizer.tok_to_id[tokenizer.inst_start_tok]] = float(
+ if tokenizer.name == "inference_abs":
+ logits[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float(
"-inf"
)
@@ -183,6 +210,134 @@ def greedy_sample(
return decoded_results
+@torch.autocast(
+ "cuda",
+ dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
+)
+@torch.inference_mode()
+def sample_batch_cfg(
+ model: TransformerLM,
+ tokenizer: InferenceAbsTokenizer,
+ prompts: List[list],
+ max_new_tokens: int,
+ cfg_gamma: float,
+ force_end=False,
+ temperature: float = 0.95,
+ top_p: float = 0.95,
+ compile: bool = False,
+):
+ assert 0.0 <= cfg_gamma <= 2.0
+ assert 0.0 <= temperature <= 2.0
+ assert 0.5 <= top_p <= 1.0
+ assert tokenizer.name == "inference_abs"
+ if force_end:
+ assert max_new_tokens > 130, "prompt too long to use force_end=True"
+
+ prompts = get_cfg_prompt(
+ prompts, tokenizer.pad_tok, tokenizer.guidance_end_tok
+ )
+
+ _prompt_len = len(prompts[0])
+ _num_prompts = len(prompts)
+ assert all([len(p) == _prompt_len for p in prompts])
+
+ model.eval()
+ total_len = _prompt_len + max_new_tokens
+ seq = torch.stack(
+ [
+ torch.tensor(
+ tokenizer.encode(p + [tokenizer.pad_tok] * (total_len - len(p)))
+ )
+ for p in prompts
+ ]
+ ).cuda()
+ dim_tok_inserted = [False for _ in range(_num_prompts)]
+ eos_tok_seen = [False for _ in range(_num_prompts)]
+
+ if compile is True:
+ global decode_one
+ decode_one = torch.compile(
+ decode_one,
+ mode="reduce-overhead",
+ fullgraph=True,
+ )
+
+ model.setup_cache(
+ batch_size=_num_prompts,
+ max_seq_len=total_len,
+ dtype=(
+ torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+ ),
+ )
+
+ print(
+ f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}"
+ )
+
+ for idx in (
+ pbar := tqdm(
+ range(_prompt_len, total_len),
+ total=total_len - _prompt_len,
+ leave=False,
+ )
+ ):
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
+ if idx == _prompt_len:
+ logits = prefill(
+ model,
+ idxs=seq[:, :idx],
+ input_pos=torch.arange(0, idx, device=seq.device),
+ pad_idxs=(seq == tokenizer.pad_id),
+ )
+ else:
+ logits = decode_one(
+ model,
+ idxs=seq[:, idx - 1 : idx],
+ input_pos=torch.tensor(
+ [idx - 1], device=seq.device, dtype=torch.int
+ ),
+ pad_idxs=(seq == tokenizer.pad_id),
+ )
+
+ logits_cfg = cfg_gamma * logits[::2] + (1 - cfg_gamma) * logits[1::2]
+ logits_cfg[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float(
+ "-inf"
+ )
+
+ if temperature > 0.0:
+ probs = torch.softmax(logits_cfg / temperature, dim=-1)
+ next_token_ids = sample_top_p(probs, top_p).flatten()
+ else:
+ next_token_ids = torch.argmax(logits_cfg, dim=-1).flatten()
+
+ next_token_ids = next_token_ids.repeat_interleave(2)
+ update_seq_ids_(
+ seq=seq,
+ idx=idx,
+ next_token_ids=next_token_ids,
+ dim_tok_inserted=dim_tok_inserted,
+ eos_tok_seen=eos_tok_seen,
+ max_len=total_len,
+ force_end=force_end,
+ tokenizer=tokenizer,
+ )
+
+ if all(seen_eos is True for seen_eos in eos_tok_seen):
+ break
+
+ decoded_results = [tokenizer.decode(s) for s in seq.tolist()][::2]
+ decoded_results = [
+ (
+ res[: res.index(tokenizer.eos_tok) + 1]
+ if tokenizer.eos_tok in res
+ else res
+ )
+ for res in decoded_results
+ ]
+
+ return decoded_results
+
+
def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
@@ -194,32 +349,39 @@ def sample_top_p(probs, p):
return next_token
-# TODO: Clean up a bit and get rid of footguns
-def get_inst_prompt(
- tokenizer: Tokenizer,
+def get_inference_prompt(
+ tokenizer: InferenceAbsTokenizer,
midi_dict: MidiDict,
truncate_len: int,
- noise: bool,
+ guidance_start_ms: int,
+ guidance_end_ms: int,
+ guidance_midi_dict: MidiDict | None = None,
):
- from aria.datasets import _noise_midi_dict
- from aria.config import load_config
-
- midi_dict.metadata["noisy_intervals"] = [[0, truncate_len * 1e3]]
-
- if noise == True:
- midi_dict = _noise_midi_dict(
- midi_dict, load_config()["data"]["finetuning"]["noising"]
+ assert tokenizer.name == "inference_abs"
+
+ if guidance_midi_dict is not None:
+ assert guidance_start_ms is not None and guidance_start_ms >= 0
+ assert guidance_end_ms is not None and guidance_end_ms >= 0
+ assert (
+ tokenizer._config["guidance"]["min_ms"]
+ <= guidance_end_ms - guidance_start_ms
+ <= tokenizer._config["guidance"]["max_ms"]
)
- prompt_seq = tokenizer.tokenize(midi_dict=midi_dict)
+ prompt_seq = tokenizer.tokenize(
+ midi_dict=midi_dict,
+ prompt_intervals_ms=(
+ [[0, truncate_len * 1e3]] if truncate_len > 0 else []
+ ),
+ guidance_midi_dict=guidance_midi_dict,
+ guidance_start_ms=guidance_start_ms,
+ guidance_end_ms=guidance_end_ms,
+ )
- if tokenizer.inst_end_tok in prompt_seq:
- prompt_seq = prompt_seq[: prompt_seq.index(tokenizer.inst_end_tok) + 1]
- elif tokenizer.eos_tok in prompt_seq:
- # TODO: This is a workaround for a bug where is not inserted if
- # the pieces ends.
- assert prompt_seq[-1] == tokenizer.eos_tok
- prompt_seq[-1] = tokenizer.inst_end_tok
+ if tokenizer.prompt_end_tok in prompt_seq:
+ prompt_seq = prompt_seq[
+ : prompt_seq.index(tokenizer.prompt_end_tok) + 1
+ ]
else:
print("No notes found in prompt region")
prompt_seq = prompt_seq[: prompt_seq.index(tokenizer.bos_tok) + 1]
@@ -227,21 +389,13 @@ def get_inst_prompt(
if tokenizer.dim_tok in prompt_seq:
prompt_seq.remove(tokenizer.dim_tok)
- return prompt_seq
-
-
-def get_pt_prompt(
- tokenizer: Tokenizer,
- midi_dict: MidiDict,
- truncate_len: int,
-):
- prompt_seq = tokenizer.tokenize(midi_dict=midi_dict)
- prompt_seq = tokenizer.truncate_by_time(
- tokenized_seq=prompt_seq,
- trunc_time_ms=truncate_len * 1e3,
- )
-
- if tokenizer.dim_tok in prompt_seq:
- prompt_seq.remove(tokenizer.dim_tok)
+ if guidance_midi_dict is not None:
+ guidance_seq = copy.deepcopy(prompt_seq)
+ guidance_seq = guidance_seq[
+ : guidance_seq.index(tokenizer.guidance_end_tok)
+ ]
+ guidance_seq[0] = ("prefix", "instrument", "piano")
+ else:
+ guidance_seq = None
- return prompt_seq
+ return prompt_seq, guidance_seq
diff --git a/aria/tokenizer.py b/aria/tokenizer.py
index 2b0ce6a..818fb27 100644
--- a/aria/tokenizer.py
+++ b/aria/tokenizer.py
@@ -1,90 +1,248 @@
"""Tokenizer for MIDI conditioned completions"""
import copy
+import random
+import functools
-from ariautils.midi import MidiDict, get_duration_ms
+from typing import Callable
+
+from aria.config import load_config
+from ariautils.midi import MidiDict
from ariautils.tokenizer import AbsTokenizer as _AbsTokenizer
-class SeparatedAbsTokenizer(_AbsTokenizer):
+class InferenceAbsTokenizer(_AbsTokenizer):
def __init__(self):
super().__init__()
- self.name = "separated_abs"
- self.inst_start_tok = ""
- self.inst_end_tok = ""
- self.add_tokens_to_vocab([self.inst_start_tok, self.inst_end_tok])
- self.special_tokens.append(self.inst_start_tok)
- self.special_tokens.append(self.inst_end_tok)
-
- def tokenize(self, midi_dict: MidiDict, **kwargs):
- def _add_inst_toks(_seq: list, _start_ms: int, _end_ms: int):
- res_seq = copy.deepcopy(_seq)
-
- inst_inserted = False
- time_tok_cnt = 0
- curr_time_ms = 0
- for idx, (tok_1, tok_2) in enumerate(zip(_seq, _seq[1:])):
- if tok_1 == self.time_tok:
- time_tok_cnt += 1
- elif (
- isinstance(tok_1, tuple) and tok_1[0] in self.instruments_wd
- ):
- assert isinstance(tok_2, tuple) and tok_2[0] == "onset"
+ self.name = "inference_abs"
+ self._config = load_config()["tokenizer"]["inference_abs"]
+
+ self.prompt_start_tok = ""
+ self.prompt_end_tok = ""
+ self.guidance_start_tok = ""
+ self.guidance_end_tok = ""
+
+ self.add_tokens_to_vocab(
+ [
+ self.prompt_start_tok,
+ self.prompt_end_tok,
+ self.guidance_start_tok,
+ self.guidance_end_tok,
+ ]
+ )
+ self.special_tokens.append(self.prompt_start_tok)
+ self.special_tokens.append(self.prompt_end_tok)
+ self.special_tokens.append(self.guidance_start_tok)
+ self.special_tokens.append(self.guidance_end_tok)
- # Adjust time
- _curr_time = (
- self.config["abs_time_step_ms"] * time_tok_cnt
- ) + tok_2[1]
+ def _get_guidance_interval_ms(self, guidance_midi_dict: MidiDict):
+ first_note_onset_ms = guidance_midi_dict.tick_to_ms(
+ guidance_midi_dict.note_msgs[0]["tick"]
+ )
+ last_note_onset_ms = guidance_midi_dict.tick_to_ms(
+ guidance_midi_dict.note_msgs[-1]["tick"]
+ )
+ guidance_segment_length_ms = random.randint(
+ self._config["guidance"]["min_ms"],
+ min(self._config["guidance"]["max_ms"], last_note_onset_ms),
+ )
+ guidance_start_ms = random.randint(
+ first_note_onset_ms,
+ last_note_onset_ms - guidance_segment_length_ms,
+ )
+ guidance_end_ms = guidance_start_ms + guidance_segment_length_ms
+
+ return guidance_start_ms, guidance_end_ms
- assert _curr_time >= curr_time_ms
- curr_time_ms = _curr_time
+ def _get_guidance_seq(
+ self,
+ guidance_midi_dict: MidiDict,
+ guidance_start_ms: int | None = None,
+ guidance_end_ms: int | None = None,
+ ):
+ assert guidance_midi_dict.note_msgs is not None
- if curr_time_ms >= _start_ms and inst_inserted == False:
- res_seq.insert(idx, self.inst_start_tok)
- inst_inserted = True
- if curr_time_ms > _end_ms and inst_inserted == True:
- res_seq.insert(idx + 1, self.inst_end_tok)
- break
+ # Need to validate these numbers
+ if guidance_start_ms is None:
+ assert guidance_end_ms is None
+ guidance_start_ms, guidance_end_ms = self._get_guidance_interval_ms(
+ guidance_midi_dict=guidance_midi_dict
+ )
- return res_seq
+ slice_note_msgs = []
+ for note_msg in guidance_midi_dict.note_msgs:
+ start_ms = guidance_midi_dict.tick_to_ms(note_msg["data"]["start"])
+ if guidance_start_ms <= start_ms <= guidance_end_ms:
+ slice_note_msgs.append(note_msg)
- if midi_dict.metadata.get("noisy_intervals") is None:
- print("noisy_intervals metadata not present")
- return super().tokenize(midi_dict, **kwargs)
+ slice_midi_dict = copy.deepcopy(guidance_midi_dict)
+ slice_midi_dict.note_msgs = slice_note_msgs
- seq = super().tokenize(midi_dict, **kwargs)
+ if len(slice_midi_dict.note_msgs) == 0:
+ # Catches not note in interval
+ return []
- # This logic is required as the tokenizer removes proceeding silence
- first_note_ms = get_duration_ms(
- start_tick=0,
- end_tick=midi_dict.note_msgs[0]["data"]["start"],
- tempo_msgs=midi_dict.tempo_msgs,
- ticks_per_beat=midi_dict.ticks_per_beat,
+ guidance_seq = self._tokenize_midi_dict(
+ midi_dict=slice_midi_dict,
+ remove_preceding_silence=True,
)
- noisy_intervals = [
- [
- ival[0] - first_note_ms,
- ival[1] - first_note_ms,
- ]
- for ival in midi_dict.metadata.get("noisy_intervals")
- if ival[1] >= first_note_ms
+
+ if self.dim_tok in guidance_seq:
+ guidance_seq.remove(self.dim_tok)
+
+ guidance_seq = guidance_seq[
+ guidance_seq.index(self.bos_tok)
+ + 1 : guidance_seq.index(self.eos_tok)
]
- for start_ms, end_ms in noisy_intervals:
- seq = _add_inst_toks(seq, start_ms, end_ms)
+ return (
+ [self.guidance_start_tok] + guidance_seq + [self.guidance_end_tok]
+ )
+
+ def _add_prompt_tokens(
+ self, seq: list, prompt_start_ms: int, prompt_end_ms: int
+ ):
+ res = copy.deepcopy(seq)
+ prompt_tok_inserted = False
+ time_tok_cnt = 0
+ curr_time_ms = 0
+ for idx, (tok_1, tok_2) in enumerate(zip(seq, seq[1:])):
+ if tok_1 == self.time_tok:
+ time_tok_cnt += 1
+ elif isinstance(tok_1, tuple) and tok_1[0] in self.instruments_wd:
+ assert isinstance(tok_2, tuple) and tok_2[0] == "onset"
+
+ # Adjust time
+ curr_time_ms = (self.abs_time_step_ms * time_tok_cnt) + tok_2[1]
+
+ if (
+ curr_time_ms >= prompt_start_ms
+ and prompt_tok_inserted == False
+ ):
+ res.insert(idx, self.prompt_start_tok)
+ prompt_tok_inserted = True
+ elif (
+ curr_time_ms > prompt_end_ms and prompt_tok_inserted == True
+ ):
+ res.insert(idx + 1, self.prompt_end_tok)
+ break
- return seq
+ return res
- def detokenize(self, midi_dict: MidiDict, **kwargs):
- return super().detokenize(midi_dict, **kwargs)
+ def tokenize(
+ self,
+ midi_dict: MidiDict,
+ prompt_intervals_ms: list[tuple[int, int]],
+ guidance_midi_dict: MidiDict | None = None,
+ guidance_start_ms: int | None = None,
+ guidance_end_ms: int | None = None,
+ ):
+ seq = self._tokenize_midi_dict(
+ midi_dict=midi_dict, remove_preceding_silence=True
+ )
+ first_note_ms = midi_dict.tick_to_ms(
+ midi_dict.note_msgs[0]["data"]["start"]
+ )
+
+ for prompt_start_ms, prompt_end_ms in prompt_intervals_ms:
+ if prompt_end_ms > first_note_ms:
+ seq = self._add_prompt_tokens(
+ seq,
+ prompt_start_ms=prompt_start_ms - first_note_ms,
+ prompt_end_ms=prompt_end_ms - first_note_ms,
+ )
+
+ if guidance_midi_dict is not None:
+ guidance_seq = self._get_guidance_seq(
+ guidance_midi_dict=guidance_midi_dict,
+ guidance_start_ms=guidance_start_ms,
+ guidance_end_ms=guidance_end_ms,
+ )
+ else:
+ guidance_seq = []
+
+ return guidance_seq + seq
+
+ def detokenize(self, tokenized_seq: list, **kwargs):
+ if self.guidance_end_tok in tokenized_seq:
+ seq = tokenized_seq[tokenized_seq.index(self.guidance_end_tok) :]
+ else:
+ seq = tokenized_seq
+
+ return super()._detokenize_midi_dict(seq, **kwargs)
def export_data_aug(self):
return [
- self.export_pitch_aug(5),
- self.export_velocity_aug(1),
+ self.export_guidance_tempo_aug(max_tempo_aug=0.25, mixup=True),
+ self.export_guidance_pitch_aug(4),
+ self.export_guidance_velocity_aug(2),
]
+ def export_guidance_aug_fn(self, aug_fn):
+ """Transforms augmentation function to only apply to guidance seq"""
+
+ def _guidance_seq_aug_fn(
+ src: list,
+ _aug_fn: Callable,
+ pad_tok: str,
+ **kwargs,
+ ) -> list:
+
+ initial_seq_len = len(src)
+ if self.guidance_start_tok in src and self.guidance_end_tok in src:
+ guidance_seq = src[
+ src.index(self.guidance_start_tok)
+ + 1 : src.index(self.guidance_end_tok)
+ ]
+ seq = src[src.index(self.guidance_end_tok) + 1 :]
+
+ if len(guidance_seq) == 0:
+ return src
+ else:
+ return src
+
+ augmented_guidance_seq = _aug_fn(guidance_seq)
+ res = (
+ [self.guidance_start_tok]
+ + augmented_guidance_seq
+ + [self.guidance_end_tok]
+ + seq
+ )
+
+ # Pad or truncate to original sequence length as necessary
+ res = res[:initial_seq_len]
+ res += [pad_tok] * (initial_seq_len - len(res))
+
+ return res
+
+ return functools.partial(
+ _guidance_seq_aug_fn,
+ _aug_fn=aug_fn,
+ pad_tok=self.pad_tok,
+ )
+
+ def export_guidance_pitch_aug(self, max_pitch_aug: int):
+ """Apply pitch augmentation to the guidance sequence"""
+
+ return self.export_guidance_aug_fn(
+ self.export_pitch_aug(max_pitch_aug=max_pitch_aug)
+ )
+
+ def export_guidance_velocity_aug(self, max_num_aug_steps: int):
+ """Apply velocity augmentation to the guidance sequence"""
+
+ return self.export_guidance_aug_fn(
+ self.export_velocity_aug(max_num_aug_steps=max_num_aug_steps)
+ )
+
+ def export_guidance_tempo_aug(self, max_tempo_aug: int, mixup: bool):
+ """Apply tempo augmentation to the guidance sequence"""
+
+ return self.export_guidance_aug_fn(
+ self.export_tempo_aug(max_tempo_aug=max_tempo_aug, mixup=mixup)
+ )
+
def split(self, seq: list, seq_len: int):
def _process_chunk(_chunk: list):
# Ensure first token is note token
@@ -99,21 +257,28 @@ def _process_chunk(_chunk: list):
else:
_chunk.pop(0)
- # Insert inst_start_tok if it is missing (but required)
+ # Insert prompt_start_tok if it is missing (but required)
for idx in range(len(_chunk)):
tok = _chunk[idx]
- if tok == self.inst_start_tok:
+ if tok == self.prompt_start_tok:
break
- elif tok == self.inst_end_tok:
+ elif tok == self.prompt_end_tok:
if _chunk[0] == self.bos_tok:
- _chunk.insert(1, self.inst_start_tok)
+ _chunk.insert(1, self.prompt_start_tok)
else:
- _chunk.insert(0, self.inst_start_tok)
+ _chunk.insert(0, self.prompt_start_tok)
break
return _chunk
+ guidance = []
+ if self.guidance_start_tok in seq:
+ guidance_start = seq.index(self.guidance_start_tok)
+ guidance_end = seq.index(self.guidance_end_tok)
+ guidance = seq[guidance_start : guidance_end + 1]
+ seq = seq[guidance_end + 1 :]
+
prefix = []
while seq:
tok = seq[0]
@@ -122,7 +287,6 @@ def _process_chunk(_chunk: list):
else:
break
- # Generate chunks
chunks = [
_process_chunk(seq[idx : idx + seq_len])
for idx in range(0, len(seq) - 100, seq_len)
@@ -130,12 +294,10 @@ def _process_chunk(_chunk: list):
res = []
for chunk in chunks:
- if self.inst_start_tok not in chunk:
- continue
-
- sub_seq = prefix + chunk
+ sub_seq = guidance + prefix + chunk
sub_seq = sub_seq[:seq_len]
sub_seq += [self.pad_tok] * (seq_len - len(sub_seq))
+
res.append(sub_seq)
return res
diff --git a/aria/train.py b/aria/train.py
index 6d08371..6a2c74b 100644
--- a/aria/train.py
+++ b/aria/train.py
@@ -20,7 +20,7 @@
from aria.config import load_model_config
from aria.model import ModelConfig, TransformerLM
from ariautils.tokenizer import Tokenizer, AbsTokenizer, RelTokenizer
-from aria.tokenizer import SeparatedAbsTokenizer
+from aria.tokenizer import InferenceAbsTokenizer
from aria.datasets import (
TrainingDataset,
PretrainingDataset,
@@ -196,7 +196,7 @@ def get_optim(
num_epochs: int,
steps_per_epoch: int,
):
- LR = 3e-4
+ LR = 3e-5
END_RATIO = 0.1
WARMUP_STEPS = 200
@@ -363,8 +363,11 @@ def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0):
) # Transpose for CrossEntropyLoss
loss = loss_fn(logits, tgt)
- loss = loss * mask
- loss = loss[loss != 0.0].mean() # != 0.0 here is important
+ if mask.sum() == 0:
+ loss = (loss * 0).sum()
+ else:
+ loss = loss * mask
+ loss = loss[loss != 0.0].mean()
# Calculate statistics
loss_buffer.append(accelerator.gather(loss).mean(dim=0).item())
@@ -423,8 +426,11 @@ def val_loop(dataloader, _epoch: int):
logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss
loss = loss_fn(logits, tgt)
- loss = loss * mask
- loss = loss[loss != 0.0].mean()
+ if mask.sum() == 0:
+ loss = (loss * 0).sum()
+ else:
+ loss = loss * mask
+ loss = loss[loss != 0.0].mean()
# Logging
loss_buffer.append(accelerator.gather(loss).mean(dim=0).item())
@@ -530,8 +536,8 @@ def resume_train(
tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path)
if tokenizer_name == "abs":
tokenizer = AbsTokenizer()
- elif tokenizer_name == "separated_abs":
- tokenizer = SeparatedAbsTokenizer()
+ elif tokenizer_name == "inference_abs":
+ tokenizer = InferenceAbsTokenizer()
elif tokenizer_name == "rel":
tokenizer = RelTokenizer()
else:
@@ -656,8 +662,8 @@ def train(
tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path)
if tokenizer_name == "abs":
tokenizer = AbsTokenizer()
- elif tokenizer_name == "separated_abs":
- tokenizer = SeparatedAbsTokenizer()
+ elif tokenizer_name == "inference_abs":
+ tokenizer = InferenceAbsTokenizer()
elif tokenizer_name == "rel":
tokenizer = RelTokenizer()
else:
diff --git a/config/config.json b/config/config.json
index dbbea32..7e0a736 100644
--- a/config/config.json
+++ b/config/config.json
@@ -2,7 +2,7 @@
"data": {
"tests": {
"note_density_in_interval":{
- "run": true,
+ "run": false,
"args": {
"test_params_list":
[
@@ -30,20 +30,20 @@
}
},
"note_timing_entropy":{
- "run": true,
+ "run": false,
"args": {
"min_length_entropy": 2.5,
"min_onset_delta_entropy": 0.0
}
},
"note_pitch_entropy":{
- "run": true,
+ "run": false,
"args": {
"min_entropy": 3.0
}
},
"unique_pitch_count_in_interval":{
- "run": true,
+ "run": false,
"args": {
"test_params_list":
[
@@ -54,58 +54,59 @@
}
},
"unique_pitch_count":{
- "run": true,
+ "run": false,
"args": {
"min_num_unique_pitches": 12
}
},
"silent_interval":{
- "run": true,
+ "run": false,
"args": {
"max_silence_s": 20
}
},
"mean_note_velocity":{
- "run": true,
+ "run": false,
"args": {
"min_mean_velocity": 20,
"max_mean_velocity": 105
}
},
"max_programs":{
- "run": true,
+ "run": false,
"args": {
"max": 12
}
},
"max_instruments":{
- "run": true,
+ "run": false,
"args": {
"max": 7
}
},
"total_note_frequency":{
- "run": true,
+ "run": false,
"args": {
"min_per_second": 1.5,
"max_per_second": 30
}
},
"note_frequency_per_instrument":{
- "run": true,
+ "run": false,
"args": {
"min_per_second": 1.0,
"max_per_second": 25
}
},
- "min_length":{
- "run": true,
+ "length":{
+ "run": false,
"args": {
- "min_seconds": 30
+ "min_length_s": 30,
+ "max_length_s": 7200
}
},
"repetitive_content":{
- "run": true,
+ "run": false,
"args": {
"min_length_m": 20,
"num_chunks": 5,
@@ -138,6 +139,10 @@
},
"metadata": {
"functions": {
+ "aria_midi_json": {
+ "run": true,
+ "args": {}
+ },
"composer_filename": {
"run": false,
"args": {
@@ -171,16 +176,17 @@
}
},
"finetuning": {
+ "guidance_prob": 0.5,
"min_noisy_interval_ms": 5000,
"max_noisy_interval_ms": 60000,
"min_clean_interval_ms": 60000,
"max_clean_interval_ms": 200000,
"noising": {
- "activation_prob": 0.95,
+ "activation_prob": 0.8,
"remove_notes": {
- "activation_prob": 0.75,
- "min_ratio": 0.1,
- "max_ratio": 0.4
+ "activation_prob": 0.5,
+ "min_ratio": 0.0,
+ "max_ratio": 0.3
},
"adjust_velocity": {
"activation_prob": 0.3,
@@ -190,11 +196,11 @@
"min_ratio": 0.30
},
"adjust_onsets": {
- "activation_prob": 0.5,
- "min_adjust_s": 0.03,
+ "activation_prob": 0.25,
+ "min_adjust_s": 0.01,
"max_adjust_s": 0.07,
"max_ratio": 0.15,
- "min_ratio": 0.5
+ "min_ratio": 0.3
},
"quantize_onsets": {
"activation_prob": 0.15,
@@ -203,6 +209,17 @@
"max_vel_delta": 45
}
}
+ }
+ },
+ "tokenizer": {
+ "inference_abs": {
+ "guidance": {
+ "min_ms": 5000,
+ "max_ms": 30000
+ }
+
+
}
}
+
}