Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update inference and scripts #37

Merged
merged 12 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@ pip install -e .

Download the preliminary model weights:

Piano (not final)
```
wget https://storage.googleapis.com/aria-checkpoints/amt/guitar-temp.safetensors
```
Piano (v1)

Classical guitar (not final)
```
wget https://storage.googleapis.com/aria-checkpoints/amt/piano-temp.safetensors
```
Expand All @@ -36,7 +32,7 @@ You can then transcribe using the cli:

```
aria-amt transcribe \
small-final \
medium-stacked \
<path-to-checkpoint> \
-load_path <path-to-audio> \
-save_dir <path-to-save-dir> \
Expand Down
7 changes: 5 additions & 2 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def __init__(
self.register_buffer(f"applause_{i}", applause)
self.num_applause += 1

# 256 - 0-8000 2048-256
# 512 - 30-8000 2048-384 30-8000 800-128
# 764 - 30-8000 4096-384 30-8000 2048-256 30-4000 768-128
self.spec_transform_large = torchaudio.transforms.Spectrogram(
n_fft=self.n_fft_large,
hop_length=self.config["hop_len"],
Expand Down Expand Up @@ -416,8 +419,8 @@ def log_mel(

# Norm
concat_mel = torch.cat(
(mel_spec_large, mel_spec_med, mel_spec_small),
# (mel_spec_large, mel_spec_small),
# (mel_spec_large, mel_spec_med, mel_spec_small),
(mel_spec_large, mel_spec_small),
dim=1,
)
log_mel = self.norm_mel(concat_mel)
Expand Down
6 changes: 3 additions & 3 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ def process_segments(
),
)

# logits[:, 389] *= 1.05
# next_tok_ids = torch.argmax(logits, dim=-1)
logits[:, 389] *= 1.05
next_tok_ids = torch.argmax(logits, dim=-1)

next_tok_ids = recalculate_tok_ids(
logits=logits,
Expand Down Expand Up @@ -544,7 +544,7 @@ def get_silent_intervals(wav: torch.Tensor):
wav.numpy(),
frame_length=FRAME_LEN,
hop_length=HOP_LEN,
top_db=30,
top_db=45,
ref=np.max,
)
non_silent = np.concatenate(([True], non_silent, [True]))
Expand Down
246 changes: 246 additions & 0 deletions amt/mir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
import glob
from tqdm.auto import tqdm
import pretty_midi
import numpy as np
import mir_eval
import json
import os

from aria.data.midi import MidiDict, get_duration_ms

pretty_midi.pretty_midi.MAX_TICK = 1e10


def midi_to_intervals_and_pitches(midi_file_path):
mid_dict = MidiDict.from_midi(midi_file_path)
mid_dict.resolve_pedal()

intervals, pitches, velocities = [], [], []
for note_msg in mid_dict.note_msgs:
pitch = note_msg["data"]["pitch"]
onset_s = (
get_duration_ms(
start_tick=0,
end_tick=note_msg["data"]["start"],
tempo_msgs=mid_dict.tempo_msgs,
ticks_per_beat=mid_dict.ticks_per_beat,
)
* 1e-3
)
offset_s = (
get_duration_ms(
start_tick=0,
end_tick=note_msg["data"]["end"],
tempo_msgs=mid_dict.tempo_msgs,
ticks_per_beat=mid_dict.ticks_per_beat,
)
* 1e-3
)
velocity = note_msg["data"]["velocity"]

if onset_s >= offset_s:
print("Skipping duration zero note")
continue

intervals.append([onset_s, offset_s])
pitches.append(pitch)
velocities.append(velocity)

return np.array(intervals), np.array(pitches), np.array(velocities)


def midi_to_hz(note, shift=0):
"""
Convert MIDI to HZ.

Shift, if != 0, is subtracted from the MIDI note.
Use "2" for the hFT augmented model transcriptions, else pitches won't match.
"""
# the one used in hFT transformer
return 440.0 * (2.0 ** (note.astype(int) - shift - 69) / 12)
# a = 440 # frequency of A (common value is 440Hz)
# return (a / 32) * (2 ** ((note - 9) / 12))


def get_matched_files(est_dir: str, ref_dir: str):
# We assume that the files have the same path relative to their directory

res = []
est_paths = glob.glob(os.path.join(est_dir, "**/*.mid"), recursive=True)
if len(est_paths) == 0:
est_paths = glob.glob(
os.path.join(est_dir, "**/*.midi"), recursive=True
)
print(f"found {len(est_paths)} est files")

for est_path in est_paths:
est_rel_path = os.path.relpath(est_path, est_dir)
ref_path = os.path.join(
ref_dir, os.path.splitext(est_rel_path)[0] + ".mid"
)
if os.path.isfile(ref_path):
res.append((est_path, ref_path))
else:
ref_path = os.path.join(
ref_dir, os.path.splitext(est_rel_path)[0] + ".midi"
)
if os.path.isfile(ref_path):
res.append((est_path, ref_path))

print(f"found {len(res)} matched est-ref pairs")

return res


def get_matched_files_direct(est_dir: str, ref_dir: str):
# Helper to extract filenames with normalized extensions
def get_filenames(paths):
normalized_files = {}
for path in paths:
basename = os.path.basename(path)
name, ext = os.path.splitext(basename)

name = name[:-12] if name.endswith("_transcribed") else name

if ext in [".mid", ".midi"]:
normalized_files[name] = path
return normalized_files

# Gather all potential MIDI files in both directories
est_files = glob.glob(os.path.join(est_dir, "**/*.*"), recursive=True)
ref_files = glob.glob(os.path.join(ref_dir, "**/*.*"), recursive=True)

# Map filenames to their full paths with normalized extensions
est_file_map = get_filenames(est_files)
ref_file_map = get_filenames(ref_files)

# Find matching files by filename disregarding extension differences
matched_files = []
for filename, ref_path in ref_file_map.items():
if filename in est_file_map:
matched_files.append((est_file_map[filename], ref_path))

print(f"found {len(est_file_map)} MIDI files in estimation directory")
print(f"found {len(ref_file_map)} MIDI files in reference directory")
print(f"found {len(matched_files)} matched MIDI file pairs")

return matched_files


def get_avg_scores(scores):
totals = {}
counts = {}
for d in scores:
for key, value in d.items():
if key == "f_name":
continue
totals[key] = totals.get(key, 0) + value
counts[key] = counts.get(key, 0) + 1
averages = {f"{key}_avg": totals[key] / counts[key] for key in totals}
return averages


def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
"""
Evaluate the estimated pitches against the reference pitches using mir_eval.
"""

est_ref_pairs = get_matched_files(est_dir, ref_dir)
if len(est_ref_pairs) == 0:
print("Failed to find files, trying direct search")
est_ref_pairs = get_matched_files_direct(est_dir, ref_dir)

output_fhandle = (
open(output_stats_file, "w") if output_stats_file is not None else None
)

res = []
for est_file, ref_file in tqdm(est_ref_pairs):
ref_intervals, ref_pitches, ref_velocities = (
midi_to_intervals_and_pitches(ref_file)
)
est_intervals, est_pitches, est_velocities = (
midi_to_intervals_and_pitches(est_file)
)
ref_pitches_hz = midi_to_hz(ref_pitches)
est_pitches_hz = midi_to_hz(est_pitches, est_shift)

scores = mir_eval.transcription.evaluate(
ref_intervals,
ref_pitches_hz,
est_intervals,
est_pitches_hz,
)

prec_vel, recall_vel, f1_vel, _ = (
mir_eval.transcription_velocity.precision_recall_f1_overlap(
ref_intervals=ref_intervals,
ref_pitches=ref_pitches,
ref_velocities=ref_velocities,
est_intervals=est_intervals,
est_pitches=est_pitches,
est_velocities=est_velocities,
)
)

scores["Precision_vel"] = prec_vel
scores["Recall_vel"] = recall_vel
scores["F1_vel"] = f1_vel
scores["f_name"] = est_file
res.append(scores)

avg_scores = get_avg_scores(res)
output_fhandle.write(json.dumps(avg_scores))
output_fhandle.write("\n")

res.sort(key=lambda x: x["F-measure"])
for s in res:
output_fhandle.write(json.dumps(s))
output_fhandle.write("\n")


def evaluate_single(est_file, ref_file):
ref_intervals, ref_pitches, ref_velocities = midi_to_intervals_and_pitches(
ref_file
)
est_intervals, est_pitches, est_velocities = midi_to_intervals_and_pitches(
est_file
)
ref_pitches_hz = midi_to_hz(ref_pitches)
est_pitches_hz = midi_to_hz(est_pitches)

return mir_eval.transcription.evaluate(
ref_intervals,
ref_pitches_hz,
est_intervals,
est_pitches_hz,
)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(usage="evaluate <command> [<args>]")
parser.add_argument(
"--est-dir",
type=str,
help="Path to the directory containing either the transcribed MIDI files or WAV files to be transcribed.",
)
parser.add_argument(
"--ref-dir",
type=str,
help="Path to the directory containing the reference files (we'll use gold MIDI for mir_eval, WAV for dtw).",
)
parser.add_argument(
"--output-stats-file",
default=None,
type=str,
help="Path to the file to save the evaluation stats",
)

args = parser.parse_args()
evaluate_mir_eval(
args.est_dir,
args.ref_dir,
args.output_stats_file,
)
Loading
Loading