Skip to content

Commit

Permalink
Update tokenizers (#80)
Browse files Browse the repository at this point in the history
* not working

* fix dict_to_midi

* aug not working

* add mp spawn warnings

* abs soundfont path

* add data aug and tests

* update .gitignore

* update make format

* update dataset and tokenizers to use rel/abs

* fix sampling cli

* upgrade train.py to use both tokenizers

* update req

* fix acc_convert

* fix entrypoint

* add comment

* fix test

* small fixes

* rmv compile

* update profile flops

* typo

---------

Co-authored-by: Louis <[email protected]>
  • Loading branch information
loubbrad and Louis authored Dec 14, 2023
1 parent 8205d85 commit 7fc0908
Show file tree
Hide file tree
Showing 18 changed files with 1,448 additions and 341 deletions.
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ cython_debug/

# Project specific
tools/
data/
./data/
fluidsynth/
*.DS_Store
tests/test_results
lightning_logs/
.vscode/
.vscode/
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ test:
.PHONY: format
format:
black --line-length 80 ./aria
black --line-length 80 ./tests
87 changes: 71 additions & 16 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from pathlib import Path
from typing import Callable, Iterable
from collections import defaultdict
from multiprocessing import Pool, Process, Queue
from multiprocessing import Pool, Process, Queue, get_start_method

from aria.config import load_config
from aria.tokenizer import Tokenizer, TokenizerLazy
from aria.tokenizer import Tokenizer
from aria.data.midi import MidiDict, get_test_fn


Expand Down Expand Up @@ -277,6 +277,11 @@ def _get_mididicts_mp(_paths):
yield mid_dict

logger = setup_logger()
if get_start_method() == "spawn":
logger.warning(
'The current multiprocessing start method is "spawn", this '
"will slow down dataset building"
)

paths = []
if recur is True:
Expand Down Expand Up @@ -332,6 +337,27 @@ def init_epoch(self, epoch_num: int | None = None):
def build(**kwargs):
raise NotImplementedError

@classmethod
def get_config_from_path(cls, path: str):
"""Returns config dict from dataset file/directory.
If a directory provided, it is assumed t"""

def _get_config_from_fp(_path):
# Finetuning Dataset
return FinetuningDataset.get_config_from_path(path=_path)

def _get_config_from_dir(_path):
# Pretraining Dataset
return PretrainingDataset.get_config_from_path(path=_path)

if os.path.isfile(path):
return _get_config_from_fp(path)
elif os.path.isdir(path):
return _get_config_from_dir(path)
else:
raise FileNotFoundError("Invalid path provided")

def close(self):
if self.file_buff:
self.file_buff.close()
Expand Down Expand Up @@ -362,6 +388,7 @@ def _format(tok):
src = seq
tgt = seq[1:] + [self.tokenizer.pad_tok]

# Fine till here
return self.tokenizer.encode(src), self.tokenizer.encode(tgt)

def check_config(self):
Expand Down Expand Up @@ -475,10 +502,6 @@ def get_seqs(
tokenizer: Tokenizer,
midi_dict_iter: Iterable,
):
# TokenizerLazy is the only supported tokenizer due to the truncate
# and stride logic in _get_tokenized_seqs
assert isinstance(tokenizer, TokenizerLazy), "Unsupported tokenizer"

iq = Queue()
oq = Queue()

Expand Down Expand Up @@ -520,6 +543,19 @@ def __init__(self, dir_path: str, tokenizer: Tokenizer):
def __len__(self):
return len(self.index)

@classmethod
def get_config_from_path(cls, path: str):
"""Returns config dict from dataset directory.
Note that this will return the config corresponding to epoch0.jsonl.
"""
assert os.path.isdir(path), "directory not found"
assert os.path.isfile(
epoch0_path := os.path.join(path, "epoch0.jsonl")
), "epoch file not found"
with open(epoch0_path) as f:
return json.loads(f.readline())

def init_epoch(self, idx: int | None = None):
if idx is None:
idx = self.curr_epoch + 1
Expand Down Expand Up @@ -551,7 +587,6 @@ def _get_epoch_files(self):
os.path.join(self.dir_path, file_name) for file_name in file_names
]

# Check correct formatting
present_epochs = []
for file_name in file_names:
if not re.match(r"^epoch\d+\.jsonl$", file_name):
Expand Down Expand Up @@ -606,6 +641,7 @@ def _build_epoch(_save_path, _midi_dataset):
)

