Skip to content

Commit

Permalink
Migrate to ariautils and adjust MidiDict tests. (#116)
Browse files Browse the repository at this point in the history
* migrate to ariautils mididict

* upgrade to ariautils tokenizers and change MidiDict test settings

* fix
  • Loading branch information
loubbrad authored Dec 6, 2024
1 parent 31b07cb commit 3c03782
Show file tree
Hide file tree
Showing 12 changed files with 174 additions and 2,921 deletions.
Empty file removed aria/data/__init__.py
Empty file.
1,051 changes: 0 additions & 1,051 deletions aria/data/midi.py

This file was deleted.

57 changes: 43 additions & 14 deletions aria/data/datasets.py → aria/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,24 @@
import torch
import functools
import shutil
import multiprocessing

from mido.midifiles.units import second2tick
from pathlib import Path
from typing import List
from copy import deepcopy
from typing import Callable, Iterable
from collections import defaultdict
from multiprocessing import Pool, get_start_method

from aria.config import load_config
from aria.tokenizer import Tokenizer, SeparatedAbsTokenizer
from aria.data.midi import MidiDict, get_test_fn, get_duration_ms
from aria.tokenizer import SeparatedAbsTokenizer
from ariautils.tokenizer import Tokenizer
from ariautils.midi import (
MidiDict,
get_test_fn,
get_duration_ms,
get_metadata_fn,
)


def setup_logger():
Expand All @@ -45,6 +51,8 @@ def setup_logger():
return logger


# TODO: Change the build settings so that it saves the config used for tests
# as json on the first line.
class MidiDataset:
"""Container for datasets of MidiDict objects.
Expand Down Expand Up @@ -253,6 +261,23 @@ def _get_mididict(path: Path):
# (bool, (MidiDict, str, Path)) where the first element determines if the
# loaded MidiDict was succesfully preprocessed.

def _add_metadata(_mid_dict: MidiDict):
for metadata_process_name, metadata_process_config in config[
"metadata"
]["functions"].items():
if metadata_process_config["run"] is True:
metadata_fn = get_metadata_fn(
metadata_process_name=metadata_process_name
)
fn_args: dict = metadata_process_config["args"]

collected_metadata = metadata_fn(_mid_dict, **fn_args)
if collected_metadata:
for k, v in collected_metadata.items():
_mid_dict.metadata[k] = v

return _mid_dict

def _run_tests(_mid_dict: MidiDict):
failed_tests = []
for test_name, test_config in config["tests"].items():
Expand Down Expand Up @@ -291,15 +316,17 @@ def _preprocess_mididict(_mid_dict: MidiDict):
logger.error(f"Failed to load MIDI at {path}: {e}")
return False, None

mid_hash = mid_dict.calculate_hash()
failed_tests = _run_tests(mid_dict)
if failed_tests:
logger.info(
f"MIDI at {path} failed preprocessing tests: {failed_tests} "
)
return False, None
else:
return True, (_preprocess_mididict(mid_dict), mid_hash, path)
mid_dict = _preprocess_mididict(mid_dict)
mid_dict = _add_metadata(mid_dict)
mid_hash = mid_dict.calculate_hash()
return True, (mid_dict, mid_hash, path)


def build_mididict_dataset(
Expand Down Expand Up @@ -333,7 +360,7 @@ def build_mididict_dataset(
"""

def _get_mididicts_mp(_paths):
with Pool() as pool:
with multiprocessing.Pool() as pool:
results = pool.imap_unordered(_get_mididict, _paths)
seen_hashes = defaultdict(list)
dupe_cnt = 0
Expand Down Expand Up @@ -365,7 +392,7 @@ def _get_mididicts_mp(_paths):
)

