Skip to content

Commit

Permalink
Fix training and inference (#24)
Browse files Browse the repository at this point in the history
* fix

* fix batched

* adjust

* transfer

* add cleanup

* working

* move_to_node

* fix config

* fix msg

* add synth dataset gen

* remote changes

* local changes

* add scripts

* fix audio params
  • Loading branch information
loubbrad authored Apr 9, 2024
1 parent 50f0b60 commit 1c9d666
Show file tree
Hide file tree
Showing 21 changed files with 1,064 additions and 302 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# data files
*.csv
*.json
*.xls
*.xlsx
*.pkl
Expand Down
21 changes: 13 additions & 8 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,15 @@ def __init__(
max_snr: int = 50,
max_dist_gain: int = 25,
min_dist_gain: int = 0,
noise_ratio: float = 0.95,
reverb_ratio: float = 0.95,
noise_ratio: float = 0.75,
reverb_ratio: float = 0.75,
applause_ratio: float = 0.01,
bandpass_ratio: float = 0.15,
distort_ratio: float = 0.15,
reduce_ratio: float = 0.01,
detune_ratio: float = 0.1,
detune_max_shift: float = 0.15,
spec_aug_ratio: float = 0.5,
detune_ratio: float = 0.0,
detune_max_shift: float = 0.0,
spec_aug_ratio: float = 0.9,
):
super().__init__()
self.tokenizer = AmtTokenizer()
Expand All @@ -223,7 +223,10 @@ def __init__(
self.detune_ratio = detune_ratio
self.detune_max_shift = detune_max_shift
self.spec_aug_ratio = spec_aug_ratio
self.reduction_resample_rate = 6000 # Hardcoded?

self.time_mask_param = 2500
self.freq_mask_param = 15
self.reduction_resample_rate = 6000

# Audio aug
impulse_paths = self._get_paths(
Expand Down Expand Up @@ -263,10 +266,10 @@ def __init__(
)
self.spec_aug = torch.nn.Sequential(
torchaudio.transforms.FrequencyMasking(
freq_mask_param=15, iid_masks=True
freq_mask_param=self.freq_mask_param, iid_masks=True
),
torchaudio.transforms.TimeMasking(
time_mask_param=1000, iid_masks=True
time_mask_param=self.time_mask_param, iid_masks=True
),
)

Expand All @@ -281,6 +284,8 @@ def get_params(self):
"detune_ratio": self.detune_ratio,
"detune_max_shift": self.detune_max_shift,
"spec_aug_ratio": self.spec_aug_ratio,
"time_mask_param": self.time_mask_param,
"freq_mask_param": self.freq_mask_param,
}

def _get_paths(self, dir_path):
Expand Down
92 changes: 84 additions & 8 deletions amt/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import mmap
import os
import io
import random
import shlex
import base64
import shutil
import orjson
Expand All @@ -16,14 +18,22 @@
from amt.audio import pad_or_trim


# Occasionally the worker util goes to 0 for some reason, debug this
def _check_onset_threshold(seq: list, onset: int):
for tok_1, tok_2 in zip(seq, seq[1:]):
if isinstance(tok_1, tuple) and tok_1[0] in ("on", "off"):
_onset = tok_2[1]
if _onset > onset:
return True

return False


def get_wav_mid_segments(
audio_path: str,
mid_path: str = "",
return_json: bool = False,
stride_factor: int | None = None,
pad_last=False,
):
"""This function yields tuples of matched log mel spectrograms and
tokenized sequences (np.array, list). If it is given only an audio path
Expand Down Expand Up @@ -61,10 +71,12 @@ def get_wav_mid_segments(

# Create features
total_samples = wav.shape[-1]
pad_factor = 2 if pad_last is True else 1
res = []
for idx in range(
0,
total_samples - (num_samples - num_samples // stride_factor),
total_samples
- (num_samples - pad_factor * (num_samples // stride_factor)),
num_samples // stride_factor,
):
audio_feature = pad_or_trim(wav[idx:], length=num_samples)
Expand All @@ -75,6 +87,12 @@ def get_wav_mid_segments(
end_ms=(idx + num_samples) / samples_per_ms,
max_pedal_len_ms=10000,
)

# Hardcoded to 2.5s
if _check_onset_threshold(mid_feature, 2500) is False:
print("No note messages after 2.5s - skipping")
continue

else:
mid_feature = []

Expand All @@ -86,6 +104,56 @@ def get_wav_mid_segments(
return res


def pianoteq_cmd_fn(mid_path: str, wav_path: str):
presets = [
"C. Bechstein",
"C. Bechstein Close Mic",
"C. Bechstein Under Lid",
"C. Bechstein 440",
"C. Bechstein Recording",
"C. Bechstein Werckmeister III",
"C. Bechstein Neidhardt III",
"C. Bechstein mesotonic",
"C. Bechstein well tempered",
"HB Steinway D Blues",
"HB Steinway D Pop",
"HB Steinway D New Age",
"HB Steinway D Prelude",
"HB Steinway D Felt I",
"HB Steinway D Felt II",
"HB Steinway Model D",
"HB Steinway D Classical Recording",
"HB Steinway D Jazz Recording",
"HB Steinway D Chamber Recording",
"HB Steinway D Studio Recording",
"HB Steinway D Intimate",
"HB Steinway D Cinematic",
"HB Steinway D Close Mic Classical",
"HB Steinway D Close Mic Jazz",
"HB Steinway D Player Wide",
"HB Steinway D Player Clean",
"HB Steinway D Trio",
"HB Steinway D Duo",
"HB Steinway D Cabaret",
"HB Steinway D Bright",
"HB Steinway D Hyper Bright",
"HB Steinway D Prepared",
"HB Steinway D Honky Tonk",
]

preset = random.choice(presets)

# Safely quote the preset name, MIDI path, and WAV path
safe_preset = shlex.quote(preset)
safe_mid_path = shlex.quote(mid_path)
safe_wav_path = shlex.quote(wav_path)

# Construct the command
command = f"/home/mchorse/pianoteq/x86-64bit/Pianoteq\\ 8\\ STAGE --preset {safe_preset} --midi {safe_mid_path} --wav {safe_wav_path}"

return command


def write_features(audio_path: str, mid_path: str, save_path: str):
features = get_wav_mid_segments(
audio_path=audio_path,
Expand Down Expand Up @@ -121,7 +189,7 @@ def write_synth_features(cli_cmd_fn: Callable, mid_path: str, save_path: str):

try:
get_synth_audio(
cli_cmd=cli_cmd_fn, mid_path=mid_path, wav_path=audio_path_temp
cli_cmd_fn=cli_cmd_fn, mid_path=mid_path, wav_path=audio_path_temp
)
except:
if os.path.isfile(audio_path_temp):
Expand All @@ -133,7 +201,11 @@ def write_synth_features(cli_cmd_fn: Callable, mid_path: str, save_path: str):
mid_path=mid_path,
return_json=False,
)
os.remove(audio_path_temp)

if os.path.isfile(audio_path_temp):
os.remove(audio_path_temp)

print(f"Found {len(features)}")

with open(save_path, mode="a") as file:
for wav, seq in features:
Expand Down Expand Up @@ -174,7 +246,11 @@ def build_synth_worker_fn(

while not load_path_queue.empty():
mid_path = load_path_queue.get()
write_synth_features(cli_cmd, mid_path, worker_save_path)
try:
write_synth_features(cli_cmd, mid_path, worker_save_path)
except Exception as e:
print("Failed")
print(e)

save_path_queue.put(worker_save_path)

Expand Down Expand Up @@ -239,7 +315,7 @@ def _format(tok):
seq_len=self.config["max_seq_len"],
)

return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt)
return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt), idx

def _build_index(self):
self.file_mmap.seek(0)
Expand All @@ -254,7 +330,7 @@ def _build_index(self):

return index

def _save_index(self, index: list[int], save_path: str):
def _save_index(self, index: list, save_path: str):
with open(save_path, "w") as file:
for idx in index:
file.write(f"{idx}\n")
Expand Down Expand Up @@ -325,7 +401,7 @@ def build(
]
else:
# Build synthetic dataset
assert len(load_paths[0]) == 1, "Invalid load paths"
assert isinstance(load_paths[0], str), "Invalid load paths"
print("Building synthetic dataset")
worker_processes = [
Process(
Expand Down
78 changes: 27 additions & 51 deletions amt/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,32 @@ def midi_to_hz(note, shift=0):
# return (a / 32) * (2 ** ((note - 9) / 12))


def get_matched_files(est_dir: str, ref_dir: str):
# We assume that the files have the same path relative to their directory

res = []
est_paths = glob.glob(os.path.join(est_dir, "**/*.mid"), recursive=True)
print(f"found {len(est_paths)} est files")

for est_path in est_paths:
est_rel_path = os.path.relpath(est_path, est_dir)
ref_path = os.path.join(
ref_dir, os.path.splitext(est_rel_path)[0] + ".midi"
)
if os.path.isfile(ref_path):
res.append((est_path, ref_path))

print(f"found {len(res)} matched est-ref pairs")

return res


def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
"""
Evaluate the estimated pitches against the reference pitches using mir_eval.
"""
# Evaluate the estimated pitches against the reference pitches
ref_midi_files = glob.glob(f"{ref_dir}/*.mid*")
est_midi_files = glob.glob(f"{est_dir}/*.mid*")

est_ref_pairs = []
for est_fpath in est_midi_files:
ref_fpath = os.path.join(ref_dir, os.path.basename(est_fpath))
if ref_fpath in ref_midi_files:
est_ref_pairs.append((est_fpath, ref_fpath))
if ref_fpath.replace(".mid", ".midi") in ref_midi_files:
est_ref_pairs.append(
(est_fpath, ref_fpath.replace(".mid", ".midi"))
)
else:
print(
f"Reference file not found for {est_fpath} (ref file: {ref_fpath})"
)

est_ref_pairs = get_matched_files(est_dir, ref_dir)

output_fhandle = (
open(output_stats_file, "w") if output_stats_file is not None else None
Expand Down Expand Up @@ -104,38 +109,9 @@ def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
help="Path to the file to save the evaluation stats",
)

# add mir_eval and dtw subparsers
subparsers = parser.add_subparsers(help="sub-command help")
mir_eval_parse = subparsers.add_parser(
"run_mir_eval",
help="Run standard mir_eval evaluation on MAESTRO test set.",
)
mir_eval_parse.add_argument(
"--shift",
type=int,
default=0,
help="Shift to apply to the estimated pitches.",
)

# to come
dtw_eval_parse = subparsers.add_parser(
"run_dtw",
help="Run dynamic time warping evaluation on a specified dataset.",
)

args = parser.parse_args()
if not hasattr(args, "command"):
parser.print_help()
print("Unrecognized command")
exit(1)

# todo: should we add an option to run transcription again every time we wish to evaluate?
# that way, we can run both tests with a range of different audio augmentations right here.
# -> We expect that baseline methods will fall flat on these, while aria-amt will be OK.

if args.command == "run_mir_eval":
evaluate_mir_eval(
args.est_dir, args.ref_dir, args.output_stats_file, args.shift
)
elif args.command == "run_dtw":
pass
evaluate_mir_eval(
args.est_dir,
args.ref_dir,
args.output_stats_file,
)
Loading

0 comments on commit 1c9d666

Please sign in to comment.