buffer = []
# TODO: Profile why mp takes a while to spit up
for entry in get_seqs(tokenizer, _midi_dataset):
if entry is not None:
buffer += entry
Expand All @@ -617,6 +653,11 @@ def _build_epoch(_save_path, _midi_dataset):
logger = setup_logger()
assert max_seq_len > 0, "max_seq_len must be greater than 0"
assert num_epochs > 0, "num_epochs must be greater than 0"
if get_start_method() == "spawn":
logger.warning(
'The current multiprocessing start method is "spawn", this '
"will slow down dataset building"
)

if os.path.isdir(save_dir) and os.listdir(save_dir):
print(
Expand All @@ -632,29 +673,31 @@ def _build_epoch(_save_path, _midi_dataset):
if not os.path.exists(save_dir):
os.mkdir(save_dir)

# TODO: This is very slow right now
if not midi_dataset:
midi_dataset = MidiDataset.load(midi_dataset_path)
else:
Exception("Must provide either midi_dataset or midi_dataset_path")

logger.info(
f"Building PretrainingDataset with config: "
f"max_seq_len={max_seq_len} "
f"max_seq_len={max_seq_len}, "
f"tokenizer_name={tokenizer.name}"
)
_num_proc = os.cpu_count()
if 2 * _num_proc > len(midi_dataset):
logger.warning(
"Number of processes is close to the number of MidiDicts "
"in the dataset. This can result in shuffling not working "
"as intended when building different epochs."
"as intended when building different epochs"
)
for idx in range(num_epochs):
logger.info(f"Building epoch {idx}/{num_epochs - 1}...")
_build_epoch(
_save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"),
_midi_dataset=midi_dataset,
)
# TODO: This is very slow for large datasets
midi_dataset.shuffle()

logger.info(
Expand All @@ -679,6 +722,13 @@ def __init__(self, file_path: str, tokenizer: Tokenizer):
def __len__(self):
return len(self.index)

@classmethod
def get_config_from_path(cls, path: str):
"""Returns config dict from dataset file"""
assert os.path.isfile(path), "dataset file not found"
with open(path) as f:
return json.loads(f.readline())

# Do nothing in this case
def init_epoch(self, idx: int | None = None):
self.logger.info(f"Successful initiated epoch {idx}")
Expand All @@ -693,8 +743,9 @@ def build(
midi_dataset: MidiDataset = None,
midi_dataset_path: str = None,
):
"""Builds and returns PretrainingDataset."""
"""Builds and returns FinetuningDataset."""

# This function should be made more robust in the future
def _truncate_and_stride(_tokenized_seq: list):
prefix = []

Expand All @@ -720,13 +771,12 @@ def _truncate_and_stride(_tokenized_seq: list):

# Checks that next start note will not be cutoff midway
while idx < seq_len:
# Break loop when a non 'wait' or 'dur' is seen
if _tokenized_seq[idx] in tokenizer.special_tokens:
break
elif _tokenized_seq[idx][0] in {"wait", "dur"}:
idx += 1
else:
elif _tokenized_seq[idx][0] in tokenizer.instruments_wd:
break
else:
idx += 1

# Add the last sequence
_seq = prefix + _tokenized_seq[idx : idx + max_seq_len - prefix_len]
Expand All @@ -748,8 +798,8 @@ def _build(_midi_dataset):
)
logger.info(
f"Building FinetuningDataset with config: "
f"tokenizer_name=tokenizer.name"
f"max_seq_len={max_seq_len} "
f"tokenizer_name={tokenizer.name}, "
f"max_seq_len={max_seq_len}, "
f"stride_len={stride_len}"
)

Expand All @@ -760,6 +810,11 @@ def _build(_midi_dataset):

logger = setup_logger()
assert max_seq_len > 0, "max_seq_len must be greater than 0"
if get_start_method() == "spawn":
logger.warning(
'The current multiprocessing start method is "spawn", this '
"will slow down dataset building"
)

if os.path.isfile(save_path):
print(
Expand Down
10 changes: 9 additions & 1 deletion aria/data/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,14 @@ def dict_to_midi(mid_data: dict):
Returns:
mido.MidiFile: The MIDI parsed from the input data.
"""

# Magic sorting function
def _sort_fn(msg):
if hasattr(msg, "velocity"):
return (msg.time, msg.velocity)
else:
return (msg.time, 1000)

assert mid_data.keys() == {
"meta_msgs",
"tempo_msgs",
Expand Down Expand Up @@ -475,7 +483,7 @@ def dict_to_midi(mid_data: dict):
)

# Sort and convert from abs_time -> delta_time
track = sorted(track, key=lambda msg: msg.time)
track = sorted(track, key=_sort_fn)
tick = 0
for msg in track:
msg.time -= tick
Expand Down
55 changes: 46 additions & 9 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
import warnings


# TODO: Implement a way of inferring the tokenizer name automatically
def _parse_sample_args():
argp = argparse.ArgumentParser(prog="aria sample")
argp.add_argument(
"-tok",
help="name of tokenizer",
choices=["abs", "rel"],
required=True,
)
argp.add_argument("-m", help="name of model config file")
argp.add_argument("-c", help="path to model checkpoint")
argp.add_argument("-p", help="path to midi file")
Expand Down Expand Up @@ -89,14 +96,16 @@ def _get_midi_path(midi_path: str | None) -> str:
return midi_path


# TODO: Add arg for supressing the audio conversion, and commands for changing
# the sampling params from the cli
def sample(args):
"""Entrypoint for sampling"""

import torch
from torch.cuda import is_available as cuda_is_available
from aria.model import TransformerLM, ModelConfig
from aria.config import load_model_config
from aria.tokenizer import TokenizerLazy
from aria.tokenizer import RelTokenizer, AbsTokenizer
from aria.sample import greedy_sample
from aria.data.midi import MidiDict
from aria.utils import midi_to_audio
Expand All @@ -121,11 +130,21 @@ def sample(args):
truncate_len = args.trunc
force_end = args.e

tokenizer = TokenizerLazy(return_tensors=True)
if args.tok == "abs":
tokenizer = AbsTokenizer(return_tensors=True)
elif args.tok == "rel":
tokenizer = RelTokenizer(return_tensors=True)

model_config = ModelConfig(**load_model_config(model_name))
model_config.set_vocab_size(tokenizer.vocab_size)
model = TransformerLM(model_config).to(device)
model.load_state_dict(model_state)
try:
model.load_state_dict(model_state)
except:
print(
"Failed to load state_dict, this could be because the wrong "
"tokenizer was selected"
)
if args.q:
if device.type != "cpu":
warnings.warn(
Expand Down Expand Up @@ -244,17 +263,24 @@ def _parse_pretrain_dataset_args():
argp = argparse.ArgumentParser(prog="aria pretrain-dataset")
argp.add_argument("load_path", help="path midi_dict dataset")
argp.add_argument("save_dir", help="path to save dataset")
argp.add_argument(
"tokenizer_name", help="tokenizer name", choices=["abs", "rel"]
)
argp.add_argument("-l", help="max sequence length", type=int, default=2048)
argp.add_argument("-e", help="num epochs", type=int, default=1)

return argp.parse_args(sys.argv[2:])


def build_pretraining_dataset(args):
from aria.tokenizer import TokenizerLazy
from aria.tokenizer import AbsTokenizer, RelTokenizer
from aria.data.datasets import PretrainingDataset

tokenizer = TokenizerLazy()
if args.tokenizer_name == "abs":
tokenizer = AbsTokenizer()
elif args.tokenizer_name == "rel":
tokenizer = RelTokenizer()

dataset = PretrainingDataset.build(
tokenizer=tokenizer,
save_dir=args.save_dir,
Expand All @@ -268,18 +294,24 @@ def _parse_finetune_dataset_args():
argp = argparse.ArgumentParser(prog="aria finetune-dataset")
argp.add_argument("load_path", help="path midi_dict dataset")
argp.add_argument("save_path", help="path to save dataset")
argp.add_argument(
"tokenizer_name", help="tokenizer name", choices=["abs", "rel"]
)
argp.add_argument("-l", help="max sequence length", type=int, default=2048)
argp.add_argument("-s", help="stride length", type=int, default=512)

return argp.parse_args(sys.argv[2:])


# This might not be correct - double check
def build_finetune_dataset(args):
from aria.tokenizer import TokenizerLazy
from aria.tokenizer import AbsTokenizer, RelTokenizer
from aria.data.datasets import FinetuningDataset

tokenizer = TokenizerLazy()
if args.tokenizer_name == "abs":
tokenizer = AbsTokenizer()
elif args.tokenizer_name == "rel":
tokenizer = RelTokenizer()

dataset = FinetuningDataset.build(
tokenizer=tokenizer,
save_path=args.save_path,
Expand All @@ -295,7 +327,12 @@ def main():
parser.add_argument(
"command",
help="command to run",
choices=("sample", "midi-dataset", "pretrain-dataset"),
choices=(
"sample",
"midi-dataset",
"pretrain-dataset",
"finetune-dataset",
),
)

# parse_args defaults to [1:] for args, but you need to
Expand Down
2 changes: 1 addition & 1 deletion aria/sample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Contains generation/sampling code"""
# This file contains code from https://github.com/facebookresearch/llama which
# is available under the following licence:
# is available under the following license:

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU
Expand Down
2 changes: 1 addition & 1 deletion aria/tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .tokenizer import Tokenizer, TokenizerLazy
from .tokenizer import Tokenizer, RelTokenizer, AbsTokenizer
Loading

0 comments on commit 7fc0908

Please sign in to comment.