logger = setup_logger()
if get_start_method() == "spawn":
if multiprocessing.get_start_method() == "spawn":
logger.warning(
'The current multiprocessing start method is "spawn", this '
"will slow down dataset building"
Expand Down Expand Up @@ -577,7 +604,11 @@ def _format(tok):
tgt = seq[1:] + [self.tokenizer.pad_tok]
mask = self.get_loss_mask(tgt)

return self.tokenizer.encode(src), self.tokenizer.encode(tgt), mask
return (
torch.tensor(self.tokenizer.encode(src)),
torch.tensor(self.tokenizer.encode(tgt)),
mask,
)

def check_config(self, epoch_load_path: str):
def _check_config():
Expand Down Expand Up @@ -689,16 +720,14 @@ def get_seqs(
tokenizer: Tokenizer,
midi_dict_iter: Iterable,
):
num_proc = os.cpu_count()

# Can't pickle geneator object when start method is spawn
if get_start_method() == "spawn":
if multiprocessing.get_start_method() == "spawn":
logging.info(
"Converting generator to list due to multiprocessing start method"
)
midi_dict_iter = [_ for _ in midi_dict_iter]

with Pool() as pool:
with multiprocessing.Pool() as pool:
results = pool.imap_unordered(
functools.partial(_get_seqs, _tokenizer=tokenizer), midi_dict_iter
)
Expand Down Expand Up @@ -801,7 +830,7 @@ def _build_epoch(_save_path, _midi_dataset):
logger = setup_logger()
assert max_seq_len > 0, "max_seq_len must be greater than 0"
assert num_epochs > 0, "num_epochs must be greater than 0"
if get_start_method() == "spawn":
if multiprocessing.get_start_method() == "spawn":
logger.warning(
'The current multiprocessing start method is "spawn", this '
"will slow down dataset building"
Expand Down Expand Up @@ -1229,7 +1258,7 @@ def _build_epoch(_save_path, _midi_dataset):
assert os.path.isfile(clean_dataset_path), "file not found"
for __path in noisy_dataset_paths:
assert os.path.isfile(__path), "file not found"
if get_start_method() == "spawn":
if multiprocessing.get_start_method() == "spawn":
logger.warning(
'The current multiprocessing start method is "spawn", this '
"will slow down dataset building"
Expand Down
24 changes: 12 additions & 12 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,18 @@ def sample(args):
from aria.inference import TransformerLM
from aria.model import ModelConfig
from aria.config import load_model_config, load_config
from aria.tokenizer import AbsTokenizer, SeparatedAbsTokenizer
from ariautils.tokenizer import AbsTokenizer
from aria.tokenizer import SeparatedAbsTokenizer
from aria.sample import greedy_sample, get_pt_prompt, get_inst_prompt
from aria.data.midi import MidiDict
from aria.data.datasets import _noise_midi_dict
from aria.utils import midi_to_audio, _load_weight
from ariautils.midi import MidiDict
from aria.utils import _load_weight

if not cuda_is_available():
raise Exception("CUDA device is not available.")

model_state = _load_weight(args.c, "cuda")
model_state = {
k: v for k, v in model_state.items() if "rotary_emb" not in k
k.replace("_orig_mod.", ""): v for k, v in model_state.items()
}

manual_metadata = {k: v for k, v in args.metadata} if args.metadata else {}
Expand All @@ -117,9 +117,9 @@ def sample(args):
model_name = args.m

if args.pt == True:
tokenizer = AbsTokenizer(return_tensors=True)
tokenizer = AbsTokenizer()
else:
tokenizer = SeparatedAbsTokenizer(return_tensors=True)
tokenizer = SeparatedAbsTokenizer()

model_config = ModelConfig(**load_model_config(model_name))
model_config.set_vocab_size(tokenizer.vocab_size)
Expand Down Expand Up @@ -233,7 +233,7 @@ def _parse_midi_dataset_args():

def build_midi_dataset(args):
"""Entrypoint for building MidiDatasets from a directory"""
from aria.data.datasets import MidiDataset
from aria.datasets import MidiDataset

assert args.dir, "build directory must be provided"
manual_metadata = {k: v for k, v in args.metadata} if args.metadata else {}
Expand Down Expand Up @@ -269,8 +269,8 @@ def _parse_pretrain_dataset_args():


def build_pretraining_dataset(args):
from aria.tokenizer import AbsTokenizer, RelTokenizer
from aria.data.datasets import PretrainingDataset
from ariautils.tokenizer import AbsTokenizer, RelTokenizer
from aria.datasets import PretrainingDataset

if args.tokenizer_name == "abs":
tokenizer = AbsTokenizer()
Expand Down Expand Up @@ -306,10 +306,10 @@ def _parse_finetune_dataset_args():

def build_finetune_dataset(args):
from aria.tokenizer import SeparatedAbsTokenizer
from aria.data.datasets import FinetuningDataset
from aria.datasets import FinetuningDataset

tokenizer = SeparatedAbsTokenizer()
dataset = FinetuningDataset.build(
FinetuningDataset.build(
tokenizer=tokenizer,
save_dir=args.save_dir,
max_seq_len=args.l,
Expand Down
24 changes: 15 additions & 9 deletions aria/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@
from tqdm import tqdm

from aria.inference import TransformerLM
from aria.tokenizer import Tokenizer
from aria.data.midi import MidiDict
from ariautils.tokenizer import Tokenizer
from ariautils.midi import MidiDict

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True


@torch.no_grad()
@torch.inference_mode()
def prefill(model, idxs: torch.Tensor, input_pos: torch.Tensor):
logits = model.forward(idxs=idxs, input_pos=input_pos)[:, -1]

return logits


@torch.no_grad()
@torch.inference_mode()
def decode_one(model, idxs: torch.Tensor, input_pos: torch.Tensor):
logits = model.forward(idxs=idxs, input_pos=input_pos)[:, -1]

Expand Down Expand Up @@ -63,11 +63,12 @@ def update_seq_ids_(


# TODO: Add CFG back into this when working
# TODO: Check that unexpected instrument warnings still working
@torch.autocast(
"cuda",
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
)
@torch.no_grad()
@torch.inference_mode()
def greedy_sample(
model: TransformerLM,
tokenizer: Tokenizer,
Expand All @@ -81,7 +82,6 @@ def greedy_sample(
):
"""Performs greedy (top_p) auto-regressive sampling on a batch of prompts."""

assert tokenizer.return_tensors is True, "tokenizer must return tensors."
if force_end:
assert max_new_tokens > 130, "prompt too long to use force_end=True"

Expand All @@ -92,7 +92,9 @@ def greedy_sample(
total_len = _prompt_len + max_new_tokens
seq = torch.stack(
[
tokenizer.encode(p + [tokenizer.pad_tok] * (total_len - len(p)))
torch.tensor(
tokenizer.encode(p + [tokenizer.pad_tok] * (total_len - len(p)))
)
for p in prompts
]
).cuda()
Expand Down Expand Up @@ -199,8 +201,7 @@ def get_inst_prompt(
truncate_len: int,
noise: bool,
):
from aria.data.datasets import _noise_midi_dict
from aria.data.midi import MidiDict
from aria.datasets import _noise_midi_dict
from aria.config import load_config

midi_dict.metadata["noisy_intervals"] = [[0, truncate_len * 1e3]]
Expand All @@ -214,6 +215,11 @@ def get_inst_prompt(

if tokenizer.inst_end_tok in prompt_seq:
prompt_seq = prompt_seq[: prompt_seq.index(tokenizer.inst_end_tok) + 1]
elif tokenizer.eos_tok in prompt_seq:
# TODO: This is a workaround for a bug where </INST> is not inserted if
# the pieces ends.
assert prompt_seq[-1] == tokenizer.eos_tok
prompt_seq[-1] = tokenizer.inst_end_tok
else:
print("No notes found in prompt region")
prompt_seq = prompt_seq[: prompt_seq.index(tokenizer.bos_tok) + 1]
Expand Down
Loading

0 comments on commit 3c03782

Please sign in to comment.