Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix training and inference #24

Merged
merged 14 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# data files
*.csv
*.json
*.xls
*.xlsx
*.pkl
Expand Down
21 changes: 13 additions & 8 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,15 @@ def __init__(
max_snr: int = 50,
max_dist_gain: int = 25,
min_dist_gain: int = 0,
noise_ratio: float = 0.95,
reverb_ratio: float = 0.95,
noise_ratio: float = 0.75,
reverb_ratio: float = 0.75,
applause_ratio: float = 0.01,
bandpass_ratio: float = 0.15,
distort_ratio: float = 0.15,
reduce_ratio: float = 0.01,
detune_ratio: float = 0.1,
detune_max_shift: float = 0.15,
spec_aug_ratio: float = 0.5,
detune_ratio: float = 0.0,
detune_max_shift: float = 0.0,
spec_aug_ratio: float = 0.9,
):
super().__init__()
self.tokenizer = AmtTokenizer()
Expand All @@ -223,7 +223,10 @@ def __init__(
self.detune_ratio = detune_ratio
self.detune_max_shift = detune_max_shift
self.spec_aug_ratio = spec_aug_ratio
self.reduction_resample_rate = 6000 # Hardcoded?

self.time_mask_param = 2500
self.freq_mask_param = 15
self.reduction_resample_rate = 6000

# Audio aug
impulse_paths = self._get_paths(
Expand Down Expand Up @@ -263,10 +266,10 @@ def __init__(
)
self.spec_aug = torch.nn.Sequential(
torchaudio.transforms.FrequencyMasking(
freq_mask_param=15, iid_masks=True
freq_mask_param=self.freq_mask_param, iid_masks=True
),
torchaudio.transforms.TimeMasking(
time_mask_param=1000, iid_masks=True
time_mask_param=self.time_mask_param, iid_masks=True
),
)

Expand All @@ -281,6 +284,8 @@ def get_params(self):
"detune_ratio": self.detune_ratio,
"detune_max_shift": self.detune_max_shift,
"spec_aug_ratio": self.spec_aug_ratio,
"time_mask_param": self.time_mask_param,
"freq_mask_param": self.freq_mask_param,
}

def _get_paths(self, dir_path):
Expand Down
92 changes: 84 additions & 8 deletions amt/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import mmap
import os
import io
import random
import shlex
import base64
import shutil
import orjson
Expand All @@ -16,14 +18,22 @@
from amt.audio import pad_or_trim


# Occasionally the worker util goes to 0 for some reason, debug this
def _check_onset_threshold(seq: list, onset: int):
for tok_1, tok_2 in zip(seq, seq[1:]):
if isinstance(tok_1, tuple) and tok_1[0] in ("on", "off"):
_onset = tok_2[1]
if _onset > onset:
return True

return False


def get_wav_mid_segments(
audio_path: str,
mid_path: str = "",
return_json: bool = False,
stride_factor: int | None = None,
pad_last=False,
):
"""This function yields tuples of matched log mel spectrograms and
tokenized sequences (np.array, list). If it is given only an audio path
Expand Down Expand Up @@ -61,10 +71,12 @@ def get_wav_mid_segments(

# Create features
total_samples = wav.shape[-1]
pad_factor = 2 if pad_last is True else 1
res = []
for idx in range(
0,
total_samples - (num_samples - num_samples // stride_factor),
total_samples
- (num_samples - pad_factor * (num_samples // stride_factor)),
num_samples // stride_factor,
):
audio_feature = pad_or_trim(wav[idx:], length=num_samples)
Expand All @@ -75,6 +87,12 @@ def get_wav_mid_segments(
end_ms=(idx + num_samples) / samples_per_ms,
max_pedal_len_ms=10000,
)

# Hardcoded to 2.5s
if _check_onset_threshold(mid_feature, 2500) is False:
print("No note messages after 2.5s - skipping")
continue

else:
mid_feature = []

Expand All @@ -86,6 +104,56 @@ def get_wav_mid_segments(
return res


def pianoteq_cmd_fn(mid_path: str, wav_path: str):
presets = [
"C. Bechstein",
"C. Bechstein Close Mic",
"C. Bechstein Under Lid",
"C. Bechstein 440",
"C. Bechstein Recording",
"C. Bechstein Werckmeister III",
"C. Bechstein Neidhardt III",
"C. Bechstein mesotonic",
"C. Bechstein well tempered",
"HB Steinway D Blues",
"HB Steinway D Pop",
"HB Steinway D New Age",
"HB Steinway D Prelude",
"HB Steinway D Felt I",
"HB Steinway D Felt II",
"HB Steinway Model D",
"HB Steinway D Classical Recording",
"HB Steinway D Jazz Recording",
"HB Steinway D Chamber Recording",
"HB Steinway D Studio Recording",
"HB Steinway D Intimate",
"HB Steinway D Cinematic",
"HB Steinway D Close Mic Classical",
"HB Steinway D Close Mic Jazz",
"HB Steinway D Player Wide",
"HB Steinway D Player Clean",
"HB Steinway D Trio",
"HB Steinway D Duo",
"HB Steinway D Cabaret",
"HB Steinway D Bright",
"HB Steinway D Hyper Bright",
"HB Steinway D Prepared",
"HB Steinway D Honky Tonk",
]

preset = random.choice(presets)

# Safely quote the preset name, MIDI path, and WAV path
safe_preset = shlex.quote(preset)
safe_mid_path = shlex.quote(mid_path)
safe_wav_path = shlex.quote(wav_path)

# Construct the command
command = f"/home/mchorse/pianoteq/x86-64bit/Pianoteq\\ 8\\ STAGE --preset {safe_preset} --midi {safe_mid_path} --wav {safe_wav_path}"

return command


def write_features(audio_path: str, mid_path: str, save_path: str):
features = get_wav_mid_segments(
audio_path=audio_path,
Expand Down Expand Up @@ -121,7 +189,7 @@ def write_synth_features(cli_cmd_fn: Callable, mid_path: str, save_path: str):

try:
get_synth_audio(
cli_cmd=cli_cmd_fn, mid_path=mid_path, wav_path=audio_path_temp
cli_cmd_fn=cli_cmd_fn, mid_path=mid_path, wav_path=audio_path_temp
)
except:
if os.path.isfile(audio_path_temp):
Expand All @@ -133,7 +201,11 @@ def write_synth_features(cli_cmd_fn: Callable, mid_path: str, save_path: str):
mid_path=mid_path,
return_json=False,
)
os.remove(audio_path_temp)

if os.path.isfile(audio_path_temp):
os.remove(audio_path_temp)

print(f"Found {len(features)}")

with open(save_path, mode="a") as file:
for wav, seq in features:
Expand Down Expand Up @@ -174,7 +246,11 @@ def build_synth_worker_fn(

while not load_path_queue.empty():
mid_path = load_path_queue.get()
write_synth_features(cli_cmd, mid_path, worker_save_path)
try:
write_synth_features(cli_cmd, mid_path, worker_save_path)
except Exception as e:
print("Failed")
print(e)

save_path_queue.put(worker_save_path)

Expand Down Expand Up @@ -239,7 +315,7 @@ def _format(tok):
seq_len=self.config["max_seq_len"],
)

return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt)
return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt), idx

def _build_index(self):
self.file_mmap.seek(0)
Expand All @@ -254,7 +330,7 @@ def _build_index(self):

return index

def _save_index(self, index: list[int], save_path: str):
def _save_index(self, index: list, save_path: str):
with open(save_path, "w") as file:
for idx in index:
file.write(f"{idx}\n")
Expand Down Expand Up @@ -325,7 +401,7 @@ def build(
]
else:
# Build synthetic dataset
assert len(load_paths[0]) == 1, "Invalid load paths"
assert isinstance(load_paths[0], str), "Invalid load paths"
print("Building synthetic dataset")
worker_processes = [
Process(
Expand Down
78 changes: 27 additions & 51 deletions amt/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,32 @@ def midi_to_hz(note, shift=0):
# return (a / 32) * (2 ** ((note - 9) / 12))


def get_matched_files(est_dir: str, ref_dir: str):
# We assume that the files have the same path relative to their directory

res = []
est_paths = glob.glob(os.path.join(est_dir, "**/*.mid"), recursive=True)
print(f"found {len(est_paths)} est files")

for est_path in est_paths:
est_rel_path = os.path.relpath(est_path, est_dir)
ref_path = os.path.join(
ref_dir, os.path.splitext(est_rel_path)[0] + ".midi"
)
if os.path.isfile(ref_path):
res.append((est_path, ref_path))

print(f"found {len(res)} matched est-ref pairs")

return res


def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
"""
Evaluate the estimated pitches against the reference pitches using mir_eval.
"""
# Evaluate the estimated pitches against the reference pitches
ref_midi_files = glob.glob(f"{ref_dir}/*.mid*")
est_midi_files = glob.glob(f"{est_dir}/*.mid*")

est_ref_pairs = []
for est_fpath in est_midi_files:
ref_fpath = os.path.join(ref_dir, os.path.basename(est_fpath))
if ref_fpath in ref_midi_files:
est_ref_pairs.append((est_fpath, ref_fpath))
if ref_fpath.replace(".mid", ".midi") in ref_midi_files:
est_ref_pairs.append(
(est_fpath, ref_fpath.replace(".mid", ".midi"))
)
else:
print(
f"Reference file not found for {est_fpath} (ref file: {ref_fpath})"
)

est_ref_pairs = get_matched_files(est_dir, ref_dir)

output_fhandle = (
open(output_stats_file, "w") if output_stats_file is not None else None
Expand Down Expand Up @@ -104,38 +109,9 @@ def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
help="Path to the file to save the evaluation stats",
)

# add mir_eval and dtw subparsers
subparsers = parser.add_subparsers(help="sub-command help")
mir_eval_parse = subparsers.add_parser(
"run_mir_eval",
help="Run standard mir_eval evaluation on MAESTRO test set.",
)
mir_eval_parse.add_argument(
"--shift",
type=int,
default=0,
help="Shift to apply to the estimated pitches.",
)

# to come
dtw_eval_parse = subparsers.add_parser(
"run_dtw",
help="Run dynamic time warping evaluation on a specified dataset.",
)

args = parser.parse_args()
if not hasattr(args, "command"):
parser.print_help()
print("Unrecognized command")
exit(1)

# todo: should we add an option to run transcription again every time we wish to evaluate?
# that way, we can run both tests with a range of different audio augmentations right here.
# -> We expect that baseline methods will fall flat on these, while aria-amt will be OK.

if args.command == "run_mir_eval":
evaluate_mir_eval(
args.est_dir, args.ref_dir, args.output_stats_file, args.shift
)
elif args.command == "run_dtw":
pass
evaluate_mir_eval(
args.est_dir,
args.ref_dir,
args.output_stats_file,
)
Loading
Loading